大模型训练实践之SFT、PPO、DPO、GRPO详细介绍(附代码)
前言
本文主要介绍大模型训练(优化)的几种高效策略的理论原理与工程实现代码:监督微调(SFT)、近端策略优化(PPO)、直接偏好优化(DPO)以及组相对策略优化(GRPO)。
一、SFT
监督微调(SFT)让预训练模型学会理解人类指令并生成符合预期的回应,实现“基础指令对齐”。是所有后续对齐技术的基础。
1.1 理论介绍
1.1.1 核心思想
模仿学习。为模型提供高质量的“问题-答案”配对数据,让它学习如何生成正确的回答。
1.1.2 如何工作
-
你有一个已经预训练好的基础模型(比如学会了海量互联网文本,有通用知识)。
-
你准备一个高质量的数据集,里面都是精心编写的示例。例如:
-
用户: “写一首关于春天的诗。”
-
助手: “春风吹绿柳梢头,细雨沾红花影柔...”
-
-
用这个数据集去继续训练(微调)基础模型。训练的目标很简单:让模型在看到输入(用户问题)时,能最大概率地输出我们期望的答案。
1.1.3 目的
-
让模型掌握特定的任务格式(如对话、写作、编程)。
-
激发模型在预训练阶段已经学会但未展现的能力。
-
让模型的输出更可控、更符合指令。
1.1.4 优点
-
简单、直接、稳定。
1.1.5 缺点
-
非常依赖于高质量的人工标注数据,成本高。
-
只能让模型模仿已有的答案,无法学会在复杂情况下做出“更好”的权衡(比如更有帮助、更无害)。
-
存在“模仿漂移”:模型可能会模仿数据中一些不必要的风格或错误。
1.2 工程实现代码
1.2.1 数据输入
(1)数据格式:“指令-回应”配对数据(JSON格式)。示例如下:
[
{"instruction": "写一封请假邮件", "input": "", "output": "尊敬的领导:因身体不适,需请假1天,望批准。"}
]
(2)数据规模:10k~100k条(中小模型),100k~1M条(大模型)。
(3)数据来源:公开数据集(Alpaca、ShareGPT)、人工标注、大模型生成(数据蒸馏)。
1.2.2 训练输出
(1)具备基础指令遵循能力的模型(如Llama-3-8B-SFT、Qwen-1.5-7B-SFT);
(2)模型 checkpoint 文件(含权重、配置、tokenizer)。
1.2.3 实现代码
(1)数据准备与清洗
import json
# 加载原始数据
with open("sft_data.json", "r") as f:
data = json.load(f)
# 清洗:过滤长度>500token、含有害词的样本
clean_data = []
for item in data:
if len(item["output"]) < 500 and "有害词" not in item["output"]:
clean_data.append(item)
# 保存清洗后数据
with open("clean_sft_data.json", "w") as f:
json.dump(clean_data, f, indent=2)
(2)模型与tokenizer加载
from transformers import AutoModelForCausalLM, AutoTokenizer
model_name = "meta-llama/Llama-3-8B" # 预训练模型
model = AutoModelForCausalLM.from_pretrained(
model_name,
device_map="auto", # 自动分配GPU/CPU
torch_dtype="auto" # 自动选择数据类型(如BF16)
)
tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.pad_token = tokenizer.eos_token # 设置pad token(必填)
(3)数据集格式化(转为模型输入格式)
def format_prompt(item):
return f"[INST] {item['instruction']} [/INST] {item['output']}"
# 转换为tokenized数据集
from datasets import Dataset
dataset = Dataset.from_list(clean_data)
def tokenize_function(examples):
prompts = [format_prompt(item) for item in examples]
return tokenizer(prompts, truncation=True, max_length=512, padding="max_length")
tokenized_dataset = dataset.map(tokenize_function, batched=True)
(4)训练配置与启动
from transformers import TrainingArguments, Trainer
training_args = TrainingArguments(
output_dir="./llama3-8b-sft", # 模型保存路径
per_device_train_batch_size=8, # 单卡batch size(12GB显存建议8)
learning_rate=3e-5, # 学习率(中小模型常用3e-5)
num_train_epochs=3, # 训练轮数(避免过拟合)
logging_steps=100, # 每100步打印日志
save_strategy="epoch", # 每轮保存一次模型
fp16=True, # 混合精度训练(加速且省显存)
report_to="none" # 不使用wandb等监控工具
)
trainer = Trainer(
model=model,
args=training_args,
train_dataset=tokenized_dataset
)
trainer.train() # 启动训练
1.2.4 注意事项
(1)数据质量优先:10k高质量样本(无错误、逻辑清晰)优于100k低质样本;
(2)避免过拟合:验证集损失上升时停止训练(设置load_best_model_at_end=True);
(3)学习率选择:模型越大,学习率越小(如70B模型用1e-5,7B模型用3e-5);
(4)保留预训练能力:冻结底层20%权重(可选),防止模型遗忘通用知识。
二、PPO
近端策略优化PPO通过强化学习(RLHF)的思想,让模型在“奖励模型(RM)打分”引导下优化输出,实现“复杂偏好对齐”。
2.1 理论介绍
2.1.1 核心思想
强化学习。让模型通过试错,根据一个“奖励模型”的反馈来优化自己的策略(即生成答案的方式)。
2.1.2 如何工作
-
第一步:训练一个奖励模型:收集对模型不同回答的偏好数据(例如,给定一个问题,标注员(或LLM)认为答案A比答案B更好)。用这些数据训练一个独立的“奖励模型”,这个模型学会给“更好”的回答打高分,“更差”的回答打低分。
-
第二步:用PPO算法优化SFT模型:
-
人员:当前SFT后的模型(需要被优化的策略)。
-
环境:用户提出一个问题。
-
行动:模型生成一个回答。
-
奖励:奖励模型给这个回答打一个分数。
-
优化:PPO算法根据这个奖励分数,调整模型参数,使其未来生成能获得更高奖励的回答。同时,它有一个“近端”约束,防止一次更新太多,导致模型崩溃或忘记之前学会的知识。
-
2.1.3 目的
-
优化那些难以用简单“对错”衡量,但关乎“质量”的维度,如:有帮助性、真实性、无害性、流畅性等。
-
让模型学会在多个约束下做出平衡(例如,既要有帮助,又必须安全)。
2.1.4 优点
-
能处理非常复杂、多维度的优化目标。
-
是ChatGPT等模型实现卓越对话能力的核心技术。
2.1.5 缺点
-
流程极其复杂:需要训练和维护两个模型(策略模型和奖励模型)。
-
不稳定:训练过程像“走钢丝”,容易失控。
-
计算成本高。
2.2 工程实现代码
2.2.1 数据输入
(1)SFT模型:作为策略网络初始化权重;
(2)奖励模型(RM)训练数据:“指令+回应+人类打分(1-10分)”三元组;
(3)策略更新数据:大规模指令集(用于生成候选回应)。
2.2.2 模型输出
(1)符合复杂偏好的强化学习模型(如Llama-3-8B-PPO);
(2)模型checkpoint(相比DPO,更能平衡多维度偏好)。
2.2.3 实现代码(基于TRL库)
(1)训练奖励模型
# 定义奖励模型(基于SFT模型添加奖励头)
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch.nn as nn
class RewardModel(nn.Module):
def __init__(self, base_model):
super().__init__()
self.base_model = base_model
self.reward_head = nn.Linear(base_model.config.hidden_size, 1) # 输出标量奖励
def forward(self, input_ids, attention_mask):
outputs = self.base_model(input_ids=input_ids, attention_mask=attention_mask)
last_hidden = outputs.last_hidden_state[:, -1, :] # 取最后token的隐藏态
return self.reward_head(last_hidden).squeeze(-1) # 输出奖励分
# 加载SFT模型作为基础
base_model = AutoModelForCausalLM.from_pretrained("./llama3-8b-sft")
reward_model = RewardModel(base_model)
# 训练RM(简化代码,实际需用MSE损失训练)
(2)PPO策略优化
from trl import PPOTrainer, PPOConfig, AutoModelForCausalLMWithValueHead
# 策略网络(带价值头,用于估计预期奖励)
model = AutoModelForCausalLMWithValueHead.from_pretrained("./llama3-8b-sft")
# 配置PPO
ppo_config = PPOConfig(
output_dir="./llama3-8b-ppo",
learning_rate=5e-6, # 策略更新需谨慎,用小学习率
batch_size=32,
eps_clip=0.2, # 裁剪系数(0.1~0.2)
kl_coef=0.05, # KL惩罚(防止策略偏移)
num_train_epochs=3
)
# 加载指令数据集(用于生成回应)
instruction_dataset = Dataset.from_list([{"instruction": "..."}, ...])
# 初始化PPO Trainer
ppo_trainer = PPOTrainer(
model=model,
config=ppo_config,
tokenizer=tokenizer,
dataset=instruction_dataset,
reward_model=reward_model # 奖励模型
)
# 迭代训练(采样-打分-更新)
for epoch in range(3):
# 1. 生成回应(采样)
outputs = ppo_trainer.generate(batch_size=128, max_new_tokens=200)
# 2. 奖励打分
rewards = reward_model(** outputs["response_ids"])
# 3. 更新策略
stats = ppo_trainer.train(minibatch_size=32)
2.2.4 注意事项
(1)奖励模型设计:多维度奖励需合理加权(如安全分权重过高会导致模型“不说有用内容”);
(2)控制KL散度:策略与SFT模型的KL值需保持在0.01~0.05(过高说明策略偏移,过低说明无提升);
(3)避免奖励欺骗:定期人工检查模型输出,防止模型生成“看似高分实则无意义”的内容(如重复关键词);
(4)混合SFT数据:每轮训练加入10%的SFT数据,防止策略遗忘基础指令能力。
三、DPO
直接偏好优化DPO通过奖励模型(RM)训练,直接用“人类偏好对”(优质回应vs劣质回应)优化模型,实现“高效偏好对齐”,降低PPO的复杂性。
3.1 理论介绍
3.1.1 核心思想
绕过奖励模型,直接使用偏好数据来优化模型。它是一种更聪明、更简洁的数学转换。
3.1.2 如何工作
-
它发现,训练一个奖励模型然后再用PPO去优化,这个流程可以数学等价为直接用一个损失函数在偏好数据上微调模型。
-
所需数据:偏好数据(对于问题X,回答A优于回答B)。
-
直接优化:DPO设定一个损失函数,它的目标是:让模型对自己生成“优选回答”的概率,与生成“劣选回答”的概率之差越来越大。模型参数直接根据这个目标进行调整。
3.1.3 目的
-
实现与PPO类似的目标(根据人类偏好优化模型),但方法更直接。
3.1.4 优点
-
极其简单:整个流程就像做SFT一样简单,只需要一个模型和一份偏好数据集。
-
稳定高效:避免了PPO复杂的强化学习循环,训练更稳定,计算成本更低。
-
在许多任务上被证明能达到甚至超过PPO的效果。
3.1.5 缺点
-
理论框架假设偏好数据是基于某个潜在的奖励模型产生的,对于极度复杂或动态变化的偏好,其理论最优性可能不如PPO灵活。
3.2 工程实现代码
3.2.1 数据输入
(1)数据格式:
[
{
"instruction": "推荐一本科幻小说",
"chosen": "《三体》:刘慈欣的经典,探讨宇宙文明法则", # 优质
"rejected": "《西游记》:不是科幻小说" # 劣质
}
]
(2)数据规模:1k~10k条(远少于SFT);
(3)数据来源:人工标注偏好对、SFT模型生成多候选后筛选。
3.2.2 训练输出
(1)符合人类偏好的模型(如Llama-3-8B-DPO);
(2)模型checkpoint(相比SFT,更倾向生成“优质回应”风格的内容)。
3.2.3 实现代码(基于TRL库)
(1)数据准备
# 加载偏好数据
with open("dpo_data.json", "r") as f:
dpo_data = json.load(f)
# 转换为TRL库要求的格式(需包含"chosen"和"rejected"字段)
from datasets import Dataset
dpo_dataset = Dataset.from_list(dpo_data)
(2)模型初始化
from transformers import AutoModelForCausalLM, AutoTokenizer
from trl import DPOTrainer, DPOConfig
# 参考模型(固定SFT模型)
ref_model = AutoModelForCausalLM.from_pretrained("./llama3-8b-sft", device_map="auto")
# 待优化模型(复制SFT权重)
model = AutoModelForCausalLM.from_pretrained("./llama3-8b-sft", device_map="auto")
tokenizer = AutoTokenizer.from_pretrained("./llama3-8b-sft")
(3)DPO训练配置与启动
dpo_config = DPOConfig(
output_dir="./llama3-8b-dpo",
per_device_train_batch_size=16, # 偏好数据少,可设大batch
learning_rate=2e-5, # 小于SFT,避免破坏基础能力
num_train_epochs=2, # 偏好数据易过拟合
beta=0.1, # 温度参数(常用0.1~0.5)
logging_steps=50
)
# 初始化DPO Trainer
dpo_trainer = DPOTrainer(
model=model,
ref_model=ref_model,
args=dpo_config,
train_dataset=dpo_dataset,
tokenizer=tokenizer,
max_prompt_length=256, # 指令最大长度
max_length=512 # 回应最大长度
)
dpo_trainer.train() # 启动训练
3.2.4 注意事项
(1)偏好对质量:避免“模糊对”(优质/劣质差异不明显),否则模型无法学习偏好;
(2)β值调整:β过小(如0.05)会导致偏好学习不充分;β过大(如1.0)会导致模型输出保守(只说安全但无用的内容);
(3)参考模型固定:训练中ref_model权重不可更新,否则失去约束作用;
(4)数据增强:对稀缺偏好类型(如“代码简洁性”),用大模型生成相似样本扩充。
四、GRPO
组相对策略优点GRPO通过改进PPO,用“组内相对奖励”替代额外的评价,降低计算成本,提升推理密集型任务(数学、代码)的对齐效果。
4.1 理论介绍
4.1.1 核心思想
用简单规则对模型自己实时生成的结果进行分组比较,从而实现自我迭代优化。 其革命性在于,它完全摆脱了对“人类标注的偏好数据”或“奖励模型”的依赖,仅依靠一个可编程的规则(如“是否遵循了格式要求”)和模型自身生成的内容,就能实现模型行为的定向对齐。
4.1.2 如何工作
GRPO是一个在线的迭代优化循环,可以分解为四个步骤:
-
第一步:生成
对于同一个提示(例如,“用Python写一个排序函数,并附上注释”),让当前待优化的模型生成一组(例如4-9个)不同的回答。这些回答是在训练过程中实时采样得到的,而不是使用一个固定的数据集。 -
第二步:评估与分组
用一个预先定义的、确定性的规则函数对每个回答进行打分。这个规则通常非常简单、可量化,例如:-
格式合规:回答是否包含了要求的代码块(```python)?
-
内容包含:回答是否包含了关键词“冒泡排序”或“快速排序”?
-
拒绝安全性:对于恶意提问,模型是否给出了拒绝回答?
-
长度控制:回答是否在要求的字数范围内?
根据得分,将所有回答分成 “优胜组” (得分高)和 “普通组” (得分低)。
-
-
第三步:优化
这是GRPO的核心。其损失函数旨在:-
最大化模型产生“优胜组”回答的概率(鼓励模型学习好的行为)。
-
最小化模型产生“普通组”回答的概率(抑制模型的普通或不达标行为)。
这个优化目标通过一种对比损失函数来实现,它直接比较两组回答的概率分布差异,推动模型参数向生成“优胜组”风格答案的方向更新。
-
-
第四步:迭代
更新模型参数后,回到第一步,用更新后的模型生成新一批回答,继续评估、分组和优化。这个过程不断重复,模型在“自我生成-自我评判-自我改进”的循环中持续进化。
4.1.3 目的
-
主要目的:以一种极低成本、高稳定性的方式,让模型可靠地遵守明确、可验证的约束规则(如输出格式、安全准则、特定内容包含等)。
-
深层目的:探索一条不依赖于人类主观偏好标注的模型对齐新路径,为大规模、自动化地打磨模型基础能力(如指令跟随)提供解决方案。
4.1.4 优点
-
成本极低:完全不需要昂贵且耗时的人类偏好数据标注,也无需训练和维护复杂的奖励模型。
-
流程简单稳定:整个流程类似“增强版的SFT”,避免了强化学习(如PPO)的复杂性和不稳定性,易于实现和调试。
-
高度可控:优化目标完全由设计者定义的清晰规则所决定,过程透明,结果可预测。
-
在线学习,自我迭代:模型在训练中实时看到自己生成结果的好坏,能进行快速、有针对性的自我改进。
4.1.5 缺点
-
规则的天花板:模型的最终能力被预设的简单规则所限制。它无法学习规则无法描述的、复杂微妙的“质量”概念,如创造性、智慧深度或情感共鸣。
-
过拟合风险:模型可能学会“钻空子”,机械地满足规则的表象(例如,为了满足“包含关键词”而强行插入无关的关键词),却牺牲了回答的真实质量和逻辑性。
-
适用场景受限:它擅长的是“合规性对齐”,而非“质量性对齐”。对于提升开放性对话的趣味性、有帮助性,它的能力远不及基于人类真实偏好的DPO/PPO。
-
规则设计有挑战:如何设计出全面、无漏洞且能引导模型向好方向发展的规则,本身需要一定的专业知识和迭代。
4.2 工程代码实现
4.2.1 数据输入
(1)SFT模型:作为策略网络初始化;
(2)组采样数据:对每个指令生成GGG个候选回应(G=8∼32,推理任务用大G);
(3)奖励模型:同PPO(支持多维度打分)。
4.2.2 训练输出
(1)强推理能力的对齐模型(如Llama-3-8B-GRPO);
(2)模型checkpoint(在数学、代码任务上表现优于PPO)。
4.2.3 实现代码(基于TRL库)
(1)组采样数据生成
def generate_group(instruction, model, tokenizer, G=16):
"""为单个指令生成G个候选回应"""
inputs = tokenizer(instruction, return_tensors="pt").to("cuda")
responses = []
for _ in range(G):
# 带随机性生成,确保候选多样性
outputs = model.generate(**inputs, max_new_tokens=200, temperature=0.8, do_sample=True)
responses.append(tokenizer.decode(outputs[0], skip_special_tokens=True))
return {"instruction": instruction, "responses": responses}
# 生成组数据(示例)
group_data = [generate_group(inst["instruction"], model, tokenizer) for inst in instruction_dataset]
(2)GRPO训练配置与启动
from trl import GRPOTrainer, GRPOConfig
grpo_config = GRPOConfig(
output_dir="./llama3-8b-grpo",
learning_rate=5e-6,
group_size=16, # 组内候选数(推理任务用32)
kl_coef=0.03, # KL约束(比PPO更严格)
num_train_epochs=2
)
# 初始化GRPO Trainer
grpo_trainer = GRPOTrainer(
model=model,
config=grpo_config,
tokenizer=tokenizer,
train_dataset=group_data, # 组采样数据
reward_model=reward_model
)
grpo_trainer.train() # 启动训练
4.2.4 注意事项
(1)组大小选择:数学/代码任务用G=32G=32G=32(提升对比效果),简单任务用G=8G=8G=8(降低成本);
(2)异步采样:用单独进程生成组数据,避免训练中断(可节省50%时间);
(3)奖励归一化:组内奖励必须标准化(否则相对优势计算失效);
(4)拒绝采样:每轮筛选组内Top 20%高奖励回应,作为下一轮训练的补充数据(提升效率)。
五、总结
用同一个比喻来贯穿这四种技术,假设我们的目标是训练一个写作助手:
-
SFT:临摹范文。我给你大量优秀的文章范例,你一遍遍临摹,学习基本的写作结构和修辞手法。
-
PPO:参加高级写作班。老师(奖励模型)根据一套复杂的审美和思想标准(人类偏好)给你的每篇习作打分。你根据分数反复修改,试图揣摩并满足老师的高标准。过程艰难但上限高。
-
DPO:私人教练点拨。教练不直接打分,而是直接给你看两篇习作(一篇好,一篇差),告诉你“这一篇比那一篇好,多学学这篇”。你从中直接领悟要旨。
-
GRPO:自我量化训练。你给自己定几条硬性指标:每篇文章必须用3个成语、不少于500字、必须有首尾呼应。你写好几篇文章,自己检查,把达标的归为一组,不达标的归为另一组。然后你专门练习那些能让自己“达标”的写法。高效、自律,但可能写不出有灵性的传世之作。
更多推荐



所有评论(0)