batch size、sequence length 对显存的非线性影响
摘要:大模型微调中常见的OOM问题往往源于对显存消耗的线性误判。batch size和sequence length并非独立变量,而是相互放大的乘法因子。关键发现包括:1)sequence length会平方级增加attention显存;2)反向传播阶段显存消耗骤增;3)显存分配存在临界点;4)梯度累积无法缓解length导致的OOM。工程建议:优先缩短sequence length而非减小bat
几乎所有 OOM,都是“我以为还能再加一点”
如果你做过大模型微调,你一定经历过这种时刻:
- batch size 调小一点 → 能跑
- sequence length 加一点 → 还能跑
- 两个一起微调 → 显存直接炸
你看着监控面板,心里非常困惑:
“不对啊,我是算过的。”
这正是问题的关键。
你算的,是线性的;
而显存消耗,从来不是。

工程师心里计算的显存 vs 实际显存曲线对比
一个必须先说清楚的结论
在正式展开之前,我先把这篇文章最重要的一句话写出来:
batch size 和 sequence length,
不是两个独立的显存旋钮,
而是一个相互放大的乘法因子。
如果你还在用:
- “batch 翻倍,显存翻倍”
- “长度翻倍,显存翻倍”
这样的直觉来理解显存,
那你几乎一定会被 OOM 教育。
第一层误解:把显存消耗理解成“参数规模问题”
很多人一说显存,第一反应是:
- 模型多大
- 参数多少
- 是不是该用 LoRA / QLoRA
这些当然重要,
但它们只决定了显存的底座。
真正让你在训练时反复爆显存的,往往不是参数,而是:
中间态(activations)。
而 batch size 和 sequence length,
正是中间态的最大放大器。
第二层:为什么 sequence length 比你想象中“更贵”
很多人会觉得:
“sequence length 只是多几个 token,
显存应该线性增加吧?”
这是一个非常危险的直觉。
一个必须面对的事实
在 Transformer 里,sequence length 影响的不只是:
- embedding
- attention 输入
而是:
- attention score
- KV cache
- 每一层的中间激活
尤其是 self-attention,
它的计算和存储复杂度是:
O(L²)
也就是说:
- length 从 1024 → 2048
- token 数翻倍
- attention 相关显存,可能直接 ×4
这就是为什么你“只是把 max_length 调大了一点”,
显存却突然不讲道理。

sequence length ↑ → attention 显存平方增长
第三层:batch size 为什么会“乘上” sequence length
单看 batch size,好像也很直观:
- batch ×2 → 数据 ×2
但问题在于:
batch size 决定的是:
同一时间,有多少条序列在走完整前向和反向。
于是显存里同时存在的,是:
batch_size × sequence_length × hidden_dim × layer_count
这不是加法,是堆叠。
当你把 batch size 和 sequence length 同时往上拉时,
你做的事情其实是:
让显存同时承载更多、更长、而且还没释放的中间态。
第四层:非线性真正出现的地方——反向传播
如果只是前向,其实很多时候还能勉强扛住。
真正让显存爆炸的,是:
反向传播阶段。
原因很简单:
- 前向:可以边算边丢
- 反向:必须留住中间态
这意味着:
- batch 越大 → 需要保留的中间结果越多
- length 越长 → 每一层要保存的激活越重
于是显存曲线会出现一个非常典型的形态:
前向看着还行,
反向直接炸。

前向 vs 反向 显存占用对比
第五层:为什么“只加一点点”,却跨过了临界点
这是最让人崩溃的地方。
你可能经历过:
- batch=2,length=2048,OK
- batch=3,length=2048,OOM
你会觉得:
“就多了一条样本,怎么就炸了?”
原因在于:
显存不是连续可用的,
而是存在碎片和临界点的。
当你跨过某个阈值:
- CUDA 需要分配一整块新的 buffer
- allocator 找不到足够连续空间
- 于是直接失败
这就是为什么:
显存不是“慢慢用完”的,
而是“突然不够用”的。
第六层:梯度累积,为什么没你想得那么“省”
很多人会说:
“batch 太大?那我用 gradient accumulation。”
这确实能缓解一部分问题,
但它并不是免费午餐。
因为:
- accumulation 并不会减少单步的 activation 显存
- 它只是减少了一次 forward/backward 中的 batch
如果你的 OOM 来自:
- sequence length 太长
- attention 中间态太重
那梯度累积几乎救不了你。
这也是为什么有些人会困惑:
“我 batch 已经很小了,为什么还 OOM?”
答案往往是:
真正压垮显存的,是 length,不是 batch。
第七层:评估阶段为什么反而更容易炸显存
这是一个很多人没想到的坑。
在评估时,你可能会:
- 关掉 dropout
- 不算 loss
- 以为显存会更省
但实际情况是:
- 推理 batch 往往更大
- sequence length 往往更长
- KV cache 占用持续存在
于是你会看到:
训练能跑,评估反而 OOM。
这不是 bug,
而是你在评估阶段:
把 batch × length 推到了另一个非线性区域。
一个非常真实的“显存误判路径”
我算过参数显存 → 应该够
我减过 batch → 应该稳
我只加了点 length → 应该没事
OOM
注意:
每一步判断,单独看都“合理”。
错的是:
你在用线性思维,面对非线性系统。
那工程上该怎么“正确理解” batch 和 length?
不是给你一个公式,
而是给你一个更安全的判断方式:
sequence length 决定了“单样本的重量”,
batch size 决定了“同时搬多少个”。
当你不知道哪里该省的时候,优先问:
- 单条样本,是不是太重了?
- attention 的 L² 是否已经不可接受?
很多时候:
减 length,比减 batch 更有效。
一个非常实用的显存自检问题
在你准备调 batch 或 length 之前,可以问自己一句话:
如果显存炸了,
我更希望模型“少看几条”,
还是“每条看短一点”?
如果你无法回答,
说明你对当前显存结构还不够清楚。
很多团队在 batch size 和 sequence length 上反复试探显存上限,本质问题不是参数没算清,而是缺乏对“中间态显存结构”的直观感知。用LLaMA-Factory online观察不同 batch / length 配置下的训练行为,更容易理解:是哪一部分在非线性放大,而不是盲目试错。
总结:显存不是被“用完”的,而是被“触发”的
我用一句话,把这篇文章彻底收住:
batch size 和 sequence length
并不是慢慢吃掉显存的,
而是在某个点上,
一起把你推下悬崖。
当你开始:
- 把显存理解成结构问题
- 把 length 当成一等公民
- 放弃“线性估算”的安全感
你才真正开始工程化地调显存。
更多推荐


所有评论(0)