【大模型训练】roll GRPO源码学习 ref_log_probs 计算原理
好的,我们来详细梳理一下 RLVR (Reinforcement Learning from Video/Valuable Responses) 的完整训练流程,并结合你提供的代码进行深入解析。假设我们现在有一批prompt,我们的目标是训练一个 Actor 模型,让它能根据这些prompt生成高质量的response。整个过程可以分解为以下几个宏观阶段,这与promptreferencecrit
https://github.com/alibaba/ROLL/blob/434cabff689d5275704ca4de11d6d7359bf22c8b/roll/distributed/strategy/megatron_strategy.py
好的,我们来详细梳理一下 RLVR (Reinforcement Learning from Video/Valuable Responses) 的完整训练流程,并结合你提供的代码进行深入解析。
假设我们现在有一批 prompt,我们的目标是训练一个 Actor 模型,让它能根据这些 prompt 生成高质量的 response。
整个过程可以分解为以下几个宏观阶段,这与 RLVRPipeline.run() 方法中的主循环一一对应:
阶段 0: 准备工作 (在 RLVRPipeline.__init__ 中完成)
- 加载数据: 从文件中加载
prompt数据集。 - 初始化模型: 初始化所有需要的模型集群:
actor_train,actor_infer,reference,critic, 和reward。它们初始时可能使用相同的预训练模型权重。 - 状态卸载: 所有模型权重默认被卸载到 CPU,以节省 GPU 显存。
阶段 1: Rollout - 生成经验数据 (step_generate)
这是 PPO 算法中的“与环境交互”阶段。在 RLVR 中,“环境”就是由 prompt 和 reward 模型构成的虚拟环境。
-
权重同步:
- 代码:
self.model_update(global_step) - 目的: 将
actor_train(正在学习的模型)最新的权重同步到actor_infer(负责生成的模型)。这确保了我们接下来生成的经验数据是基于当前最新的策略。
- 代码:
-
生成回复 (Response Generation):
- 代码:
self.generate_schedulers[domain].get_batch.remote(...)->ActorWorker.generate/ActorWorker.start_server - 过程:
RLVRPipeline从数据集中取出一批prompt。- 将这批
prompt发送给actor_infer集群。 ActorWorker收到请求后,使用state_offload_manger将模型加载到 GPU。- 调用
self.strategy.generate()方法,让模型根据prompt生成response。 - 生成的结果是一个包含
prompts、responses、完整的input_ids和attention_mask等信息的DataProto对象。
- 代码:
-
奖励评估 (Reward Evaluation):
- 代码:
scheduler内部逻辑会调用RewardWorker.compute_rewards。 - 过程:
- 上一步生成的
input_ids(prompt + response)被发送到reward模型集群。 RewardWorker收到请求,加载模型,并对序列进行前向传播。- 它会输出一个分数,这个分数就是对这次生成质量的原始奖励信号(
response_level_rewards)。
- 上一步生成的
- 代码:
此阶段结束时,我们得到了一批“经验数据”,它至少包含:
prompts: 原始提示responses: Actor 生成的回复input_ids: prompt + response 拼接后的 token IDattention_mask: 对应的注意力掩码response_level_rewards: 奖励模型给出的原始分数
阶段 2: 评估与计算 - 为训练准备数据 (cal_ref_log_probs, cal_old_log_probs_values 等)
这个阶段的目标是为我们刚刚收集到的经验数据计算所有训练需要用到的值,比如优势函数(Advantage)。
-
计算参考模型对数概率 (
ref_log_probs):- 代码:
self.reference.compute_log_probs(batch, ...)->ActorWorker.compute_log_probs - 目的: 计算 KL 散度惩罚。
- 过程: 将
input_ids发送给reference集群。ActorWorker(此时扮演 reference 角色)计算并返回序列中每个 token 在固定不变的基础模型策略下的对数概率ref_log_probs。
- 代码:
-
计算旧策略信息 (
old_log_probs和values):- 代码:
self.actor_train.compute_log_probs(batch, ...)->ActorWorker.compute_log_probsself.critic.compute_values(batch, ...)->CriticWorker.compute_values
- 目的: 获取 PPO 训练和优势估计所需的数据。
- 过程: 这两个计算是并行执行的,以提高效率。
old_log_probs: 将input_ids发送给actor_train集群。ActorWorker计算并返回序列中每个 token 在生成数据时所用的那个旧策略(即同步前的actor_train权重)下的对数概率。values: 将input_ids发送给critic集群。CriticWorker计算并返回序列中每个 token 的价值估计V(s_t)。
- 代码:
-
计算最终奖励和优势 (Advantage Estimation):
- 代码:
compute_token_reward和compute_advantage - 过程:
- KL 惩罚: 在
compute_token_reward中,我们计算 KL 散度:kl = old_log_probs - ref_log_probs。 - 每步奖励 (Token-level Reward): 最终的每一步奖励由两部分组成:
token_rewards = -kl_coef * kl + reward_model_score。其中,reward_model_score只在序列的最后一步有值,其余步为 0。 - 优势计算 (GAE):
compute_advantage函数实现了 GAE (Generalized Advantage Estimation) 算法。它利用token_rewards和critic输出的values来计算优势函数advantages和回报returns。advantages_t = δ_t + (γλ)δ_{t+1} + ...- 其中
δ_t = reward_t + γ * value_{t+1} - value_t是 TD 误差。 returns_t = advantages_t + values_t。
- KL 惩罚: 在
- 代码:
此阶段结束时,我们的 DataProto 对象中已经包含了所有训练所需的数据:
input_idsattention_mask,response_maskold_log_probsref_log_probsadvantagesreturns
阶段 3: 训练 - 更新模型参数 (step_train)
这是最后一步,利用我们精心准备好的数据来更新 actor_train 和 critic 模型的权重。
-
Actor 训练:
- 代码:
self.actor_train.train_step(batch, ...)->ActorWorker.train_step->ActorWorker.loss_func - 过程:
actor_train的Worker收到包含所有计算好数据的batch。- 它会迭代
ppo_epochs次。在每次迭代中:ActorWorker对当前模型(新策略)进行前向传播,得到新的对数概率log_probs。- 进入
ActorWorker.loss_func计算 PPO 目标函数。
- 代码:
-
Critic 训练:
- 代码:
self.critic.train_step(batch, ...)->CriticWorker.train_step->CriticWorker.loss_func - 过程:
critic的Worker收到同样的batch。- 它对当前
critic模型进行前向传播,得到新的价值预测values。 - 进入
CriticWorker.loss_func,计算价值损失:vf_loss = 0.5 * mean((values - returns)^2)。 - 通过反向传播和优化器步骤更新
critic模型。
- 代码:
这两个训练过程是并行提交的,以最大化硬件利用率。
目标函数详解 (结合 ActorWorker.loss_func)
ActorWorker.loss_func 是整个 RLVR 流程的“心脏”,它定义了 Actor 模型到底在优化什么。
def loss_func(self, data: DataProto, output_tensor: torch.Tensor):
# 1. 准备数据
response_mask = data.batch["response_mask"][:, 1:].long() # 只关心 response 部分
ref_log_probs = data.batch["ref_log_probs"] # 参考模型的 logp
old_log_probs = data.batch["old_log_probs"] # 生成数据时的 logp
advantages = data.batch["advantages"] # GAE 计算出的优势
# 2. 计算新策略的 logp
log_probs = self.strategy.op_compute_log_probs( # 这是新策略的 logp
logits=output_tensor, input_ids=data.batch["input_ids"], attention_mask=data.batch["response_mask"]
)
# 3. 计算 PPO 的 Clipped Surrogate Objective
# 概率比:衡量新旧策略的差异
ratio = (log_probs - old_log_probs).exp()
# 无裁剪的目标函数项
surr1 = ratio * advantages
# 裁剪后的目标函数项,将 ratio 限制在 [1-ε, 1+ε] 范围内
surr2 = ratio.clamp(1 - pg_clip_low, 1 + pg_clip_high) * advantages
# PPO 损失:取两者的较小值,并加负号,因为我们要最大化目标函数
pg_loss = -torch.min(surr1, surr2)
# 对所有有效的 token 取平均
pg_loss = agg_loss(loss_mat=pg_loss, loss_mask=response_mask, ...)
# 4. 计算 KL 散度损失 (作为正则项)
kl_loss = compute_approx_kl(log_probs=log_probs, log_probs_base=ref_log_probs, ...)
kl_loss = agg_loss(loss_mat=kl_loss, loss_mask=response_mask, ...)
# 5. 计算总损失
if self.pipeline_config.use_kl_loss:
# PPO 损失 + KL 惩罚
total_loss = pg_loss + kl_loss * self.pipeline_config.kl_loss_coef
else:
total_loss = pg_loss
# 6. (可选) 熵损失 (作为正则项,鼓励探索)
if self.pipeline_config.entropy_loss_coef > 0:
entropy = self.strategy.op_compute_entropy(...)
entropy_loss = agg_loss(loss_mat=entropy, ...)
total_loss = total_loss - entropy_loss * self.pipeline_config.entropy_loss_coef
# 7. 返回总损失和用于监控的指标
return total_loss, pg_metrics
总结一下,Actor 的最终目标函数是:
Total Loss = Policy Loss (PPO) + kl_coef * KL Loss - entropy_coef * Entropy Loss
- Policy Loss (PPO):
E[min(ratio * A, clip(ratio, 1-ε, 1+ε) * A)]- 核心驱动力:如果一个动作的优势
A是正的(好动作),就提高采取这个动作的概率(增大ratio);如果是负的(坏动作),就降低概率。 clip的作用:防止单步更新过大,导致策略崩溃,从而保证训练的稳定性。
- 核心驱动力:如果一个动作的优势
- KL Loss:
E[KL(π_new || π_ref)]- 正则化项:惩罚新策略
π_new与原始基础模型π_ref之间的差异,防止模型为了迎合奖励而“忘掉”语言本身,保证生成文本的流畅性和多样性。
- 正则化项:惩罚新策略
- Entropy Loss:
E[H(π_new)]- 正则化项:鼓励策略保持一定的随机性(熵),避免策略过早地收敛到一个次优的确定性策略,有助于探索。
通过最小化这个 Total Loss,actor_train 模型就能学会在保持语言流畅性的同时,生成更能获得高奖励的回复。
计算ref_log_probs,请以prompt + response 举例,详细举例怎么算的
好的,我们来通过一个具体的例子,详细拆解 ref_log_probs 的计算过程。这对于理解 PPO 算法中的 KL 散度惩罚至关重要。
1. 目标与角色
- 目标: 计算由 Actor 模型 生成的完整序列(
prompt + response)在 Reference 模型 下的对数概率。 - Actor 模型 (
actor_infer): 这是一个正在被优化的、动态变化的策略。它生成了我们要评估的数据。 - Reference 模型 (
reference): 这是一个固定不变的基础模型(通常是未经任何 RL 微调的预训练模型)。它代表了我们不希望 Actor 模型偏离太远的“安全”基准。
2. 假设与输入数据
假设我们的词汇表非常简单,只包含 {<bos>, a, b, c, d, <eos>}。
ID 分别为 {0, 1, 2, 3, 4, 5}。
actor_infer 收到一个 prompt,并生成了一个 response:
- Prompt:
a bprompt_ids:[1, 2]
- Response:
c d <eos>response_ids:[3, 4, 5]
现在,我们将它们拼接起来,并加上起始符 <bos>,形成一个完整的序列,这就是我们要喂给 Reference 模型的 input_ids:
input_ids:<bos> a b c d <eos>- Token IDs:
[0, 1, 2, 3, 4, 5] - 序列长度: 6
- Token IDs:
这个 input_ids 会被打包成 DataProto 对象,并发送给 reference 集群的 compute_log_probs 方法。
3. ref_log_probs 计算步骤详解
reference 集群中的 ActorWorker 接收到 input_ids后,会执行以下步骤(主要在 ActorWorker.compute_log_probs -> forward_func_log_probs -> op_compute_log_probs 中):
步骤 1: Reference 模型前向传播
将 input_ids [0, 1, 2, 3, 4, 5] 输入到 Reference 模型中。模型会进行一次完整的前向传播,为序列中的每个位置都生成一个 logit 向量。这个 logit 向量的维度等于整个词汇表的大小。
logits = reference_model([0, 1, 2, 3, 4, 5])
假设模型输出了一个形状为 [1, 6, 6] 的 logits 张量 (batch_size=1, seq_len=6, vocab_size=6)。
| 位置 (Time Step) | 输入 Token | 输出 Logits (用来预测下一个 Token) |
|---|---|---|
| 0 | <bos> |
logits[0, 0, :] -> 预测 a |
| 1 | a |
logits[0, 1, :] -> 预测 b |
| 2 | b |
logits[0, 2, :] -> 预测 c |
| 3 | c |
logits[0, 3, :] -> 预测 d |
| 4 | d |
logits[0, 4, :] -> 预测 <eos> |
| 5 | <eos> |
logits[0, 5, :] -> 预测… (这个位置的预测无用) |
步骤 2: 计算对数概率 (在 op_compute_log_probs 函数内)
现在,我们需要计算实际出现的 token 的对数概率。这涉及到 “teacher-forcing” 的思想。
-
创建
labels: 我们需要为每个位置的logits提供一个“正确答案”(label)。这个 label 就是input_ids向左移动一位的结果。input_ids:[0, 1, 2, 3, 4, 5]labels=input_ids[1:]:[1, 2, 3, 4, 5]
-
对齐
logits和labels:logits[0, 0, :]应该对应label=1(a)logits[0, 1, :]应该对应label=2(b)logits[0, 2, :]应该对应label=3(c)logits[0, 3, :]应该对应label=4(d)logits[0, 4, :]应该对应label=5(<eos>)
-
计算 Log-Softmax 和 Gather:
对于每个位置t(从 0 到 4),我们执行以下操作:- a. 计算 Softmax 分母: 对
logits[0, t, :]应用 Softmax,得到一个概率分布。在roll的实现中,为了数值稳定性和分布式计算,它会先计算log_softmax。# 伪代码 log_probs_distribution_t = log_softmax(logits[0, t, :]) - b. 提取 (Gather) 目标概率: 从这个对数概率分布中,取出
labels[t]索引对应的那个值。# 伪代码 ref_log_prob_t = log_probs_distribution_t[labels[t]]
举例说明:
- t=0:
ref_log_prob_0 = log_softmax(logits[0, 0, :])[1](tokena的对数概率) - t=1:
ref_log_prob_1 = log_softmax(logits[0, 1, :])[2](tokenb的对数概率) - t=2:
ref_log_prob_2 = log_softmax(logits[0, 2, :])[3](tokenc的对数概率) - … 以此类推,直到
t=4。
- a. 计算 Softmax 分母: 对
步骤 3: 掩码 (Masking)
计算出的 ref_log_probs 序列是 [ref_log_prob_0, ref_log_prob_1, ..., ref_log_prob_4]。
然而,在 PPO 损失函数中,我们只关心 response 部分的 KL 散度。prompt 部分是给定的,我们不需要对它进行优化或惩罚。
response_mask: 这是一个与input_ids等长的掩码,标记了哪些是responsetoken。input_ids:[<bos>, a, b, c, d, <eos>]response_mask:[0, 0, 0, 1, 1, 1](假设 BOS 和 prompt 不属于 response)- 代码中通常使用
response_mask[:, 1:],因为 log_probs 序列比input_ids短一位。 response_mask[:, 1:]:[0, 0, 1, 1, 1]
在 ActorWorker.loss_func 中,当计算 KL 损失时,会用到这个掩码:
kl_loss = compute_approx_kl(
log_probs=log_probs, # 新策略的 logp
log_probs_base=ref_log_probs, # 参考模型的 logp
action_mask=response_mask # 只在 response 部分计算
)
compute_approx_kl 内部会做类似 (log_probs_base - log_probs) * action_mask 的操作,这样 prompt 部分的 KL 值就自动变为 0 了。
4. 最终结果
compute_log_probs 方法最终返回一个 DataProto 对象,其中包含一个名为 ref_log_probs 的张量。
对于我们的例子,这个张量的(未被掩码的)值是:[logP_ref(<bos> -> a), logP_ref(a -> b), logP_ref(b -> c), logP_ref(c -> d), logP_ref(d -> <eos>)]
这个张量将被用于 ActorWorker.loss_func 中,与新策略的 log_probs 一起计算 KL 散度,作为对 Actor 模型更新的正则化惩罚。
简而言之,ref_log_probs 就是用一个固定的、权威的“裁判”(Reference 模型)来评估 Actor 生成的每一个词的“合理性”,并给出一个分数。如果 Actor 生成了一个非常规、但在 Reward 模型看来分很高的词,这个词在 Reference 模型下的 ref_log_probs 就会很低,从而在 KL 损失中产生一个较大的惩罚,拉住 Actor,防止它“走火入魔”。
好的,我们完全换一个例子,用更清晰的词元和ID,来彻底讲清楚 log_probs_distribution_t 的含义和整个计算过程。
1. 目标与角色 (不变)
- 目标: 计算 Actor 生成的序列在 Reference 模型下的对数概率。
- Actor 模型: 动态变化的策略,生成了数据。
- Reference 模型: 固定不变的基础模型,作为“裁判”。
2. 假设与输入数据
这次,我们的词汇表和ID如下:
| Token | Token ID |
|---|---|
<start> |
10 |
the |
25 |
cat |
33 |
sat |
42 |
on |
55 |
mat |
68 |
<end> |
99 |
Actor 模型收到了一个 prompt,并生成了一个 response:
- Prompt:
the catprompt_ids:[25, 33]
- Response:
sat on the mat <end>response_ids:[42, 55, 25, 68, 99]
现在,我们将它们拼接,并加上起始符 <start>,形成完整的 input_ids,发送给 Reference 模型:
input_ids:<start> the cat sat on the mat <end>- Token IDs:
[10, 25, 33, 42, 55, 25, 68, 99] - 序列长度 (Sequence Length): 8
- Token IDs:
3. ref_log_probs 计算步骤详解
reference 集群的 Worker 接收到 input_ids [10, 25, ...] 后,开始计算。
步骤 1: Reference 模型前向传播
将 input_ids 输入到 Reference 模型。模型会为序列中的每个位置(从 0 到 7)输出一个 logit 向量。
logits = reference_model([10, 25, 33, 42, 55, 25, 68, 99])
这个 logits 张量的形状是 [1, 8, 7] (batch_size=1, seq_len=8, vocab_size=7)。
logits[0, 0, :]是一个 7 维向量,是模型在看到<start>(ID: 10) 后,对下一个词元的预测。logits[0, 1, :]是一个 7 维向量,是模型在看到<start> the(ID: 10, 25) 后,对下一个词元的预测。- …以此类推。
步骤 2: 计算对数概率 (在 op_compute_log_probs 函数内)
这一步是核心,我们来详细拆解 log_probs_distribution_t。
log_probs_distribution_t 的含义:
它是在时间步 t,模型预测的整个词汇表的对数概率分布。它是一个向量,维度等于词汇表大小。向量中的每一个值,代表模型认为下一个词元是该词元的对-数-概-率。
它是如何计算的?
通过对 logits 向量应用 log_softmax 函数。log_softmax(x_i) = x_i - log(sum(exp(x_j)))
现在,我们一步一步来看:
a. 创建 labels
为了计算我们关心的实际 token 的概率,我们需要一个“正确答案”序列。
input_ids:[10, 25, 33, 42, 55, 25, 68, 99]labels=input_ids[1:]:[25, 33, 42, 55, 25, 68, 99]
b. 逐个位置计算
-
时间步 t=0:
- 输入:
<start>(ID: 10) - 模型输出:
logits_0 = logits[0, 0, :]。这是一个 7 维向量,例如[-1.2, 0.5, -3.1, ...]。 - 计算分布:
log_probs_distribution_0 = log_softmax(logits_0)。
这仍然是一个 7 维向量,代表模型在看到<start>后,对下一个词元是<start>,the,cat,sat,on,mat,<end>的对数概率。
例如,它可能是[-3.5, -1.8, -5.4, -4.1, -3.9, -6.0, -8.2]。 - 提取目标概率: 我们实际的下一个词元是
the(ID: 25)。labels[0]就是25。我们需要从log_probs_distribution_0中找到ID为25的词元对应的那个值。ref_log_prob_0 = log_probs_distribution_0[index_for_id_25]
(假设 ‘the’ 是词汇表第2个,索引为1)ref_log_prob_0 = log_probs_distribution_0[1] = -1.8。
含义: 在看到<start>后,Reference 模型认为下一个词是the的对数概率是 -1.8。
- 输入:
-
时间步 t=1:
- 输入:
<start> the - 模型输出:
logits_1 = logits[0, 1, :]。 - 计算分布:
log_probs_distribution_1 = log_softmax(logits_1)。 - 提取目标概率: 实际的下一个词元是
cat(ID: 33)。labels[1]是33。ref_log_prob_1 = log_probs_distribution_1[index_for_id_33]
含义: 在看到<start> the后,Reference 模型认为下一个词是cat的对数概率。
- 输入:
-
时间步 t=2:
- 输入:
<start> the cat - 模型输出:
logits_2 = logits[0, 2, :]。 - 计算分布:
log_probs_distribution_2 = log_softmax(logits_2)。 - 提取目标概率: 实际的下一个词元是
sat(ID: 42)。labels[2]是42。ref_log_prob_2 = log_probs_distribution_2[index_for_id_42]
含义: 在看到<start> the cat后,Reference 模型认为下一个词是sat的对数概率。
- 输入:
这个过程会一直持续到序列的末尾。
步骤 3: 掩码 (Masking)
我们会得到一个完整的对数概率序列 ref_log_probs,它的长度是 7 (比 input_ids 短 1)。
ref_log_probs = [ref_log_prob_0, ref_log_prob_1, ..., ref_log_prob_6]
response_mask 用于标记哪些是 response 部分。
input_ids:[<start>, the, cat, sat, on, the, mat, <end>]response_mask[:, 1:]:[0, 0, 1, 1, 1, 1, 1][0]对应the(prompt的一部分)[0]对应cat(prompt的一部分)[1]对应sat(response的一部分)- …
当计算 KL 损失时,只有 response_mask 为 1 的位置才会被计算在内。
4. 最终结果
compute_log_probs 方法返回的 ref_log_probs 张量,其内容就是我们一步步计算出来的序列:[logP_ref(the|start), logP_ref(cat|...the), logP_ref(sat|...cat), ...]
这个张量精确地记录了,对于 Actor 生成的这个序列,固定的 Reference 模型 认为它有多“合理”。这个“合理性”得分将作为一种“引力”,防止 Actor 在追逐奖励的路上跑得太偏,从而保证了生成文本的基本质量和连贯性。
更多推荐


所有评论(0)