投机解码、异步训练
它想用 sparse attention 加速 RL rollout,但发现 naive sparse rollout 会把 RL 训练搞崩;所以提出 DISTILLSPARSE,用 LoRA distillation + reward-aware oversampling/filtering 来让 sparse rollout 接近 dense policy。它是 ICLR 2026 SPOT
SPARSE ATTENTION FOR EFFICIENT LLM REINFORCEMENT LEARNING
它想用 sparse attention 加速 RL rollout,但发现 naive sparse rollout 会把 RL 训练搞崩;所以提出 DISTILLSPARSE,用 LoRA distillation + reward-aware oversampling/filtering 来让 sparse rollout 接近 dense policy。
它是 ICLR 2026 SPOT workshop paper,作者来自 CMU、Intel、Amazon 等。OpenReview 页面显示发布时间是 2026-03-03,最后修改 2026-03-23。
它要解决什么问题?
LLM RL 训练慢,尤其慢在 online rollout generation。
比如 GRPO/PPO 做数学 reasoning 时,一个 prompt 要采很多条长 CoT rollout。rollout 是自回归 decode,一 token 一 token 生成,长输出 + 多样本会让 inference 侧成为瓶颈。论文说 RL 可扩展性越来越受 online rollout 成本限制,尤其是 long chain-of-thought generation 和 large-batch sampling。
所以作者自然想到:
能不能 rollout 时不用 dense attention,
而用 sparse attention,
降低每个 token 的 attention cost,
提高 rollout throughput?
这听起来很合理。
但是论文发现:直接用 sparse attention rollout 很危险。
为什么 naive sparse rollout 会崩?
这里是论文最重要的 insight。
RL 训练里有两个模型分布:
actor / rollout policy:
负责生成 trajectory
policy / training model:
负责做 PPO / GRPO 更新
如果 rollout 用 sparse attention,而 trainer 用 dense attention,那么生成数据的分布其实来自:
μ_sparse
但训练更新时算 logprob / ratio 的模型是:
π_dense
这就产生了 actor-policy distribution mismatch。
更麻烦的是,这个 mismatch 不是普通异步 RL 的 staleness。异步 RL 里 old policy 和 current policy 的差距通常是“几个 update 版本”的差距;但 sparse attention 和 dense attention 的差距是 attention 近似误差导致的结构性分布差异。论文强调,这种 sparse-to-dense KL divergence 可能比普通 staleness 问题大很多,而且会随着 generation length 增大而累积。
在 RL 训练里,一条 rollout 通常经历两段:
1. rollout worker 生成: prompt x -> 生成 response y 2. trainer worker 训练: 拿到 x, y,再对这条序列做一次 forward, 计算 logprob / ratio / loss / gradient如果 rollout worker 用的是 sparse attention,那么这条 response 实际是从:
μ_sparse(y | x)采样出来的。
但 trainer worker 拿到这条 response 后,用 dense attention policy 重新算 logprob:
π_dense(y_t | x, y_<t)所以训练时看到的是:
“这条 y 在 dense policy 下的概率是多少?”而不是:
“这条 y 当初到底是怎么被 sparse policy 采出来的?”于是就出现了分布差别。
importance sampling 为什么救不了?
正常我们可能会想:
既然 rollout policy 和 training policy 不一样,
那用 importance sampling correction 不就好了?
论文说不够。
理想的 sequence-level importance ratio 是:
π_dense(y | x) / μ_sparse(y | x)
但长序列下这个东西数值非常不稳定,所以实际一般做 token-level approximation:
π_dense(y_t | x, y_<t) / μ_sparse(y_t | x, y_<t)
问题是 token-level IS correction 只有在 ratio 接近 1 的时候比较可靠。可是 sparse attention 下,这个 ratio 会随着生成长度增加显著偏离 1。论文在方法部分明确说,token-level approximation works well only when probability ratio is close to 1,而 sparse rollout 的 ratio 会随 generation length 增大明显偏离 1。
所以简单的 TIS / masking / rejection sampling 都会遇到问题:
ratio 不稳定 → correction 方差大
截断太狠 → bias 大
rejection 太多 → sample efficiency 低
这就是它为什么不满足于“rollout 用 sparse,trainer 用 dense,然后加个 IS correction”。
DISTILLSPARSE 怎么做?
它有两个核心组件。
第一:LoRA-based on-policy distillation
它维护两个 policy:
dense policy:
full attention,主训练模型 θ
sparse policy:
sparse attention + LoRA adapter ϕ,用于 rollout
每轮流程大概是:
1. 用 sparse policy μ_{θ,ϕ}^{sparse} 生成 rollout
2. 对这些 rollout 算 reward / advantage
3. 用 dense policy 重新算 token logprob
4. 用 dense logprob 做 PPO/GRPO 更新 dense base model θ
5. 再用 sparse policy 算 logprob
6. 冻结 θ,只更新 LoRA ϕ
7. 让 sparse policy 通过 distillation / KL loss 贴近 dense policy
8. rollout engine 刷新 θ 和 LoRA ϕ
1. rollout worker 用 sparse attention + θ + LoRA ϕ 生成 rollout
2. 对 rollout 算 reward / advantage
3. trainer 用 dense attention + θ 对同一条 rollout 重新 forward,
得到 dense logprob
4. 用 dense logprob 做 PPO/GRPO,
更新 dense base model θ
5. 冻结更新后的 θ
6. 用 sparse attention + θ + LoRA ϕ 对同一批 rollout 再 forward,
得到 sparse logprob / sparse logits
7. 用 dense policy 的 logits/logprobs 作为 teacher,
更新 LoRA ϕ,让 sparse policy 贴近 dense policy
8. rollout engine 同步新的 θ 和新的 LoRA ϕ
所以 LoRA 的使用位置有两个:
rollout 阶段:
用 θ + ϕ + sparse attention 生成数据
distillation 阶段:
更新 ϕ,让 sparse policy 继续贴近 dense policy
它不是只在第 6 步更新时才用,而是 下一轮 rollout 时就会用这个更新后的 LoRA。
论文 Algorithm 1 就是这个流程:先 sparse rollout,然后 dense forward recompute log-probs,更新 base model;再 freeze base model,更新 LoRA,让 sparse rollout policy 对齐 dense policy。
这里很妙的一点是:它不额外用 dense policy 生成新 trajectory。因为如果 dense policy 也 rollout,那加速意义就没了。它只复用 sparse rollout 出来的 trajectories,在这些 token 上同时算 dense logprob 和 sparse logprob,做 on-policy distillation。
所以它的思想不是:
用 dense 生成数据教 sparse
而是:
用 sparse 生成数据;
dense 在这些数据上给 soft target / logprob;
LoRA 让 sparse policy 向 dense policy 靠拢。
这样 sparse actor 不会随着 RL 迭代越走越偏。
第 6、7 步确实额外增加了一次 forward
原本普通 dense RL 大概是:
rollout:
dense/sampling autoregressive decode
trainer:
dense forward 重新算 logprob
dense backward 更新 θ
DISTILLSPARSE 变成:
rollout:
sparse + LoRA autoregressive decode
trainer:
dense forward 重新算 logprob
dense backward 更新 θ
distillation:
sparse + LoRA forward
LoRA backward 更新 ϕ
所以它确实多了:
一次 sparse forward + 一次 LoRA backward
这个不是免费的。
为什么它还能加速?
因为最贵的部分通常不是这次 distillation forward,而是 rollout decode。
rollout decode 是自回归的:
生成第 1 个 token
生成第 2 个 token
生成第 3 个 token
...
生成第 T 个 token
每一步都要跑一次模型 forward,而且长 CoT 可能几千到上万 token。这个过程很慢。
而第 6 步的 sparse forward 是 trainer 上对完整序列做一次 teacher-forcing forward:
一次性输入完整 x + y
并行计算所有 token 的 logits/logprob
它不是自回归一个 token 一个 token 慢慢生成,所以吞吐会高很多。
也就是说:
rollout decode:
T 次 sequential forward,慢
distillation sparse forward:
1 次 batched teacher-forcing forward,快很多
所以虽然多了一次 forward,但它不是同等级别的开销。
第二:oversampling + reward-aware filtering
对于 extreme sparsity / long generation,光 LoRA distillation 还不一定够稳。
所以它做了一个简单但有效的策略:
原来每个 prompt 需要 n 条 samples
现在先生成 M 条 samples,M > n
然后选 reward 最高的 n 条进入训练
论文把这个叫 parallel sampling and filtering。它的经验观察是:更接近 dense policy 的 trajectory 往往 reward 更高。虽然直接用 dense logprob 去挑最接近 dense 的样本很贵,但 reward 是 RL 本来就要算的,所以用 reward-aware filtering 当 proxy。
换句话说,它用 sparse attention 省下来的 rollout 成本,换成更多 sample width:
dense rollout:
贵,所以每个 prompt 采 n 条
sparse rollout:
便宜,所以每个 prompt 采 M 条
然后只保留 reward top-n
这既提高样本质量,也降低 sparse-to-dense mismatch 的风险。
它的训练目标和普通 GRPO/PPO 的关系
这篇论文不是重新发明 RL objective。它基本还是 PPO / GRPO-style objective。
不同的是,rollout 来自 sparse policy,但 policy update 用 dense logprob。
你可以理解成:
sample:
y ~ π_sparse,LoRA
train dense model:
用 π_dense(y_t | x, y_<t) 算 PPO/GRPO loss
train sparse LoRA:
用 KL / distillation loss 让 π_sparse,LoRA 接近 π_dense
所以它的 contribution 不是新的 advantage 估计,也不是新的 PPO clip,而是:
如何让 sparse rollout policy 在 RL 迭代中不偏离 dense training policy。
https://arxiv.org/html/2603.24840v1
Prune as You Generate: Online Rollout Pruning for Faster and Better RLVR
https://arxiv.org/pdf/2604.26779
NeMo-RL speculative decoding
https://arxiv.org/pdf/2601.09083
SRT: Accelerating Reinforcement Learning via Speculative Rollout with Tree-Structured Cache
更多推荐


所有评论(0)