经典深度学习模型——ResNet(详细解释 + 代码)
ResNet 是由微软研究院的 He Kaiming 等人在 2015 年提出的一种深度卷积神经网络,论文题目是《Deep Residual Learning for Image Recognition》。它主要解决了随着网络加深,训练困难以及精度饱和甚至下降的问题。ResNet 在 2015 年的 ImageNet 图像分类比赛中取得了突破性成绩,并且极大推动了深度学习的发展。方面说明提出时间2
一、什么是 ResNet?
ResNet 是由微软研究院的 He Kaiming 等人在 2015 年提出的一种深度卷积神经网络,论文题目是《Deep Residual Learning for Image Recognition》。它主要解决了随着网络加深,训练困难以及精度饱和甚至下降的问题。ResNet 在 2015 年的 ImageNet 图像分类比赛中取得了突破性成绩,并且极大推动了深度学习的发展。
二、ResNet 的核心原理
1. 深度网络难训练的问题
传统的深度神经网络,随着层数增加,会出现梯度消失或梯度爆炸,导致网络难以训练,且训练误差不降反升,这叫做退化问题(Degradation Problem)。
2. 残差学习(Residual Learning)
ResNet 的核心思想是引入残差块(Residual Block),让网络学习残差函数(Residual function),而不是直接学习期望映射,其核心就是一个加法,很多论文里面的 Adding 其实就是残差学习。
-
假设映射为F,映射后的输出为 H(x)H(x)H(x)
-
如果不用残差的话,H(x)=F(x)H(x) = F(x)H(x)=F(x)。而在 ResNet 中
H(x)=F(x)+x H(x) = F(x) + x H(x)=F(x)+x
这样做的动机是,在反向传播更新参数的时候,不至于因为梯度为零而无法更新。且如果恒等映射是最优的,网络只需要学习 F(x)=0F(x) = 0F(x)=0 的残差,避免了直接拟合复杂函数的难度。
三、数学公式
一个基本的残差块可以表示为:
y=F(x,{Wi})+x \mathbf{y} = \mathcal{F}(\mathbf{x}, \{W_i\}) + \mathbf{x} y=F(x,{Wi})+x
其中:
- x\mathbf{x}x 是输入向量
- F(x,{Wi})\mathcal{F}(\mathbf{x}, \{W_i\})F(x,{Wi}) 是残差函数,通常是两个或三个卷积层堆叠后的映射
- y\mathbf{y}y 是残差块的输出
- 加法操作称为“跳跃连接(skip connection)”或“快捷连接(shortcut connection)”
具体结构:
比如一个典型的残差块是两层卷积 + BN + ReLU:
F(x)=ReLU(BN(W2∗ReLU(BN(W1∗x)))) \mathcal{F}(\mathbf{x}) = \text{ReLU}(BN(W_2 * \text{ReLU}(BN(W_1 * \mathbf{x})))) F(x)=ReLU(BN(W2∗ReLU(BN(W1∗x))))
然后加上输入:
y=F(x)+x \mathbf{y} = \mathcal{F}(\mathbf{x}) + \mathbf{x} y=F(x)+x
如果维度不匹配,输入 x\mathbf{x}x 会通过 1×1 卷积变换以匹配残差函数输出的维度。
四、ResNet 的作用与优势
- 缓解梯度消失和退化问题:残差连接保证了梯度可以直接反传,便于训练非常深的网络。
- 支持训练非常深的网络:ResNet 能训练超过 100 层甚至更深的网络,如 ResNet-50、ResNet-101、ResNet-152。
- 提升网络性能:提升了图像识别、目标检测、语义分割等任务的性能。
- 促进网络设计多样性:后续许多网络架构(如 DenseNet、EfficientNet)都借鉴了残差思想。
五、常用的 ResNet 变体和结构
- ResNet-18 / ResNet-34:较浅的网络,适合资源受限或较简单任务
- ResNet-50 / ResNet-101 / ResNet-152:瓶颈结构,较深的网络,适合大型数据集和复杂任务
- ResNeXt、SE-ResNet:在 ResNet 基础上做了改进,提高性能
六、ResNet 常用场景
- 图像分类(Image Classification):ImageNet 等大规模数据集的分类任务
- 目标检测(Object Detection):作为特征提取骨干网络,配合 Faster R-CNN、YOLO 等检测框架
- 语义分割(Semantic Segmentation):如 DeepLab、FCN 等分割模型的主干网络
- 图像生成和风格迁移
- 医学图像分析
- 视频分析
- 自然语言处理中部分任务的特征提取
七、总结
| 方面 | 说明 |
|---|---|
| 提出时间 | 2015 年 |
| 关键创新 | 残差连接(skip connection) |
| 核心公式 | y=F(x)+x\mathbf{y} = \mathcal{F}(\mathbf{x}) + \mathbf{x}y=F(x)+x |
| 作用 | 解决深层网络训练困难,促进训练更深网络 |
| 应用领域 | 图像分类、目标检测、语义分割等视觉任务 |
八、代码
代码基于 PyTorch 的 ResNet ,使用官方 torchvision 库里自带的 CIFAR-10 数据集来训练一个 ResNet-18。
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torchvision.models import resnet18
def main():
# 设备配置(GPU 优先)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)
# 数据预处理和增强
transform_train = transforms.Compose([
transforms.RandomCrop(32, padding=4), # 随机裁剪
transforms.RandomHorizontalFlip(), # 随机水平翻转
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465),
(0.2023, 0.1994, 0.2010))
])
transform_test = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465),
(0.2023, 0.1994, 0.2010))
])
# 下载并加载 CIFAR-10 训练集和测试集
trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
download=True, transform=transform_train)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=128,
shuffle=True, num_workers=2)
testset = torchvision.datasets.CIFAR10(root='./data', train=False,
download=True, transform=transform_test)
testloader = torch.utils.data.DataLoader(testset, batch_size=100,
shuffle=False, num_workers=2)
# 加载预定义的 ResNet-18 模型,修改最后一层全连接层以适应 CIFAR-10(10 类)
model = resnet18(pretrained=False)
model.fc = nn.Linear(model.fc.in_features, 10) # CIFAR-10 有 10 个类别
model = model.to(device)
# 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.1, momentum=0.9, weight_decay=5e-4)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.1) # 学习率衰减
# 训练函数
def train(epoch):
model.train()
running_loss = 0.0
for batch_idx, (inputs, targets) in enumerate(trainloader):
inputs, targets = inputs.to(device), targets.to(device)
optimizer.zero_grad()
outputs = model(inputs)
loss = criterion(outputs, targets)
loss.backward()
optimizer.step()
running_loss += loss.item()
if batch_idx % 100 == 99: # 每 100 个 batch 输出一次日志
print(f"[Epoch {epoch+1}, Batch {batch_idx+1}] loss: {running_loss/100:.3f}")
running_loss = 0.0
# 测试函数
def test():
model.eval()
correct = 0
total = 0
with torch.no_grad():
for inputs, targets in testloader:
inputs, targets = inputs.to(device), targets.to(device)
outputs = model(inputs)
_, predicted = outputs.max(1)
total += targets.size(0)
correct += predicted.eq(targets).sum().item()
acc = 100. * correct / total
print(f"Test Accuracy: {acc:.2f}%")
return acc
# 主训练循环
best_acc = 0
for epoch in range(1, 51): # 训练 50 个 epoch
train(epoch)
acc = test()
scheduler.step()
# 保存最优模型
if acc > best_acc:
best_acc = acc
torch.save(model.state_dict(), "best_resnet18_cifar10.pth")
print(f"Best Test Accuracy: {best_acc:.2f}%")
if __name__ == '__main__':
main()
更多推荐

所有评论(0)