【面试必问】大模型算法 推导Multi-Head Attention计算过程,解释Q/K/V的作用
Multi-Head Attention计算过程解析 本文深入分析了Multi-Head Attention的计算机制,从单头Attention扩展到多头设计。主要内容包括: 核心计算流程:通过Q/K/V三组矩阵实现注意力计算,包括相似度计算、缩放、softmax归一化和加权求和四个步骤。 Q/K/V作用: Query表示查询意图 Key作为索引标识 Value存储实际信息 多头设计:通过并行多个
推导Multi-Head Attention计算过程,解释Q/K/V的作用
摘要
本文从单头Attention出发,通过完整的数学推导和代码实现,深入解析Multi-Head Attention的计算流程。重点阐述Query/Key/Value三者各自的作用机制,以及多头设计如何增强模型的表达能力。适合已了解Attention基础概念的读者深入理解其核心原理。
目录
- 1. 从单头Attention说起
- 2. Q/K/V的数学本质与作用
- 3. Scaled Dot-Product Attention推导
- 4. Multi-Head Attention完整计算流程
- 5. PyTorch实现解析
- 6. 为什么需要多头设计?
- 7. 总结
1. 从单头Attention说起
Attention机制的核心思想是:根据Query从一组Key-Value对中提取有价值的信息。在Transformer中,我们使用缩放点积注意力(Scaled Dot-Product Attention)作为基础单元。
给定输入序列 X ∈ R n × d X \in \mathbb{R}^{n \times d} X∈Rn×d( n n n个token,每个维度 d d d),我们首先生成三个矩阵:
Q = X W Q 其中 W Q ∈ R d × d k K = X W K 其中 W K ∈ R d × d k V = X W V 其中 W V ∈ R d × d v \begin{align*} Q &= XW_Q \quad &\text{其中 } W_Q \in \mathbb{R}^{d \times d_k} \\ K &= XW_K \quad &\text{其中 } W_K \in \mathbb{R}^{d \times d_k} \\ V &= XW_V \quad &\text{其中 } W_V \in \mathbb{R}^{d \times d_v} \end{align*} QKV=XWQ=XWK=XWV其中 WQ∈Rd×dk其中 WK∈Rd×dk其中 WV∈Rd×dv
2. Q/K/V的数学本质与作用
Query (查询): 我想要什么信息
- 作用:表示当前token的查询意图。类比数据库查询,Q是SQL中的
SELECT语句。 - 维度: Q ∈ R n × d k Q \in \mathbb{R}^{n \times d_k} Q∈Rn×dk
- 理解:每个token通过 W Q W_Q WQ投影到查询空间,生成一个"需求向量",用于寻找相关的key。
Key (键): 我能提供什么信息
- 作用:表示token的索引标识。类比数据库,K是
PRIMARY KEY,用于被查询匹配。 - 维度: K ∈ R n × d k K \in \mathbb{R}^{n \times d_k} K∈Rn×dk
- 理解:每个token通过 W K W_K WK投影到键空间,生成一个"身份标识",响应查询请求。
Value (值): 我实际的信息内容
- 作用:存储token的实际语义信息。类比数据库,V是存储的
DATA本身。 - 维度: V ∈ R n × d v V \in \mathbb{R}^{n \times d_v} V∈Rn×dv
- 理解:每个token通过 W V W_V WV投影到值空间,生成真正的"信息载体"。
三者关系:通过计算Query与Key的相似度,决定从Value中提取多少信息。这是一种可学习的动态加权机制。
3. Scaled Dot-Product Attention推导
Step 1: 相似度计算
计算Query与所有Key的点积,得到注意力分数:
scores = Q K T ∈ R n × n \text{scores} = QK^T \in \mathbb{R}^{n \times n} scores=QKT∈Rn×n
其中每个元素 s i j = q i ⋅ k j T s_{ij} = q_i \cdot k_j^T sij=qi⋅kjT 表示token i i i对token j j j的关注程度。
Step 2: 缩放(Scale)
为了防止梯度消失,除以 d k \sqrt{d_k} dk:
scores scaled = Q K T d k \text{scores}_{\text{scaled}} = \frac{QK^T}{\sqrt{d_k}} scoresscaled=dkQKT
原因:当 d k d_k dk较大时,点积的方差会增大,softmax会进入饱和区,梯度变小。
Step 3: Softmax归一化
对每一行进行softmax,得到注意力权重:
A = softmax ( Q K T d k ) ∈ R n × n A = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right) \in \mathbb{R}^{n \times n} A=softmax(dkQKT)∈Rn×n
其中 A i j A_{ij} Aij表示token i i i对token j j j的注意力权重。
Step 4: 加权求和
用注意力权重对Value进行加权:
Attention ( Q , K , V ) = A V ∈ R n × d v \text{Attention}(Q,K,V) = AV \in \mathbb{R}^{n \times d_v} Attention(Q,K,V)=AV∈Rn×dv
最终输出是Value的凸组合(权重和为1)。
完整公式:
Attention ( Q , K , V ) = softmax ( Q K T d k ) V \text{Attention}(Q,K,V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V Attention(Q,K,V)=softmax(dkQKT)V
4. Multi-Head Attention完整计算流程
Multi-Head Attention通过 h h h个独立的注意力头并行计算,每个头关注不同的子空间信息。
Step 1: 多组线性投影
对每个头 i ∈ [ 1 , h ] i \in [1, h] i∈[1,h]:
Q i = X W Q i W Q i ∈ R d × d k K i = X W K i W K i ∈ R d × d k V i = X W V i W V i ∈ R d × d v \begin{align*} Q_i &= XW_Q^i \quad &W_Q^i \in \mathbb{R}^{d \times d_k} \\ K_i &= XW_K^i \quad &W_K^i \in \mathbb{R}^{d \times d_k} \\ V_i &= XW_V^i \quad &W_V^i \in \mathbb{R}^{d \times d_v} \end{align*} QiKiVi=XWQi=XWKi=XWViWQi∈Rd×dkWKi∈Rd×dkWVi∈Rd×dv
通常设置 d k = d v = d / h d_k = d_v = d/h dk=dv=d/h。
Step 2: 每个头独立计算Attention
head i = Attention ( Q i , K i , V i ) ∈ R n × d v \text{head}_i = \text{Attention}(Q_i, K_i, V_i) \in \mathbb{R}^{n \times d_v} headi=Attention(Qi,Ki,Vi)∈Rn×dv
Step 3: 拼接所有头的输出
MultiHead ( Q , K , V ) = Concat ( head 1 , head 2 , . . . , head h ) ∈ R n × ( h ⋅ d v ) \text{MultiHead}(Q,K,V) = \text{Concat}(\text{head}_1, \text{head}_2, ..., \text{head}_h) \in \mathbb{R}^{n \times (h \cdot d_v)} MultiHead(Q,K,V)=Concat(head1,head2,...,headh)∈Rn×(h⋅dv)
Step 4: 最终线性投影
Output = MultiHead ( Q , K , V ) W O 其中 W O ∈ R ( h ⋅ d v ) × d \text{Output} = \text{MultiHead}(Q,K,V)W_O \quad \text{其中 } W_O \in \mathbb{R}^{(h \cdot d_v) \times d} Output=MultiHead(Q,K,V)WO其中 WO∈R(h⋅dv)×d
将维度映射回原始维度 d d d,实现不同头之间的信息融合。
完整公式:
MultiHead ( Q , K , V ) = Concat ( head 1 , . . . , head h ) W O where head i = Attention ( X W Q i , X W K i , X W V i ) \text{MultiHead}(Q,K,V) = \text{Concat}(\text{head}_1, ..., \text{head}_h)W_O \\ \text{where } \text{head}_i = \text{Attention}(XW_Q^i, XW_K^i, XW_V^i) MultiHead(Q,K,V)=Concat(head1,...,headh)WOwhere headi=Attention(XWQi,XWKi,XWVi)
5. PyTorch实现解析
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
class MultiHeadAttention(nn.Module):
def __init__(self, d_model=512, num_heads=8):
super().__init__()
assert d_model % num_heads == 0
self.d_model = d_model # 模型维度
self.num_heads = num_heads # 头数
self.d_k = d_model // num_heads # 每个头的维度
# 定义Q/K/V的线性变换矩阵
self.W_Q = nn.Linear(d_model, d_model)
self.W_K = nn.Linear(d_model, d_model)
self.W_V = nn.Linear(d_model, d_model)
# 输出投影矩阵
self.W_O = nn.Linear(d_model, d_model)
def scaled_dot_product_attention(self, Q, K, V, mask=None):
"""
核心注意力计算
Q: [batch_size, num_heads, seq_len, d_k]
K: [batch_size, num_heads, seq_len, d_k]
V: [batch_size, num_heads, seq_len, d_v]
"""
# 1. 计算QK^T并缩放
scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)
# 2. 应用掩码(可选)
if mask is not None:
scores = scores.masked_fill(mask == 0, -1e9)
# 3. Softmax归一化
attn_weights = F.softmax(scores, dim=-1)
# 4. 加权求和
output = torch.matmul(attn_weights, V)
return output, attn_weights
def forward(self, X, mask=None):
"""
X: [batch_size, seq_len, d_model]
"""
batch_size, seq_len, d_model = X.size()
# Step 1: 线性变换生成Q/K/V
# [b, seq, d_model] -> [b, seq, d_model]
Q = self.W_Q(X)
K = self.W_K(X)
V = self.W_V(X)
# Step 2: 重塑为多头形式
# [b, seq, d_model] -> [b, seq, num_heads, d_k] -> [b, num_heads, seq, d_k]
Q = Q.view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
K = K.view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
V = V.view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
# Step 3: 计算每个头的注意力
# attn_output: [b, num_heads, seq, d_k]
attn_output, attn_weights = self.scaled_dot_product_attention(Q, K, V, mask)
# Step 4: 拼接多头输出
# [b, num_heads, seq, d_k] -> [b, seq, num_heads, d_k] -> [b, seq, d_model]
attn_output = attn_output.transpose(1, 2).contiguous().view(
batch_size, seq_len, d_model
)
# Step 5: 最终线性投影
output = self.W_O(attn_output)
return output, attn_weights
# 使用示例
d_model, num_heads = 512, 8
mha = MultiHeadAttention(d_model, num_heads)
# 模拟输入: batch_size=32, seq_len=100, d_model=512
X = torch.randn(32, 100, 512)
output, attn_weights = mha(X)
print(f"Output shape: {output.shape}") # [32, 100, 512]
print(f"Attention weights shape: {attn_weights.shape}") # [32, 8, 100, 100]
更多推荐



所有评论(0)