梯度累积(Gradient Accumulation)是一种在显存受限的情况下,通过分批计算梯度并延迟更新参数的技术,模拟更大批量(batch size)训练的效果。以下是其核心原理、数学公式、生活类比及注意事项的详细分析。


一、梯度累积的核心思想

当显存不足以加载一个完整的批量(batch)时,可以将批量拆分为多个小批次(mini-batch),逐个计算梯度并累积,最终一次性更新参数。这相当于:

“小步快跑,积少成多”:每次只处理小批量数据,但通过多次累积,最终达到与大批次训练等效的效果。


二、数学原理

假设目标批量大小为 BB,但显存只能支持小批量大小 bb,则需要累积 k=Bbk=bB​ 次梯度后更新参数。

  1. 损失函数
    对于第 iii 个小批次,损失为 LiL_iLi,其梯度为 ∇θLi∇_θL_iθLi
    累积 kk 次后的总梯度为:

    ∇θacc=∑i=1k∇θLi∇_θ^{acc}=∑_{i=1}^k∇_θL_iθacc=i=1kθLi

    参数更新公式为:

    θ←θ−η⋅(1k∑i=1k∇θLi)θ←θ−η⋅(\frac{1}{k}∑_{i=1}^k∇_θL_i)θθη(k1i=1kθLi)

    其中 ηηη 是学习率,1k\frac{1}{k}k1​ 表示梯度平均(等效于大 batch 的梯度)。

  2. 损失缩放
    在代码实现中,通常对损失函数进行缩放:

    Liscaled=LikL_i^{scaled}=\frac{L_i}{k}Liscaled=kLi

    这样反向传播得到的梯度会自动满足 ∇θacc=∑i=1k∇θLiscaled∇_θ^{acc}=∑_{i=1}^k∇_θL_i^{scaled}θacc=i=1kθLiscaled,无需手动归一化。


三、工作流程(步骤分解)
  1. 初始化
    清空梯度:optimizer.zero_grad()

  2. 前向传播
    对每个小批次 iii,计算输出 yi=f(xi)y_i=f(x_i)yi=f(xi) 和损失 LiL_iLi

  3. 反向传播
    计算梯度:∇θLi​∇_θL_i​θLi,并累加到全局梯度中:

    global_grad+=∇θLi+=∇_θL_i+=θLi

  4. 参数更新
    当累积满 kk 次后,执行一次参数更新:

    θ←θ−η⋅θ←θ−η⋅θθηglobal_grad

    并重置梯度:optimizer.zero_grad()


四、生活类比(简单易懂)
例1:搬砖盖房
  • 目标:一次搬运一整车砖(大 batch),但小推车(显存)装不下。
  • 解决方案:分4次搬运,每次搬1/4车砖(小 batch),但先不砌墙(不更新参数),等4次都搬完后,再统一砌墙(更新参数)。
  • 效果:虽然每次搬得少,但最终砌墙的效果等价于一次搬运整车。
例2:调制鸡尾酒
  • 目标:用一整瓶酒调制一杯鸡尾酒(大 batch),但调酒壶太小(显存不足)。
  • 解决方案:将酒分成4小瓶(小 batch),每次倒一瓶进壶里调味道(前向+反向),但不倒出(不更新参数)。4次都调完后,再根据总味道统一调出一杯酒(参数更新)。
  • 效果:分次调制,最终味道与一次性调制一致。

五、代码实现(PyTorch示例)
model = MyModel()
optimizer = Adam(model.parameters(), lr=1e-4)
loss_fn = nn.CrossEntropyLoss()

accumulation_steps = 4  # 相当于 batch_size = 4 * mini_batch_size
optimizer.zero_grad()   # 初始化累积梯度

for i, (inputs, labels) in enumerate(dataloader):
    outputs = model(inputs)
    loss = loss_fn(outputs, labels) / accumulation_steps  # 损失缩放
    loss.backward()  # 反向传播,梯度自动累加

    if (i + 1) % accumulation_steps == 0:
        optimizer.step()        # 更新参数
        optimizer.zero_grad()   # 清零累积梯度

六、注意事项与最佳实践
  1. 学习率调整
    使用梯度累积后,等效 batch size 增大为 k×bk×b。通常需要线性缩放学习率

    ηnew=η×kη_{new}=η×kηnew=η×k

  2. Batch Normalization(BN)问题
    BN 依赖 batch 统计,若小 batch 太小(如 b=1b=1b=1),统计量会不准确。可改用 GroupNormSyncBN

  3. 训练时间增加
    虽然显存降低,但训练时间会增加约 kkk 倍(更多前向/反向传播)。

  4. 梯度裁剪
    可在累积完成后进行梯度裁剪,防止梯度爆炸。


七、总结

梯度累积的核心优势在于:

  • 显存换时间:在显存有限时,模拟大 batch 训练。
  • 稳定性提升:更稳定的梯度估计,加速收敛。
  • 灵活性高:适用于 NLP、CV、大模型训练等场景。

一句话总结:

“梯度累积 = 小步快跑,积少成多,最终一步到位更新模型。”


八、扩展思考
  • 与分布式训练结合:梯度累积可与多设备并行训练结合,进一步降低单设备显存占用。
  • 动态调整累积步数:根据显存剩余动态调整 kkk,平衡训练速度与显存占用。
Logo

有“AI”的1024 = 2048,欢迎大家加入2048 AI社区

更多推荐