梯度累积策略详解
显存换时间:在显存有限时,模拟大 batch 训练。稳定性提升:更稳定的梯度估计,加速收敛。灵活性高:适用于 NLP、CV、大模型训练等场景。“梯度累积 = 小步快跑,积少成多,最终一步到位更新模型。
梯度累积(Gradient Accumulation)是一种在显存受限的情况下,通过分批计算梯度并延迟更新参数的技术,模拟更大批量(batch size)训练的效果。以下是其核心原理、数学公式、生活类比及注意事项的详细分析。
一、梯度累积的核心思想
当显存不足以加载一个完整的批量(batch)时,可以将批量拆分为多个小批次(mini-batch),逐个计算梯度并累积,最终一次性更新参数。这相当于:
“小步快跑,积少成多”:每次只处理小批量数据,但通过多次累积,最终达到与大批次训练等效的效果。
二、数学原理
假设目标批量大小为 BB,但显存只能支持小批量大小 bb,则需要累积 k=Bbk=bB 次梯度后更新参数。
-
损失函数:
对于第 iii 个小批次,损失为 LiL_iLi,其梯度为 ∇θLi∇_θL_i∇θLi。
累积 kk 次后的总梯度为:∇θacc=∑i=1k∇θLi∇_θ^{acc}=∑_{i=1}^k∇_θL_i∇θacc=∑i=1k∇θLi
参数更新公式为:
θ←θ−η⋅(1k∑i=1k∇θLi)θ←θ−η⋅(\frac{1}{k}∑_{i=1}^k∇_θL_i)θ←θ−η⋅(k1∑i=1k∇θLi)
其中 ηηη 是学习率,1k\frac{1}{k}k1 表示梯度平均(等效于大 batch 的梯度)。
-
损失缩放:
在代码实现中,通常对损失函数进行缩放:Liscaled=LikL_i^{scaled}=\frac{L_i}{k}Liscaled=kLi
这样反向传播得到的梯度会自动满足 ∇θacc=∑i=1k∇θLiscaled∇_θ^{acc}=∑_{i=1}^k∇_θL_i^{scaled}∇θacc=∑i=1k∇θLiscaled,无需手动归一化。
三、工作流程(步骤分解)
-
初始化:
清空梯度:optimizer.zero_grad()。 -
前向传播:
对每个小批次 iii,计算输出 yi=f(xi)y_i=f(x_i)yi=f(xi) 和损失 LiL_iLi。 -
反向传播:
计算梯度:∇θLi∇_θL_i∇θLi,并累加到全局梯度中:global_grad+=∇θLi+=∇_θL_i+=∇θLi
-
参数更新:
当累积满 kk 次后,执行一次参数更新:θ←θ−η⋅θ←θ−η⋅θ←θ−η⋅global_grad
并重置梯度:
optimizer.zero_grad()。
四、生活类比(简单易懂)
例1:搬砖盖房
- 目标:一次搬运一整车砖(大 batch),但小推车(显存)装不下。
- 解决方案:分4次搬运,每次搬1/4车砖(小 batch),但先不砌墙(不更新参数),等4次都搬完后,再统一砌墙(更新参数)。
- 效果:虽然每次搬得少,但最终砌墙的效果等价于一次搬运整车。
例2:调制鸡尾酒
- 目标:用一整瓶酒调制一杯鸡尾酒(大 batch),但调酒壶太小(显存不足)。
- 解决方案:将酒分成4小瓶(小 batch),每次倒一瓶进壶里调味道(前向+反向),但不倒出(不更新参数)。4次都调完后,再根据总味道统一调出一杯酒(参数更新)。
- 效果:分次调制,最终味道与一次性调制一致。
五、代码实现(PyTorch示例)
model = MyModel()
optimizer = Adam(model.parameters(), lr=1e-4)
loss_fn = nn.CrossEntropyLoss()
accumulation_steps = 4 # 相当于 batch_size = 4 * mini_batch_size
optimizer.zero_grad() # 初始化累积梯度
for i, (inputs, labels) in enumerate(dataloader):
outputs = model(inputs)
loss = loss_fn(outputs, labels) / accumulation_steps # 损失缩放
loss.backward() # 反向传播,梯度自动累加
if (i + 1) % accumulation_steps == 0:
optimizer.step() # 更新参数
optimizer.zero_grad() # 清零累积梯度
六、注意事项与最佳实践
-
学习率调整:
使用梯度累积后,等效 batch size 增大为 k×bk×b。通常需要线性缩放学习率:ηnew=η×kη_{new}=η×kηnew=η×k
-
Batch Normalization(BN)问题:
BN 依赖 batch 统计,若小 batch 太小(如 b=1b=1b=1),统计量会不准确。可改用 GroupNorm 或 SyncBN。 -
训练时间增加:
虽然显存降低,但训练时间会增加约 kkk 倍(更多前向/反向传播)。 -
梯度裁剪:
可在累积完成后进行梯度裁剪,防止梯度爆炸。
七、总结
梯度累积的核心优势在于:
- 显存换时间:在显存有限时,模拟大 batch 训练。
- 稳定性提升:更稳定的梯度估计,加速收敛。
- 灵活性高:适用于 NLP、CV、大模型训练等场景。
一句话总结:
“梯度累积 = 小步快跑,积少成多,最终一步到位更新模型。”
八、扩展思考
- 与分布式训练结合:梯度累积可与多设备并行训练结合,进一步降低单设备显存占用。
- 动态调整累积步数:根据显存剩余动态调整 kkk,平衡训练速度与显存占用。
更多推荐



所有评论(0)