pytorch DDP加速之gradient accumulation设置
pytorch DDPhttps://zhuanlan.zhihu.com/p/250471767gradient accumulation在梯度累加的情况下,假设一次梯度累加循环有accumulation_steps个step,每次梯度累加循环会进行K次 all_reduce,但事实上,每次梯度累加循环只会有一次 optimizer.step(),即只应用一次参数更新,这意味着在每一次梯度累加循
pytorch DDP
参考:https://zhuanlan.zhihu.com/p/250471767
GPU高效通信算法-Ring Allreduce: https://www.zhihu.com/question/57799212/answer/612786337
梯度累积: https://www.zhihu.com/question/303070254/answer/573037166
gradient accumulation
在梯度累加的情况下,假设一次梯度累加循环有accumulation_steps个step,每次梯度累加循环会进行K次 all_reduce,但事实上,每次梯度累加循环只会有一次 optimizer.step(),即只应用一次参数更新,这意味着在每一次梯度累加循环中,我们其实只要进行一次gradient all_reduce即可满足要求,有accumulation_steps - 1次all_reduce被浪费了。而每次 all_reduce的时间成本是比较高的。 解决问题的思路在于,对前accumulation_steps - 1次step取消其梯度同步。DDP给我们提供了一个暂时取消梯度同步的context函数 no_sync()
(源代码)。在这个context下,DDP不会进行梯度同步。
for epoch in range(epoches):
for j, data in enumerate(train_loader):
# 前accumulation_steps - 1个step,不进行梯度同步,累积梯度。
if accumulation_count % accumulation_steps != 0:
with model.no_sync():
loss = model(data)
loss = loss/accumulation_steps
loss.backward()
else:
loss = model(data)
loss = loss / accumulation_steps
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), 1)
model_optimizer.step()
if model_scheduler is not None:
model_scheduler.step()
model_optimizer.zero_grad()
accumulation_count += 1
优雅的写法(兼容单卡和DDP模式):
from contextlib import nullcontext
# 如果python版本小于3.7,则使用下面这个:
# from contextlib import suppress as nullcontext
if local_rank != -1:
model = DDP(model)
optimizer.zero_grad()
for epoch in range(epoches):
for i, data in enumerate(train_loader):
# 只在DDP模式下,轮数不是accumulation_steps整数倍的时候使用no_sync
mcontext = model.no_sync if local_rank != -1 and accumulation_count % accumulation_steps != 0 else nullcontext
with mcontext():
loss = model(data)
loss = loss / accumulation_steps
loss.backward()
# 轮数为accumulation_steps整数倍的时候,传播梯度,并更新参数
if accumulation_count % accumulation_steps == 0:
optimizer.step()
if model_scheduler is not None:
model_scheduler.step()
optimizer.zero_grad()
accumulation_count += 1
更多推荐
所有评论(0)