推导Multi-Head Attention计算过程,解释Q/K/V的作用

摘要

本文从单头Attention出发,通过完整的数学推导和代码实现,深入解析Multi-Head Attention的计算流程。重点阐述Query/Key/Value三者各自的作用机制,以及多头设计如何增强模型的表达能力。适合已了解Attention基础概念的读者深入理解其核心原理。


目录


1. 从单头Attention说起

Attention机制的核心思想是:根据Query从一组Key-Value对中提取有价值的信息。在Transformer中,我们使用缩放点积注意力(Scaled Dot-Product Attention)作为基础单元。

给定输入序列 X ∈ R n × d X \in \mathbb{R}^{n \times d} XRn×d n n n个token,每个维度 d d d),我们首先生成三个矩阵:

Q = X W Q 其中  W Q ∈ R d × d k K = X W K 其中  W K ∈ R d × d k V = X W V 其中  W V ∈ R d × d v \begin{align*} Q &= XW_Q \quad &\text{其中 } W_Q \in \mathbb{R}^{d \times d_k} \\ K &= XW_K \quad &\text{其中 } W_K \in \mathbb{R}^{d \times d_k} \\ V &= XW_V \quad &\text{其中 } W_V \in \mathbb{R}^{d \times d_v} \end{align*} QKV=XWQ=XWK=XWV其中 WQRd×dk其中 WKRd×dk其中 WVRd×dv


2. Q/K/V的数学本质与作用

Query (查询): 我想要什么信息

  • 作用:表示当前token的查询意图。类比数据库查询,Q是SQL中的SELECT语句。
  • 维度 Q ∈ R n × d k Q \in \mathbb{R}^{n \times d_k} QRn×dk
  • 理解:每个token通过 W Q W_Q WQ投影到查询空间,生成一个"需求向量",用于寻找相关的key。

Key (键): 我能提供什么信息

  • 作用:表示token的索引标识。类比数据库,K是PRIMARY KEY,用于被查询匹配。
  • 维度 K ∈ R n × d k K \in \mathbb{R}^{n \times d_k} KRn×dk
  • 理解:每个token通过 W K W_K WK投影到键空间,生成一个"身份标识",响应查询请求。

Value (值): 我实际的信息内容

  • 作用:存储token的实际语义信息。类比数据库,V是存储的DATA本身。
  • 维度 V ∈ R n × d v V \in \mathbb{R}^{n \times d_v} VRn×dv
  • 理解:每个token通过 W V W_V WV投影到值空间,生成真正的"信息载体"。

三者关系:通过计算Query与Key的相似度,决定从Value中提取多少信息。这是一种可学习的动态加权机制。


3. Scaled Dot-Product Attention推导

Step 1: 相似度计算

计算Query与所有Key的点积,得到注意力分数:
scores = Q K T ∈ R n × n \text{scores} = QK^T \in \mathbb{R}^{n \times n} scores=QKTRn×n
其中每个元素 s i j = q i ⋅ k j T s_{ij} = q_i \cdot k_j^T sij=qikjT 表示token i i i对token j j j的关注程度。

Step 2: 缩放(Scale)

为了防止梯度消失,除以 d k \sqrt{d_k} dk
scores scaled = Q K T d k \text{scores}_{\text{scaled}} = \frac{QK^T}{\sqrt{d_k}} scoresscaled=dk QKT
原因:当 d k d_k dk较大时,点积的方差会增大,softmax会进入饱和区,梯度变小。

Step 3: Softmax归一化

对每一行进行softmax,得到注意力权重:
A = softmax ( Q K T d k ) ∈ R n × n A = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right) \in \mathbb{R}^{n \times n} A=softmax(dk QKT)Rn×n
其中 A i j A_{ij} Aij表示token i i i对token j j j的注意力权重。

Step 4: 加权求和

用注意力权重对Value进行加权:
Attention ( Q , K , V ) = A V ∈ R n × d v \text{Attention}(Q,K,V) = AV \in \mathbb{R}^{n \times d_v} Attention(Q,K,V)=AVRn×dv
最终输出是Value的凸组合(权重和为1)。

完整公式
Attention ( Q , K , V ) = softmax ( Q K T d k ) V \text{Attention}(Q,K,V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V Attention(Q,K,V)=softmax(dk QKT)V


4. Multi-Head Attention完整计算流程

Multi-Head Attention通过 h h h个独立的注意力头并行计算,每个头关注不同的子空间信息。

Step 1: 多组线性投影

对每个头 i ∈ [ 1 , h ] i \in [1, h] i[1,h]
Q i = X W Q i W Q i ∈ R d × d k K i = X W K i W K i ∈ R d × d k V i = X W V i W V i ∈ R d × d v \begin{align*} Q_i &= XW_Q^i \quad &W_Q^i \in \mathbb{R}^{d \times d_k} \\ K_i &= XW_K^i \quad &W_K^i \in \mathbb{R}^{d \times d_k} \\ V_i &= XW_V^i \quad &W_V^i \in \mathbb{R}^{d \times d_v} \end{align*} QiKiVi=XWQi=XWKi=XWViWQiRd×dkWKiRd×dkWViRd×dv
通常设置 d k = d v = d / h d_k = d_v = d/h dk=dv=d/h

Step 2: 每个头独立计算Attention

head i = Attention ( Q i , K i , V i ) ∈ R n × d v \text{head}_i = \text{Attention}(Q_i, K_i, V_i) \in \mathbb{R}^{n \times d_v} headi=Attention(Qi,Ki,Vi)Rn×dv

Step 3: 拼接所有头的输出

MultiHead ( Q , K , V ) = Concat ( head 1 , head 2 , . . . , head h ) ∈ R n × ( h ⋅ d v ) \text{MultiHead}(Q,K,V) = \text{Concat}(\text{head}_1, \text{head}_2, ..., \text{head}_h) \in \mathbb{R}^{n \times (h \cdot d_v)} MultiHead(Q,K,V)=Concat(head1,head2,...,headh)Rn×(hdv)

Step 4: 最终线性投影

Output = MultiHead ( Q , K , V ) W O 其中  W O ∈ R ( h ⋅ d v ) × d \text{Output} = \text{MultiHead}(Q,K,V)W_O \quad \text{其中 } W_O \in \mathbb{R}^{(h \cdot d_v) \times d} Output=MultiHead(Q,K,V)WO其中 WOR(hdv)×d
将维度映射回原始维度 d d d,实现不同头之间的信息融合。

完整公式
MultiHead ( Q , K , V ) = Concat ( head 1 , . . . , head h ) W O where  head i = Attention ( X W Q i , X W K i , X W V i ) \text{MultiHead}(Q,K,V) = \text{Concat}(\text{head}_1, ..., \text{head}_h)W_O \\ \text{where } \text{head}_i = \text{Attention}(XW_Q^i, XW_K^i, XW_V^i) MultiHead(Q,K,V)=Concat(head1,...,headh)WOwhere headi=Attention(XWQi,XWKi,XWVi)


5. PyTorch实现解析

import torch
import torch.nn as nn
import torch.nn.functional as F
import math

class MultiHeadAttention(nn.Module):
    def __init__(self, d_model=512, num_heads=8):
        super().__init__()
        assert d_model % num_heads == 0
        
        self.d_model = d_model          # 模型维度
        self.num_heads = num_heads      # 头数
        self.d_k = d_model // num_heads # 每个头的维度
        
        # 定义Q/K/V的线性变换矩阵
        self.W_Q = nn.Linear(d_model, d_model)
        self.W_K = nn.Linear(d_model, d_model)
        self.W_V = nn.Linear(d_model, d_model)
        
        # 输出投影矩阵
        self.W_O = nn.Linear(d_model, d_model)
        
    def scaled_dot_product_attention(self, Q, K, V, mask=None):
        """
        核心注意力计算
        Q: [batch_size, num_heads, seq_len, d_k]
        K: [batch_size, num_heads, seq_len, d_k]
        V: [batch_size, num_heads, seq_len, d_v]
        """
        # 1. 计算QK^T并缩放
        scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)
        
        # 2. 应用掩码(可选)
        if mask is not None:
            scores = scores.masked_fill(mask == 0, -1e9)
        
        # 3. Softmax归一化
        attn_weights = F.softmax(scores, dim=-1)
        
        # 4. 加权求和
        output = torch.matmul(attn_weights, V)
        
        return output, attn_weights
    
    def forward(self, X, mask=None):
        """
        X: [batch_size, seq_len, d_model]
        """
        batch_size, seq_len, d_model = X.size()
        
        # Step 1: 线性变换生成Q/K/V
        # [b, seq, d_model] -> [b, seq, d_model]
        Q = self.W_Q(X)
        K = self.W_K(X)
        V = self.W_V(X)
        
        # Step 2: 重塑为多头形式
        # [b, seq, d_model] -> [b, seq, num_heads, d_k] -> [b, num_heads, seq, d_k]
        Q = Q.view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
        K = K.view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
        V = V.view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
        
        # Step 3: 计算每个头的注意力
        # attn_output: [b, num_heads, seq, d_k]
        attn_output, attn_weights = self.scaled_dot_product_attention(Q, K, V, mask)
        
        # Step 4: 拼接多头输出
        # [b, num_heads, seq, d_k] -> [b, seq, num_heads, d_k] -> [b, seq, d_model]
        attn_output = attn_output.transpose(1, 2).contiguous().view(
            batch_size, seq_len, d_model
        )
        
        # Step 5: 最终线性投影
        output = self.W_O(attn_output)
        
        return output, attn_weights

# 使用示例
d_model, num_heads = 512, 8
mha = MultiHeadAttention(d_model, num_heads)

# 模拟输入: batch_size=32, seq_len=100, d_model=512
X = torch.randn(32, 100, 512)
output, attn_weights = mha(X)

print(f"Output shape: {output.shape}")          # [32, 100, 512]
print(f"Attention weights shape: {attn_weights.shape}")  # [32, 8, 100, 100]
Logo

有“AI”的1024 = 2048,欢迎大家加入2048 AI社区

更多推荐