一、什么是 ResNet?

ResNet 是由微软研究院的 He Kaiming 等人在 2015 年提出的一种深度卷积神经网络,论文题目是《Deep Residual Learning for Image Recognition》。它主要解决了随着网络加深,训练困难以及精度饱和甚至下降的问题。ResNet 在 2015 年的 ImageNet 图像分类比赛中取得了突破性成绩,并且极大推动了深度学习的发展。


二、ResNet 的核心原理

1. 深度网络难训练的问题

传统的深度神经网络,随着层数增加,会出现梯度消失或梯度爆炸,导致网络难以训练,且训练误差不降反升,这叫做退化问题(Degradation Problem)

2. 残差学习(Residual Learning)

ResNet 的核心思想是引入残差块(Residual Block),让网络学习残差函数(Residual function),而不是直接学习期望映射,其核心就是一个加法,很多论文里面的 Adding 其实就是残差学习。
在这里插入图片描述

  • 假设映射为F,映射后的输出为 H(x)H(x)H(x)

  • 如果不用残差的话,H(x)=F(x)H(x) = F(x)H(x)=F(x)。而在 ResNet 中

    H(x)=F(x)+x H(x) = F(x) + x H(x)=F(x)+x

这样做的动机是,在反向传播更新参数的时候,不至于因为梯度为零而无法更新。且如果恒等映射是最优的,网络只需要学习 F(x)=0F(x) = 0F(x)=0 的残差,避免了直接拟合复杂函数的难度。


三、数学公式

一个基本的残差块可以表示为:

y=F(x,{Wi})+x \mathbf{y} = \mathcal{F}(\mathbf{x}, \{W_i\}) + \mathbf{x} y=F(x,{Wi})+x

其中:

  • x\mathbf{x}x 是输入向量
  • F(x,{Wi})\mathcal{F}(\mathbf{x}, \{W_i\})F(x,{Wi}) 是残差函数,通常是两个或三个卷积层堆叠后的映射
  • y\mathbf{y}y 是残差块的输出
  • 加法操作称为“跳跃连接(skip connection)”或“快捷连接(shortcut connection)”

具体结构:

比如一个典型的残差块是两层卷积 + BN + ReLU:

F(x)=ReLU(BN(W2∗ReLU(BN(W1∗x)))) \mathcal{F}(\mathbf{x}) = \text{ReLU}(BN(W_2 * \text{ReLU}(BN(W_1 * \mathbf{x})))) F(x)=ReLU(BN(W2ReLU(BN(W1x))))

然后加上输入:

y=F(x)+x \mathbf{y} = \mathcal{F}(\mathbf{x}) + \mathbf{x} y=F(x)+x

如果维度不匹配,输入 x\mathbf{x}x 会通过 1×1 卷积变换以匹配残差函数输出的维度。


四、ResNet 的作用与优势

  • 缓解梯度消失和退化问题:残差连接保证了梯度可以直接反传,便于训练非常深的网络。
  • 支持训练非常深的网络:ResNet 能训练超过 100 层甚至更深的网络,如 ResNet-50、ResNet-101、ResNet-152。
  • 提升网络性能:提升了图像识别、目标检测、语义分割等任务的性能。
  • 促进网络设计多样性:后续许多网络架构(如 DenseNet、EfficientNet)都借鉴了残差思想。

五、常用的 ResNet 变体和结构

  • ResNet-18 / ResNet-34:较浅的网络,适合资源受限或较简单任务
  • ResNet-50 / ResNet-101 / ResNet-152:瓶颈结构,较深的网络,适合大型数据集和复杂任务
  • ResNeXt、SE-ResNet:在 ResNet 基础上做了改进,提高性能

六、ResNet 常用场景

  • 图像分类(Image Classification):ImageNet 等大规模数据集的分类任务
  • 目标检测(Object Detection):作为特征提取骨干网络,配合 Faster R-CNN、YOLO 等检测框架
  • 语义分割(Semantic Segmentation):如 DeepLab、FCN 等分割模型的主干网络
  • 图像生成和风格迁移
  • 医学图像分析
  • 视频分析
  • 自然语言处理中部分任务的特征提取

七、总结

方面 说明
提出时间 2015 年
关键创新 残差连接(skip connection)
核心公式 y=F(x)+x\mathbf{y} = \mathcal{F}(\mathbf{x}) + \mathbf{x}y=F(x)+x
作用 解决深层网络训练困难,促进训练更深网络
应用领域 图像分类、目标检测、语义分割等视觉任务

八、代码

代码基于 PyTorch 的 ResNet ,使用官方 torchvision 库里自带的 CIFAR-10 数据集来训练一个 ResNet-18。

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torchvision.models import resnet18

def main():
    # 设备配置(GPU 优先)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print("Using device:", device)

    # 数据预处理和增强
    transform_train = transforms.Compose([
        transforms.RandomCrop(32, padding=4),  # 随机裁剪
        transforms.RandomHorizontalFlip(),     # 随机水平翻转
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), 
                             (0.2023, 0.1994, 0.2010))
    ])

    transform_test = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), 
                             (0.2023, 0.1994, 0.2010))
    ])

    # 下载并加载 CIFAR-10 训练集和测试集
    trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                            download=True, transform=transform_train)
    trainloader = torch.utils.data.DataLoader(trainset, batch_size=128,
                                              shuffle=True, num_workers=2)

    testset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                           download=True, transform=transform_test)
    testloader = torch.utils.data.DataLoader(testset, batch_size=100,
                                             shuffle=False, num_workers=2)

    # 加载预定义的 ResNet-18 模型,修改最后一层全连接层以适应 CIFAR-10(10 类)
    model = resnet18(pretrained=False)
    model.fc = nn.Linear(model.fc.in_features, 10)  # CIFAR-10 有 10 个类别
    model = model.to(device)

    # 定义损失函数和优化器
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(model.parameters(), lr=0.1, momentum=0.9, weight_decay=5e-4)
    scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.1)  # 学习率衰减

    # 训练函数
    def train(epoch):
        model.train()
        running_loss = 0.0
        for batch_idx, (inputs, targets) in enumerate(trainloader):
            inputs, targets = inputs.to(device), targets.to(device)

            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, targets)
            loss.backward()
            optimizer.step()

            running_loss += loss.item()
            if batch_idx % 100 == 99:  # 每 100 个 batch 输出一次日志
                print(f"[Epoch {epoch+1}, Batch {batch_idx+1}] loss: {running_loss/100:.3f}")
                running_loss = 0.0

    # 测试函数
    def test():
        model.eval()
        correct = 0
        total = 0
        with torch.no_grad():
            for inputs, targets in testloader:
                inputs, targets = inputs.to(device), targets.to(device)
                outputs = model(inputs)
                _, predicted = outputs.max(1)
                total += targets.size(0)
                correct += predicted.eq(targets).sum().item()
        acc = 100. * correct / total
        print(f"Test Accuracy: {acc:.2f}%")
        return acc

    # 主训练循环
    best_acc = 0
    for epoch in range(1, 51):  # 训练 50 个 epoch
        train(epoch)
        acc = test()
        scheduler.step()

        # 保存最优模型
        if acc > best_acc:
            best_acc = acc
            torch.save(model.state_dict(), "best_resnet18_cifar10.pth")

    print(f"Best Test Accuracy: {best_acc:.2f}%")

if __name__ == '__main__':
    main()

Logo

有“AI”的1024 = 2048,欢迎大家加入2048 AI社区

更多推荐