【RL 】slime Train 函数的流程
好的,让我们来描绘一个train函数的完整、详细的执行过程。我们将设定一个具体的场景,并一步步追踪代码的执行流,结合我们之前讨论的所有概念:流水线并行、数据并行、DDP包装器、1F1B调度等。
好的,让我们来描绘一个train函数的完整、详细的执行过程。我们将设定一个具体的场景,并一步步追踪代码的执行流,结合我们之前讨论的所有概念:流水线并行、数据并行、DDP包装器、1F1B调度等。
场景设定
- 硬件: 4 个 GPU (Rank 0, 1, 2, 3)
- 模型: 一个大型语言模型,总共有 24 层。
- 并行策略:
- 2 路流水线并行 (PP): 模型被切成 2 个阶段 (Stage)。
- 2 路数据并行 (DP): 每个阶段有 2 个副本。
- 调度策略:
- 非交错流水线 (为了简化说明,我们先用最简单的调度)。
- 微批次数量: 4 个 (Micro-batches)。
GPU 分配如下:
| 流水线阶段 (PP Stage) | 数据并行组 (DP Group) | GPU Ranks | 模型块 |
|---|---|---|---|
| Stage 0 | DP Group 0 | [Rank 0, Rank 2] |
模型的 1-12 层 |
| Stage 1 | DP Group 1 | [Rank 1, Rank 3] |
模型的 13-24 层 |
model: Sequence[DDP] 参数的样子:
- 在 Rank 0 和 Rank 2 上:
model = [DDP(model_chunk_stage_0)](列表里只有1个元素) - 在 Rank 1 和 Rank 3 上:
model = [DDP(model_chunk_stage_1)](列表里也只有1个元素)
train 函数执行全过程追踪
我们主要站在 Rank 0 的视角,并描述与其他 Rank 的交互。train 函数被调用时,假设 rollout_id=0, num_microbatches=[4]。
第 1 步: train 函数入口和准备工作
def train(..., model: Sequence[DDP], ...):
args = get_args()
# 1. 重置数据迭代器 (在所有 Rank 上)
for iterator in data_iterator:
iterator.reset()
# 2. 设置为训练模式 (在所有 Rank 上)
# 对于 Rank 0, 这会调用 model[0].train(), 即 DDP(stage_0).train()
for model_module in model:
model_module.train()
# 3. 配置 Megatron 引擎 (在所有 Rank 上)
# 将 DDP 实例的梯度同步/缩放等方法与引擎绑定
config = get_model_config(model[0])
config.grad_scale_func = optimizer.scale_loss
config.no_sync_func = [model_chunk.no_sync for model_chunk in model] # 提供临时关闭梯度同步的方法
# ...
# 4. 手动 GC (如果启用)
gc.disable()
gc.collect()
# 5. 循环处理所有训练步骤 (这里只有一个步骤,step_id=0)
for step_id in range(num_steps_per_rollout):
# ... 进入 train_one_step ...
第 2 步: train_one_step - 流水线执行
train_one_step 的核心是调用 get_forward_backward_func() 返回的 forward_backward_step 函数。这个函数是整个过程的大脑。
时间线 (T 表示一个微批次在一个阶段上的计算时间):
| 时间 | Rank 0 (S0, DP0) & Rank 2 (S0, DP0) | Rank 1 (S1, DP1) & Rank 3 (S1, DP1) |
|---|---|---|
| T1 (启动) | [Fwd] 计算 mb1 (微批次1) 的前向传播。[Comm] 计算完后,Rank 0 -> Rank 1 发送 mb1 的激活值,Rank 2 -> Rank 3 发送 mb1 的激活值。 |
[Idle] 等待数据。 |
| T2 (启动) | [Fwd] 计算 mb2 的前向传播。[Comm] 发送给 Rank 1/3。 |
[Fwd] 接收到 mb1 数据,计算 mb1 的前向传播。[Comm] (无,因为是最后阶段)。 |
| T3 (稳态) | [Fwd] 计算 mb3 的前向传播。[Comm] 发送。 |
[Fwd] 接收到 mb2 数据,计算 mb2 的前向传播。 |
| T4 (稳态) | [Fwd] 计算 mb4 的前向传播。[Comm] 发送。 |
[Fwd] 接收到 mb3 数据,计算 mb3 的前向传播。 [Bwd] 在 Rank 1/3 上,mb1 已完成前向,开始计算其损失,并进行反向传播。 |
| T5 (稳态) | [Idle Fwd] 没有新的前向任务。[Bwd] 接收到 Rank 1/3 发来的 mb1 的梯度,开始计算 mb1 的反向传播。 梯度同步: 当 mb1 在 Rank 0/2 上的梯度计算完毕,DDP 的钩子触发,Rank 0 和 Rank 2 之间进行梯度的 All-Reduce。 |
[Fwd] 接收 mb4 数据,计算前向。[Bwd] 计算 mb2 的反向传播。 |
| T6 (排空) | [Idle Fwd]。[Bwd] 接收 mb2 梯度,计算反向,并与 Rank 2 同步梯度。 |
[Idle Fwd]。[Bwd] 计算 mb3 反向传播。[Comm] 将 mb3 梯度发给 Rank 0/2。 |
| T7 (排空) | [Idle]。[Bwd] 接收 mb3 梯度,计算反向,并与 Rank 2 同步梯度。 |
[Idle]。[Bwd] 计算 mb4 反向传播。[Comm] 将 mb4 梯度发给 Rank 0/2。 |
| T8 (排空) | [Idle]。[Bwd] 接收 mb4 梯度,计算反向,并与 Rank 2 同步梯度。 |
[Idle] 所有任务完成。 |
关键细节解释:
- [Fwd]: 调用
train_one_step中定义的forward_step函数,它会获取数据,执行模型块的前向计算。 - [Bwd]: Megatron 引擎自动处理反向传播。当一个模型块的输出的
.backward()被调用时,链式法则会触发该模型块参数的梯度计算。 - 梯度同步:
DDP包装器是关键。它内部的钩子在梯度计算完成后自动触发,并启动在**数据并行组(DP Group)**内的 All-Reduce 通信。例如,Rank 0和Rank 2在DP Group 0内同步梯度;Rank 1和Rank 3在DP Group 1内同步梯度。
第 3 步: train_one_step - 梯度处理和优化器步骤
当所有微批次的反向传播和梯度同步都完成后(T8之后):
-
梯度有效性检查 (
valid_step): 检查所有 Rank 上的梯度是否存在inf或NaN。这是一个同步点,所有 Rank 需要就valid_step的结果达成一致。 -
优化器步骤 (
optimizer.step()):- 如果
valid_step为True,则所有 Rank 上的优化器都会执行step()。 - 优化器使用已经 All-Reduce 过的、同步好的梯度来更新它所管理的模型参数。
- 重要:
Rank 0和Rank 2的优化器只更新 Stage 0 的参数;Rank 1和Rank 3的优化器只更新 Stage 1 的参数。 - 由于梯度在各自的 DP 组内是相同的,所以
Rank 0和Rank 2上的 Stage 0 参数更新后仍然保持一致。Rank 1和Rank 3上的 Stage 1 参数也一样。
- 如果
-
学习率调度器步骤 (
opt_param_scheduler.step()):- 更新学习率,为下一个训练迭代做准备。
-
梯度清零 (
optimizer.zero_grad()):- 清除
DDP缓冲区中的梯度,为下一个迭代做准备。
- 清除
第 4 步: train_one_step 结束和 train 函数日志记录
train_one_step 返回 loss_dict 和 grad_norm。
# ... 回到 train 函数的循环中 ...
loss_dict, grad_norm = train_one_step(...)
# 5. 日志记录
# 只有特定的 Rank (通常是 DP rank 0, TP rank 0, PP last stage) 负责打印日志
if mpu.get_data_parallel_rank(...) == 0 and ...:
# log_dict 包含了 'train/pg_loss', 'train/entropy_loss' 等
log_dict = {...}
log_dict["train/step"] = accumulated_step_id
# 使用 wandb 或其他工具记录日志
tracking_utils.log(args, log_dict, step_key="train/step")
logger.info(f"step {accumulated_step_id}: {log_dict}")
因为我们的场景中只有一个 step_id=0,所以 for 循环结束。
第 5 步: train 函数收尾
# ...
# 6. 禁用前向钩子 (如果之前启用了)
if pre_hook_enabled:
disable_forward_pre_hook(model)
这确保了在 train 函数之外,模型处于一个干净、可预测的状态。
至此,一次完整的 train 函数调用执行完毕。它完成了一个训练迭代(iteration),处理了 num_microbatches 个微批次的数据,对模型进行了一次权重更新。整个过程是数据并行、流水线并行和复杂调度逻辑协同工作的结果。
好的,让我们来详细解释 num_microbatches 是什么,以及它为什么在流水线并行中如此关键。
num_microbatches 是什么?
num_microbatches(微批次数量)是指在一次完整的训练迭代(也称为一个 “global batch”)中,我们将一个大的训练批次(batch)拆分成的小数据块的数量。
简单来说:
Global Batch = num_microbatches × micro_batch_size
global_batch_size: 这是我们从用户的角度看到的批次大小,即模型在进行一次权重更新(optimizer step)前所看到的样本总数。micro_batch_size: 这是每次送入流水线进行处理的最小数据单元的大小。通常这个值很小(比如 1, 2, 4, 8),小到可以轻松放入单个 GPU 的显存中。num_microbatches: 就是为了凑够一个global_batch,我们需要处理多少个micro_batch。
为什么需要 num_microbatches?—— 为了“喂饱”流水线
在没有流水线并行的情况下,我们可以直接处理一个大的 batch。但一旦引入了流水线并行,情况就变了。
回想一下我们的汽车装配流水线比喻。如果我们一次只造一辆车(num_microbatches = 1),会发生什么?
| 时间步 | GPU 0 (S0) | GPU 1 (S1) | GPU 2 (S2) | GPU 3 (S3) |
|---|---|---|---|---|
| 1 | F_0_1 | 空闲 | 空闲 | 空闲 |
| 2 | 空闲 | F_1_1 | 空闲 | 空闲 |
| 3 | 空闲 | 空闲 | F_2_1 | 空闲 |
| 4 | 空闲 | 空闲 | 空闲 | F_3_1 |
在任何一个时间点,只有一个 GPU 在工作! 其他所有 GPU 都在空闲等待。这完全失去了并行的意义,硬件利用率极低。
num_microbatches 的作用就是让流水线“流动”起来。
通过将大 batch 切分成多个 micro-batch,并连续地将它们送入流水线,我们就可以让所有 GPU 同时处理不同 micro-batch 的不同阶段,从而实现真正的并行计算。
num_microbatches 的值越大,流水线中的“稳态”(所有 GPU 都在忙碌)持续的时间就越长,启动和排空阶段的“气泡”所占的比例就越小,整体效率就越高。
举例说明
假设我们的训练配置如下:
global_batch_size= 32micro_batch_size= 4- 流水线并行度 (PP size) = 4
- 数据并行度 (DP size) = 8
- 总 GPU 数量 = 32
1. 计算 num_microbatches
num_microbatches = global_batch_size / micro_batch_sizenum_microbatches = 32 / 4 = 8
这意味着,为了完成一个 global_batch 的处理,我们需要将 32 个样本分成 8 个微批次,每个微批次包含 4 个样本。
2. train 函数接收的参数
在 slime/core/trainers.py 的代码中,num_microbatches 参数的形式是一个列表,例如 [8] 或者在更复杂的场景下可能是 [4, 4]。在我们的例子里,train 函数会接收到 num_microbatches=[8]。
这个列表 [8] 告诉 train 函数,它需要执行一个包含 8 个微批次的训练步骤。
3. 训练过程
- 数据准备: 训练框架会准备一个包含 32 个样本的
global_batch。 - 数据切分:
forward_backward_step函数在执行时,会从这个global_batch中依次取出 8 个micro-batch,每个micro-batch包含 4 个样本。 - 送入流水线:
- 第 1 个
micro-batch(样本 1-4) 进入流水线。 - 紧接着,第 2 个
micro-batch(样本 5-8) 进入流水线。 - …
- 第 8 个
micro-batch(样本 29-32) 进入流水线。
- 第 1 个
- 梯度累积:
- 每当一个
micro-batch完成了前向和反向传播,它计算出的梯度会被累积起来。 optimizer.step()不会在每个micro-batch之后执行。- 只有当所有 8 个
micro-batch的梯度都计算并累积完毕后,才进行一次optimizer.step()。
- 每当一个
num_microbatches 与梯度累积的关系
从效果上看,处理 num_microbatches 个 micro-batch 等价于梯度累积(Gradient Accumulation)。
- 在没有流水线并行的情况下,我们可以通过一个循环来实现梯度累积:
optimizer.zero_grad() for i in range(num_accumulation_steps): # 处理一个小 batch loss = model(small_batch) loss.backward() # 梯度会累积 optimizer.step() # 用累积的梯度更新一次 - 在流水线并行中,
num_microbatches的概念与此非常相似,但它是被整合在流水线调度中的。框架会自动处理这num_microbatches个微批次的梯度累积。
num_microbatches 的选择
选择一个合适的 num_microbatches 值是一个权衡:
- 值太小:
- 流水线气泡占比过高,硬件利用率低,训练速度慢。
- 例如,如果
num_microbatches<流水线并行度,那么流水线永远不会被填满。
- 值太大:
- 硬件利用率高,接近理论上限。
- 但是,需要存储更多中间激活值(每个正在流水线中“飞行”的微批次都需要存储其激活值),导致显存占用增加。
- 可能会稍微影响模型的收敛动态(因为批次大小的统计特性变了),但通常影响不大。
经验法则:
通常将 num_microbatches 设置为流水线并行度 (PP size) 的数倍,例如 2 倍、4 倍或 8 倍,以在硬件利用率和显存占用之间找到一个好的平衡点。例如,如果 PP size = 8,那么 num_microbatches 至少应该是 8,通常设置为 16 或 32 会获得不错的效率。
更多推荐



所有评论(0)