详解推测性采样加速推理的算法逻辑
推测性采样算法通过小模型快速生成草稿序列,再经大模型并行验证与修正,在不改变目标模型分布的前提下实现2-2.5倍加速。该算法分为三阶段:草稿生成阶段由小模型自回归生成K个候选token;并行评分阶段由大模型一次性计算K+1个上下文的概率分布;验证修正阶段通过概率比较和随机采样决定是否保留或修正草稿token。算法关键创新在于通过并行计算减少大模型调用次数,同时利用数学设计确保生成质量与单独使用大模
文章目录
推测性采样加速推理算法,是DeepMind 在《Accelerating Large Language Model Decoding with Speculative Sampling》中提出的核心算法,其本质是通过“小模型快速草稿生成+大模型并行验证与修正”的组合,在不改变目标模型(大模型)分布、不损失生成质量的前提下,将自回归解码速度提升2-2.5倍。该算法针对传统自回归生成(ArS)“一次调用仅生成1个token、效率低下”的痛点,通过三阶段逻辑实现多token并行加速,以下结合算法步骤、数学原理和实际案例展开解析。
一、算法核心目标与前置定义
1. 核心目标
- 加速不降质:在保证生成结果完全符合“目标模型(大模型)概率分布”的前提下,减少目标模型的调用次数,提升token生成速度;
- 突破内存带宽瓶颈:传统自回归生成受限于“每次调用目标模型仅处理1个token”的内存带宽限制,算法通过“一次目标模型调用验证多个草稿token”,让硬件资源利用率更高。
2. 关键定义(对应算法输入)
符号/术语 | 定义 |
---|---|
目标模型(q) | 性能强但速度慢的大模型(如论文中的Chinchilla-70B),是最终生成质量的“基准”; |
草稿模型(p) | 性能弱但速度快的小模型(如论文中的4B参数模型),用于快速生成“候选token序列”; |
K | 草稿模型每次生成的“草稿token数量”(即“前瞻长度”,论文中K=3-4时效果最优); |
x₁…xₙ | 已确定的“历史序列”(如对话前文、生成文本的已完成部分); |
T | 最终需要生成的“总序列长度”(算法终止条件:生成的token总数达到T); |
ṽ₁…ṽ_K | 草稿模型生成的“候选token序列”(需经目标模型验证后决定是否保留); |
二、算法完整步骤拆解
算法分为草稿生成、并行评分、逐token验证与修正三大阶段,循环执行直至达到总序列长度T,每个阶段的操作和背后逻辑如下:
阶段1:草稿生成(Draft Generation)——小模型快速出“候选答案”
while n < T do # 循环:直到生成的token总数n达到目标长度T
for t=1:K do
Sample draft auto-regressively ṽₜ ~ p(·|x₁,...,xₙ, ṽ₁,...,ṽₜ₋₁)
end for
操作逻辑:
草稿模型(p)以“已确定的历史序列x₁…xₙ”为上下文,自回归生成K个连续的候选token(ṽ₁到ṽ_K) 。
例如:历史序列是“今天天气”,K=3,草稿模型可能生成候选token序列ṽ₁=“很”、ṽ₂=“好”、ṽ₃=“适合”,即完整草稿为“今天天气很好适合”。
关键设计原因:
- 草稿模型速度远快于目标模型(论文中草稿模型生成1个token仅需1.8ms,目标模型需14.1ms),用K次草稿模型调用生成K个token,总耗时仅1.8K ms,远低于目标模型生成K个token的14.1K ms;
- 采用“自回归生成草稿”而非“并行生成”,是为了保证草稿序列的“语义连贯性”(并行生成易出现逻辑断裂,而自回归生成的草稿更符合语言习惯)。
阶段2:并行评分(Parallel Scoring)——大模型批量算“可信度”
In parallel, compute K+1 sets of logits from drafts ṽ₁...ṽ_K:
q(·|x₁,...,xₙ), q(·|x₁,...,xₙ, ṽ₁), ..., q(·|x₁,...,xₙ, ṽ₁,...,ṽ_K)
操作逻辑:
目标模型(q)一次性并行计算“K+1个上下文对应的logits(概率分布)” ,这K+1个上下文分别是:
- 基础上下文:仅历史序列x₁…xₙ(对应“生成ṽ₁前的上下文”);
- 累加上下文1:x₁…xₙ + ṽ₁(对应“生成ṽ₂前的上下文”);
- 累加上下文2:x₁…xₙ + ṽ₁ + ṽ₂(对应“生成ṽ₃前的上下文”);
…
K. 累加上下文K:x₁…xₙ + ṽ₁ + … + ṽ_K(对应“生成ṽ_K后,下一个token的上下文”)。
关键设计原因(算法核心创新点):
- 并行计算突破效率瓶颈:传统自回归生成中,目标模型需“逐个计算每个token的logits”(先算x₁…xₙ→ṽ₁的logits,再算x₁…xₙ+ṽ₁→ṽ₂的logits,共K次调用);而本步骤中,目标模型通过“一次性加载所有上下文并并行计算”,仅需1次调用就能得到K+1个logits,大幅减少目标模型调用次数;
- 硬件友好性:论文证明,“并行计算K个短序列的logits”与“串行计算1个token的logits”耗时几乎相等(因两者均受限于内存带宽,而非计算量)——这是算法能“加速不增耗”的核心前提。
阶段3:逐token验证与修正(Modified Rejection Sampling)——保留可信、修正错误
这是算法最关键的阶段:通过“概率比较+随机验证”逐token判断是否保留草稿token,若遇到不满足条件的token则立即修正,确保最终生成结果完全符合目标模型的分布。步骤拆解如下:
步骤3.1:循环验证每个草稿token(从ṽ₁到ṽ_K)
for t=1:K do
Sample r ~ U[0,1] from a uniform distribution. # 生成[0,1]间的随机数r
# 条件判断:比较目标模型与草稿模型对当前token的概率
if r < min(1, q(ṽₜ|x₁,...,xₙ₊ₜ₋₁) / p(ṽₜ|x₁,...,xₙ₊ₜ₋₁)) then
# 接受草稿token:将ṽₜ加入确定序列,更新已生成token数n
Set xₙ₊ₜ ← ṽₜ and n ← n+1
else
# 拒绝草稿token:从修正分布中重新采样token
Sample xₙ₊ₜ ~ (q(·|x₁,...,xₙ₊ₜ₋₁) - p(·|x₁,...,xₙ₊ₜ₋₁))₊
exit for loop # 拒绝后立即退出循环,剩余草稿token不再验证
end if
end for
操作逻辑(分“接受”和“拒绝”两种情况):
情况1:接受草稿token(ṽₜ可信,保留)
- 判断依据:随机数r小于“目标模型概率与草稿模型概率的比值上限1”(即
min(1, q/ p)
)。
这里的q(ṽₜ|...)
是目标模型对“当前上下文下生成ṽₜ”的概率,p(ṽₜ|...)
是草稿模型的对应概率——比值q/p
越大,说明目标模型越认可草稿token,接受概率越高;若q/p ≥1
,则min(1, q/p)=1
,此时r必然小于1,token一定会被接受。 - 实例:上下文是“今天天气”,草稿tokenṽ₁=“很”,目标模型对“很”的概率q=0.8,草稿模型的概率p=0.7,
q/p≈1.14
,min(1,1.14)=1
,随机数r=0.6<1,因此接受“很”,确定序列变为“今天天气很”,n从4变为5。
情况2:拒绝草稿token(ṽₜ不可信,修正)
- 判断依据:随机数r大于等于
min(1, q/p)
,说明目标模型与草稿模型对当前token的认知差异大,草稿token不可信。 - 修正逻辑:从“修正分布”
(q - p)₊
中重新采样token,其中(·)₊
表示“取正后归一化”(即max(0, q(·)-p(·))
除以所有正差值的和,确保是合法概率分布)。
这个修正分布的本质是“提取目标模型认为‘草稿模型低估’的token”——只选择那些“目标模型概率高于草稿模型”的token,保证修正后的token符合目标模型的偏好。 - 实例:上下文是“今天天气很”,草稿tokenṽ₂=“热”,目标模型对“热”的概率q=0.3,草稿模型的概率p=0.6,
q/p=0.5
,min(1,0.5)=0.5
,若随机数r=0.6≥0.5,则拒绝“热”;此时从修正分布(q - p)₊
中采样,假设目标模型对“好”的概率q=0.9,p=0.2,q-p=0.7
(为正),则大概率采样到“好”,确定序列变为“今天天气很好”,n从5变为6,同时退出循环,剩余草稿tokenṽ₃=“适合”不再验证。
步骤3.2:全接受额外奖励(所有草稿token均可信,多生成1个token)
If all tokens xₙ₊₁,...,xₙ₊K are accepted, sample extra token xₙ₊K₊₁ ~ q(·|x₁,...,xₙ₊K) and set n ← n+1.
- 操作逻辑:若K个草稿token全部被接受,说明草稿模型与目标模型高度一致,此时利用阶段2中“累加上下文K”的logits(已提前计算好),直接从目标模型中采样1个额外token——相当于“奖励性多生成1个token”,进一步提升效率(最多一次循环生成K+1个token)。
- 实例:K=3,草稿token“很”“好”“适合”全部被接受,此时利用上下文“今天天气很好适合”的logits,从目标模型采样额外token“出门”,确定序列变为“今天天气很好适合出门”,n额外+1。
三、算法关键特性:为何能“加速且不降质”?
推测性采样的算法,并非简单的“用小模型替代大模型”,而是通过严谨的数学设计保证“生成质量与目标模型完全一致”,同时实现加速,核心特性有两点:
1. 概率分布保真(Lossless)
论文通过定理1证明,无论token被“接受”还是“拒绝修正”,最终生成的token分布完全等于目标模型的分布q:
- 接受时:token来自草稿模型,但通过
min(1, q/p)
的概率筛选,确保其在目标模型中的概率权重被正确保留; - 拒绝时:token从
(q-p)₊
分布中采样,本质是“补全目标模型中草稿模型未覆盖的高概率token”,两者结合后总分布恰好为q。
这意味着生成结果的质量与“直接用目标模型自回归生成”完全一致,无任何性能损失。
2. 效率最大化(Linear Speedup)
算法的加速来自“减少目标模型调用次数”:
- 传统自回归生成:生成M个token需调用目标模型M次;
- 投机采样:若每次循环平均接受m个token(m≤K+1),则生成M个token仅需调用目标模型M/m次——论文中m≈2-2.5,因此速度提升2-2.5倍,且K=3-4时效率最优(K过大时,后期草稿token的接受率下降,反而增加计算 overhead)。
四、总结:
推测性采样算法的本质是“用小模型的速度优势做‘前置筛选’,用大模型的并行计算做‘批量验证’,用修正采样做‘质量兜底’ ”,形成“快速生成→高效验证→精准修正”的闭环:
- 草稿模型快速生成K个候选token,解决“速度慢”问题;
- 目标模型并行计算K+1个logits,解决“大模型调用频繁”问题;
- 修正采样确保分布保真,解决“质量降维”问题。
这种设计既突破了传统自回归的效率瓶颈,又保留了大模型的生成质量,成为长序列生成(如代码、文档)中提升推理速度的核心算法之一。
附录:
1. 用通俗的例子理解 r < min(1, q/p)
r < min(1, q/p)”这个判断条件的核心目的是,用草稿模型p的“快速采样”做基础,再通过“q/p的比例修正”,确保最终留下的样本符合大模型q的偏好,而不是被p的偏差带偏。我们可以用“选候选人”的生活场景类比,把抽象的概率公式转化为通俗逻辑:
第一步:先理解“为什么需要p(草稿模型),又为什么不能只信p”
假设你要招聘一位“符合公司需求的员工”(对应“生成符合大模型q分布的token”),但直接筛选所有候选人(对应“用q直接采样”)太慢、成本太高。
于是你找了一位“快速初筛员”p(草稿模型):p能快速推荐一批候选人(生成K个草稿token),但p的判断标准和公司真实需求q不完全一致——比如p可能更看重学历(对应p对某些token的概率偏高),而公司q更看重能力(对应q对另一些token的概率偏高)。
此时问题就来了:p推荐的候选人里,有些符合q的需求,有些不符合。我们需要一个规则,把“符合q需求的候选人留下,不符合的剔除”——而“r < min(1, q/p)”就是这个规则。
第二步:拆解“q/p”的本质——“公司需求”和“初筛员判断”的对齐度
公式里的“q(x|…)”是“公司q认为候选人x符合需求的概率”(大模型对token x的置信度),“p(x|…)”是“初筛员p推荐候选人x的概率”(草稿模型对token x的置信度)。
“q/p”这个比值,本质是**“公司需求与初筛员判断的对齐程度”**:
- 若q/p ≥ 1(比如q=0.8,p=0.6):说明“公司q比初筛员p更认可x”——x不仅过了p的初筛,还更符合q的真实需求,这种候选人必须优先留下;
- 若q/p < 1(比如q=0.3,p=0.6):说明“初筛员p比公司q更认可x”——x是p推荐的,但q觉得x不太行,这种候选人需要“打个折”考虑,不能直接留下。
第三步:为什么要“min(1, q/p)”——给“对齐度”设个“安全上限”
“min(1, q/p)”的作用是避免“过度认可”,让规则更合理:
- 当q/p ≥ 1时,min(1, q/p)=1:此时无论随机数r(后面会说r的作用)是多少(只要r在[0,1]之间),都会满足r < 1——相当于“公司q特别认可的候选人,直接录用,不用再抽签”;
- 当q/p < 1时,min(1, q/p)=q/p(比如q/p=0.5):此时只有当r < 0.5时才会留下x——相当于“公司q不太认可的候选人,只有50%的概率录用”,概率高低完全由“q比p认可多少”决定。
第四步:随机数r的作用——“用概率公平筛选,避免一刀切”
为什么要引入一个[0,1]的随机数r?因为我们不想“一刀切”地拒绝所有q/p < 1的x,而是想“按q的认可度概率筛选”,让最终留下的样本整体符合q的分布。
比如q/p=0.5:如果直接拒绝所有x,会漏掉一些q虽然不那么认可、但仍有价值的x;用r < 0.5来筛选,既能让“q不太认可的x”有合理的入选概率,又不会让它们占比过高——长期下来,所有留下的x的整体分布,就会无限接近q的分布,而不是p的分布。
一句话总结:这个公式在干一件什么事?
用初筛员p快速找出“候选名单”,再用“公司q对候选人的认可度 ÷ p对候选人的认可度”这个比值,决定候选人的“录用概率”——q越认可、p越不认可的,录用概率越高;q越不认可、p越认可的,录用概率越低。最终通过这种“修正”,确保录用的人符合公司q的需求,而不是被p的初筛标准带偏。
更多推荐
所有评论(0)