别再死记ResNet18结构图了!用PyTorch代码逐层拆解,搞懂残差连接到底怎么跑的
2026/6/14 7:56:24 网站建设 项目流程

用PyTorch代码逐层拆解ResNet18:从张量流动理解残差连接

当你第一次看到ResNet18的结构图时,是否曾被那些交错连接的箭头弄得晕头转向?作为计算机视觉领域的里程碑式架构,残差网络(ResNet)通过引入跳跃连接(skip connection)解决了深度神经网络中的梯度消失问题。但纸上得来终觉浅,本文将带你用PyTorch从零构建ResNet18,通过打印每一层的张量形状变化,直观理解数据在网络中的流动路径。

1. 残差网络的核心思想

在传统的卷积神经网络中,随着网络层数的增加,模型性能往往会达到饱和甚至下降。这种现象被称为"退化问题"(degradation problem),并非由过拟合引起,而是因为深层网络难以优化。ResNet创造性地提出了残差学习框架——与其让网络直接拟合目标映射H(x),不如让它学习残差函数F(x) = H(x) - x,这样原始映射就变为F(x) + x。

这种设计的精妙之处在于:

  • 恒等映射的捷径:当残差F(x)趋近于0时,该层仅需执行恒等映射,这使得深层网络的训练至少不会比浅层网络更困难
  • 梯度高速公路:跳跃连接为反向传播提供了直达路径,有效缓解了梯度消失问题
  • 特征复用:浅层特征可以直接传递到深层,避免了重复学习
import torch import torch.nn as nn class BasicBlock(nn.Module): expansion = 1 def __init__(self, in_channels, out_channels, stride=1): super().__init__() self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False) self.bn1 = nn.BatchNorm2d(out_channels) self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False) self.bn2 = nn.BatchNorm2d(out_channels) self.shortcut = nn.Sequential() if stride != 1 or in_channels != out_channels: self.shortcut = nn.Sequential( nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False), nn.BatchNorm2d(out_channels) ) def forward(self, x): print(f"输入形状: {x.shape}") residual = x out = torch.relu(self.bn1(self.conv1(x))) print(f"第一个卷积后形状: {out.shape}") out = self.bn2(self.conv2(out)) print(f"第二个卷积后形状: {out.shape}") out += self.shortcut(residual) print(f"残差连接后形状: {out.shape}") out = torch.relu(out) return out

2. ResNet18的层次结构解析

ResNet18由初始卷积层、四个残差块阶段和全连接层组成。让我们分解每个阶段的数据变化过程,重点关注特征图尺寸和通道数的变化规律。

2.1 初始卷积层

输入图像首先经过一个7×7的大卷积核进行初步特征提取,这有助于在早期捕获更大范围的视觉特征:

# 假设输入为3通道的224x224图像 x = torch.randn(1, 3, 224, 224) conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False) bn1 = nn.BatchNorm2d(64) maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) out = conv1(x) print(f"初始卷积后形状: {out.shape}") # torch.Size([1, 64, 112, 112]) out = bn1(out) out = torch.relu(out) out = maxpool(out) print(f"池化后形状: {out.shape}") # torch.Size([1, 64, 56, 56])

2.2 残差块阶段

ResNet18包含四个主要阶段,每个阶段由多个BasicBlock组成。观察虚线残差块的特殊处理:

阶段块类型重复次数输出尺寸通道变化
1BasicBlock256×5664→64
2BasicBlock228×2864→128
3BasicBlock214×14128→256
4BasicBlock27×7256→512
# 阶段1示例 - 无下采样 layer1 = nn.Sequential( BasicBlock(64, 64), BasicBlock(64, 64) ) out = layer1(out) print(f"阶段1输出形状: {out.shape}") # 阶段2示例 - 带下采样 layer2 = nn.Sequential( BasicBlock(64, 128, stride=2), # 注意stride=2 BasicBlock(128, 128) ) out = layer2(out) print(f"阶段2输出形状: {out.shape}")

2.3 跳跃连接实现细节

残差块中的跳跃连接处理分为两种情况:

  1. 实线连接:当输入输出通道数相同时,直接相加
  2. 虚线连接:当通道数变化时,使用1×1卷积调整通道和尺寸
# 实线连接示例 x = torch.randn(1, 64, 56, 56) block = BasicBlock(64, 64) out = block(x) # 直接相加 # 虚线连接示例 x = torch.randn(1, 64, 56, 56) block = BasicBlock(64, 128, stride=2) out = block(x) # 使用1×1卷积调整

3. 完整ResNet18实现与验证

现在我们将所有组件组合起来,构建完整的ResNet18并验证其结构:

class ResNet18(nn.Module): def __init__(self, num_classes=1000): super().__init__() self.in_channels = 64 self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False) self.bn1 = nn.BatchNorm2d(64) self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) self.layer1 = self._make_layer(64, 2, stride=1) self.layer2 = self._make_layer(128, 2, stride=2) self.layer3 = self._make_layer(256, 2, stride=2) self.layer4 = self._make_layer(512, 2, stride=2) self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) self.fc = nn.Linear(512 * BasicBlock.expansion, num_classes) def _make_layer(self, out_channels, blocks, stride): layers = [] layers.append(BasicBlock(self.in_channels, out_channels, stride)) self.in_channels = out_channels * BasicBlock.expansion for _ in range(1, blocks): layers.append(BasicBlock(self.in_channels, out_channels)) return nn.Sequential(*layers) def forward(self, x): print("\n=== 初始卷积 ===") x = self.conv1(x) print(f"初始卷积输出: {x.shape}") x = self.bn1(x) x = torch.relu(x) print("\n=== 最大池化 ===") x = self.maxpool(x) print(f"池化后: {x.shape}") print("\n=== 阶段1 ===") x = self.layer1(x) print(f"阶段1输出: {x.shape}") print("\n=== 阶段2 ===") x = self.layer2(x) print(f"阶段2输出: {x.shape}") print("\n=== 阶段3 ===") x = self.layer3(x) print(f"阶段3输出: {x.shape}") print("\n=== 阶段4 ===") x = self.layer4(x) print(f"阶段4输出: {x.shape}") x = self.avgpool(x) x = torch.flatten(x, 1) x = self.fc(x) return x # 验证网络结构 model = ResNet18() input_tensor = torch.randn(1, 3, 224, 224) output = model(input_tensor)

4. 残差连接的可视化分析

为了更直观地理解残差连接的作用,我们可以对比有无跳跃连接时的梯度流动:

有残差连接时的梯度计算

∂loss/∂x = ∂loss/∂F(x) * ∂F(x)/∂x + ∂loss/∂F(x)

无残差连接时的梯度计算

∂loss/∂x = ∂loss/∂F(x) * ∂F(x)/∂x

这种设计确保了即使深层网络的梯度很小,至少能有∂loss/∂F(x)这一项直接传递到浅层,避免了梯度消失。

在实际项目中调试ResNet时,有几个实用技巧值得注意:

  • 初始化残差块最后一层BN的γ为0:这样初始状态下残差块输出为0,网络从浅层开始学习
  • 下采样放在第一个残差块:这样后续块可以专注于特征提取而非尺寸调整
  • 适当使用预训练权重:特别是当数据集较小时,ImageNet预训练的特征提取器非常有效

需要专业的网站建设服务?

联系我们获取免费的网站建设咨询和方案报价,让我们帮助您实现业务目标

立即咨询