几乎所有 OOM,都是“我以为还能再加一点”

如果你做过大模型微调,你一定经历过这种时刻:

  • batch size 调小一点 → 能跑
  • sequence length 加一点 → 还能跑
  • 两个一起微调 → 显存直接炸

你看着监控面板,心里非常困惑:

“不对啊,我是算过的。”

这正是问题的关键。

你算的,是线性的;
而显存消耗,从来不是。

在这里插入图片描述

工程师心里计算的显存 vs 实际显存曲线对比

一个必须先说清楚的结论

在正式展开之前,我先把这篇文章最重要的一句话写出来:

batch size 和 sequence length,
不是两个独立的显存旋钮,
而是一个相互放大的乘法因子。

如果你还在用:

  • “batch 翻倍,显存翻倍”
  • “长度翻倍,显存翻倍”

这样的直觉来理解显存,
那你几乎一定会被 OOM 教育。

第一层误解:把显存消耗理解成“参数规模问题”

很多人一说显存,第一反应是:

  • 模型多大
  • 参数多少
  • 是不是该用 LoRA / QLoRA

这些当然重要,
但它们只决定了显存的底座

真正让你在训练时反复爆显存的,往往不是参数,而是:

中间态(activations)。

而 batch size 和 sequence length,
正是中间态的最大放大器。

第二层:为什么 sequence length 比你想象中“更贵”

很多人会觉得:

“sequence length 只是多几个 token,
显存应该线性增加吧?”

这是一个非常危险的直觉。

一个必须面对的事实

在 Transformer 里,sequence length 影响的不只是:

  • embedding
  • attention 输入

而是:

  • attention score
  • KV cache
  • 每一层的中间激活

尤其是 self-attention
它的计算和存储复杂度是:

O(L²)

也就是说:

  • length 从 1024 → 2048
  • token 数翻倍
  • attention 相关显存,可能直接 ×4

这就是为什么你“只是把 max_length 调大了一点”,
显存却突然不讲道理。

在这里插入图片描述

sequence length ↑ → attention 显存平方增长

第三层:batch size 为什么会“乘上” sequence length

单看 batch size,好像也很直观:

  • batch ×2 → 数据 ×2

但问题在于:

batch size 决定的是:
同一时间,有多少条序列在走完整前向和反向。

于是显存里同时存在的,是:

batch_size × sequence_length × hidden_dim × layer_count

这不是加法,是堆叠

当你把 batch size 和 sequence length 同时往上拉时,
你做的事情其实是:

让显存同时承载更多、更长、而且还没释放的中间态。

第四层:非线性真正出现的地方——反向传播

如果只是前向,其实很多时候还能勉强扛住。

真正让显存爆炸的,是:

反向传播阶段。

原因很简单:

  • 前向:可以边算边丢
  • 反向:必须留住中间态

这意味着:

  • batch 越大 → 需要保留的中间结果越多
  • length 越长 → 每一层要保存的激活越重

于是显存曲线会出现一个非常典型的形态:

前向看着还行,
反向直接炸。

在这里插入图片描述

前向 vs 反向 显存占用对比

第五层:为什么“只加一点点”,却跨过了临界点

这是最让人崩溃的地方。

你可能经历过:

  • batch=2,length=2048,OK
  • batch=3,length=2048,OOM

你会觉得:

“就多了一条样本,怎么就炸了?”

原因在于:

显存不是连续可用的,
而是存在碎片和临界点的。

当你跨过某个阈值:

  • CUDA 需要分配一整块新的 buffer
  • allocator 找不到足够连续空间
  • 于是直接失败

这就是为什么:

显存不是“慢慢用完”的,
而是“突然不够用”的。

第六层:梯度累积,为什么没你想得那么“省”

很多人会说:

“batch 太大?那我用 gradient accumulation。”

这确实能缓解一部分问题,
但它并不是免费午餐。

因为:

  • accumulation 并不会减少单步的 activation 显存
  • 它只是减少了一次 forward/backward 中的 batch

如果你的 OOM 来自:

  • sequence length 太长
  • attention 中间态太重

那梯度累积几乎救不了你

这也是为什么有些人会困惑:

“我 batch 已经很小了,为什么还 OOM?”

答案往往是:

真正压垮显存的,是 length,不是 batch。

第七层:评估阶段为什么反而更容易炸显存

这是一个很多人没想到的坑。

在评估时,你可能会:

  • 关掉 dropout
  • 不算 loss
  • 以为显存会更省

但实际情况是:

  • 推理 batch 往往更大
  • sequence length 往往更长
  • KV cache 占用持续存在

于是你会看到:

训练能跑,评估反而 OOM。

这不是 bug,
而是你在评估阶段:

把 batch × length 推到了另一个非线性区域。

一个非常真实的“显存误判路径”

我算过参数显存 → 应该够
我减过 batch → 应该稳
我只加了点 length → 应该没事
OOM

注意:
每一步判断,单独看都“合理”。

错的是:
你在用线性思维,面对非线性系统。

那工程上该怎么“正确理解” batch 和 length?

不是给你一个公式,
而是给你一个更安全的判断方式

sequence length 决定了“单样本的重量”,
batch size 决定了“同时搬多少个”。

当你不知道哪里该省的时候,优先问:

  • 单条样本,是不是太重了?
  • attention 的 L² 是否已经不可接受?

很多时候:

减 length,比减 batch 更有效。

一个非常实用的显存自检问题

在你准备调 batch 或 length 之前,可以问自己一句话:

如果显存炸了,
我更希望模型“少看几条”,
还是“每条看短一点”?

如果你无法回答,
说明你对当前显存结构还不够清楚。

很多团队在 batch size 和 sequence length 上反复试探显存上限,本质问题不是参数没算清,而是缺乏对“中间态显存结构”的直观感知。用LLaMA-Factory online观察不同 batch / length 配置下的训练行为,更容易理解:是哪一部分在非线性放大,而不是盲目试错。

总结:显存不是被“用完”的,而是被“触发”的

我用一句话,把这篇文章彻底收住:

batch size 和 sequence length
并不是慢慢吃掉显存的,
而是在某个点上,
一起把你推下悬崖。

当你开始:

  • 把显存理解成结构问题
  • 把 length 当成一等公民
  • 放弃“线性估算”的安全感

你才真正开始工程化地调显存

Logo

有“AI”的1024 = 2048,欢迎大家加入2048 AI社区

更多推荐