在大模型基于人类反馈的强化学习(RLHF)微调中,PPO 是最经典的策略优化方法,但实际落地中我们常遇到 “更新幅度过大”“模型偏离原始能力” 等问题,而GRPO(Group Relative Policy Optimization) 作为 PPO 的优化变体,通过更精细的正则化策略解决了这些痛点。

本文会从基础版 GRPO 出发,逐步讲解 “带裁剪的 GRPO”“带 KL 散度惩罚的 GRPO”,结合代码逐行解析 + 数值实例,让你彻底搞懂 GRPO 的核心逻辑。

一、GRPO 核心思想:先搞懂 3 个基础概念

在看代码前,先明确 GRPO 的核心设计思路,这是理解后续所有变体的关键:

  1. 策略模型(Policy Model):正在微调的模型(比如加了 LoRA 的 BabyLlama),目标是让它生成更符合人类偏好的文本;
  2. 参考模型(Reference Model):原始的预训练模型(深拷贝得到,参数固定),作为 “基准”,避免策略模型偏离基础能力;
  3. 优势值(Advantage):衡量每个 token 的 “收益”—— 正数表示该 token 符合优化目标(比如人类喜欢),负数表示不符合(人类不喜欢)。

GRPO 的核心逻辑:对比策略模型和参考模型生成 token 的概率,结合优势值衡量 token 的 “正向 / 负向贡献”,通过损失函数让策略模型 “放大正向贡献的 token 生成概率,缩小负向贡献的 token 生成概率”,同时通过正则化(裁剪 / KL 散度)避免模型更新失控。

二、基础版 GRPO 损失函数(grpo_loss)

先从最基础的 GRPO 开始,这是所有变体的核心骨架。

2.1 完整代码

import torch
import torch.nn.functional as F

# 前置依赖函数
def prepare_inputs(prompt, completion):
    """预处理输入:拼接prompt+completion,生成input_ids、attention_mask、completion_mask"""
    from transformers import AutoTokenizer
    tokenizer = AutoTokenizer.from_pretrained("babylm/babyllama-100m-2024")
    tokenizer.padding_side = "left"
    tokenizer.truncation_side = "left"
    
    # 分别分词
    prompt_tokens = tokenizer(prompt, return_tensors="pt")
    completion_tokens = tokenizer(completion, return_tensors="pt")
    
    # 拼接input_ids和attention_mask(沿序列维度dim=1)
    input_ids = torch.cat([prompt_tokens["input_ids"], completion_tokens["input_ids"]], dim=1)
    attention_mask = torch.cat([prompt_tokens["attention_mask"], completion_tokens["attention_mask"]], dim=1)
    
    # 生成completion_mask:仅completion部分计算损失(1表示计算,0表示不计算)
    prompt_length = prompt_tokens["input_ids"].shape[1]
    total_length = input_ids.shape[1]
    completion_mask = torch.zeros(total_length, dtype=torch.float32)
    completion_mask[prompt_length:] = 1.0
    
    return input_ids, attention_mask, completion_mask

def compute_log_probs(model, input_ids, attention_mask):
    """计算每个token的对数概率:模型输出logits→log_softmax→提取对应token的对数概率"""
    # 前向传播
    outputs = model(input_ids, attention_mask=attention_mask)
    # logits转对数概率(dim=-1表示在词汇表维度归一化)
    log_probs = F.log_softmax(outputs.logits, dim=-1)
    # 提取每个input_ids对应的对数概率,匹配输入形状
    return log_probs.gather(dim=-1, index=input_ids.unsqueeze(-1)).squeeze(-1)

# 核心:基础版GRPO损失
def grpo_loss(model, ref_model, prompt, completion, advantage):
    """
    基础GRPO损失计算
    Args:
        model: 策略模型(微调中的模型)
        ref_model: 参考模型(原始预训练模型)
        prompt: 提示文本
        completion: 模型需要生成的目标文本
        advantage: 优势值(形状与token序列一致)
    Returns:
        仅completion部分的平均GRPO损失
    """
    # 1. 预处理输入:得到拼接后的token、掩码、损失计算掩码
    input_ids, attention_mask, completion_mask = prepare_inputs(prompt, completion)
 
    # 2. 计算策略模型的token对数概率
    token_log_probs = compute_log_probs(model, input_ids, attention_mask)
 
    # 3. 计算参考模型的token对数概率(固定参考模型,不计算梯度)
    with torch.no_grad():
        ref_token_log_probs = compute_log_probs(ref_model, input_ids, attention_mask)
 
    # 4. 计算概率比率:策略模型概率 / 参考模型概率(exp消除对数)
    # ratio>1 → 策略模型更倾向生成该token;ratio<1 → 更不倾向
    ratio = torch.exp(token_log_probs - ref_token_log_probs)
 
    # 5. 结合优势值计算策略损失:优势值越大,该token的正向贡献越高
    policy_loss = ratio * advantage
 
    # 6. 反转损失符号:优化器最小化损失 → 等价于最大化奖励
    # (因为优化器只能minimize,而我们需要maximize reward)
    per_token_loss = -policy_loss
 
    # 7. 仅计算completion部分的平均损失(prompt仅作为输入,不参与损失)
    # 求和后除以completion的token数,避免长度影响损失大小
    loss = (per_token_loss * completion_mask).sum() / completion_mask.sum()
    return loss

2.2 核心步骤解读

1.  预处理,准备输入数据

def prepare_inputs(prompt, completion):
    """预处理输入:拼接prompt+completion,生成input_ids、attention_mask、completion_mask"""
    from transformers import AutoTokenizer
    tokenizer = AutoTokenizer.from_pretrained("babylm/babyllama-100m-2024")
    tokenizer.padding_side = "left"
    tokenizer.truncation_side = "left"
    
    # 分别分词
    prompt_tokens = tokenizer(prompt, return_tensors="pt")
    completion_tokens = tokenizer(completion, return_tensors="pt")
    
    # 拼接input_ids和attention_mask(沿序列维度dim=1)
    input_ids = torch.cat([
                    prompt_tokens["input_ids"], 
                    completion_tokens["input_ids"]], dim=1)
    attention_mask = torch.cat([
                        prompt_tokens["attention_mask"], 
                        completion_tokens["attention_mask"]], dim=1)
    
    # 生成completion_mask:仅completion部分计算损失(1表示计算,0表示不计算)
    prompt_length = prompt_tokens["input_ids"].shape[1]
    total_length = input_ids.shape[1]
    completion_mask = torch.zeros(total_length, dtype=torch.float32)
    completion_mask[prompt_length:] = 1.0
    
    return input_ids, attention_mask, completion_mask

(1)输入数据:

  • prompt:用户输入或问题部分(例如 "The capital of France is ")
  • completion:期望的模型回复或答案部分(例如 "Paris.")

(2)分词

  • 分别对 prompt 和 completion 进行独立分
  • return_tensors="pt":返回 PyTorch 张量,而不是 Python 列表
  • 结果是两个字典,每个包含:
    • "input_ids": token ID 序列,形状 [1, seq_len]
    • "attention_mask": 注意力掩码(1 表示真实 token,0 表示 padding),这里因为没有 padding,全是 1 (因为我们处理的是单个样本(single example),而且这个样本的长度远远小于模型支持的最大序列长度,所以不需要 padding(填充))

(3)拼接

  • 拼接输入 input_ids:
    • 将 prompt 和 completion 的 token ID 在序列维度(dim=1)上拼接
    • 得到完整的输入序列:[prompt_tokens] + [completion_tokens]
    • 形状变为 [1, prompt_length + completion_length]
  • 同样拼接注意力掩码 attention_mask(注意力掩码的作用是告诉模型 “哪些 token 是真实文本(值为 1),哪些是填充(值为 0)”):
    • 因为两个部分都没有 padding,所以结果是一个全为 1 的掩码,长度等于总序列长。
    • 这个掩码告诉 Transformer 哪些位置是有效的 token。

(4)计算三个重要的长度:

  • prompt_length:prompt 部分的 token 数
  • completion_length:completion 部分的 token 数
  • total_length:完整序列的总 token 数(用于后续创建掩码)

(5)补全掩码

  • 创建一个补全掩码(completion mask),长度为 total_length
  • 初始化全为 0
  • 从第 prompt_length 个位置(包含)到末尾设为 1.0
  • 含义:
    • 0:属于 prompt 部分 → 在计算损失时忽略(模型不需要预测已知输入)
    • 1:属于 completion 部分 → 在计算损失时使用(模型需要学会预测这些 token)

(6)返回三个张量:

  • input_ids:[1, total_length] → 完整的输入 token ID
  • attention_mask:[1, total_length] → 全 1 的注意力掩码
  • completion_mask:[total_length] → 1D 张量,标记哪些位置需要计算损失

示例:

(1)输入:

  • prompt: "The capital of France is"
  • completion: " Paris."

(2)分词过程:( 假设 tokenizer 编码为:The=1000, capital=2000, of=3000, France=4000, is=5000, Paris=6000, .=7)

  • prompt → ["The", " capital", " of", " France", " is"] (注意:LLaMA tokenizer 会把空格合并到下一个词,所以常见是 "▁The", "▁capital" 等,但这里简化)→ {"input_ids": tensor([[1000, 2000, 3000, 4000, 5000]]), "attention_mask": tensor([[1, 1, 1, 1, 1]])}
  • completion → [" Paris", "."] →  {"input_ids": tensor([[6000, 7]]), "attention_mask": tensor([[1, 1]])}

(3)拼接:

  • input_ids:形状 [1, 7]  → tensor([[1000, 2000, 3000, 4000, 5000, 6000,    7]])

  • attention_mask:形状 [1, 7] →  tensor([[1, 1, 1, 1, 1, 1, 1]])     

(4)计算长度:

  • prompt_length:7(prompt 部分的 token 数)
  • completion_length:2(completion 部分的 token 数)
  • total_length:9(完整序列的总 token 数(用于后续创建掩码))

(5)补全掩码(completion_mask):形状 [7](1D) →  tensor([0., 0., 0., 0., 0., 1., 1.])      

(6)返回(张量):

  • input_ids:tensor([[1000, 2000, 3000, 4000, 5000, 6000,    7]])
  • attention_mask: tensor([[1, 1, 1, 1, 1, 1, 1]]) 
  • completion_mask: tensor([0., 0., 0., 0., 0., 1., 1.]) 

2. 对数概率计算

def compute_log_probs(model, input_ids, attention_mask):
    """计算每个token的对数概率:模型输出logits→log_softmax→提取对应token的对数概率"""
    # 前向传播
    outputs = model(input_ids, attention_mask=attention_mask)
    # logits转对数概率(dim=-1表示在词汇表维度归一化)
    log_probs = F.log_softmax(outputs.logits, dim=-1)
    # 提取每个input_ids对应的对数概率,匹配输入形状
    return log_probs.gather(dim=-1, index=input_ids.unsqueeze(-1)).squeeze(-1)

(1)输入数据:

  • model:已经包装好的 PEFT 模型(或原始 Transformers 模型),支持 LoRA 等
  • input_ids:形状为 [batch_size, seq_len] 的 token ID 张量(通常 batch_size=1)
  • attention_mask:形状相同的注意力掩码(1 表示有效 token,0 表示 padding)

(2)调用模型的前向传播,得到 outputs:

  • 对于 AutoModelForCausalLM,outputs 是一个命名元组,主要包含:
    • outputs.logits:形状 [batch_size, seq_len, vocab_size] 的未归一化 logits(每个位置对整个词汇表的得分)

(3)对 logits 应用 log_softmax

  • 先在词汇表维度(dim=-1)上做 softmax,得到概率分布(每行和为1)
  • 再取自然对数,得到对数概率
  • 结果 log_probs 形状仍为 [batch_size, seq_len, vocab_size]
  • 使用 log_softmax 而不是单独 softmax + log 的原因:数值更稳定(避免下溢)

(4)从每个位置的概率分布中提取“正确”token(即输入序列中实际出现的下一个 token)的对数概率:

  • input_ids.unsqueeze(-1):
    • 将 input_ids 从 [batch_size, seq_len] 扩展为 [batch_size, seq_len, 1],增加一个维度,便于作为索引使用。
  • log_probs.gather(dim=-1, index=...):
    • 在词汇表维度(dim=-1)上,根据 index(即正确的 token ID)提取对应的 log_prob 值。
    • 结果形状:[batch_size, seq_len, 1]
  • .squeeze(-1):
    • 移除最后一个维度(大小为1),得到最终形状 [batch_size, seq_len]

(5)返回的张量:

  • 每个位置 (i, t) 的值是:模型在看到序列前 t 个 token 后,预测第 t+1 个 token(即 input_ids[i, t])的对数概率。
  • 注意:由于是因果模型,第1个位置(t=0)的 log_prob 对应模型对序列第1个 token 的预测(在空上下文下的概率),通常会被忽略或掩码掉。

示例:

(1)输入数据

  • model:已经包装好的 PEFT 模型(或原始 Transformers 模型),支持 LoRA 等
  • input_ids:上一步输出的input_ids
  • attention_mask:上一步输出的attention_mask

(2)调用前向传播模型,输出outputs

(3)对 logits 应用 log_softmax,并从每个位置的概率分布中提取“正确”token(即输入序列中实际出现的下一个 token)的对数概率:假设输入序列是:"The capital of France is Paris."

  • input_ids:["The", "capital", "of", "France", "is", "Paris", "."](假设长度为7)
  • 模型会输出5个位置的 logits:
    • 位置0:基于 BOS(或无上下文)预测 "The" 的概率
    • 位置1:基于 "The" 预测 "capital" 的概率
    • 位置2:基于 "The capital" 预测 "of" 的概率
    • 以此类推……
  • 具体实现如下:
位置 输入上下文(到当前位置之前) 需要预测的正确 token 该 token 的 log_prob(示例值)
0 (空,或 BOS) The -10.5(模型对第一个词预测不准)
1 The capital -2.1
2 The capital of -0.8
3 The capital of France -1.2
4 The capital of France is -0.5
5 The capital of France is Paris -0.3 ← 我们关心的部分
6 The capital of France is Paris . -0.1 ← 我们关心的部分
  • 实际过程拆解:
    • input_ids:tensor([[1000, 2000, 3000, 4000, 5000, 6000,    7]])
    • 模型前向传播后,输出的 logits(未归一化分数)形状为 [1, 5, 7],经过 log_softmax 后得到 log_probs = tensor([[
          [-99.0, -10.5, -13.0, -14.4, -15.0, -16.5, -17.0],   # 位置0:预测 "The" (正确ID=1) 的 log_prob = -10.5
          [-4.0, -6.0, -2.1, -2.0, -3.3, -4.2, -7.1, -8.0],    # 位置1:预测 "capital"  (正确ID=2) 的 log_prob = -2.1
          [-5.0, -9.0, -4.0, -0.8, -3.1, -5.5, -6.7, - 7.9 ]     # 位置2:预测 "of"  (正确ID=3) 的 log_prob = -0.8
      …………]])

    • input_ids.unsqueeze(-1) = tensor([[[1000], [2000], [3000], [4000], [5000], [6000], [7]]])
      • 原来形状 [1, 7] → 变成 [1, 7, 1]
      • 目的:让它能作为索引广播到词汇表维度
    • log_probs.gather(dim=-1, index=input_ids.unsqueeze(-1))

      • gather 的作用是:沿着指定的维度(这里是 dim=-1,即词汇表维度),按照 index 中的 ID 取出对应的值

      • 具体过程如下:

        • 在位置 0:index = 1 → 从第1行第0个位置的词汇表分布中取出下标为1的值 → -10.5
        • 在位置 1:index = 2 → 取出下标为2的值 → -2.1
        • 在位置 2:index = 3 → 取出下标为3的值 → -0.8
        • ……以此类推
        • 结果:gather_result = tensor([[[-10.5], [-2.1], [-0.8], [-1.2], [-0.5], [-0.3], [-0.1]]])   # 形状 [1, 7, 1]
    • .squeeze(-1):去除大小为1的最后一个维度

(4)返回的 log_probs 将是这7个预测的对数值: tensor([-10.5, -2.1, -0.8, -1.2, -0.5, -0.3, -0.1])

策略模型对 “lazy” 的对数概率是 - 0.2,参考模型是 - 0.5 → ratio=exp(-0.2 - (-0.5))=exp(0.3)≈1.35

3. 计算GRPO损失

def grpo_loss(model, ref_model, prompt, completion, advantage):
    """
    基础GRPO损失计算
    Args:
        model: 策略模型(微调中的模型)
        ref_model: 参考模型(原始预训练模型)
        prompt: 提示文本
        completion: 模型需要生成的目标文本
        advantage: 优势值(形状与token序列一致)
    Returns:
        仅completion部分的平均GRPO损失
    """
    # 1. 预处理输入:得到拼接后的token、掩码、损失计算掩码
    input_ids, attention_mask, completion_mask = prepare_inputs(prompt, completion)
 
    # 2. 计算策略模型的token对数概率
    token_log_probs = compute_log_probs(model, input_ids, attention_mask)
 
    # 3. 计算参考模型的token对数概率(固定参考模型,不计算梯度)
    with torch.no_grad():
        ref_token_log_probs = compute_log_probs(ref_model, input_ids, attention_mask)
 
    # 4. 计算概率比率:策略模型概率 / 参考模型概率(exp消除对数)
    # ratio>1 → 策略模型更倾向生成该token;ratio<1 → 更不倾向
    ratio = torch.exp(token_log_probs - ref_token_log_probs)
 
    # 5. 结合优势值计算策略损失:优势值越大,该token的正向贡献越高
    policy_loss = ratio * advantage
 
    # 6. 反转损失符号:优化器最小化损失 → 等价于最大化奖励
    # (因为优化器只能minimize,而我们需要maximize reward)
    per_token_loss = -policy_loss
 
    # 7. 仅计算completion部分的平均损失(prompt仅作为输入,不参与损失)
    # 求和后除以completion的token数,避免长度影响损失大小
    loss = (per_token_loss * completion_mask).sum() / completion_mask.sum()
    return loss

(1)预处理输入数据:调用之前定义的prepare_inputs函数,将prompt+completion转换为模型可直接处理的张量:

  • input_ids:拼接后的完整 token 序列(prompt + completion);
  • attention_mask:拼接后的注意力掩码(标记真实 token / 填充 token);
  • completion_mask:0/1 掩码(0=prompt 部分,1=completion 部分),用于后续仅计算 completion 的损失;

(2)计算策略模型的 token 对数概率:

  • 调用compute_log_probs函数,传入策略模型和预处理后的张量
  • token_log_probs:策略模型对每个输入 token 的对数概率(形状[1, 序列长度]),反映 “策略模型认为该 token 应该被生成的置信度”

(3)计算参考模型的 token 对数概率:

  • 无梯度上下文管理器:
    • 参考模型是 “固定的基准”,不需要计算梯度(梯度仅用于更新策略模型的参数);
    • 该操作能节省显存、加快计算,同时避免参考模型的参数被意外修改。
  • 调用compute_log_probs函数,传入参考模型
  • ref_token_log_probs:参考模型对每个 token 的对数概率(形状与token_log_probs一致),作为 “模型原始生成习惯” 的基准

(4)计算策略模型 vs 参考模型的概率比率(GRPO 核心):

  • token_log_probs - ref_token_log_probs:对数概率的差值(对数域中,这个差值等价于 “策略模型概率 / 参考模型概率”)
  • torch.exp(...):将对数域的差值转换为原始概率的比率(因为数学上exp(log(a) - log(b)) = a/b
  • ratio的含义:
    • ratio > 1:策略模型生成该 token 的概率 > 参考模型 → 模型更倾向于生成这个 token
    • ratio < 1:策略模型生成该 token 的概率 < 参考模型 → 模型更不倾向于生成这个 token

(5)结合优势值计算策略损失:

  • advantage(优势值):核心是告诉模型 “哪些 token 对优化目标有正向贡献(正数值)、哪些有负向贡献(负数值)”
  • ratio * advantage:用概率比率缩放优势值,核心逻辑:
    • 若 advantage 为正(正向贡献):ratio 越大,policy_loss 越大 → 模型越应该强化生成这个 token
    • 若 advantage 为负(负向贡献):ratio 越大,policy_loss 越小 → 模型越应该弱化生成这个 token

(6)反转损失符号(适配优化器逻辑):

  • 核心原因:我们的目标是最大化奖励(reward),但 PyTorch 的优化器(如 AdamW)只能最小化损失(loss)
  • 取负后:
    • 若 policy_loss>0(正向贡献):per_token_loss 为负 → 优化器最小化损失时,会让这个值更负 → 对应 policy_loss 更大 → 奖励更高;
    • 若 policy_loss<0(负向贡献):per_token_loss 为正 → 优化器最小化损失时,会让这个值更小 → 对应 policy_loss 更小 → 奖励更高。

(7)计算最终的 GRPO 平均损失(仅针对 completion 部分):

  • 第一步:per_token_loss * completion_mask → 用掩码过滤损失,仅保留 completion 部分的 token 损失(prompt 部分乘以 0,不参与计算);
  • 第二步:.sum() → 将 completion 部分的所有 token 损失求和;
  • 第三步:/ completion_mask.sum() → 除以 completion 部分的 token 数量(即掩码中 1 的个数),得到平均损失
  • 这样做的目的:只优化模型 “生成目标补全文本(completion)” 的能力,prompt 仅作为输入,不参与损失计算。

(8)返回最终损失值:函数返回计算好的 GRPO 平均损失,供优化器使用 —— 优化器会通过反向传播更新策略模型的参数(如 LoRA 参数),最小化该损失,从而优化模型的生成行为。

GRPO 损失核心思想

可以把这个过程简化理解为:

  1. 参考模型:代表 “模型原本会怎么生成”;
  2. 策略模型:代表 “我们想让模型怎么生成”;
  3. 优势值:告诉模型 “哪些生成是好的、哪些是坏的”;
  4. 比率:告诉模型 “新生成习惯比旧习惯强 / 弱多少”;
  5. GRPO 损失:让模型 “放大好的生成习惯,缩小坏的生成习惯”。

2.3 存在的问题

基础 GRPO 没有限制ratio的范围,如果ratio过大(比如 10),会导致模型单次更新幅度过大,训练震荡甚至崩溃。

三、带Clip(裁剪) + KL 散度的 GRPO(grpo_loss_with_kl)

带裁剪的 GRPO 解决了 “更新幅度” 问题,但仍可能让策略模型完全偏离参考模型(比如微调后丢失基础语言能力)。新增KL 散度惩罚:惩罚策略模型与参考模型概率分布的差异,平衡 “优化目标” 和 “保留原始能力”。

3.1 完整代码

def grpo_loss_with_kl(model, ref_model, prompt, completion, advantage, epsilon = 0.2, beta = 0.1):
    """
    带裁剪+KL散度的GRPO损失:既限制更新幅度,又避免偏离参考模型
    Args:
        epsilon: 裁剪阈值(默认0.2,ratio限制在0.8~1.2)
        beta: KL散度惩罚权重(默认0.1,值越大惩罚越重)
    """
    # 1. 预处理输入(同前)
    input_ids, attention_mask, completion_mask = prepare_inputs(prompt, completion)
 
    # 2. 计算策略/参考模型的对数概率(同前)
    token_log_probs = compute_log_probs(model, input_ids, attention_mask)
    with torch.no_grad():
        ref_token_log_probs = compute_log_probs(ref_model, input_ids, attention_mask)
 
    # 3. 计算概率比率(同前)
    ratio = torch.exp(token_log_probs - ref_token_log_probs)
 
    # 4. 裁剪逻辑
    unclipped = ratio * advantage
    clipped = torch.clamp(ratio, 1-epsilon, 1+epsilon) * advantage
    policy_loss = torch.min(unclipped, clipped)
 
    # 5. KL散度惩罚核心:衡量策略模型与参考模型的差异
    # delta = 策略对数概率 - 参考对数概率 → delta>0更自信,delta<0更保守
    delta = token_log_probs - ref_token_log_probs
    # 逐token KL散度计算(公式:exp(-delta) + delta - 1)
    # delta≠0时,KL>0(产生惩罚);delta偏离0越远,惩罚越重
    per_token_kl = torch.exp(-delta) - (-delta) -1
    
    # 6. 结合策略损失和KL惩罚:净收益 = 策略损失 - beta*KL惩罚
    # 反转符号:优化器最小化损失 → 最大化净收益
    per_token_loss = -(policy_loss - beta * per_token_kl)
 
    # 7. 仅计算completion部分的平均损失(同前)
    loss = (per_token_loss * completion_mask).sum() / completion_mask.sum()
    return loss

3.2 核心步骤解读

1. 比率裁剪

将 ratio 限制在 [1-epsilon, 1+epsilon](默认 epsilon=0.2,即 0.8~1.2),避免极端值。

(1)代码解读

  • 计算裁剪后的策略损失
    • torch.clamp(ratio, 1-epsilon, 1+epsilon):对ratio进行范围裁剪,强制将比率限制在[1-epsilon, 1+epsilon]区间内:
      • epsilon=0.2时,比率被限制在[0.8, 1.2]
      • ratio>1.2,会被截断为 1.2;若ratio<0.8,会被截断为 0.8;
    • 裁剪后再乘以优势值,得到clipped(裁剪后的策略损失);
    • 裁剪的目的:避免ratio过大(如策略模型概率是参考模型的 10 倍)或过小(如 1/10),导致模型单次更新幅度过大,引发训练震荡、过拟合甚至崩溃。
  • 取无裁剪 / 裁剪损失的最小值:
    • torch.min(unclipped, clipped):对每个 token 的unclippedclipped取最小值,作为最终的policy_loss
    • 核心思想(保守更新):
      • 如果unclipped(无裁剪损失)更小 → 说明原始比率未超出范围,直接用原始损失;
      • 如果clipped(裁剪损失)更小 → 说明原始比率超出了安全范围,用裁剪后的损失,避免模型 “激进更新”;
    • 简单举例:若unclipped=5clipped=2.4(因 ratio 被裁剪),则取 2.4;若unclipped=1clipped=1,则取 1。

(2)裁剪逻辑示例:假设 ratio=1.35(超过 1.2)

  • unclipped=1.35*1.5=2.025
  • clipped=torch.clamp(1.35, 0.8, 1.2)*1.5=1.2*1.5=1.8
  • policy_loss=min(2.025, 1.8)=1.8

最终损失会基于 1.8 计算,而非 2.025,避免模型 “激进更新”。

(3)裁剪的核心作用:给模型加 “刹车”:允许模型微调(ratio 在 0.8~1.2 内),但禁止单次更新太夸张,保证训练稳定性。

2.  KL 散度数值

(1)代码解读

  • 计算对数概率差值delta
    • delta = 策略模型对数概率 - 参考模型对数概率,等价于log(策略模型概率 / 参考模型概率)
    • 注释含义:delta>0 → 策略模型对该 token 的生成概率 > 参考模型 → 策略模型更 “自信”;delta<0 → 策略模型更 “不自信”
  • 计算逐 token 的 KL 散度惩罚项(KL 散度衡量两个概率分布的差异):
    • 核心数学逻辑:
      • delta=0(策略模型概率 = 参考模型概率):per_token_kl = exp(0) - 0 -1 = 1-0-1=0 → 无惩罚
      • delta≠0(策略与参考概率差异越大):per_token_kl的值越大 → 惩罚越重
    • 作用:KL 散度越大,说明策略模型和参考模型的生成习惯差异越大,通过惩罚项避免模型完全偏离原始能力(比如微调后丢失基础语言能力)
  • 结合策略损失和 KL 惩罚计算逐 token 损失:
    • 核心矛盾:policy_loss(策略损失)越大,说明 token 的正向收益越高(越好)per_token_kl(KL 散度)越大,说明模型偏离参考模型越远(越差)
    • policy_loss - beta * per_token_kl:先从 “正向收益” 中扣除 “KL 惩罚”(beta控制惩罚强度),得到 “净收益”(净收益越大越好)
    • 加负号-():因为优化器只能 “最小化损失”,而我们需要 “最大化净收益”,所以取负后,优化器最小化per_token_loss等价于最大化 “净收益”
    • 简单举例:若policy_loss=2per_token_kl=1beta=0.1,则净收益 = 2-0.1×1=1.9,per_token_loss=-1.9(优化器会让这个值更负,即净收益更大)

(2)示例:

先明确 KL 散度公式:per_token_kl = exp(-δ) + δ - 1(δ=delta),δ≠0 时 KL 必为正,偏离越远惩罚越重

示例 1:δ>0(策略模型更自信)

  • δ=0.5(轻微自信):per_token_kl = exp(-0.5) + 0.5 -1 ≈0.6065+0.5-1=0.1065(轻惩罚)
  • δ=1.0(明显自信):per_token_kl = exp(-1.0)+1.0-1≈0.3679+1-1=0.3679(中惩罚)

示例 2:δ<0(策略模型更保守)

  • δ=-0.5(轻微保守):per_token_kl = exp(0.5) -0.5 -1≈1.6487-0.5-1=0.1487(轻惩罚)
  • δ=-1.0(明显保守):per_token_kl = exp(1.0)-1.0-1≈2.7183-1-1=0.7183(重惩罚)

示例计算(结合 beta=0.1)

假设 policy_loss=1.8,δ=0.5(KL=0.1065):

  • 净收益 = 1.8 - 0.1*0.1065 = 1.78935
  • per_token_loss = -1.78935

最终损失会基于 1.78935 计算,既保留了正向收益,又通过 KL 惩罚避免模型过度自信。

(3)KL 散度的核心作用

给模型加 “安全带”:允许模型向优化目标微调,但禁止完全抛弃参考模型的基础能力(比如微调后只会生成 “Paris.”,却不认识 “The capital of France is”)。

最后

GRPO 作为 PPO 的优化变体,通过 “裁剪 + KL 散度” 两大正则化手段,解决了大模型 RLHF 微调中 “更新失控”“偏离原始能力” 的核心问题:

  1. 基础 GRPO:核心是对比策略 / 参考模型的概率,结合优势值优化生成偏好;
  2. 裁剪 GRPO:限制 ratio 范围,避免单次更新幅度过大;
  3. KL 散度 GRPO:惩罚模型偏离参考模型,保留基础语言能力;

实际落地中,带裁剪 + KL 散度的 GRPO 是最稳定的选择,也是工业界微调大模型的主流方案之一。

如果本文对你理解 GRPO 有帮助,欢迎点赞 + 收藏,后续会继续分享 RLHF 微调的实战技巧~

Logo

有“AI”的1024 = 2048,欢迎大家加入2048 AI社区

更多推荐