注意力的目标是将一个词嵌入 Query 和一组词嵌入及其 Key,Value 点积映射到一个输出。注意力机制同时处理一组 Query,点积成矩阵 Q。Key,Value 也点积成矩阵 K 和 V。原始的点积注意力公式为:

Attention ( Q , K , V ) = softmax ( Q K T ) \text{Attention}(Q, K, V) = \text{softmax}(QK^T) Attention(Q,K,V)=softmax(QKT)

其中 Q K T QK^T QKT 矩阵的每个元素 A i j = q i ⋅ k j A_{ij} = q_i \cdot k_j Aij=qikj 代表了第 i 个查询与第 j 个键的相似度。

但这个公式存在问题,点积是所有对应位置乘积之和,当把很多这样的乘积加起来时,总结果的统计方差会随着维度的增加而线性增大,维度越高,点积结果的数值就越可能走向极端,要么非常大,要么非常小。

假设查询 q 和键 k 是 d k d_k dk 维向量,其分量是独立随机变量,均值为 μ = 0 \mu = 0 μ=0,方差为 σ 2 = 1 \sigma^2 = 1 σ2=1,它们的点积:

s = q ⋅ k = ∑ i = 1 d k q i k i s = q \cdot k = \sum_{i=1}^{d_k} q_i k_i s=qk=i=1dkqiki

s 的期望值:

E [ s ] = E [ ∑ i = 1 d k q i k i ] = ∑ i = 1 d k E [ q i k i ] = ∑ i = 1 d k E [ q i ] E [ k i ] = 0 \text{E}[s] = \text{E}[\sum_{i=1}^{d_k} q_i k_i] = \sum_{i=1}^{d_k} \text{E}[q_i k_i] = \sum_{i=1}^{d_k} \text{E}[q_i] \text{E}[k_i] = 0 E[s]=E[i=1dkqiki]=i=1dkE[qiki]=i=1dkE[qi]E[ki]=0

其方差为:

Var ( s ) = Var ( ∑ i = 1 d k q i k i \text{Var}(s) = \text{Var}(\sum_{i=1}^{d_k} q_i k_i Var(s)=Var(i=1dkqiki

由于 q i q_i qi k i k_i ki 独立,协方差项为零,方差具有可加性:

Var ( s ) = ∑ i = 1 d k Var ( q i k i \text{Var}(s) = \sum_{i=1}^{d_k} \text{Var}(q_i k_i Var(s)=i=1dkVar(qiki

计算 Var ( q i k i ) \text{Var}(q_i k_i) Var(qiki)

Var ( q i k i ) = E [ ( q i k i ) 2 ] − ( E [ q i k i ] ) 2 = E [ q i 2 ] E [ k i 2 ] − 0 \text{Var}(q_i k_i) = \text{E}[(q_i k_i)^2] - (\text{E}[q_i k_i])^2 = \text{E}[q_i^2] \text{E}[k_i^2] - 0 Var(qiki)=E[(qiki)2](E[qiki])2=E[qi2]E[ki2]0

因为 E [ q i 2 ] = Var ( q i ) + ( E [ q i ] ) 2 = 1 + 0 = 1 \text{E}[q_i^2] = \text{Var}(q_i) + (\text{E}[q_i])^2 = 1 + 0 = 1 E[qi2]=Var(qi)+(E[qi])2=1+0=1,同理 E [ k i 2 ] = 1 \text{E}[k_i^2] = 1 E[ki2]=1。所以:

Var ( q i k i ) = 1 × 1 = 1 \text{Var}(q_i k_i) = 1 \times 1 = 1 Var(qiki)=1×1=1

最终:

Var ( s ) = ∑ i = 1 d k 1 = d k \text{Var}(s) = \sum_{i=1}^{d_k} 1 = d_k Var(s)=i=1dk1=dk

结论很显然,点积 s 的方差与维度 d k d_k dk 成正比,当 d k d_k dk 很大时,点积结果的绝对值可能会变得非常大。后果是 Softmax 造成梯度消失。

注意力权重通过 softmax 函数获得:

a i = exp ⁡ ( s i ) ∑ j exp ⁡ ( s j ) a_i = \dfrac{\exp(s_i)}{\sum_j \exp(s_j)} ai=jexp(sj)exp(si)

当某个 s i s_i si 远大于其他得分时, a i → 1 a_i \to 1 ai1,而其他 a j → 0 a_j \to 0 aj0,这被称为 softmax 函数饱和。反向传播中,softmax 的梯度为:

∂ a i ∂ s j = a i ( δ i j − a j ) \frac{\partial a_i}{\partial s_j} = a_i (\delta_{ij} - a_j) sjai=ai(δijaj)

当分布饱和,即一个 a i ≈ 1 a_i \approx 1 ai1,其余 a j ≈ 0 a_j \approx 0 aj0 时,所有这些梯度项都趋近于 0,这导致了梯度消失,使得模型参数更新极其缓慢,训练效率低下。

解决方案就是采用缩放点积。

为了解决方差随维度增长的问题,我们将点积按其维度平方根缩放:

Attention ( Q , K , V ) = softmax ( Q K T d k ) V \text{Attention}(Q, K, V) = \text{softmax}(\dfrac{QK^T}{\sqrt{d_k}})V Attention(Q,K,V)=softmax(dk QKT)V

缩放后的点积为 s ′ = s d k s' = \dfrac{s}{\sqrt{d_k}} s=dk s,其方差为:

Var ( s ′ ) = Var ( s d k ) = ( 1 d k ) 2 × Var ( s ) = 1 d k × d k = 1 \text{Var}(s') = \text{Var}(\dfrac{s}{\sqrt{d_k}}) = (\dfrac{1}{\sqrt{d_k}})^2 \times \text{Var}(s) = \dfrac{1}{d_k} \times d_k = 1 Var(s)=Var(dk s)=(dk 1)2×Var(s)=dk1×dk=1

这意味着,缩放操作将点积得分的方差稳定在 1,与维度 d k d_k dk 无关。

通过缩放点积,将 softmax 函数的输入控制在一个合理的动态范围,防止了梯度消失,确保了训练过程的稳定和高效,缩放后的注意力权重分布也更平滑,允许模型同时关注多个相关位置,而不是过度聚焦于单一位置,从而捕获更丰富的上下文信息。

如果缩放因子是 d k d_k dk 本身而不是 d k \sqrt{d_k} dk 会怎样。

如此,缩放后的点积 s ′ ′ = s d k s'' = \dfrac{s}{d_k} s′′=dks 的方差为:

Var ( s ′ ′ ) = ( 1 d k ) 2 × Var ( s ) = 1 d k 2 × d k = 1 d k \text{Var}(s'') = (\dfrac{1}{d_k})^2 \times \text{Var}(s) = \dfrac{1}{d_k^2} \times d_k = \dfrac{1}{d_k} Var(s′′)=(dk1)2×Var(s)=dk21×dk=dk1

d k d_k dk 很大时,方差趋近于0。这意味着所有点积得分都密集地集中在 0 附近,经过softmax后,注意力权重会趋近于均匀分布 ( 1 L , 1 L , . . . , 1 L ) (\dfrac{1}{L}, \dfrac{1}{L}, ...,\dfrac{1}{L}) (L1,L1,...,L1),这完全破坏了注意力机制的选择性聚焦能力。

d k \sqrt{d_k} dk 是个精确缩放因子,它完美地抵消了方差随维度的增长的影响,保证注意力机制有效工作。

看个无关但有趣的,主宰 TCP AIMD 的中心极限定理(buffer 的平方反比律),我是三句话不离本行,AIMD 对 buffer 的占用就是按 n 缩放,所以它随 n 越来越小,与本文结论一致,非常公平地使 “权重会趋近于均匀分布”,公平性是 AIMD 特征,但 softmax 恰恰需要 “选择性聚焦”,而不是公平。

另一方面,如果找到另一个反馈控制算法,收敛值为 BDP 的 1 n \dfrac{1}{\sqrt{n}} n 1,便可全世界统一 buffer 大小了,但这时公平性问题就浮上水面,正如点积缩放一样,“同时关注多个相关位置,而不是过度聚焦于单一位置”,这印证了 AIMD 确实高尚,一切都是想通的。

浙江文章皮鞋湿,下雨进水不会胖。

Logo

有“AI”的1024 = 2048,欢迎大家加入2048 AI社区

更多推荐