【生成模型】【强化学习】(一)RLHF & DPO
大型语言模型(如 GPT 系列)在训练后,往往需要“对齐”(alignment),让它们生成更符合人类偏好的输出。比如,你问 AI 一个问题,它应该给出有帮助、礼貌的回答,而不是胡说八道或有害的内容。传统的对齐方法是 RLHF(Reinforcement Learning from Human Feedback,从人类反馈中强化学习):Step1:收集人类反馈数据。通常是给模型一个提示(promp
RLHF
大型语言模型(如 GPT 系列)在训练后,往往需要“对齐”(alignment),让它们生成更符合人类偏好的输出。比如,你问 AI 一个问题,它应该给出有帮助、礼貌的回答,而不是胡说八道或有害的内容。
传统的对齐方法是 RLHF(Reinforcement Learning from Human Feedback,从人类反馈中强化学习):
Step1:收集人类反馈数据。通常是给模型一个提示(prompt,比如“解释量子力学”),模型生成多个回答,然后人类标注哪个更好或者给回答打分 s s s。
Step2:用这些偏好数据 s s s训练一个“奖励模型”(reward model),它像一个打分器,能给回答打分(高分表示好,低分表示差)。
θ ∗ = m i n θ L o s s ( R ( x ; θ ) , s ) \theta^* = min_\theta Loss(R(x;\theta), s) θ∗=minθLoss(R(x;θ),s)
Step3:用这个奖励模型作为指导,通过强化学习(如 PPO 算法)来微调语言模型,让它生成更高分的回答。同时,要防止模型偏离原版(用 KL 散度来约束)。
R()
问题:训练奖励模型需要大量计算,PPO 等强化学习算法不稳定,容易崩溃。而且,整个过程像是在“间接”优化:先学奖励,再用奖励优化模型。
DPO 的出现:DPO 就是为了解决这些问题。它直接从偏好数据中优化模型,而不需要单独训练奖励模型。简单来说,DPO 像是一个“捷径”:它把偏好数据直接转化为模型的优化目标,过程更稳定、更高效。实验显示,DPO 在对齐任务上能媲美甚至超过 RLHF,但计算成本低得多。
https://km.woa.com/articles/show/635261?kmref=search&from_page=1&no=7#7-DPO(direct-preference-optimization
Diffusion DPO伪代码
下面是Diffusion DPO的伪代码 (Diffusion Model Alignment Using Direct Preference Optimization)
def diffusion_dpo_loss(model, ref_model, x_w, x_l, prompt, beta):
"""
Args:
model: 待训练的Diffusion模型(输入:带噪图像、prompt、时间步)
ref_model: 冻结的参考模型(如预训练SDXL)
x_w: 偏好图像(latent空间)
x_l: 非偏好图像(latent空间)
prompt: 文本条件
beta: 温度超参数(SDXL用5000,SD1.5用2000)
Returns:
单样本对的DPO损失
"""
# 1. 随机采样时间步和噪声
timestep = torch.randint(0, 1000, (1,)).item() # T=1000
noise = torch.randn_like(x_w) # 真实噪声
# 2. 对偏好/非偏好样本前向加噪
noisy_x_w = add_noise(x_w, noise, timestep) # q(x_t | x_0^w)
noisy_x_l = add_noise(x_l, noise, timestep) # q(x_t | x_0^l)
# 3. 模型与参考模型的噪声预测
pred_noise_w = model(noisy_x_w, prompt, timestep)
pred_noise_l = model(noisy_x_l, prompt, timestep)
ref_noise_w = ref_model(noisy_x_w, prompt, timestep)
ref_noise_l = ref_model(noisy_x_l, prompt, timestep)
# 4. 计算噪声预测误差(MSE)
model_err_w = (pred_noise_w - noise).norm(p=2).pow(2)
model_err_l = (pred_noise_l - noise).norm(p=2).pow(2)
ref_err_w = (ref_noise_w - noise).norm(p=2).pow(2)
ref_err_l = (ref_noise_l - noise).norm(p=2).pow(2)
# 5. 计算DPO损失核心项
w_diff = model_err_w - ref_err_w # 偏好样本的误差改进
l_diff = model_err_l - ref_err_l # 非偏好样本的误差改进
inside_term = -beta * (w_diff - l_diff) # 鼓励w_diff < l_diff
loss = -torch.log(torch.sigmoid(inside_term)) # 负log似然损失
return loss
更多推荐



所有评论(0)