02-Transformer核心架构详解-自注意力与多头注意力
#Transformer详解 #Self-Attention #Multi-Head-Attention #QKV矩阵 #位置编码 #LayerNorm #残差连接 #吴恩达课程 #PyTorch实现 #大模型架构
Transformer核心架构详解:自注意力与多头注意力 | 吴恩达2025最新课程笔记
本文深度剖析Transformer的核心机制——自注意力(Self-Attention)和多头注意力(Multi-Head Attention)。通过数学推导、可视化图表和PyTorch代码实现,详细讲解QKV矩阵计算、注意力分数、缩放点积注意力等关键技术。涵盖Transformer Block完整结构、残差连接、层归一化等工程实践要点,是理解现代大语言模型架构的必读教程。
一、Transformer架构全景
1.1 Transformer诞生的背景
2017年Google发表的《Attention Is All You Need》彻底改变了NLP领域。其核心创新:完全抛弃RNN/CNN,只用注意力机制。
1.2 完整架构图
1.3 三大核心组件
| 组件 | 作用 | 关键技术 |
|---|---|---|
| Self-Attention | 捕捉序列内部依赖 | QKV矩阵、缩放点积 |
| Multi-Head Attention | 多视角信息提取 | 多个注意力头并行 |
| Feed-Forward Network | 非线性变换 | 两层全连接+激活 |
二、自注意力机制(Self-Attention)详解
2.1 核心思想
Self-Attention的本质:让序列中的每个元素都能"看到"其他所有元素,并决定关注程度。
2.2 QKV三剑客:Query、Key、Value
核心概念:
- Query (Q): “我要找什么?”(查询向量)
- Key (K): “我是什么?”(键向量)
- Value (V): “我有什么信息?”(值向量)
类比理解:就像图书馆检索系统
2.3 数学公式推导
Step 1: 生成QKV矩阵
对于输入序列 X∈Rn×dX \in \mathbb{R}^{n \times d}X∈Rn×d (n个词,每个d维):
Q=XWQ,WQ∈Rd×dkK=XWK,WK∈Rd×dkV=XWV,WV∈Rd×dv \begin{align} Q &= XW^Q, \quad W^Q \in \mathbb{R}^{d \times d_k} \\ K &= XW^K, \quad W^K \in \mathbb{R}^{d \times d_k} \\ V &= XW^V, \quad W^V \in \mathbb{R}^{d \times d_v} \end{align} QKV=XWQ,WQ∈Rd×dk=XWK,WK∈Rd×dk=XWV,WV∈Rd×dv
Step 2: 计算注意力分数
Score=QKTdk \text{Score} = \frac{QK^T}{\sqrt{d_k}} Score=dkQKT
为什么要除以 dk\sqrt{d_k}dk?
Step 3: Softmax归一化
Attention Weights=softmax(Score)=softmax(QKTdk) \text{Attention Weights} = \text{softmax}(\text{Score}) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right) Attention Weights=softmax(Score)=softmax(dkQKT)
Step 4: 加权求和
Output=Attention Weights×V \text{Output} = \text{Attention Weights} \times V Output=Attention Weights×V
完整公式:
Attention(Q,K,V)=softmax(QKTdk)V \boxed{\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V} Attention(Q,K,V)=softmax(dkQKT)V
2.4 直观示例:计算"爱"的新表示
输入句子:“我 爱 学习 AI”
假设计算"爱"对其他词的注意力:
| 词 | Query·Key | Score | Softmax | 最终贡献 |
|---|---|---|---|---|
| 我 | 2.1 | 2.1/√64=0.26 | 0.1 | 0.1×V(我) |
| 爱 | 3.5 | 0.44 | 0.2 | 0.2×V(爱) |
| 学习 | 7.2 | 0.90 | 0.4 | 0.4×V(学习) |
| AI | 5.8 | 0.73 | 0.3 | 0.3×V(AI) |
最终:"爱"的新表示 = 0.1×V(我) + 0.2×V(爱) + 0.4×V(学习) + 0.3×V(AI)
三、多头注意力(Multi-Head Attention)
3.1 为什么需要多头?
单头的局限:只能捕捉一种关系模式
3.2 多头注意力架构
3.3 数学公式
对于 hhh 个注意力头:
headi=Attention(QWiQ,KWiK,VWiV)MultiHead(Q,K,V)=Concat(head1,...,headh)WO \begin{align} \text{head}_i &= \text{Attention}(QW_i^Q, KW_i^K, VW_i^V) \\ \text{MultiHead}(Q,K,V) &= \text{Concat}(\text{head}_1, ..., \text{head}_h)W^O \end{align} headiMultiHead(Q,K,V)=Attention(QWiQ,KWiK,VWiV)=Concat(head1,...,headh)WO
其中:
- WiQ∈Rdmodel×dkW_i^Q \in \mathbb{R}^{d_{model} \times d_k}WiQ∈Rdmodel×dk
- WiK∈Rdmodel×dkW_i^K \in \mathbb{R}^{d_{model} \times d_k}WiK∈Rdmodel×dk
- WiV∈Rdmodel×dvW_i^V \in \mathbb{R}^{d_{model} \times d_v}WiV∈Rdmodel×dv
- WO∈Rhdv×dmodelW^O \in \mathbb{R}^{hd_v \times d_{model}}WO∈Rhdv×dmodel
典型配置(如BERT):
- dmodel=768d_{model} = 768dmodel=768 (模型维度)
- h=12h = 12h=12 (注意力头数)
- dk=dv=dmodel/h=64d_k = d_v = d_{model}/h = 64dk=dv=dmodel/h=64 (每个头的维度)
3.4 多头的优势
四、PyTorch完整实现
4.1 Scaled Dot-Product Attention
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
def scaled_dot_product_attention(Q, K, V, mask=None):
"""
参数:
Q: [batch_size, n_heads, seq_len, d_k]
K: [batch_size, n_heads, seq_len, d_k]
V: [batch_size, n_heads, seq_len, d_v]
mask: [batch_size, 1, 1, seq_len] 可选
返回:
output: [batch_size, n_heads, seq_len, d_v]
attention_weights: [batch_size, n_heads, seq_len, seq_len]
"""
d_k = Q.size(-1)
# 1. 计算注意力分数: QK^T / √d_k
scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(d_k)
# scores: [batch_size, n_heads, seq_len, seq_len]
# 2. 可选:应用mask(用于Decoder中的自回归)
if mask is not None:
scores = scores.masked_fill(mask == 0, -1e9)
# 3. Softmax归一化
attention_weights = F.softmax(scores, dim=-1)
# 4. 加权求和
output = torch.matmul(attention_weights, V)
return output, attention_weights
# 测试
batch_size, n_heads, seq_len, d_k = 2, 8, 10, 64
Q = torch.randn(batch_size, n_heads, seq_len, d_k)
K = torch.randn(batch_size, n_heads, seq_len, d_k)
V = torch.randn(batch_size, n_heads, seq_len, d_k)
output, attn_weights = scaled_dot_product_attention(Q, K, V)
print(f"输出形状: {output.shape}") # torch.Size([2, 8, 10, 64])
print(f"注意力权重: {attn_weights.shape}") # torch.Size([2, 8, 10, 10])
4.2 Multi-Head Attention完整实现
class MultiHeadAttention(nn.Module):
def __init__(self, d_model, n_heads):
"""
参数:
d_model: 模型维度(如768)
n_heads: 注意力头数(如12)
"""
super().__init__()
assert d_model % n_heads == 0, "d_model必须能被n_heads整除"
self.d_model = d_model
self.n_heads = n_heads
self.d_k = d_model // n_heads # 每个头的维度
# QKV的线性变换矩阵
self.W_Q = nn.Linear(d_model, d_model)
self.W_K = nn.Linear(d_model, d_model)
self.W_V = nn.Linear(d_model, d_model)
# 输出的线性变换
self.W_O = nn.Linear(d_model, d_model)
def split_heads(self, x):
"""
将输入分割成多个头
x: [batch_size, seq_len, d_model]
返回: [batch_size, n_heads, seq_len, d_k]
"""
batch_size, seq_len, d_model = x.size()
return x.view(batch_size, seq_len, self.n_heads, self.d_k).transpose(1, 2)
def forward(self, Q, K, V, mask=None):
"""
参数:
Q, K, V: [batch_size, seq_len, d_model]
mask: [batch_size, 1, 1, seq_len]
"""
batch_size = Q.size(0)
# 1. 线性变换
Q = self.W_Q(Q) # [batch_size, seq_len, d_model]
K = self.W_K(K)
V = self.W_V(V)
# 2. 分割成多个头
Q = self.split_heads(Q) # [batch_size, n_heads, seq_len, d_k]
K = self.split_heads(K)
V = self.split_heads(V)
# 3. 缩放点积注意力
attn_output, attn_weights = scaled_dot_product_attention(Q, K, V, mask)
# attn_output: [batch_size, n_heads, seq_len, d_k]
# 4. 合并多个头
attn_output = attn_output.transpose(1, 2).contiguous()
# [batch_size, seq_len, n_heads, d_k]
attn_output = attn_output.view(batch_size, -1, self.d_model)
# [batch_size, seq_len, d_model]
# 5. 最终线性变换
output = self.W_O(attn_output)
return output, attn_weights
# 使用示例
d_model = 512
n_heads = 8
batch_size = 2
seq_len = 10
mha = MultiHeadAttention(d_model, n_heads)
# 输入
x = torch.randn(batch_size, seq_len, d_model)
# 自注意力:Q=K=V
output, attn_weights = mha(x, x, x)
print(f"输入形状: {x.shape}") # torch.Size([2, 10, 512])
print(f"输出形状: {output.shape}") # torch.Size([2, 10, 512])
print(f"注意力权重: {attn_weights.shape}") # torch.Size([2, 8, 10, 10])
4.3 可视化注意力权重
import matplotlib.pyplot as plt
import seaborn as sns
def visualize_attention(attn_weights, tokens, head_idx=0):
"""
可视化某个注意力头的权重
attn_weights: [batch_size, n_heads, seq_len, seq_len]
tokens: 词列表
head_idx: 要可视化的头索引
"""
# 提取第一个样本的指定头
attn = attn_weights[0, head_idx].detach().cpu().numpy()
plt.figure(figsize=(10, 8))
sns.heatmap(attn,
xticklabels=tokens,
yticklabels=tokens,
cmap='YlOrRd',
annot=True,
fmt='.2f',
cbar_kws={'label': '注意力权重'})
plt.title(f'Attention Head {head_idx}')
plt.xlabel('Key')
plt.ylabel('Query')
plt.tight_layout()
plt.show()
# 示例:可视化英译中的注意力
tokens = ['I', 'love', 'learning', 'AI', '<EOS>']
seq_len = len(tokens)
# 模拟注意力权重
mha = MultiHeadAttention(d_model=512, n_heads=8)
x = torch.randn(1, seq_len, 512)
_, attn_weights = mha(x, x, x)
# 可视化第0个头
visualize_attention(attn_weights, tokens, head_idx=0)
输出效果:
注意力矩阵 (Head 0)
I love learning AI <EOS>
I 0.20 0.15 0.10 0.35 0.20
love 0.10 0.50 0.30 0.05 0.05
learning 0.05 0.30 0.40 0.20 0.05
AI 0.15 0.10 0.25 0.45 0.05
<EOS> 0.05 0.05 0.05 0.10 0.75
五、Transformer Block完整结构
5.1 单个Block的组成
5.2 残差连接(Residual Connection)
为什么需要?
数学表示:
Output=LayerNorm(X+MultiHeadAttention(X)) \text{Output} = \text{LayerNorm}(X + \text{MultiHeadAttention}(X)) Output=LayerNorm(X+MultiHeadAttention(X))
5.3 Layer Normalization
与Batch Norm的区别:
| 特性 | Batch Norm | Layer Norm |
|---|---|---|
| 归一化维度 | 跨batch维度 | 跨特征维度 |
| 适用场景 | CNN(固定batch) | NLP(变长序列) |
| 依赖性 | 依赖batch大小 | 独立于batch |
Layer Norm公式:
LayerNorm(x)=γx−μσ2+ϵ+β \text{LayerNorm}(x) = \gamma \frac{x - \mu}{\sqrt{\sigma^2 + \epsilon}} + \beta LayerNorm(x)=γσ2+ϵx−μ+β
其中 μ,σ\mu, \sigmaμ,σ 是当前层所有特征的均值和标准差。
5.4 前馈网络(Feed-Forward Network)
结构:两层全连接+激活函数
FFN(x)=ReLU(xW1+b1)W2+b2 \text{FFN}(x) = \text{ReLU}(xW_1 + b_1)W_2 + b_2 FFN(x)=ReLU(xW1+b1)W2+b2
典型配置:
- BERT: dmodel=768d_{model}=768dmodel=768, dff=3072d_{ff}=3072dff=3072 (4倍)
- GPT-3: dmodel=12288d_{model}=12288dmodel=12288, dff=49152d_{ff}=49152dff=49152 (4倍)
5.5 完整Transformer Block实现
class TransformerBlock(nn.Module):
def __init__(self, d_model, n_heads, d_ff, dropout=0.1):
"""
参数:
d_model: 模型维度
n_heads: 注意力头数
d_ff: 前馈网络隐藏层维度
dropout: Dropout比率
"""
super().__init__()
# 多头注意力
self.mha = MultiHeadAttention(d_model, n_heads)
# 前馈网络
self.ffn = nn.Sequential(
nn.Linear(d_model, d_ff),
nn.ReLU(),
nn.Linear(d_ff, d_model)
)
# Layer Normalization
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
# Dropout
self.dropout1 = nn.Dropout(dropout)
self.dropout2 = nn.Dropout(dropout)
def forward(self, x, mask=None):
"""
x: [batch_size, seq_len, d_model]
"""
# 1. 多头自注意力 + 残差连接 + LayerNorm
attn_output, _ = self.mha(x, x, x, mask)
attn_output = self.dropout1(attn_output)
x = self.norm1(x + attn_output) # 残差连接
# 2. 前馈网络 + 残差连接 + LayerNorm
ffn_output = self.ffn(x)
ffn_output = self.dropout2(ffn_output)
x = self.norm2(x + ffn_output) # 残差连接
return x
# 测试
d_model = 512
n_heads = 8
d_ff = 2048
block = TransformerBlock(d_model, n_heads, d_ff)
x = torch.randn(2, 10, d_model) # [batch, seq_len, d_model]
output = block(x)
print(f"输入形状: {x.shape}") # torch.Size([2, 10, 512])
print(f"输出形状: {output.shape}") # torch.Size([2, 10, 512])
六、位置编码(Positional Encoding)
6.1 为什么需要位置编码?
问题:自注意力是置换不变的(permutation-invariant)
sentence1 = "我 爱 AI"
sentence2 = "AI 爱 我"
# 如果没有位置编码,Self-Attention会给出相同的结果!
6.2 正弦位置编码
公式:
PE(pos,2i)=sin(pos100002i/dmodel)PE(pos,2i+1)=cos(pos100002i/dmodel) \begin{align} PE_{(pos, 2i)} &= \sin\left(\frac{pos}{10000^{2i/d_{model}}}\right) \\ PE_{(pos, 2i+1)} &= \cos\left(\frac{pos}{10000^{2i/d_{model}}}\right) \end{align} PE(pos,2i)PE(pos,2i+1)=sin(100002i/dmodelpos)=cos(100002i/dmodelpos)
其中:
- pospospos: 词的位置(0, 1, 2, …)
- iii: 维度索引(0, 1, …, d_model/2)
优点:
- ✅ 可以处理任意长度的序列
- ✅ 不需要训练参数
- ✅ 相对位置关系可以通过线性变换表达
6.3 实现代码
class PositionalEncoding(nn.Module):
def __init__(self, d_model, max_len=5000):
super().__init__()
# 创建位置编码矩阵
pe = torch.zeros(max_len, d_model)
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
# 计算分母
div_term = torch.exp(torch.arange(0, d_model, 2).float() *
(-math.log(10000.0) / d_model))
# 计算正弦和余弦
pe[:, 0::2] = torch.sin(position * div_term) # 偶数维度
pe[:, 1::2] = torch.cos(position * div_term) # 奇数维度
# 添加batch维度
pe = pe.unsqueeze(0) # [1, max_len, d_model]
self.register_buffer('pe', pe)
def forward(self, x):
"""
x: [batch_size, seq_len, d_model]
"""
# 添加位置编码
x = x + self.pe[:, :x.size(1), :]
return x
# 可视化位置编码
def visualize_positional_encoding(d_model=128, max_len=100):
pe = PositionalEncoding(d_model, max_len)
encoding = pe.pe[0, :max_len, :].numpy()
plt.figure(figsize=(15, 5))
plt.imshow(encoding.T, aspect='auto', cmap='RdBu',
interpolation='nearest')
plt.colorbar(label='编码值')
plt.xlabel('位置')
plt.ylabel('维度')
plt.title('正弦位置编码可视化')
plt.tight_layout()
plt.show()
visualize_positional_encoding()
输出效果:会看到规律的波浪状图案,不同频率编码不同维度。
七、完整Encoder实现
class TransformerEncoder(nn.Module):
def __init__(self, vocab_size, d_model, n_heads, d_ff,
n_layers, max_len=5000, dropout=0.1):
"""
参数:
vocab_size: 词汇表大小
d_model: 模型维度
n_heads: 注意力头数
d_ff: 前馈网络维度
n_layers: Transformer Block层数
max_len: 最大序列长度
dropout: Dropout比率
"""
super().__init__()
# 词嵌入
self.embedding = nn.Embedding(vocab_size, d_model)
# 位置编码
self.pos_encoding = PositionalEncoding(d_model, max_len)
# 多层Transformer Block
self.layers = nn.ModuleList([
TransformerBlock(d_model, n_heads, d_ff, dropout)
for _ in range(n_layers)
])
self.dropout = nn.Dropout(dropout)
def forward(self, x, mask=None):
"""
x: [batch_size, seq_len] (token indices)
"""
# 1. 词嵌入 + 位置编码
x = self.embedding(x) * math.sqrt(self.embedding.embedding_dim)
x = self.pos_encoding(x)
x = self.dropout(x)
# 2. 通过多层Transformer Block
for layer in self.layers:
x = layer(x, mask)
return x
# 使用示例:构建一个小型BERT
vocab_size = 30000
d_model = 768
n_heads = 12
d_ff = 3072
n_layers = 12
encoder = TransformerEncoder(vocab_size, d_model, n_heads, d_ff, n_layers)
# 输入token indices
input_ids = torch.randint(0, vocab_size, (2, 20)) # [batch=2, seq_len=20]
output = encoder(input_ids)
print(f"输入形状: {input_ids.shape}") # torch.Size([2, 20])
print(f"输出形状: {output.shape}") # torch.Size([2, 20, 768])
print(f"参数量: {sum(p.numel() for p in encoder.parameters())/1e6:.1f}M")
# 输出: 参数量: 110.1M (接近BERT-Base的110M)
八、关键概念总结
8.1 公式总结
| 组件 | 公式 | 说明 |
|---|---|---|
| Self-Attention | Attention(Q,K,V)=softmax(QKTdk)V\text{Attention}(Q,K,V) = \text{softmax}(\frac{QK^T}{\sqrt{d_k}})VAttention(Q,K,V)=softmax(dkQKT)V | 缩放点积注意力 |
| Multi-Head | Concat(head1,...,headh)WO\text{Concat}(\text{head}_1,...,\text{head}_h)W^OConcat(head1,...,headh)WO | 多视角融合 |
| FFN | ReLU(xW1+b1)W2+b2\text{ReLU}(xW_1+b_1)W_2+b_2ReLU(xW1+b1)W2+b2 | 两层全连接 |
| Layer Norm | γx−μσ2+ϵ+β\gamma\frac{x-\mu}{\sqrt{\sigma^2+\epsilon}}+\betaγσ2+ϵx−μ+β | 特征归一化 |
| Positional | sin(pos100002i/d)\sin(\frac{pos}{10000^{2i/d}})sin(100002i/dpos) | 位置编码 |
8.2 配置对比
| 模型 | d_model | n_heads | n_layers | d_ff | 参数量 |
|---|---|---|---|---|---|
| BERT-Base | 768 | 12 | 12 | 3072 | 110M |
| BERT-Large | 1024 | 16 | 24 | 4096 | 340M |
| GPT-2 | 768 | 12 | 12 | 3072 | 117M |
| GPT-3 | 12288 | 96 | 96 | 49152 | 175B |
8.3 架构流程图
九、实战练习
练习1:计算注意力权重
题目:给定QKV矩阵,手工计算注意力输出
# 已知(简化为2x2矩阵方便计算)
Q = torch.tensor([[1.0, 0.0],
[0.0, 1.0]])
K = torch.tensor([[1.0, 0.0],
[0.5, 0.5]])
V = torch.tensor([[2.0, 0.0],
[1.0, 1.0]])
# 步骤:
# 1. 计算 QK^T / √d_k
# 2. Softmax
# 3. 乘以V
# 你的答案:
练习2:实现Masked Self-Attention
任务:修改Self-Attention,实现Decoder中的mask机制(当前词不能看到未来词)
def masked_self_attention(Q, K, V):
"""
TODO: 实现masked attention
提示: 使用torch.tril创建下三角mask
"""
pass
练习3:分析注意力头
任务:加载预训练BERT,可视化不同注意力头关注的模式
from transformers import BertModel, BertTokenizer
model = BertModel.from_pretrained('bert-base-uncased', output_attentions=True)
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
text = "The animal didn't cross the street because it was too tired."
inputs = tokenizer(text, return_tensors='pt')
with torch.no_grad():
outputs = model(**inputs)
attentions = outputs.attentions # 12层,每层12个头
更多推荐

所有评论(0)