FlashAttention 原理之 softmax 分块计算

在这里插入图片描述

标准 Softmax

Softmax 函数(也称为归一化指数函数)是一个将向量转换成概率分布的函数。对于输入向量 x,softmax 函数将其转换为一个概率分布向量,其中每个元素的值在 (0,1) 之间,且所有元素之和为 1。

softmax(xi)=exi∑jexj softmax(x_i) = \frac{e^{x_i}}{\sum_j e^{x_j}} softmax(xi)=jexjexi

归一化 Softmax

m(x):=max⁡ixi m(x) := \max_i x_i m(x):=imaxxi

这里,x=[x1,x2,…,xB]x = [x_1, x_2, \dots, x_B]x=[x1,x2,,xB] 表示一个有 B 个分量的向量(例如,模型输出的对各类的打分)。m(x)m(x)m(x) 则是向量 xxx 中所有分量的最大值。

f(x):=[e x1−m(x), e x2−m(x), …, e xB−m(x)] f(x) := \bigl[e^{\,x_1 - m(x)},\, e^{\,x_2 - m(x)},\,\dots,\, e^{\,x_B - m(x)}\bigr] f(x):=[ex1m(x),ex2m(x),,exBm(x)]
这里做了一个“减最大值”的操作,即把每个 xix_ixi 都减去整个向量的最大分量 m(x)m(x)m(x),然后取指数。这样做的好处是数值更稳定:当 xix_ixi 很大时,直接算 exie^{x_i}exi 容易导致溢出;但减去最大值以后,指数部分变为 xi−m(x)x_i - m(x)xim(x)(一个相对较小的或非正的数),从而避免数值爆炸。

ℓ(x):=∑if(x)i \ell(x) := \sum_i f(x)_i (x):=if(x)i
这里把向量 f(x)f(x)f(x) 的各个分量加起来得到标量 ℓ(x)\ell(x)(x)

softmax(x):=f(x)ℓ(x) \text{softmax}(x) := \frac{f(x)}{\ell(x)} softmax(x):=(x)f(x)
把每个分量 f(x)if(x)_if(x)i 除以 ℓ(x)\ell(x)(x) 后,就得到了标准的 Softmax 输出向量。

softmax(x)i=e xi−m(x)∑je xj−m(x) . \text{softmax}(x)_i = \frac{e^{\,x_i - m(x)}}{\sum_j e^{\,x_j - m(x)}} \,. softmax(x)i=jexjm(x)exim(x).

由于每一项都经过指数函数且被总和归一化,它满足所有分量都非负且所有分量之和为 1,因此是一个有效的概率分布。

分块 Softmax

假设我们有两个同维度向量
x(1)和x(2)∈RB\mathbf{x}^{(1)} 和 \mathbf{x}^{(2)} \in \mathbb{R}^Bx(1)x(2)RB,把它们拼接(concatenate)成
x=[x(1), x(2)]∈R2B. \mathbf{x} = \bigl[\mathbf{x}^{(1)},\,\mathbf{x}^{(2)}\bigr] \in \mathbb{R}^{2B}. x=[x(1),x(2)]R2B.

下面的公式说明,如何在不重复完整计算的情况下,用“各自的部分计算结果”组合成拼接后向量的 Softmax。先给出它的步骤,再解释其意义和好处:

最大值 m(x) 的分块计算

定义单个向量的最大值

对于 x(1)∈RB\mathbf{x}^{(1)}\in \mathbb{R}^Bx(1)RB,我们先定义
m(x(1))  =  max⁡i(xi(1)),m(x(2))  =  max⁡i(xi(2)). m\bigl(\mathbf{x}^{(1)}\bigr) \;=\; \max_i \Bigl(\mathbf{x}^{(1)}_i\Bigr), \quad m\bigl(\mathbf{x}^{(2)}\bigr) \;=\; \max_i \Bigl(\mathbf{x}^{(2)}_i\Bigr). m(x(1))=imax(xi(1)),m(x(2))=imax(xi(2)).

定义拼接向量的最大值

由于 x\mathbf{x}x 是把 x(1)\mathbf{x}^{(1)}x(1)x(2)\mathbf{x}^{(2)}x(2) 拼到一起,那么
m(x)  =  m([x(1), x(2)])  =  max⁡(m(x(1)),  m(x(2))). m(\mathbf{x}) \;=\; m\bigl([\mathbf{x}^{(1)},\, \mathbf{x}^{(2)}]\bigr) \;=\; \max\bigl(m(\mathbf{x}^{(1)}), \;m(\mathbf{x}^{(2)})\bigr). m(x)=m([x(1),x(2)])=max(m(x(1)),m(x(2))).

这样就不需要对拼接后的 x\mathbf{x}x 再扫描一次去找最大值,而是只要比较两个子向量各自的最大值即可。

“指数向量” f(x) 的分块计算

Recall
f(x)  =  [e x1−m(x),  e x2−m(x),  …,  e x2B−m(x)]. f(\mathbf{x}) \;=\; \Bigl[e^{\,x_1 - m(\mathbf{x})},\; e^{\,x_2 - m(\mathbf{x})},\;\dots,\; e^{\,x_{2B} - m(\mathbf{x})}\Bigr]. f(x)=[ex1m(x),ex2m(x),,ex2Bm(x)].

由于 x\mathbf{x}x 拆成了两块 x(1)\mathbf{x}^{(1)}x(1)x(2)\mathbf{x}^{(2)}x(2),我们分别对每块计算其对应的“指数向量”:
f(x(1))和f(x(2)). f\bigl(\mathbf{x}^{(1)}\bigr) \quad\text{和}\quad f\bigl(\mathbf{x}^{(2)}\bigr). f(x(1))f(x(2)).

然后拼起来即可。但要记住,每一块真正要减去的“中心化值”是整个 x\mathbf{x}x 的最大值 m(x)m(\mathbf{x})m(x)
,因此它们之间会出现一个额外的“补偿系数”:

f(x)  =  [e m(x(1))−m(x) ⋅f(x(1)),    e m(x(2))−m(x) ⋅f(x(2))]. f(\mathbf{x}) \;=\; \Bigl[ e^{\,m(\mathbf{x}^{(1)}) - m(\mathbf{x})}\,\cdot f\bigl(\mathbf{x}^{(1)}\bigr), \;\; e^{\,m(\mathbf{x}^{(2)}) - m(\mathbf{x})}\,\cdot f\bigl(\mathbf{x}^{(2)}\bigr) \Bigr]. f(x)=[em(x(1))m(x)f(x(1)),em(x(2))m(x)f(x(2))].

直观上看,如果某一块(比如 x(1)\mathbf{x}^{(1)}x(1))的最大元素是整个拼接向量的最大元素,那么它带来的指数因子就会是 e m(x(1))−m(x)=e0=1e^{\,m(\mathbf{x}^{(1)}) - m(\mathbf{x})} = e^0 = 1em(x(1))m(x)=e0=1。而另一块若不是最大的,就会额外乘上一个小于 1 的因子。

归一化项 ℓ(x)\ell(x)(x) 的分块计算

Softmax 要把向量的指数项归一化到和为 1,所以我们需要计算
ℓ(x)  =  ∑i=12Be xi−m(x). \ell(\mathbf{x}) \;=\; \sum_{i=1}^{2B} e^{\,x_i - m(\mathbf{x})}. (x)=i=12Bexim(x).

利用分块思想,可以分成两段求和,再用与上一步相同的补偿系数连接起来:

ℓ(x)  =  e m(x(1))−m(x) ℓ(x(1))⏟第1块贡献  +  e m(x(2))−m(x) ℓ(x(2))⏟第2块贡献, \ell(\mathbf{x}) \;=\; \underbrace{e^{\,m(\mathbf{x}^{(1)}) - m(\mathbf{x})}\,\ell\bigl(\mathbf{x}^{(1)}\bigr)}{\text{第1块贡献}} \;+\; \underbrace{e^{\,m(\mathbf{x}^{(2)}) - m(\mathbf{x})}\,\ell\bigl(\mathbf{x}^{(2)}\bigr)}{\text{第2块贡献}}, (x)= em(x(1))m(x)(x(1))1块贡献+ em(x(2))m(x)(x(2))2块贡献,

同理,也只需要各块自己内部的和,再用一个相对的尺度因子即可。

最大值 softmax(x)\mathrm{softmax}(x)softmax(x) 的分块形式

把上面得到的 f(x)f(\mathbf{x})f(x)ℓ(x)\ell(\mathbf{x})(x) 带入到
softmax(x)  =  f(x)ℓ(x), \mathrm{softmax}(\mathbf{x}) \;=\; \frac{f(\mathbf{x})}{\ell(\mathbf{x})}, softmax(x)=(x)f(x),

就得到在分块后的 Softmax 形式:

softmax(x)  =  [e m(x(1))−m(x) f(x(1)),    e m(x(2))−m(x) f(x(2))]e m(x(1))−m(x) ℓ(x(1))  +  e m(x(2))−m(x) ℓ(x(2)). \mathrm{softmax}(\mathbf{x}) \;=\; \frac{ \Bigl[ e^{\,m(\mathbf{x}^{(1)}) - m(\mathbf{x})} \,f\bigl(\mathbf{x}^{(1)}\bigr), \;\; e^{\,m(\mathbf{x}^{(2)}) - m(\mathbf{x})} \,f\bigl(\mathbf{x}^{(2)}\bigr) \Bigr] }{ e^{\,m(\mathbf{x}^{(1)}) - m(\mathbf{x})}\,\ell\bigl(\mathbf{x}^{(1)}\bigr) \;+\; e^{\,m(\mathbf{x}^{(2)}) - m(\mathbf{x})}\,\ell\bigl(\mathbf{x}^{(2)}\bigr) }. softmax(x)=em(x(1))m(x)(x(1))+em(x(2))m(x)(x(2))[em(x(1))m(x)f(x(1)),em(x(2))m(x)f(x(2))].

为什么这样做?

  1. 数值稳定性
    跟单向量计算 Softmax 类似,这里也要减去整段向量的最大值 m(x)m(\mathbf{x})m(x),以避免 eze^zez 里的 zzz 太大或太小导致溢出/下溢。
  2. 减少重复计算
    如果我们已经知道各块各自的 max⁡\maxmax 值和求和 ℓ(x(k))\ell(\mathbf{x}^{(k)})(x(k)),那就无需把 x\mathbf{x}x 整体重新扫描、求最大值、求指数和,总结出公式即可快速拼成拼接后向量的软最大值。
  3. 方便分布式或分块处理
    在实际系统里,x(1)\mathbf{x}^{(1)}x(1)x(2)\mathbf{x}^{(2)}x(2) 可能来自不同子网络或不同设备。这种分块计算可以让每一块先在本地完成自己的 Softmax 部分计算,最后再做一次简短的组合归一化即可。
Logo

有“AI”的1024 = 2048,欢迎大家加入2048 AI社区

更多推荐