写在前面

CoT 是为了让模型“想得更清楚”(我的解读blog),而本文解读的 SFT + RLHF 则是为了让模型“按人类想要的方式说和做”。(我的大模型预训练解读blog)。同样,这套方法论也可以拓展到多模态大模型中,以更好的对齐不同模态的信息。SFT + RLHF 不仅能拓展到多模态,而且在多模态场景中比纯文本更“必要”。

(该图复制自SFT与RLHF优缺点全解析

原论文发表于NeurIPS-2022(原文链接),中文标题翻译过来是训练语言模型使他们能服从人类的指示,在OpenAI自己的后续工作和博客中被称为InstructGPT,是ChatGPT-3.5微调时借鉴的前置工作。

本文是我的翻译+笔记,我学习的内容是论文原文和李沐老师的论文精读b站视频(b站链接)。


目录

写在前面

1.翻译

1.0.摘要

1.1.导论

1.2.相关工作(aliment对齐)

1.3.方法和实验细节

1.3.1.高水平的方法论

1.3.2.数据集

1.3.3.人类数据采集(标数据)

1.3.4.模型

1.3.5.评估

1.4.结果

1.4.1.API 分布上的评估

1.4.2.公共NLP数据集上的评估

1.4.3.量化结果

1.5.讨论

1.5.1.对齐研究的启示

1.5.2.局限性

1.5.3.更广泛的影响

2.笔记

2.1.原有的预训练GPT3方法带来的问题

2.2.方法

2.3.结论

2.4.价值观讨论

3.代码实验

3.1.数据准备

3.2.SFT 代码(第 1 步:监督微调)

3.3.RM 代码(第 2 步:奖励模型训练)

3.4.PPO 代码(第 3 步:用 RM 做 RLHF 微调 SFT 模型)



1.翻译

1.0.摘要

将语言模型做得更大并不一定能使它们更好地遵循用户的意图。例如,大型语言模型可以生成不真实的、有毒的或者根本对用户没有帮助的输出。换句话说,这些模型并没有与他们的用户对齐。在本文中,我们展示了一种通过人工反馈微调的方式在广泛的任务上将语言模型与用户意图进行对齐的方法。从一组标注工具编写的提示和通过语言模型API提交的提示开始,我们收集了一个标注器演示所需模型行为的数据集,我们使用该数据集使用监督学习对GPT-3进行微调。 然后,我们收集了一个模型输出排名的数据集,该数据集用于使用来自人类反馈的强化学习来进一步微调这个监督模型。我们将得到的模型称为InstructGPT。在人类对我们的即时分配的评价中,输出来自于1.3B参数InstructGPT模型更适合从175B GPT-3输出,尽管它的参数少了100倍。此外,在公开的NLP数据集上,InstructGPT 模型在真实性和减少有毒产出生成方面有所改进,同时具有最小的性能回归。尽管InstructGPT仍然犯了简单的错误,但我们的结果表明,使用人类反馈进行微调是使语言模型与人类意图保持一致的一个很有前途的方向。

1.1.导论

大型语言模型( LM )可以通过“提示prompt”的方法执行一系列自然语言处理( NLP )任务,给定任务的一些实例作为输入。然而,这些模型经常表达非预期行为,如虚构事实,生成偏见或有毒文本,或根本不遵循用户指令。这是因为用于最近的许多大型LM的语言建模的目的是 -- 从互联网上预测网页上的下一个token -- 与"有帮助地和安全地遵循用户的指示"的目标不同。因此,我们说语言建模的目标是错位的(没有aligned)。避免这些非预期行为对于部署和使用在数百个应用程序中的语言模型来说尤其重要。

我们在对齐语言模型方面取得了进展,训练它们按照用户的意图行事。这既包括遵循指令等外显意图,也包括保持真实、不带偏见、不有毒或有害等内隐意图。使用Askell et al . ( 2021 )的语言,我们希望语言模型是有用的(他们应该帮助用户解决他们的任务)、诚实的(他们不应该编造信息或误导用户)和无害的(不得对人或环境造成生理、心理或社会危害)。我们将在第3.5节详细阐述对这些标准的评价。

我们重点关注语言模型对齐的微调方法。具体来说,我们使用来自人类反馈的强化学习( RLHF,最早提出于本文作者2017年的工作 )来微调GPT-3,以遵循广泛的书面指令(见图2)。该技术使用人类偏好作为奖励信号来微调我们的模型。我们首先雇佣了一个由40个承包商组成的团队来标注我们的数据,基于他们在筛选测试(见第3.3节和附录B.1)上的表现。然后,我们收集了提交给一个语言模型API和一些标签书写提示的(主要是英文)提示上的期望输出行为的人工书写演示数据集,并以此来训练我们的监督学习基线。 接下来,我们在更大的API提示集合上收集了我们的模型的输出之间的人工标注的比较的数据集。然后,我们在该数据集上训练一个奖励模型( RM ),以预测我们的标注者更喜欢哪个模型输出。最后,我们使用这个RM作为奖励函数,并使用PPO算法微调我们的监督学习基线,以使这个奖励最大化。我们在图2中说明了这一过程。这一程序将GPT-3的行为与特定人群(主要是我们的标注者和研究者)所声明的偏好保持一致,而不是任何更广泛的"人类价值"概念;我们在附录G.2中对此进行了进一步的讨论。我们将得到的模型称为InstructGPT。

我们主要通过让我们的标注者在测试集上对模型输出的质量进行评分来评估我们的模型,测试集由其他用户(在训练数据中没有表示的人)的提示组成。我们还在一系列公开的NLP数据集上进行了自动评估。我们训练了三种模型大小的( 1.3B、6B、175B参数),所有的模型都使用GPT-3架构。我们的主要发现是:

与GPT-3的输出结果相比,标注人员显著更喜欢InstructGPT的输出结果。1.3B参数InstructGPT模型的输出比从175B GPT-3的输出更受人喜欢,尽管它比175B GPT-3少了2个100倍以上的参数。这些模型具有相同的架构,区别仅在于InstructGPT是根据我们的人类数据进行微调的。这个结果即使当我们给GPT-3添加了一个小数提示,使其更好地跟随指令时也是正确的。我们的175B InstructGPT的输出在85±3 %的情况下优于175B GPT-3输出,在71±4 %的情况下优先于few-shot的175B GPT-3输出。InstructGPT 也会根据我们的标签生成更合适的输出。

InstructGPT 模型显示出比GPT-3在真实性方面有改进。在Truthful QA基准上,InstructGPT 比GPT-3更多地生成真实和信息丰富的答案。对于我们的API提示分布中的"封闭域"任务,其中输出不应该包含输入中不存在的信息,InstructGPT 模型弥补输入中不存在的信息的频率约为GPT-3 (幻觉发生率分别为21%和41%)的一半。

与GPT-3相比,InstructGPT 对毒性的改善较小,但在偏见上没有改善。为了测量毒性,我们使用了RealToxicityPrompts数据集,并进行了自动和人工评估。当提示尊重时,InstructGPT 模型比GPT-3少产生约25%的毒性输出。InstructGPT 在Winogender 和CrowSPairs 数据集上较GPT-3没有显著提升。

我们可以通过修改RLHF微调过程来最小化公开NLP数据集上的性能回归(解释见笔记-结论)。在RLHF微调过程中,我们在一些公开的NLP数据集上观察到了相对于GPT-3的性能回归。通过将PPO更新与增加预训练分布对数似然值( PPO-ptx )的更新混合,我们可以在不牺牲标签者偏好分数的情况下,极大地减少这些数据集上的性能回归。

我们的模型也更受没有参与标注的人员的偏好。为了测试我们模型的泛化性,我们对没有参与标注的人进行了初步实验,发现他们更喜欢InstructGPT的输出,而不是GPT-3的输出,其速率与我们的训练标注者差不多。然而,需要更多的工作来研究这些模型如何在更广泛的用户群体中执行,以及在人类对所期望的行为不一致的情况下如何执行(人与人之间的喜好度有一定的相关性)。

公开的NLP数据集并不能反映我们的语言模型是如何被使用的(解释见笔记-结论)。我们比较了在我们的人类偏好数据( 即InstructGPT )上微调的GPT-3和在两个不同的公共NLP任务汇编上微调的GPT-3:FLAN ( Wei et al . , 2021)和T0 ( Sanh et al , 2021) (特别地, T0++变体)。这些数据集由多种NLP任务组成,并结合了每个任务的自然语言说明。在我们的API提示分布上,我们的FLAN和T0模型的表现略差于我们的SFT基线,标签制作者显著偏好InstructGPT而非这些模型。

InstructGPT 模型对RLHF微调分布以外的指令具有很好的泛化能力(解释见笔记-结论)。我们定性地考察了InstructGPT的能力,发现它能够遵循总结代码的指令,回答关于代码的问题,有时也会遵循不同语言的指令,尽管这些关于代码的指令在微调分布中非常罕见。这个结果是令人兴奋的,因为它表明我们的模型能够推广"跟随指令"的概念。即使在他们很少得到直接监督的任务上,他们也保留了一定的一致性。

InstructGPT 仍然会犯简单的错误。例如,InstructGPT 仍然可能无法遵循指令、捏造事实、对简单问题给出冗长的对冲答案,或者无法检测带有虚假前提的指令。

总的来说,我们的结果表明,使用人类偏好微调大型语言模型显著地改善了它们在广泛任务上的行为,尽管仍然需要做许多工作来提高它们的安全性和可靠性。

1.2.相关工作(aliment对齐)

对齐和从人类反馈中学习的研究。我们基于先前的技术,将模型与人类意图进行对齐,尤其是RLHF技术。它最初是为在模拟环境和Atari游戏中训练简单机器人而开发的,最近被应用于微调语言模型,以总结文本。这项工作反过来又受到类似工作的影响,这些工作使用人类反馈作为奖励,例如对话,翻译,语义解析,故事生成,评论生成和证据提取。在并行工作中,Askell et al . ( 2021 );白重恩等 ( 2022 )提出语言助手作为对齐研究的测试平台,并使用RLHF训练模型。我们的工作可以看作是RLHF在广泛分布的语言任务上对齐语言模型的直接应用

训练语言模型以遵循指令。我们的工作还涉及到语言模型中跨任务泛化的研究,其中LM在广泛的公共NLP数据集(通常以适当的指令为前缀)上进行微调,并在不同的NLP任务集上进行评估。在该领域已经有了一系列的工作,在训练和评估数据、指令的格式、预训练模型的大小以及其他实验细节上都有所不同。

减轻语言模式的危害。修改语言模型的行为的一个目的是减轻这些模型在现实世界中部署时的危害。这些风险已经被广泛地记录在案。语言模型会产生有偏输出,泄露私有数据,产生错误信息,被恶意使用;为了全面回顾,我们将读者引向魏丁格尔et al . ( 2021 )。 减轻这些危害的方法有很多,包括在一个小的、有价值的数据集上进行微调,对预训练数据集进行过滤,或者在人在回路(human-in-the-loop)中收集数据。

1.3.方法和实验细节

1.3.1.高水平的方法论

(可见笔记--方法)

我们的研究方法沿用了齐格勒等人( 2019 )和Stiennon等人( 2020 )的研究方法,并将其应用于文体延续和文摘领域。我们从一个预训练的语言模型开始,一个我们希望模型产生对齐输出的提示分布,以及一个训练有素的人工标注器团队(见第3.3节为详细信息)。我们采用以下3个步骤(图2 )。

步骤1:收集示范数据,训练有监督的政策。我们的标注者在输入提示分布(见第3.2节关于这一分布的详细信息)上提供了期望行为的演示。然后我们使用监督学习在这个数据上微调一个预训练的GPT-3模型。

步骤2:收集对比数据,训练奖励模型RM。我们收集了一个模型输出之间比较的数据集,其中标记者表示对于给定的输入,他们更喜欢哪种输出。然后,我们训练一个奖励模型来预测人类偏好的输出。

步骤3:使用PPO针对奖励模型进行策略优化。我们使用RM的输出作为标量奖励。我们使用PPO算法微调有监督的策略来优化这个奖励。

第2步和第3步可以不断迭代;对当前最好的策略收集更多的比较数据,用于训练一个新的RM,然后训练一个新的策略。在实践中,我们的比较数据大多来自于我们的监督政策,也有一部分来自于我们的PPO政策。

1.3.2.数据集

我们的提示数据集主要包括提交给商业语言模型API的文本提示,以及少量由标注者编写的提示。这些提示非常多样化,包括生成、问答、对话、摘要、抽取和其他自然语言任务(见附录A)。我们的数据集超过96%是英文的。我们启发式地去重提示,并确保验证集和测试集中不包含数据在训练集中的用户的数据。我们还对包含个人身份信息( PII )的提示进行了筛选

从这些提示中,我们产生了三个不同的数据集用于我们的微调过程:( 1 )我们的SFT数据集,和标注者的回答答案示例一起用于训练我们的SFT模型;( 2 )我们的RM数据集,和标签者对于输出排名打分一起,用于训练我们的RM;( 3 )我们的PPO数据集,没有任何人类标签,这些标签被用作RLHF微调的输入。SFT数据集包含约13k条训练提示(来自API和人工标注者),RM数据集包含约33k条训练提示(来自API和人工标注者),PPO数据集包含约31k条训练提示(只需从API调用即可)。在表3中提供了关于数据集大小的更多细节。

1.3.3.人类数据采集(标数据)

为了提供我们的示范和比较数据,并进行我们的主要评估,我们雇用了一个由大约40个合同工组成的团队,他们是在Upwork和ScaleAI上进行的。与早期在摘要任务上收集人类偏好数据的工作相比,我们的输入跨越了更广泛的任务范围,并且可以偶尔包括有争议和敏感的主题。我们的目的是选择一组对不同人口统计群体的偏好敏感,并且善于识别潜在有害产出的标记者。因此,我们在这些(见附录B.1)轴上进行了测量贴标机性能的筛选试验。 作为一个初步的研究,以了解我们的模型如何很好地推广到其他标注者(人类)的偏好,我们雇佣了一组独立的实验者,他们不产生(标注)任何训练数据。这些标签来自相同的供应商,但不经过筛选测试。尽管任务复杂,但我们发现标注者之间的一致率很高(相互同意对方的评测):训练标注者为72.6±1.5 %、没有参与贴标签者为77.3±1.3 %。相比之下,在Stiennon等( 2020 )的总结工作中,研究者-研究者一致性为73±4 %。

1.3.4.模型

(解释见笔记--方法)

从GPT-3开始,我们用3种不同的技术训练模型:

有监督微调( Supervised Fine-Tuning,sft )。我们使用监督学习在我们的标签示例上微调GPT-3。我们训练了16个epoch,使用余弦学习率衰减,残差dropout为0.2。我们在验证集上做了基于RM得分的最终SFT模型选择。我们发现我们的SFT模型在1个epoch后对验证损失过拟合;然而,我们发现更多epoch的训练有助于RM评分和人类偏好评分

奖励建模( RM )。我们对GPT-3进行微调,使其接受一个提示和响应,并输出一个标量奖励。在本文中我们只使用了6B RM,因为这节省了大量的计算,并且我们发现175B RM训练可能是不稳定的,因此不太适合作为RL (详见附录D)期间的值函数。在Stiennon等人( 2020 )中,RM是在相同输入的两个模型输出之间的比较数据集上训练的,他们使用交叉熵损失,以比较作为标签--奖励的差异代表了一个响应会被人类标记者优先于另一个响应的对数优势。 为了加速比较集合,我们在K=4和K=9响应之间进行了标记排序,并将每个提示的所有(C(K,2)个)比较作为单个批元素进行训练,以提高计算效率(见附录D )。RM的损失函数为公式1(见笔记--方法)。其中rθ(x , y)是奖励模型对提示x和带有参数θ的结果y的标量输出,yw 是yw和yl中的优选结果,D是比较数据集。

强化学习(RL)。 同样遵循 Stiennon et al.(2020)的做法,我们使用 PPO 对 SFT 模型进行微调。环境被设定为一个 bandit 环境,该环境会给出一个随机的用户提示(prompt),并期望模型生成对该提示的响应。给定提示和响应后,环境根据奖励模型(reward model)产生一个奖励,并结束该回合(episode)。此外,我们在每个 token 处加入了相对于 SFT 模型的 逐 token KL 惩罚项,以缓解对奖励模型的过度优化。价值函数(value function)由奖励模型(RM)进行初始化。我们将这些模型称为 “PPO”。我们还尝试将预训练阶段的梯度PPO 梯度进行混合,以修复在公开 NLP 数据集上出现的性能退化问题(见附录 D.4)。我们将这些模型称为 “PPO-ptx”。除非另有说明,本文中提到的 InstructGPT 指的都是 PPO-ptx 模型

基线(Baselines)。 我们将所提出的 PPO 模型的性能与 SFT 模型以及 GPT-3 进行比较。我们还将其与在提供 few-shot 前缀、以将其“提示(prompt)”进入指令遵循模式下的 GPT-3 进行比较(记为 GPT-3-prompted)。该前缀会被拼接在用户指定的指令之前

此外,我们还将 InstructGPT 与在 FLANT0 数据集上对 175B GPT-3 进行微调的方法进行比较。这两个数据集都由多种 NLP 任务构成,并为每个任务配有自然语言指令(二者在所包含的 NLP 数据集以及所使用的指令风格上有所不同)。我们在约 100 万条样本上对其进行微调,并选择在验证集上获得最高奖励模型(RM)得分的检查点(更多细节见附录 D)。

1.3.5.评估

遵循 Askell et al.(2021),如果我们的模型是有帮助的(helpful)真实的(truthful)以及无害的(harmless),我们就认为它们是对齐的(aligned)(相关细节见附录 C.2)。我们将定量评估分为两部分:

API 分布上的评估。 我们的主要评估指标是在与训练分布来源相同、但被保留(held-out)的一组提示(prompt)上进行的人类偏好评分。在使用来自 API 的提示进行评估时,我们只选择那些未被用于训练的用户所提交的提示。对于每个模型,我们计算其输出相较于某一基线策略被偏好的频率;我们选择 175B 的 SFT 模型作为基线,因为它的性能大致处于各模型的中间水平。此外,我们还要求标注者使用 1–7 的李克特量表(Likert scale)对每个回答的整体质量进行评分,并为每个模型输出收集一系列元数据(见表 11)。其中,尤其包括旨在刻画部署模型中可能导致有害行为的不同方面的数据:我们让标注者评估模型输出在客户助理场景下是否不恰当、是否贬损受保护群体,以及是否包含性或暴力内容

公共 NLP 数据集上的评估。 我们在两类公共数据集上进行评估:一类用于刻画语言模型安全性的某些方面,尤其是真实性(truthfulness)、毒性(toxicity)和偏见(bias);另一类用于刻画模型在传统 NLP 任务(如问答、阅读理解和摘要)上的零样本(zero-shot)性能。此外,我们还在 RealToxicityPrompts 数据集上进行了人工评估。

1.4.结果

1.4.1.API 分布上的评估

标注者显著地更偏好 InstructGPT 的输出,而非 GPT-3 的输出。在我们的测试集上,标注者在不同模型规模下都显著更偏好 InstructGPT 的输出(见图 1)。我们发现,GPT-3 的输出表现最差,而可以通过以下方式获得显著的逐步提升:首先是使用精心设计的 few-shot 提示(GPT-3(prompted)),其次是使用示范数据进行有监督训练(SFT),最后是使用对比数据通过 PPO 进行训练。在 PPO 过程中加入预训练数据混合的更新,并不会在标注者偏好上带来显著变化。为了说明性能提升的幅度:在直接比较时,175B 的 InstructGPT 输出在 85 ± 3% 的情况下更受偏好于 GPT-3 的输出,并且在 71 ± 4% 的情况下更受偏好于 few-shot GPT-3

在图 4 中,我们还展示了标注者在多个更具体的维度上对 InstructGPT 输出给予了更高评价。具体而言,与 GPT-3 相比,InstructGPT 的输出在客户助理场景下更为恰当,更常遵循指令中明确给出的约束条件(例如“将你的回答写成不超过两段”),更不容易完全未能遵循正确的指令,并且在封闭领域任务中更少编造事实(“产生幻觉”)。

我们的模型能够泛化到那些**未参与任何训练数据生成的“保留(held-out)标注者”**的偏好上。保留标注者的排序偏好与用于生成训练数据的标注者非常相似(见图 3)。尤其是,根据这些保留标注者的评价,我们的所有 InstructGPT 模型仍然显著优于 GPT-3 基线。因此,我们的 InstructGPT 模型并非只是对训练标注者的偏好产生了过拟合。

公共 NLP 数据集并不能反映语言模型在实际中的使用方式。在图 5a 中,我们还将 InstructGPT 与在 FLANT0 数据集上微调的 175B GPT-3 基线模型进行了比较(详见附录 D)。我们发现,这些模型的表现优于 GPT-3,与使用精心选择提示的 GPT-3 表现相当,但仍不如我们的 SFT 基线模型。这表明,这些数据集在多样性上不足,无法提升模型在我们的 API 提示分布上的性能。我们认为,这在一定程度上是因为学术数据集主要关注那些性能易于度量的任务(如分类和问答),而我们的 API 分布中大多数(约 57%)是开放式生成任务

1.4.2.公共NLP数据集上的评估

InstructGPT 模型在真实性方面相较于 GPT-3 有所提升。 通过在 TruthfulQA 数据集上的人工评估可以看到,我们的 PPO 模型在生成真实且信息充分的输出方面,相较于 GPT-3 表现出幅度虽小但具有统计显著性的改进(见图 5b)。这种行为是模型的默认表现:模型并不需要被特别指示“要说实话”,就能够体现出更高的真实性。一个有趣的例外是 1.3B 的 PPO-ptx 模型,其表现略逊于同规模的 GPT-3 模型。我们在真实性方面的改进还体现在:在封闭领域任务中,我们的 PPO 模型更少出现编造事实(幻觉)的情况(见图 4)。

InstructGPT 在毒性方面相较于 GPT-3 有小幅改进,但在偏见方面没有明显提升。 我们首先在 RealToxicityPrompts 数据集上通过人工评估对模型进行测试,结果如图 5c 所示。我们发现,当明确指示模型生成安全且尊重他人的输出(“respectful prompt”)时,根据 Perspective API 的评估,InstructGPT 模型生成的内容比 GPT-3 的输出毒性更低;但当移除这一尊重性提示(“no prompt”)后,这一优势便消失了。使用 Perspective API 进行自动评估时也得到了类似的结果(见附录 F.7)。

通过修改 RLHF 微调流程,可以减小模型在公共 NLP 数据集上的性能退化。 在图 25 中我们展示了,在 PPO 微调过程中加入预训练数据的更新(PPO-ptx),可以缓解在公共 NLP 数据集上的性能退化,甚至在 HellaSwag 数据集上超过 GPT-3。然而,PPO-ptx 模型在 DROP、SQuADv2 以及翻译任务上的性能仍然落后于 GPT-3,仍需进一步研究以理解并消除这些性能回退。我们还发现,相较于简单地增大 KL 系数混合加入预训练更新是一种效果更好的方法(见图 36)。

1.4.3.量化结果

InstructGPT 模型在 RLHF 微调分布之外的指令上表现出良好的泛化能力。 尤其是,我们发现 InstructGPT 具备在非英语语言中遵循指令的能力,并且能够对代码执行摘要和问答任务。这一点非常有趣,因为非英语语言和代码在我们的微调数据中只占极小的一部分,这表明在某些情况下,对齐方法可以泛化到人类并未直接监督的输入上,从而产生期望的行为。我们在图 26 中展示了一些定性示例。

InstructGPT 仍然会犯一些简单错误。 在与我们的 175B PPO-ptx 模型交互时,我们注意到,尽管它在多种语言任务上表现强劲,但仍可能出现一些简单错误。举例来说:(1)当指令包含错误前提时,模型有时会错误地假设该前提为真;(2)模型有时会过度谨慎(过度回避断言):在面对一个简单问题时,可能会声称问题没有唯一答案,并给出多个可能的回答,即便在上下文中存在一个相对明确的答案;(3)当指令包含多个明确约束(例如“列出 10 部 1930 年代在法国拍摄的电影”),或当约束本身对语言模型而言较为困难(例如要求用指定数量的句子撰写摘要)时,模型的性能会下降。我们在图 27 中展示了这些行为的一些示例。

我们推测,行为(2)在一定程度上源于我们在标注过程中鼓励标注者奖励认识论上的谦逊(epistemic humility);因此,标注者可能倾向于奖励带有保留态度的回答,而这种偏好被奖励模型学习到了。我们还推测,行为(1)之所以出现,是因为训练集中包含错误前提的提示数量较少,模型无法很好地泛化到这类情形。我们认为,通过对抗式数据收集,这两类行为都可以得到显著缓解。

1.5.讨论

1.5.1.对齐研究的启示

本文中我们对对齐研究的方式是迭代式的:我们致力于改进当前已存在的 AI 系统的对齐性,而不是抽象地关注那些尚不存在的 AI 系统的对齐问题,这使我们能够获得一个清晰的经验性反馈回路,以判断哪些方法有效、哪些无效。我们认为,这种反馈回路对于不断完善对齐技术至关重要,同时也迫使我们与机器学习领域的进展保持同步。

从本工作中,我们可以总结出一些对更广泛对齐研究的启示。首先,相对于预训练而言,提升模型对齐性的成本是适中的。训练 175B 的 SFT 模型需要 4.9 petaflops/s-days,而训练 175B 的 PPO-ptx 模型需要 60 petaflops/s-days,相比之下,GPT-3 的预训练成本为 3,640 petaflops/s-days。与此同时,我们的结果表明,RLHF 在提升语言模型对用户的有帮助性方面非常有效,其效果超过了将模型规模扩大 100 倍。这表明,在当前阶段,将资源投入到对已有语言模型进行对齐,比单纯训练更大的模型更具成本效益。

其次,我们观察到一些证据表明,InstructGPT 能够将“遵循指令”的能力泛化到未被直接监督的场景。这一性质非常重要,因为让人类对模型执行的每一项任务都进行监督在成本上是不可行的。

最后,我们成功缓解了微调过程中引入的大部分性能退化。如果无法做到这一点,这些性能退化将构成一种“对齐税(alignment tax)”,即为了对齐模型而额外付出的性能成本。任何对齐方法如果伴随着过高的对齐税,都可能难以被实际采用,因此避免这种成本非常重要。

1.5.2.局限性

方法论方面。 InstructGPT 模型的行为在一定程度上由我们从标注承包人员处获得的人类反馈所决定。一些标注任务依赖于价值判断,而这些判断可能会受到标注者身份、信念、文化背景和个人经历的影响。我们保持了一个规模较小的标注团队,以便与全职标注者进行高带宽沟通,但这一群体显然无法代表所有可能受到这些模型影响的人群。一个简单的例子是,我们的标注者主要是英语使用者,数据也几乎全部由英语指令构成。

模型方面。 我们的模型既没有完全对齐,也并非完全安全;它们仍然可能生成有毒或有偏见的内容、编造事实,甚至在没有明确提示的情况下生成性或暴力内容。在某些输入下,它们也可能无法生成合理的输出(部分示例如图 27 所示)。也许我们模型最大的局限在于:在大多数情况下,它们会遵循用户的指令,即便这在现实世界中可能造成伤害。例如,当提示模型尽可能具有偏见时,InstructGPT 会生成比同规模 GPT-3 更具毒性的输出。

1.5.3.更广泛的影响

本工作的动机在于,通过训练大型语言模型去执行一组人类所期望的行为,从而提升其正向影响。默认情况下,语言模型仅优化下一个词预测目标,而该目标只是我们真正期望模型行为的一个代理。我们的结果表明,这些技术在提升语言模型的有帮助性、真实性和无害性方面具有潜力。

从长期来看,对齐失败可能会导致更加严重的后果,尤其是在这些模型被部署到安全关键场景中时。然而,使语言模型更好地遵循用户意图,也会使其更容易被滥用,例如生成具有迷惑性的错误信息,或仇恨性、攻击性内容。对齐技术并非解决大型语言模型安全问题的灵丹妙药,而应当作为更广泛安全生态系统中的一种工具。

除了蓄意滥用之外,还有许多领域需要极其谨慎地部署大型语言模型,甚至不应部署,例如:医疗诊断、基于受保护属性对人进行分类、决定信贷、就业或住房资格、生成政治广告,以及执法相关应用。

最后,一个极其重要的问题是:模型究竟与谁的价值观对齐,这一点将深刻影响这些模型的总体影响是正面还是负面;相关讨论见附录 G.2。



2.笔记

2.1.原有的预训练GPT3方法带来的问题

  1. 有效性--模型无法从预训练数据中有效的学到特定内容(或文本中压根就没有这个内容)
  2. 安全性--输出了不该输出的内容

当时急需解决LLM在生成方面的一些语言模式的危害和需要更好的遵循用户指令,且强化学习(尤其是RLHF技术)在相关领域已经有所应用。

所以本文工作的省流版本:标一点数据,最后再做一次微调。但由于大语言模型推崇自监督无监督,所以要对这种SFT做一点包装,即微调一次、强化学习一次,使得更小的模型效果更好了。

2.2.方法

  • Step1和2分别标注了一块数据,3个步骤一共训练了3个模型。
  • Step1:
    • 首先找人来写各种各样的问题(prompt,图中例:向一个六岁小孩解释以下什么是月球登录)
    • 继续找人来写答案(图中例:一些人去了月球……)
    • 做微调,起名SFT--有监督微调(在gpt模型眼里两段话没有区别--本质上就是训练一段话中去预测下一个词token,因此此步的方法和之前的预训练、微调没区别)
    • 这里就和之前的gpt系列没什么区别,问题就在于让人写出所有问题和答案成本太高
  • Step2:
    • 还是给SFT后的模型一个问题(图中例:向一个六岁小孩解释以下什么是月球登录)让模型beam search方法(概率采样)去预测答案(图中size=4)
    • 人来判断给出的n(图中为4)个答案谁好谁坏(图中D>C>A=B)做排序标注,并用这个排序标注去训练一个RM奖励模型(给prompt和输出,对输出打分,使得对答案的分数满足排序标注的关系
    • 这一层标注的难度和成本远小于第一步中标注一个生成式的答案
  • Step3:继续微调第一步的SFT模型,使得它生成的答案能尽量得到一个比较高的分数每次生成答案--放入RM打分--优化SFT的参数),最终这一步后训练出的模型就叫InstructGPT。

技术要点:Step1和2的数据怎么标注的、RM模型如何训练的、Step3中怎样通过强化学习训练

  • 数据来源:主要包括提交给商业语言模型API的文本提示(用户在真实使用时提交的),以及少量由标注者编写的提示,并经过了去重和隐去身份信息。(数据标注见原文3.3节和对应的附录)
    • SFT数据集包含约13k条训练提示(来自API和人工标注者的文本提示和答案语句对),RM数据集包含约33k条训练提示(来自API的文本提示和人工标注者对结果的打分排序),PPO数据集包含约31k条训练提示(只需从API调用即可,标注是来自于RM模型的)。
  • Step1--SFT训练16个epoch,使用余弦学习率衰减,残差dropout为0.2。在1个epoch后对验证损失过拟合(因为数据量少而GPT模型太大);但是更多epoch的训练有助于RM评分和人类偏好评分
    • Lora和QLora是轻量化实现SFT的两种方法。

  • Step2--
    • 不同于Step1中的完全另外一个模型(专门训练只为打分),微调时发现6B参数的RM不仅计算效率高且比175B的更稳定
    • 因为用户的标注是一个顺序(各个结果谁更好),而这里softmax输出的是一个最大概率值(即所谓分数),因此要把这个顺序换成一个值(要做step2和3的主要原因)。因此,损失函数使用的是排序中常见的Pairwise ranking loss
      • Pairwise:对于一个prompt x,取出一对答案yw和yl,yw 是其中排序较优的结果。分别把(x,yw)和(x,yl)放入奖励模型计算出奖励并做差(因为x比l优故结果为正),希望这个值越大越好(解释见下),因此做逻辑回归(即公式1中的-log(sigmoid(·)))。每次生成4到9个答案,即每次在4到9个答案中选出两个答案组成一对yw和yl(即若k=9则选出C(9,2)=36对),公式1中有分母是为了不让K的值影响结果太多。
      • 模型优化使得loss这个值最小化,因为有-log,故变为希望最大化r(x,yw)和r(x,yl)的差,这样就可使得两个回答之间的奖励分差尽可能大。
      • 选择K=4到9,4是因为为保证排序标注的信息量有下限;9是因为除去读prompt的时间,排9个比排4个可能只多花30%-40%的时间,但是得到的36个排序对比排4个只得到6个排序对标注信息多了9倍

  • Step3--使用 PPO(结合原始预训练GPT的目标函数) 对 SFT 模型进行微调。环境被设定为一个 bandit 环境,会给出一个随机的用户prompt,并期望模型生成对该提示的响应。给定提示和响应后,环境根据奖励模型RM产生一个奖励,并结束该回合。此外,在每个 token 处加入了相对于 SFT 模型的 逐 token KL 惩罚项,以缓解对奖励模型的过度优化。价值函数由RM进行初始化。
    • bandit环境只有一步决策、没有长期状态转移的强化学习环境
      • 状态(state):一个用户 prompt,即x
      • 动作(action):模型一次性生成的完整回答(token 序列,即y)(PPO 模型根据 prompt 生成完整回答)
      • 奖励(reward):由奖励模型 RM (Step2) 给出的一个标量分数(RM 接收 (x, y),并输出一个标量奖励,希望最大,即需要优化rθ(x,y)这一项)
      • 回合(episode):生成完回答 →RM 打分 →立即结束,没有下一步状态
    • PPO:一种稳定更新策略网络的策略梯度算法
      • 策略(policy):语言模型本身(强化学习中模型叫做policy)
        • 公式2中的ΠRL(policy)就是GPT-3模型,θ表示要学习的模型。ΠRL会初始化成ΠSFT。需要最大化公式2这一目标函数。
      • 目标:提高 RM 给高分的回答概率,又不能离原模型(SFT)太远(主要实现见下方KL处)。
    • 价值函数(value function:从当前 token 状态开始,后面能拿到的期望奖励。
      • 从 RM 初始化,因为RM 已经学会了“什么回答好”。
      • 此处用来估计 advantage(“这次比我预期的好多少?”)、降低策略梯度的方差。
    • 相对于 SFT 模型的逐 token KL 惩罚项:防止模型为了骗 RM,生成奇怪、非人类语言但得分高的回答。
      • KL散度:衡量两个概率分布 p 和 q 之间差异的非对称度量。
      • 逐token因为:语言模型是 自回归生成;每一步都可能开始“走歪”;必须局部地约束。
      • 即公式2中的 -β*log() 项,目的是希望RL得到的新模型和之前的得到的SFT模型只做一些改动即可,不要跑太远,这也是PPO的主要思想。这里是使用KL散度评估两个y的概率(各个token的概率积)的相似度
    • 公式2的加的最后一项是预训练GPT-3时的损失函数,x是预训练时的数据,乘γ系数,是为了避免微调后模型只是在微调的领域效果好。若γ=0,则为PPO;若γ≠0,则为PPO-ptx。
    • 我的强化学习基础blog

总结:InstructGPT 的完整 SFT–RLHF 流程是:先用人工编写的 prompt–回答对对预训练 GPT 模型做有监督微调得到一个 SFT 模型;再让该模型对同一 prompt 生成多个候选回答,由人工对这些回答进行排序标注,并据此训练一个奖励模型 RM,用于将人类偏好映射为标量奖励;最后在一个 bandit 强化学习环境中,以 prompt 为状态、模型生成的完整回答为动作、RM 给出的分数为奖励,通过 PPO 对 SFT 模型进行进一步微调,同时在每个 token 处加入相对于 SFT 模型的 KL 惩罚项,以防止模型为了最大化奖励而偏离人类语言分布,最终得到的模型即 InstructGPT。

该图来自微信公众号:https://mp.weixin.qq.com/s/_nfRdKJn-Ra-QmgTSW4jlw

2.3.结论

  • Figure 1:在真实 API 提示分布上的人工偏好评测中,InstructGPT(PPO / PPO-ptx)在所有模型规模上都显著优于 GPT-3 与 GPT-3(prompted),甚至 1.3B 的 PPO-ptx 输出也优于 175B 的 GPT-3,体现了 RLHF 带来的巨大收益。

  • Figure 3:无论是参与训练的标注者还是未参与训练的保留标注者,偏好排序结果高度一致,且 InstructGPT 始终显著优于 GPT-3,说明模型并未过拟合特定标注者的偏好。
  • Figure 4:在 API 分布的行为元数据分析中,InstructGPT 更常正确理解并遵循指令与显式约束,在客服语境下更合适,同时显著减少幻觉,相比 GPT-3 表现出更符合人类期望的行为模式。

  • Figure 5(a):在整体主观质量(1–7 Likert)评分上,InstructGPT 明显优于 GPT-3 及其在 FLAN/T0 上微调的版本,表明学术指令数据集并不能替代针对真实使用分布的 RLHF 对齐。
  • Figure 5(b):在 TruthfulQA 上的人工评估显示,PPO / PPO-ptx 模型在真实性与信息充分性上相较 GPT-3 有小但显著的提升,说明 RLHF 能减少事实性错误与幻觉。
  • Figure 5(c):在 RealToxicityPrompts 上,当明确要求“尊重性输出”时,InstructGPT 的毒性显著低于 GPT-3,但在没有该提示时优势消失,表明其安全改进依赖于指令语境。

本文从人类偏好评估公共 NLP 基准评测两个层面系统评估 InstructGPT。在方法上,作者主要采用人工评估作为核心指标,在与训练分布一致的 API 提示上比较不同模型输出相对于基线(175B SFT)的偏好胜率,并辅以 1–7 Likert 量表质量评分及多项安全相关元数据标注(如不当内容、歧视、性与暴力);同时在 TruthfulQA、RealToxicityPrompts 等安全相关数据集以及问答、阅读理解、摘要等公共 NLP 任务上进行零样本评测。实验结果表明,InstructGPT 在所有模型规模上均显著优于 GPT-3,性能呈现出从 GPT-3 → few-shot GPT-3 → SFT → PPO 的逐步提升,175B InstructGPT 在人类偏好评估中相较 GPT-3 的胜率高达约 85%;在真实性方面,PPO 模型在 TruthfulQA 和封闭领域任务中显著减少幻觉;在毒性方面,在明确安全指令下优于 GPT-3,但在偏见上改进有限;此外,引入预训练梯度的 PPO-ptx 能有效缓解 RLHF 对公共 NLP 基准的性能回退,但仍未在所有任务上超过 GPT-3。整体结论是,RLHF 能显著提升模型在真实使用场景中的有用性、真实性与安全性,同时需要额外机制(如 PPO-ptx)来平衡通用语言能力的保持

以下是原文1.1节中的重复和解释。

  • 与GPT-3的输出结果相比,标注人员和没有参与标注的人员都显著更喜欢InstructGPT的输出结果。
  • InstructGPT 模型显示出比GPT-3在真实性方面有改进,对毒性的改善较小,但在偏见上没有改善
  • 可以通过修改RLHF微调过程来最小化公开NLP数据集上的性能回归,即专注于一个领域的微调,其他方面性能不会下降太多
  • 公开的NLP数据集并不能反映语言模型是如何被使用的,即证明了微调对于数据集的使用是比较敏感的
  • InstructGPT 模型对RLHF微调分布以外的指令具有很好的泛化能力,即不一定需要标注所有的问答类型
  • InstructGPT 仍然会犯简单的错误。

2.4.价值观讨论

本文在讨论部分指出,对齐研究应采用以现实系统为对象的迭代式方法,通过持续的经验反馈改进对齐技术;实验结果表明,RLHF 以远低于预训练的成本显著提升了模型的有帮助性,其收益甚至超过单纯扩大模型规模,同时还表现出一定的跨任务与跨分布泛化能力,并能在较大程度上避免“对齐税”。但作者也明确指出方法与模型的局限性,包括人类反馈的价值偏置、语言与文化覆盖不足,以及模型仍可能生成有害内容、在危险指令下过度服从用户。总体而言,RLHF 被认为是提升语言模型有帮助性、真实性与安全性的有前景工具,但并非万能方案,其社会影响与价值对齐对象仍需审慎对待。



3.代码实验

首先在魔搭社区或huggingface上下载好所需模型的权重文件,并配置好环境。

总体sft+rlhf的项目文件目录大致如:

rlhf_MODEL/
  MODEL/
  data/
    sft.jsonl
    rm.jsonl
    ppo_prompts.jsonl
  sft_train.py
  rm_train.py
  ppo_train.py

3.1.数据准备

这里只给出一行样本样例。

SFT 数据,每行一个样本 data/sft.jsonl

{"prompt":"请用两句话解释什么是哈希表","response":"哈希表是一种..."}

RM 偏好数据 data/rm.jsonl

{
  "prompt": "写一个 Python 函数计算斐波那契数列",
  "responses": [
    "下面是一个递归实现……",
    "你可以用循环来避免递归……",
    "斐波那契数列是一个数学概念……",
    "我不确定你指的是哪种斐波那契……",
    "def fib(n): return n if n <= 1 else fib(n-1)+fib(n-2)"
  ],
  "ranking": [4, 1, 0, 2, 3]
}

PPO prompts data/ppo_prompts.jsonl只需要 prompt(PPO 的 reward 来自 RM,不需要人再标注)

{"prompt":"给我一个周末学习计划"}

3.2.SFT 代码(第 1 步:监督微调)

sft_train.py

# =======================
# Step 1: SFT(监督微调)
# =======================
import os
import torch
from datasets import load_dataset

from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments
# AutoTokenizer:加载 和模型严格匹配 的 tokenizer(Qwen 的 chat 模板就在 tokenizer 里)
# AutoModelForCausalLM:因果语言模型(GPT-style)
# TrainingArguments:HF Trainer 体系的统一训练配置

from trl import SFTTrainer
# TRL = RLHF 官方库
# SFTTrainer 是一层封装,专门为 instruction tuning / SFT 设计

from peft import LoraConfig
# PEFT(Parameter-Efficient Fine-Tuning),是Lora的配置类

MODEL_DIR = os.environ.get("MODEL_DIR", "/path/to/model")  # 你的本地/服务器权重目录
SFT_DATA = os.environ.get("SFT_DATA", "data/sft.jsonl")
OUT_DIR  = os.environ.get("SFT_OUT",  "outputs/sft_lora")

def build_chat_text(tokenizer, prompt: str, response: str) -> str:
'''
把 (prompt, response) 变成模型“真正看到的训练文本”
'''
    # 构造对话结构,Qwen/LLaMA/ChatGPT的chat模板都是:user -> assistant
    messages = [
        {"role": "user", "content": prompt},
        {"role": "assistant", "content": response},
    ]
    return tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=False) # SFT 必须 100% 复用模型原生 chat template

def main():
    tokenizer = AutoTokenizer.from_pretrained(
        MODEL_DIR, 
        trust_remote_code=True, # Qwen 的 tokenizer / model 有自定义逻辑
        use_fast=False # SFT 稳定性优先
    )

    # 加载 base LM
    model = AutoModelForCausalLM.from_pretrained(
        MODEL_DIR,
        trust_remote_code=True,
        torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32,
        device_map="auto" # HF 自动切 GPU(单卡 / 多卡)
    )

    # 加载 SFT 数据,期望字段:prompt, response
    ds = load_dataset("json", data_files=SFT_DATA, split="train")

    def map_fn(ex):
    '''
    把原始数据转成{ "text": "<|user|>...<|assistant|>..." }
    '''
        ex["text"] = build_chat_text(tokenizer, ex["prompt"], ex["response"])
        return ex

    # 只保留 text,SFTTrainer 只需要一个文本字段
    ds = ds.map(map_fn, remove_columns=ds.column_names)

    # 用 LoRA 做 SFT(省显存/快) Lora配置
    peft_config = LoraConfig(
        r=16, # LoRA rank,16 是经验安全值
        lora_alpha=32, # LoRA scaling,实际更新幅度 ≈ alpha / r
        lora_dropout=0.05, # 防止过拟合,SFT 数据量通常较小
        bias="none", # 不训练 bias,进一步省参数
        task_type="CAUSAL_LM", # 告诉 PEFT:这是 GPT-style 模型
        target_modules=["q_proj","k_proj","v_proj","o_proj","gate_proj","up_proj","down_proj"],  # Qwen 的 Transformer 层结构,只在注意力 + FFN 上加 LoRA,embedding / LM head 保持冻结
    )

    # 这是 7B 在单卡 24GB 上的常见配置
    args = TrainingArguments(
        output_dir=OUT_DIR,
        per_device_train_batch_size=1,
        gradient_accumulation_steps=8, # 等效 batch size = 8
        num_train_epochs=1, # SFT 通常不需要多 epoch
        learning_rate=2e-5, # SFT 的常见 LR,远小于预训练阶段
        logging_steps=10,
        save_steps=200,
        save_total_limit=2,
        bf16=torch.cuda.is_available(),
        fp16=False,
        report_to="none",
    )

    trainer = SFTTrainer(
        model=model,
        tokenizer=tokenizer,
        train_dataset=ds,
        dataset_text_field="text", # 告诉 Trainer:哪个字段是要送进 LM 的完整文本
        max_seq_length=2048, # 超长会 OOM(out of memory)
        args=args,
        peft_config=peft_config,
        packing=False, # 不把多个样本 pack 到同一个 sequence
    )

    trainer.train() # 训练
    # 保存base model(只读)和LoRA adapter(可训练部分)
    trainer.save_model(OUT_DIR)
    tokenizer.save_pretrained(OUT_DIR)

if __name__ == "__main__":
    main()

运行:

MODEL_DIR=/path/to/MODEL \
python sft_train.py

产物:outputs/sft_lora/(SFT 后的“策略初始模型”)

3.3.RM 代码(第 2 步:奖励模型训练)

rm_train.py

# ===========================
# Step 2: RM(奖励模型训练)
# Pairwise ranking loss: -log(sigmoid(r_chosen - r_rejected))
# ===========================
import os
import torch
import torch.nn.functional as F
from datasets import load_dataset
from transformers import (
    AutoTokenizer,
    AutoModelForSequenceClassification, # 在 base LM 顶上加一个 reward head,输出 shape: (batch, 1)
    TrainingArguments,
    Trainer,
)

MODEL_DIR = os.environ.get("MODEL_DIR", "/path/to/model")
RM_DATA   = os.environ.get("RM_DATA", "data/rm.jsonl")
OUT_DIR   = os.environ.get("RM_OUT", "outputs/rm")

def build_chat(tokenizer, prompt: str, answer: str) -> str:
    messages = [
        {"role": "user", "content": prompt},
        {"role": "assistant", "content": answer},
    ]
    return tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=False)

class PairwiseRewardTrainer(Trainer):
'''
继承 HF Trainer
RLHF 核心之一
'''
    def compute_loss(self, model, inputs, return_outputs=False):
        # inputs: chosen_* / rejected_*
        chosen_ids = inputs["chosen_input_ids"]
        chosen_attn = inputs["chosen_attention_mask"]
        rej_ids = inputs["rejected_input_ids"]
        rej_attn = inputs["rejected_attention_mask"]

        r_chosen = model(input_ids=chosen_ids, attention_mask=chosen_attn).logits.squeeze(-1) # 对 (prompt, chosen) 打分
        r_rej    = model(input_ids=rej_ids, attention_mask=rej_attn).logits.squeeze(-1) # 对 (prompt, rejected) 打分

        # pairwise ranking loss
        loss = -F.logsigmoid(r_chosen - r_rej).mean()

        if return_outputs:
            return loss, {"r_chosen": r_chosen, "r_rejected": r_rej}
        return loss

def main():
    tokenizer = AutoTokenizer.from_pretrained(MODEL_DIR, trust_remote_code=True, use_fast=False)

    # 用同一个 base LM 做一个 reward head(sequence classification)
    rm = AutoModelForSequenceClassification.from_pretrained(
        MODEL_DIR,
        trust_remote_code=True,
        num_labels=1,
        torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32,
        device_map="auto",
    )

    ds = load_dataset("json", data_files=RM_DATA, split="train")

    def tok(ex):
        chosen_text = build_chat(tokenizer, ex["prompt"], ex["chosen"])
        rejected_text = build_chat(tokenizer, ex["prompt"], ex["rejected"])

        chosen = tokenizer(chosen_text, truncation=True, max_length=2048)
        rejected = tokenizer(rejected_text, truncation=True, max_length=2048)

        return {
            "chosen_input_ids": chosen["input_ids"],
            "chosen_attention_mask": chosen["attention_mask"],
            "rejected_input_ids": rejected["input_ids"],
            "rejected_attention_mask": rejected["attention_mask"],
        }

    ds = ds.map(tok, remove_columns=ds.column_names)

    args = TrainingArguments(
        output_dir=OUT_DIR,
        per_device_train_batch_size=1,
        gradient_accumulation_steps=8,
        num_train_epochs=1,
        learning_rate=1e-5,
        logging_steps=10,
        save_steps=200,
        save_total_limit=2,
        bf16=torch.cuda.is_available(),
        report_to="none",
    )

    trainer = PairwiseRewardTrainer(
        model=rm,
        args=args,
        train_dataset=ds,
        tokenizer=tokenizer,
    )

    trainer.train()
    trainer.save_model(OUT_DIR)
    tokenizer.save_pretrained(OUT_DIR)

if __name__ == "__main__":
    main()

运行:

MODEL_DIR=/path/to/MODEL \
python rm_train.py

产物:outputs/rm/(奖励模型 RM)

3.4.PPO 代码(第 3 步:用 RM 做 RLHF 微调 SFT 模型)

这里用 TRL 的 PPOTrainer:策略模型从 SFT 初始化;每个 prompt 采样回答;交给 RM 打分作为 reward;同时加 KL(TRL 默认会用 reference model 约束,等价于 “相对于 SFT 的 KL 罚项”)。

保存为 ppo_train.py

# ==========================================
# Step 3: PPO(RLHF:用 RM 奖励微调 SFT 模型)
# - policy: SFT 模型(带 LoRA)
# - reward: RM(prompt, response) -> scalar
# - KL: policy vs reference(SFT) 约束,防止跑飞
# ==========================================
import os
import torch
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoModelForSequenceClassification
from trl import PPOTrainer, PPOConfig
from peft import PeftModel

BASE_MODEL_DIR = os.environ.get("BASE_MODEL_DIR", "/path/to/model")
SFT_LORA_DIR   = os.environ.get("SFT_LORA_DIR",  "outputs/sft_lora")
RM_DIR         = os.environ.get("RM_DIR",        "outputs/rm")
PPO_PROMPTS    = os.environ.get("PPO_PROMPTS",   "data/ppo_prompts.jsonl")
OUT_DIR        = os.environ.get("PPO_OUT",       "outputs/ppo_lora")

def build_prompt(tokenizer, prompt: str) -> str:
    messages = [
        {"role": "user", "content": prompt},
    ]
    # add_generation_prompt=True 会让模板在最后加上 assistant 起始标记,便于生成
    return tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)

@torch.no_grad()
def rm_score(tokenizer, rm_model, prompt: str, response: str) -> float:
    text = tokenizer.apply_chat_template(
        [{"role":"user","content":prompt},{"role":"assistant","content":response}],
        tokenize=False,
        add_generation_prompt=False
    )
    toks = tokenizer(text, return_tensors="pt", truncation=True, max_length=2048).to(rm_model.device)
    score = rm_model(**toks).logits.squeeze(-1)
    return float(score.item())

def main():
    tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL_DIR, trust_remote_code=True, use_fast=False)

    # policy 初始化为 base + SFT LoRA
    policy_base = AutoModelForCausalLM.from_pretrained(
        BASE_MODEL_DIR,
        trust_remote_code=True,
        torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32,
        device_map="auto",
    )
    policy = PeftModel.from_pretrained(policy_base, SFT_LORA_DIR) # 初始化为 SFT 模型

    # reference model:用来做 KL 约束(通常是“冻结的 SFT”)
    ref_base = AutoModelForCausalLM.from_pretrained(
        BASE_MODEL_DIR,
        trust_remote_code=True,
        torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32,
        device_map="auto",
    )
    ref_model = PeftModel.from_pretrained(ref_base, SFT_LORA_DIR) # reference model = π_SFT
    ref_model.eval() # RM 不更新,只负责给 reward

    # reward model
    rm_model = AutoModelForSequenceClassification.from_pretrained(
        RM_DIR,
        trust_remote_code=True,
        num_labels=1,
        torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32,
        device_map="auto",
    )
    rm_model.eval()

    ds = load_dataset("json", data_files=PPO_PROMPTS, split="train")

    config = PPOConfig(
        batch_size=8,
        mini_batch_size=2,
        gradient_accumulation_steps=1,
        learning_rate=1e-6,
        ppo_epochs=4,
        kl_penalty="kl",   # KL 约束,防止模型为了 reward 乱说话
        target_kl=6.0,     # 可以调小/调大
        log_with=None,
    )

    ppo_trainer = PPOTrainer(
        config=config,
        model=policy,
        ref_model=ref_model,
        tokenizer=tokenizer,
    )

    gen_kwargs = dict(
        max_new_tokens=256,
        do_sample=True,
        top_p=0.9,
        temperature=0.8,
    )

    # PPO 主循环
    for step, ex in enumerate(ds):
        prompt = ex["prompt"]
        prompt_text = build_prompt(tokenizer, prompt)
        query_toks = tokenizer(prompt_text, return_tensors="pt").input_ids.to(policy.device)

        # 1) 采样回复(动作:完整 response tokens)
        response_toks = ppo_trainer.generate(query_toks[0], **gen_kwargs) # action
        response_text = tokenizer.decode(response_toks[0][query_toks.shape[-1]:], skip_special_tokens=True)

        # 2) RM 打分(奖励)
        reward = rm_score(tokenizer, rm_model, prompt, response_text) # 环境给 reward,bandit setting

        # 3) PPO 更新(带 KL 惩罚,避免偏离 SFT)
        ppo_trainer.step([query_toks[0]], [response_toks[0][query_toks.shape[-1]:]], [reward])

        if step % 20 == 0:
            print(f"[step {step}] reward={reward:.4f}\nPROMPT={prompt}\nRESPONSE={response_text}\n")

        if step > 0 and step % 200 == 0:
            ppo_trainer.save_pretrained(OUT_DIR)

    ppo_trainer.save_pretrained(OUT_DIR)
    tokenizer.save_pretrained(OUT_DIR)

if __name__ == "__main__":
    main()

运行:

BASE_MODEL_DIR=/path/to/MODEL \
SFT_LORA_DIR=outputs/sft_lora \
RM_DIR=outputs/rm \
python ppo_train.py

产物:outputs/ppo_lora/(这就是 RLHF 后的“策略模型”,也就是 InstructGPT 风格最终模型)

Logo

有“AI”的1024 = 2048,欢迎大家加入2048 AI社区

更多推荐