用PyTorch代码逐层拆解ResNet18:从张量流动理解残差连接
当你第一次看到ResNet18的结构图时,是否曾被那些交错连接的箭头弄得晕头转向?作为计算机视觉领域的里程碑式架构,残差网络(ResNet)通过引入跳跃连接(skip connection)解决了深度神经网络中的梯度消失问题。但纸上得来终觉浅,本文将带你用PyTorch从零构建ResNet18,通过打印每一层的张量形状变化,直观理解数据在网络中的流动路径。
1. 残差网络的核心思想
在传统的卷积神经网络中,随着网络层数的增加,模型性能往往会达到饱和甚至下降。这种现象被称为"退化问题"(degradation problem),并非由过拟合引起,而是因为深层网络难以优化。ResNet创造性地提出了残差学习框架——与其让网络直接拟合目标映射H(x),不如让它学习残差函数F(x) = H(x) - x,这样原始映射就变为F(x) + x。
这种设计的精妙之处在于:
- 恒等映射的捷径:当残差F(x)趋近于0时,该层仅需执行恒等映射,这使得深层网络的训练至少不会比浅层网络更困难
- 梯度高速公路:跳跃连接为反向传播提供了直达路径,有效缓解了梯度消失问题
- 特征复用:浅层特征可以直接传递到深层,避免了重复学习
import torch import torch.nn as nn class BasicBlock(nn.Module): expansion = 1 def __init__(self, in_channels, out_channels, stride=1): super().__init__() self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False) self.bn1 = nn.BatchNorm2d(out_channels) self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False) self.bn2 = nn.BatchNorm2d(out_channels) self.shortcut = nn.Sequential() if stride != 1 or in_channels != out_channels: self.shortcut = nn.Sequential( nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False), nn.BatchNorm2d(out_channels) ) def forward(self, x): print(f"输入形状: {x.shape}") residual = x out = torch.relu(self.bn1(self.conv1(x))) print(f"第一个卷积后形状: {out.shape}") out = self.bn2(self.conv2(out)) print(f"第二个卷积后形状: {out.shape}") out += self.shortcut(residual) print(f"残差连接后形状: {out.shape}") out = torch.relu(out) return out2. ResNet18的层次结构解析
ResNet18由初始卷积层、四个残差块阶段和全连接层组成。让我们分解每个阶段的数据变化过程,重点关注特征图尺寸和通道数的变化规律。
2.1 初始卷积层
输入图像首先经过一个7×7的大卷积核进行初步特征提取,这有助于在早期捕获更大范围的视觉特征:
# 假设输入为3通道的224x224图像 x = torch.randn(1, 3, 224, 224) conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False) bn1 = nn.BatchNorm2d(64) maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) out = conv1(x) print(f"初始卷积后形状: {out.shape}") # torch.Size([1, 64, 112, 112]) out = bn1(out) out = torch.relu(out) out = maxpool(out) print(f"池化后形状: {out.shape}") # torch.Size([1, 64, 56, 56])2.2 残差块阶段
ResNet18包含四个主要阶段,每个阶段由多个BasicBlock组成。观察虚线残差块的特殊处理:
| 阶段 | 块类型 | 重复次数 | 输出尺寸 | 通道变化 |
|---|---|---|---|---|
| 1 | BasicBlock | 2 | 56×56 | 64→64 |
| 2 | BasicBlock | 2 | 28×28 | 64→128 |
| 3 | BasicBlock | 2 | 14×14 | 128→256 |
| 4 | BasicBlock | 2 | 7×7 | 256→512 |
# 阶段1示例 - 无下采样 layer1 = nn.Sequential( BasicBlock(64, 64), BasicBlock(64, 64) ) out = layer1(out) print(f"阶段1输出形状: {out.shape}") # 阶段2示例 - 带下采样 layer2 = nn.Sequential( BasicBlock(64, 128, stride=2), # 注意stride=2 BasicBlock(128, 128) ) out = layer2(out) print(f"阶段2输出形状: {out.shape}")2.3 跳跃连接实现细节
残差块中的跳跃连接处理分为两种情况:
- 实线连接:当输入输出通道数相同时,直接相加
- 虚线连接:当通道数变化时,使用1×1卷积调整通道和尺寸
# 实线连接示例 x = torch.randn(1, 64, 56, 56) block = BasicBlock(64, 64) out = block(x) # 直接相加 # 虚线连接示例 x = torch.randn(1, 64, 56, 56) block = BasicBlock(64, 128, stride=2) out = block(x) # 使用1×1卷积调整3. 完整ResNet18实现与验证
现在我们将所有组件组合起来,构建完整的ResNet18并验证其结构:
class ResNet18(nn.Module): def __init__(self, num_classes=1000): super().__init__() self.in_channels = 64 self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False) self.bn1 = nn.BatchNorm2d(64) self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) self.layer1 = self._make_layer(64, 2, stride=1) self.layer2 = self._make_layer(128, 2, stride=2) self.layer3 = self._make_layer(256, 2, stride=2) self.layer4 = self._make_layer(512, 2, stride=2) self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) self.fc = nn.Linear(512 * BasicBlock.expansion, num_classes) def _make_layer(self, out_channels, blocks, stride): layers = [] layers.append(BasicBlock(self.in_channels, out_channels, stride)) self.in_channels = out_channels * BasicBlock.expansion for _ in range(1, blocks): layers.append(BasicBlock(self.in_channels, out_channels)) return nn.Sequential(*layers) def forward(self, x): print("\n=== 初始卷积 ===") x = self.conv1(x) print(f"初始卷积输出: {x.shape}") x = self.bn1(x) x = torch.relu(x) print("\n=== 最大池化 ===") x = self.maxpool(x) print(f"池化后: {x.shape}") print("\n=== 阶段1 ===") x = self.layer1(x) print(f"阶段1输出: {x.shape}") print("\n=== 阶段2 ===") x = self.layer2(x) print(f"阶段2输出: {x.shape}") print("\n=== 阶段3 ===") x = self.layer3(x) print(f"阶段3输出: {x.shape}") print("\n=== 阶段4 ===") x = self.layer4(x) print(f"阶段4输出: {x.shape}") x = self.avgpool(x) x = torch.flatten(x, 1) x = self.fc(x) return x # 验证网络结构 model = ResNet18() input_tensor = torch.randn(1, 3, 224, 224) output = model(input_tensor)4. 残差连接的可视化分析
为了更直观地理解残差连接的作用,我们可以对比有无跳跃连接时的梯度流动:
有残差连接时的梯度计算:
∂loss/∂x = ∂loss/∂F(x) * ∂F(x)/∂x + ∂loss/∂F(x)无残差连接时的梯度计算:
∂loss/∂x = ∂loss/∂F(x) * ∂F(x)/∂x这种设计确保了即使深层网络的梯度很小,至少能有∂loss/∂F(x)这一项直接传递到浅层,避免了梯度消失。
在实际项目中调试ResNet时,有几个实用技巧值得注意:
- 初始化残差块最后一层BN的γ为0:这样初始状态下残差块输出为0,网络从浅层开始学习
- 下采样放在第一个残差块:这样后续块可以专注于特征提取而非尺寸调整
- 适当使用预训练权重:特别是当数据集较小时,ImageNet预训练的特征提取器非常有效