SPARSE ATTENTION FOR EFFICIENT LLM REINFORCEMENT LEARNING

它想用 sparse attention 加速 RL rollout,但发现 naive sparse rollout 会把 RL 训练搞崩;所以提出 DISTILLSPARSE,用 LoRA distillation + reward-aware oversampling/filtering 来让 sparse rollout 接近 dense policy。

它是 ICLR 2026 SPOT workshop paper,作者来自 CMU、Intel、Amazon 等。OpenReview 页面显示发布时间是 2026-03-03,最后修改 2026-03-23。

它要解决什么问题?

LLM RL 训练慢,尤其慢在 online rollout generation

比如 GRPO/PPO 做数学 reasoning 时,一个 prompt 要采很多条长 CoT rollout。rollout 是自回归 decode,一 token 一 token 生成,长输出 + 多样本会让 inference 侧成为瓶颈。论文说 RL 可扩展性越来越受 online rollout 成本限制,尤其是 long chain-of-thought generation 和 large-batch sampling。

所以作者自然想到:

能不能 rollout 时不用 dense attention,
而用 sparse attention,
降低每个 token 的 attention cost,
提高 rollout throughput?

这听起来很合理。

但是论文发现:直接用 sparse attention rollout 很危险。

为什么 naive sparse rollout 会崩?

这里是论文最重要的 insight。

RL 训练里有两个模型分布:

actor / rollout policy:
  负责生成 trajectory

policy / training model:
  负责做 PPO / GRPO 更新

如果 rollout 用 sparse attention,而 trainer 用 dense attention,那么生成数据的分布其实来自:

μ_sparse

但训练更新时算 logprob / ratio 的模型是:

π_dense

这就产生了 actor-policy distribution mismatch

更麻烦的是,这个 mismatch 不是普通异步 RL 的 staleness。异步 RL 里 old policy 和 current policy 的差距通常是“几个 update 版本”的差距;但 sparse attention 和 dense attention 的差距是 attention 近似误差导致的结构性分布差异。论文强调,这种 sparse-to-dense KL divergence 可能比普通 staleness 问题大很多,而且会随着 generation length 增大而累积。

在 RL 训练里,一条 rollout 通常经历两段:

1. rollout worker 生成:
   prompt x -> 生成 response y

2. trainer worker 训练:
   拿到 x, y,再对这条序列做一次 forward,
   计算 logprob / ratio / loss / gradient

如果 rollout worker 用的是 sparse attention,那么这条 response 实际是从:

μ_sparse(y | x)

采样出来的。

但 trainer worker 拿到这条 response 后,用 dense attention policy 重新算 logprob:

π_dense(y_t | x, y_<t)

所以训练时看到的是:

“这条 y 在 dense policy 下的概率是多少?”

而不是:

“这条 y 当初到底是怎么被 sparse policy 采出来的?”

于是就出现了分布差别。

importance sampling 为什么救不了?

正常我们可能会想:

既然 rollout policy 和 training policy 不一样,
那用 importance sampling correction 不就好了?

论文说不够。

理想的 sequence-level importance ratio 是:

π_dense(y | x) / μ_sparse(y | x)

但长序列下这个东西数值非常不稳定,所以实际一般做 token-level approximation:

π_dense(y_t | x, y_<t) / μ_sparse(y_t | x, y_<t)

问题是 token-level IS correction 只有在 ratio 接近 1 的时候比较可靠。可是 sparse attention 下,这个 ratio 会随着生成长度增加显著偏离 1。论文在方法部分明确说,token-level approximation works well only when probability ratio is close to 1,而 sparse rollout 的 ratio 会随 generation length 增大明显偏离 1。

所以简单的 TIS / masking / rejection sampling 都会遇到问题:

ratio 不稳定 → correction 方差大
截断太狠 → bias 大
rejection 太多 → sample efficiency 低

这就是它为什么不满足于“rollout 用 sparse,trainer 用 dense,然后加个 IS correction”。

DISTILLSPARSE 怎么做?

它有两个核心组件。

第一:LoRA-based on-policy distillation

它维护两个 policy:

dense policy:
  full attention,主训练模型 θ

sparse policy:
  sparse attention + LoRA adapter ϕ,用于 rollout
每轮流程大概是:
1. 用 sparse policy μ_{θ,ϕ}^{sparse} 生成 rollout
2. 对这些 rollout 算 reward / advantage 
3. 用 dense policy 重新算 token logprob 
4. 用 dense logprob 做 PPO/GRPO 更新 dense base model θ 
5. 再用 sparse policy 算 logprob 
6. 冻结 θ,只更新 LoRA ϕ 
7. 让 sparse policy 通过 distillation / KL loss 贴近 dense policy 
8. rollout engine 刷新 θ 和 LoRA ϕ
1. rollout worker 用 sparse attention + θ + LoRA ϕ 生成 rollout

2. 对 rollout 算 reward / advantage

3. trainer 用 dense attention + θ 对同一条 rollout 重新 forward,
   得到 dense logprob

4. 用 dense logprob 做 PPO/GRPO,
   更新 dense base model θ

5. 冻结更新后的 θ

6. 用 sparse attention + θ + LoRA ϕ 对同一批 rollout 再 forward,
   得到 sparse logprob / sparse logits

7. 用 dense policy 的 logits/logprobs 作为 teacher,
   更新 LoRA ϕ,让 sparse policy 贴近 dense policy

8. rollout engine 同步新的 θ 和新的 LoRA ϕ

所以 LoRA 的使用位置有两个:

rollout 阶段:
  用 θ + ϕ + sparse attention 生成数据

distillation 阶段:
  更新 ϕ,让 sparse policy 继续贴近 dense policy

它不是只在第 6 步更新时才用,而是 下一轮 rollout 时就会用这个更新后的 LoRA

论文 Algorithm 1 就是这个流程:先 sparse rollout,然后 dense forward recompute log-probs,更新 base model;再 freeze base model,更新 LoRA,让 sparse rollout policy 对齐 dense policy。

这里很妙的一点是:它不额外用 dense policy 生成新 trajectory。因为如果 dense policy 也 rollout,那加速意义就没了。它只复用 sparse rollout 出来的 trajectories,在这些 token 上同时算 dense logprob 和 sparse logprob,做 on-policy distillation。

所以它的思想不是:

用 dense 生成数据教 sparse

而是:

用 sparse 生成数据;
dense 在这些数据上给 soft target / logprob;
LoRA 让 sparse policy 向 dense policy 靠拢。

这样 sparse actor 不会随着 RL 迭代越走越偏。

第 6、7 步确实额外增加了一次 forward

原本普通 dense RL 大概是:


rollout:
  dense/sampling autoregressive decode

trainer:
  dense forward 重新算 logprob
  dense backward 更新 θ

DISTILLSPARSE 变成:

rollout:
  sparse + LoRA autoregressive decode

trainer:
  dense forward 重新算 logprob
  dense backward 更新 θ

distillation:
  sparse + LoRA forward
  LoRA backward 更新 ϕ

所以它确实多了:

一次 sparse forward + 一次 LoRA backward

这个不是免费的。


为什么它还能加速?

因为最贵的部分通常不是这次 distillation forward,而是 rollout decode

rollout decode 是自回归的:

生成第 1 个 token
生成第 2 个 token
生成第 3 个 token
...
生成第 T 个 token

每一步都要跑一次模型 forward,而且长 CoT 可能几千到上万 token。这个过程很慢。

而第 6 步的 sparse forward 是 trainer 上对完整序列做一次 teacher-forcing forward:

一次性输入完整 x + y
并行计算所有 token 的 logits/logprob

它不是自回归一个 token 一个 token 慢慢生成,所以吞吐会高很多。

也就是说:

rollout decode:
  T 次 sequential forward,慢

distillation sparse forward:
  1 次 batched teacher-forcing forward,快很多

所以虽然多了一次 forward,但它不是同等级别的开销。


第二:oversampling + reward-aware filtering

对于 extreme sparsity / long generation,光 LoRA distillation 还不一定够稳。

所以它做了一个简单但有效的策略:

原来每个 prompt 需要 n 条 samples
现在先生成 M 条 samples,M > n
然后选 reward 最高的 n 条进入训练

论文把这个叫 parallel sampling and filtering。它的经验观察是:更接近 dense policy 的 trajectory 往往 reward 更高。虽然直接用 dense logprob 去挑最接近 dense 的样本很贵,但 reward 是 RL 本来就要算的,所以用 reward-aware filtering 当 proxy。

换句话说,它用 sparse attention 省下来的 rollout 成本,换成更多 sample width:

dense rollout:
  贵,所以每个 prompt 采 n 条

sparse rollout:
  便宜,所以每个 prompt 采 M 条
  然后只保留 reward top-n

这既提高样本质量,也降低 sparse-to-dense mismatch 的风险。

它的训练目标和普通 GRPO/PPO 的关系

这篇论文不是重新发明 RL objective。它基本还是 PPO / GRPO-style objective。

不同的是,rollout 来自 sparse policy,但 policy update 用 dense logprob。

你可以理解成:

sample:
  y ~ π_sparse,LoRA

train dense model:
  用 π_dense(y_t | x, y_<t) 算 PPO/GRPO loss

train sparse LoRA:
  用 KL / distillation loss 让 π_sparse,LoRA 接近 π_dense

所以它的 contribution 不是新的 advantage 估计,也不是新的 PPO clip,而是:

如何让 sparse rollout policy 在 RL 迭代中不偏离 dense training policy。

https://arxiv.org/html/2603.24840v1

Prune as You Generate: Online Rollout Pruning for Faster and Better RLVR

https://arxiv.org/pdf/2604.26779

NeMo-RL speculative decoding

https://arxiv.org/pdf/2601.09083

SRT: Accelerating Reinforcement Learning via Speculative Rollout with Tree-Structured Cache

Logo

有“AI”的1024 = 2048,欢迎大家加入2048 AI社区

更多推荐