深入理解 Transformer:从自注意力机制到大模型的基石原理详解
Query (Q):当前 token 在问“我应该关注谁?Key (K):每个 token 在说“我是这个类型的信息”Value (V):每个 token 实际携带的内容AttentionQKVsoftmaxQKTdkVAttentionQKVsoftmaxdkQKTV其中dkd_kdk是 Key 向量的维度,除以dk\sqrt{d_k}dk是为了防止点积过大导致 softmax 梯度
深入理解 Transformer:从自注意力机制到大模型的基石原理详解
摘要:2017 年 Google 论文《Attention Is All You Need》提出了 Transformer 架构,彻底颠覆了 NLP 乃至整个深度学习领域的格局。本文将从动机出发,逐步拆解每一个核心模块,配套完整 PyTorch 实现代码、数学推导和工程实践经验,力求让你真正"吃透"Transformer,而不只是会背公式。
目录
- 为什么需要 Transformer?RNN 的致命缺陷
- Transformer 整体架构总览
- 输入处理:Token Embedding + Positional Encoding
- 核心模块:缩放点积注意力(Scaled Dot-Product Attention)
- 多头注意力机制(Multi-Head Attention)
- 前馈神经网络(Feed-Forward Network)
- 残差连接与层归一化(Residual + LayerNorm)
- 编码器(Encoder)完整实现
- 解码器(Decoder)与 Mask 机制
- 完整 Transformer 端到端实现
- 训练技巧与超参数选择
- 常见变体对比:BERT / GPT / T5
- 推理优化:KV Cache、Flash Attention、量化
- 总结与学习路线
— [Token Embedding]
|
[Positional Encoding]
|
┌────────────────────────┐
│ Encoder │
│ ┌──────────────────┐ │
│ │ Multi-Head Attn │ │
│ │ (Self-Attention) │ │
│ └────────┬─────────┘ │
│ │ Add & Norm │
│ ┌────────┴─────────┐ │
│ │ Feed Forward │ │
│ └────────┬─────────┘ │
│ │ Add & Norm │
│ × N layers │
└────────────────────────┘
|
Context (Memory)
|
┌────────────────────────┐
│ Decoder │
│ ┌──────────────────┐ │
│ │ Masked MH Attn │ │ <-- 输出序列 (tgt)
│ └────────┬─────────┘ │
│ │ Add & Norm │
│ ┌────────┴─────────┐ │
│ │ Cross-Attention │ │ <-- Q来自Decoder, K/V来自Encoder
│ └────────┬─────────┘ │
│ │ Add & Norm │
│ ┌────────┴─────────┐ │
│ │ Feed Forward │ │
│ └────────┬─────────┘ │
│ │ Add & Norm │
│ × N layers │
└────────────────────────┘
|
[Linear + Softmax]
|
输出概率分布
**关键超参数(原始论文 Base 版本):**
| 超参数 | 值 | 含义 |
|:---|:---:|:---|
| $d_{model}$ | 512 | 模型维度(嵌入维度) |
| $d_{ff}$ | 2048 | FFN 中间层维度 |
| $h$ | 8 | 注意力头数 |
| $d_k = d_v$ | 64 | 每个头的维度(512/8) |
| $N$ | 6 | Encoder/Decoder 层数 |
| $p_{drop}$ | 0.1 | Dropout 概率 |
---
## 3. 输入处理:Token Embedding + Positional Encoding
### 3.1 Token Embedding
将词表中的每个 token 映射为 $d_{model}$ 维的稠密向量。嵌入矩阵 $W_E \in \mathbb{R}^{|V| \times d_{model}}$,其中 $|V|$ 是词表大小。
原始论文中,**Encoder 和 Decoder 的嵌入矩阵共享权重**,并且与最终的 Linear 投影层共享,减少参数量。
```python
import torch
import torch.nn as nn
import math
class TokenEmbedding(nn.Module):
def __init__(self, vocab_size: int, d_model: int):
super().__init__()
self.embedding = nn.Embedding(vocab_size, d_model)
self.d_model = d_model
def forward(self, x):
# 论文中对嵌入乘以 sqrt(d_model) 以匹配位置编码的数值范围
return self.embedding(x) * math.sqrt(self.d_model)
3.2 位置编码(Positional Encoding)
为什么需要? 自注意力本身是置换不变的(permutation-invariant),不管你怎么打乱输入顺序,注意力矩阵经过重排后结果一样,无法感知位置信息。
原始论文使用固定的正弦/余弦位置编码:
PE(pos,2i)=sin(pos100002i/dmodel)PE_{(pos, 2i)} = \sin\left(\frac{pos}{10000^{2i/d_{model}}}\right)PE(pos,2i)=sin(100002i/dmodelpos)
PE(pos,2i+1)=cos(pos100002i/dmodel)PE_{(pos, 2i+1)} = \cos\left(\frac{pos}{10000^{2i/d_{model}}}\right)PE(pos,2i+1)=cos(100002i/dmodelpos)
设计直觉:
- 不同频率的正弦波组合,可以唯一表示任意位置
- 对于固定的位置偏移 kkk,PEpos+kPE_{pos+k}PEpos+k 可以表示为 PEposPE_{pos}PEpos 的线性变换,有利于模型学习相对位置关系
- 可以外推到训练时未见过的更长序列
class PositionalEncoding(nn.Module):
def __init__(self, d_model: int, max_len: int = 5000, dropout: float = 0.1):
super().__init__()
self.dropout = nn.Dropout(p=dropout)
# 预计算位置编码矩阵 [max_len, d_model]
pe = torch.zeros(max_len, d_model)
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) # [max_len, 1]
# 计算分母:10000^(2i/d_model),用 exp(log) 形式数值更稳定
div_term = torch.exp(
torch.arange(0, d_model, 2, dtype=torch.float) *
(-math.log(10000.0) / d_model)
) # [d_model/2]
pe[:, 0::2] = torch.sin(position * div_term) # 偶数维
pe[:, 1::2] = torch.cos(position * div_term) # 奇数维
pe = pe.unsqueeze(0) # [1, max_len, d_model],广播用
self.register_buffer('pe', pe) # 不参与梯度更新
def forward(self, x: torch.Tensor) -> torch.Tensor:
# x: [batch_size, seq_len, d_model]
x = x + self.pe[:, :x.size(1), :]
return self.dropout(x)
现代改进:BERT 使用可学习的位置编码(learnable positional embedding),RoPE(旋转位置编码,LLaMA 使用)、ALiBi(偏置注意力线性插值)等方案在长序列外推上表现更好。
4. 核心模块:缩放点积注意力
4.1 公式与推导
Attention(Q,K,V)=softmax(QKTdk)V\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right) VAttention(Q,K,V)=softmax(dkQKT)V
逐步拆解:
Step 1:计算相似度分数
scoreij=qi⋅kjdk\text{score}_{ij} = \frac{q_i \cdot k_j}{\sqrt{d_k}}scoreij=dkqi⋅kj
qiq_iqi 是第 iii 个 query 向量,kjk_jkj 是第 jjj 个 key 向量,点积衡量它们的相似度。
Step 2:为什么除以 dk\sqrt{d_k}dk?
当 dkd_kdk 较大时,点积的方差会增大(约为 dkd_kdk),导致 softmax 进入梯度极小的饱和区。除以 dk\sqrt{d_k}dk 将方差归一化到 1,保持梯度流动。
证明:若 q,kq, kq,k 各分量独立同分布,均值 0 方差 1,则:
Var(q⋅k)=∑i=1dkVar(qiki)=dk\text{Var}(q \cdot k) = \sum_{i=1}^{d_k} \text{Var}(q_i k_i) = d_kVar(q⋅k)=i=1∑dkVar(qiki)=dk
Var(q⋅kdk)=1\text{Var}\left(\frac{q \cdot k}{\sqrt{d_k}}\right) = 1Var(dkq⋅k)=1
Step 3:Softmax 归一化
αij=exp(scoreij)∑jexp(scoreij)\alpha_{ij} = \frac{\exp(\text{score}_{ij})}{\sum_j \exp(\text{score}_{ij})}αij=∑jexp(scoreij)exp(scoreij)
得到注意力权重,∑jαij=1\sum_j \alpha_{ij} = 1∑jαij=1。
Step 4:加权聚合 Value
outputi=∑jαijvj\text{output}_i = \sum_j \alpha_{ij} v_joutputi=j∑αijvj
直觉上:第 iii 个 token 的输出是所有 token 的 Value 向量按注意力权重的加权平均。
4.2 矩阵形式与复杂度
- 输入:Q∈Rn×dkQ \in \mathbb{R}^{n \times d_k}Q∈Rn×dk,K∈Rm×dkK \in \mathbb{R}^{m \times d_k}K∈Rm×dk,V∈Rm×dvV \in \mathbb{R}^{m \times d_v}V∈Rm×dv
- 计算 QKTQK^TQKT:O(n⋅m⋅dk)O(n \cdot m \cdot d_k)O(n⋅m⋅dk) 时间,O(n⋅m)O(n \cdot m)O(n⋅m) 空间
- 对于自注意力(n=mn = mn=m):时间和空间复杂度均为 O(n2)O(n^2)O(n2),这是 Ransformer 意力
4.3 完整实现
key: [batch, heads, seq_k, d_k]
value: [batch, heads, seq_k, d_v]
mask: [batch, 1, seq_q, seq_k] 或 [batch, 1, 1, seq_k],True 表示需要 mask 掉
dropout_p: attention 权重的 dropout 概率
Returns:
output: [batch, heads, seq_q, d_v]
attn_weights: [batch, heads, seq_q, seq_k]
"""
d_k = query.size(-1)
# [batch, heads, seq_q, seq_k]
scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(d_k)
# 应用 mask:将需要屏蔽的位置设为极大负值,softmax 后趋近于 0
if mask is not None:
scores = scores.masked_fill(mask, float('-inf'))
attn_weights = F.softmax(scores, dim=-1)
# 处理全行为 -inf 导致 softmax 输出 NaN 的情况
attn_weights = torch.nan_to_num(attn_weights, nan=0.0)
if dropout_p > 0.0 and training:
attn_weights = F.dropout(attn_weights, p=dropout_p)
output = torch.matmul(attn_weights, value)
return output, attn_weights
---
## 5. 多头注意力机制(Multi-Head Attention)
### 5.1 设计动机
单头注意力将所有信息压缩到一个注意力分布上,表达能力有限。**多头注意力**的思路是:
> 与其让一个注意力头学所有东西,不如让多个头各自专注于不同类型的关系。
例如,在句子 "The animal didn't cross the street because **it** was too tired" 中:
- 某个头可能专注于句法依存关系(it → animal)
- 另一个头可能关注语义相似性
- 还有头关注位置邻近性
### 5.2 公式
将 $d_{model}$ 维的 Q/K/V 分别投影到 $h$ 个 $d_k$ 维子空间,独立计算注意力后拼接:
$$\text{head}_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V)$$
$$\text{MultiHead}(Q, K, V) = \text{Concat}(\text{head}_1, ..., \text{head}_h) W^O$$
其中:
- $W_i^Q \in \mathbb{R}^{d_{model} \times d_k}$,$W_i^K \in \mathbb{R}^{d_{model} \times d_k}$,$W_i^V \in \mathbb{R}^{d_{model} \times d_v}$
- $W^O \in \mathbb{R}^{hd_v \times d_{model}}$
- 原始论文中 $d_k = d_v = d_{model} / h = 64$
**参数量**:$3 \times d_{model} \times d_{model} + d_{model} \times d_{model} = 4 d_{model}^2$(与单头相同!)
self.d_model = d_model
self.num_heads = num_heads
self.d_k = d_model // num_heads
# 将 Q/K/V 的投影合并为一个大矩阵,效率更高
self.W_Q = nn.Linear(d_model, d_model, bias=False)
self.W_K = nn.Linear(d_model, d_model, bias=False)
self.W_V = nn.Linear(d_model, d_model, bias=False)
self.W_O = nn.Linear(d_model, d_model, bias=False)
self.dropout = nn.Dropout(dropout)
self.attn_weights = None # 保存供可视化
def split_heads(self, x: torch.Tensor) -> torch.Tensor:
"""[batch, seq, d_model] -> [batch, heads, seq, d_k]"""
batch, seq, _ = x.size()
x = x.view(batch, seq, self.num_heads, self.d_k)
return x.transpose(1, 2) # [batch, heads, seq, d_k]
def forward(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
batch_size = query.size(0)
# 线性投影
Q = self.split_heads(self.W_Q(query)) # [batch, heads, seq_q, d_k]
K = self.split_heads(self.W_K(key)) # [batch, heads, seq_k, d_k]
V = self.split_heads(self.W_V(value)) # [batch, heads, seq_k, d_k]
# 计算注意力
d_k = Q.size(-1)
scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(d_k)
if mask is not None:
scores = scores.masked_fill(mask, float('-inf'))
attn_weights = F.softmax(scores, dim=-1)
attn_weights = torch.nan_to_num(attn_weights, nan=0.0)
self.attn_weights = attn_weights # 保存用于可视化
attn_weights = self.dropout(attn_weights)
out = torch.matmul(attn_weights, V) # [batch, heads, seq_q, d_k]
# 合并多头:[batch, heads, seq, d_k] -> [batch, seq, d_model]
out = out.transpose(1, 2).contiguous()
out = out.view(batch_size, -1, self.d_model)
return self.W_O(out)
6. 前馈神经网络(Feed-Forward Network)
每个 Encoder/Decoder 层中,注意力层之后跟一个 位置独立的全连接网络(每个位置单独处理,权重共享):
FFN(x)=max(0,xW1+b1)W2+b2\text{FFN}(x) = \max(0, xW_1 + b_1)W_2 + b_2FFN(x)=max(0,xW1+b1)W2+b2
- W1∈Rdmodel×dffW_1 \in \mathbb{R}^{d_{model} \times d_{ff}}W1∈Rdmodel×dff,W2∈Rdff×dmodelW_2 \in \mathbb{R}^{d_{ff} \times d_{model}}W2∈Rdff×dmodel
- 原始论文:dff=4×dmodel=2048d_{ff} = 4 \times d_{model} = 2048dff=4×dmodel=2048(扩张再压缩)
- 参数量:2×dmodel×dff=2×512×2048≈2M2 \times d_{model} \times d_{ff} = 2 \times 512 \times 2048 \approx 2M2×dmodel×dff=2×512×2048≈2M
FFN 的作用是什么? 注意力层负责"信息路由"(决定关注谁),FFN 则对每个 token 的特征进行非线性变换,存储和提取"知识"。有研究表明 FFN 层像是一个 key-value 记忆系统。
class FeedForwardNetwork(nn.Module):
def __init__(self, d_model: int, d_ff: int, dropout: float = 0.1):
super().__init__()
self.linear1 = nn.Linear(d_model, d_ff)
self.linear2 = nn.Linear(d_ff, d_model)
self.dropout = nn.Dropout(dropout)
self.activation = nn.ReLU()
# 现代模型常用 GELU:self.activation = nn.GELU()
# LLaMA 用 SwiGLU:更复杂但效果更好
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.linear2(self.dropout(self.activation(self.linear1(x))))
7. 残差连接与层归一化
7.1 残差连接(Residual Connection)
每个子层(注意力/FFN)的输出都通过残差连接与输入相加:
output=LayerNorm(x+Sublayer(x))\text{output} = \text{LayerNorm}(x + \text{Sublayer}(x))output=LayerNorm(x+Sublayer(x))
作用:
- 提供梯度的"高速公路",解决深层网络梯度消失
- 使模型能轻松学习恒等映射(初始化时子层输出接近 0,整体接近恒等变换)
7.2 层归一化(Layer Normalization)
对每个样本在特征维度上归一化(不同于 BatchNorm 在 batch 维度上归一化):
LayerNorm(x)=γ⊙x−μσ+ϵ+β\text{LayerNorm}(x) = \gamma \odot \frac{x - \mu}{\sigma + \epsilon} + \betaLayerNorm(x)=γ⊙σ+ϵx−μ+β
其中 μ,σ\mu, \sigmaμ,σ 是对 xxx 在特征维度上计算的均值和标准差,γ,β\gamma, \betaγ,β 是可学习的缩放和偏移参数。
为什么用 LayerNorm 而不是 BatchNorm?
- 序列长度不固定,batch 维度统计量不稳定
- 推理时 batch_size=1 时 BatchNorm 效果差
- LayerNorm 在每个样本内部归一化,不受 batch size 影响
Pre-Norm vs Post-Norm:
- 原始论文用 Post-Norm(加完再 Norm)
- 现代大模型(GPT-2+,LLaMA)普遍用 Pre-Norm(先 Norm 再过子层),训练更稳定
class PreNormLayer(nn.Module):
"""Pre-LN 包装器(现代做法)"""
def __init__(self, d_model: int, sublayer: nn.Module):
super().__init__()
self.norm = nn.LayerNorm(d_model)
self.sublayer = sublayer
def forward(self, x, **kwargs):
return x + self.sublayer(self.norm(x), **kwargs)
8. 编码器(Encoder)完整实现
f, d_model: int, num_heads: int, d_ff: int, dropout: float = 0.1):
super().init()
self.self_attn = MultiHeadAttention(d_model, num_heads, dropout)
self.ffn = FeedForwardNetwork(d_model, d_ff, dropout)
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
self.dropout = nn.Dropout(dropout)
def forward(
self,
x: torch.Tensor,
src_mask: Optional[torch.Tensor] = None
) -> torch.Tensor:
# Post-Norm(原始论文)
# Self-Attention + 残差 + Norm
attn_out = self.self_attn(x, x, x, mask=src_mask)
x = self.norm1(x + self.dropout(attn_out))
# FFN + 残差 + Norm
ffn_out = self.ffn(x)
x = self.norm2(x + self.dropout(ffn_out))
return x
class Encoder(nn.Module):
def init(
self,
vocab_size: int,
d_model: int = 512,
num_heads: int = 8,
d_ff: int = 2048,
num_layers: int = 6,
dropout: float = 0.1,
max_len: int = 5000,
):
super().init()
self.embedding = TokenEmbedding(vocab_size, d_model)
self.pos_encoding = PositionalEncoding(d_model, max_len, dropout)
self.layers = nn.ModuleList([
EncoderLayer(d_model, num_heads, d_ff, dropout)
for _ in range(num_layers)
])
self.norm = nn.LayerNorm(d_model)
def forward(
self,
src: torch.Tensor, # [batch, src_len]
src_mask: Optional[torch.Tensor] = None # [batch, 1, 1, src_len]
) -> torch.Tensor:
x = self.pos_encoding(self.embedding(src))
for layer in self.layers:
x = layer(x, src_mask)
return self.norm(x)
---
## 9. 解码器(Decoder)与 Mask 机制
### 9.1 两种 Mask
**Padding Mask(填充遮罩):**
序列长度不一致时用 `<PAD>` 填充,注意力不应该关注这些填充位置:
```python
def make_pad_mask(seq: torch.Tensor, pad_idx: int = 0) -> torch.Tensor:
"""
seq: [batch, seq_len]
返回: [batch, 1, 1, seq_len],True 表示该位置需要被 mask
"""
return (seq == pad_idx).unsqueeze(1).unsqueeze(2)
Causal Mask / Look-ahead Mask(因果遮罩):
Decoder 在自回归生成时,位置 iii 不能看到 i+1i+1i+1 及之后的位置(防止"作弊"):
def make_causal_mask(seq_len: int, device: torch.device) -> torch.Tensor:
"""
返回上三角矩阵(True 表示需要被 mask 的未来位置)
[1, 1, seq_len, seq_len]
"""
mask = torch.triu(torch.ones(seq_len, seq_len, device=device), diagonal=1)
return mask.bool().unsqueeze(0).unsqueeze(0)
可视化(seq_len=4):
False True True True 位置0只能看自己
False False True True 位置1能看0,1
False False False True 位置2能看0,1,2
False False False False 位置3能看全部
9.2 Decoder 完整实现
class DecoderLayer(nn.Module):
def __init__(self, d_model: int, num_heads: int, d_ff: int, dropout: float = 0.1):
super().__init__()
# 1. Masked Self-Attention(只看已生成的部分)
self.self_attn = MultiHeadAttention(d_model, num_heads, dropout)
# 2. Cross-Attention(Q 来自 Decoder,K/V 来自 Encoder)
self.cross_attn = MultiHeadAttention(d_model, num_heads, dropout)
self.ffn = FeedForwardNetwork(d_model, d_ff, dropout)
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
self.norm3 = nn.LayerNorm(d_model)
self.dropout = nn.Dropout(dropout)
def forward(
self,
tgt: torch.Tensor,
memory: torch.Tensor,
tgt_mask: Optional[torch.Tensor] = None,
memory_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
# Masked Self-Attention
attn1 = self.self_attn(tgt, tgt, tgt, mask=tgt_mask)
tgt = self.norm1(tgt + self.dropout(attn1))
# Cross-Attention:Q 来自 Decoder,K/V 来自 Encoder memory
attn2 = self.cross_attn(tgt, memory, memory, mask=memory_mask)
tgt = self.norm2(tgt + self.dropout(attn2))
# FFN
ffn_out = self.ffn(tgt)
tgt = self.norm3(tgt + self.dropout(ffn_out))
return tgt
class Decoder(nn.Module):
def __init__(
self,
vocab_size: int,
d_model: int = 512,
num_heads: int = 8,
d_ff: int = 2048,
num_layers: int = 6,
dropout: float = 0.1,
max_len: int = 5000,
):
super().__init__()
self.embedding = TokenEmbedding(vocab_size, d_model)
self.pos_encoding = PositionalEncoding(d_model, max_len, dropout)
self.layers = nn.ModuleList([
DecoderLayer(d_model, num_heads, d_ff, dropout)
for _ in range(num_layers)
])
self.norm = nn.LayerNorm(d_model)
def forward(
self,
tgt: torch.Tensor,
memory: torch.Tensor,
tgt_mask: Optional[torch.Tensor] = None,
memory_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
x = self.pos_encoding(self.embedding(tgt))
for layer in self.layers:
x = layer(x, memory, tgt_mask, memory_mask)
return self.norm(x)
---
## 10. 完整 Transformer 端到端
```python
class Transformer(nn.Module):
def __init__(
self,
src_vocab_size: int,
tgt_vocab_size: int,
d_model: int = 512,
num_heads: int = 8,
d_ff: int = 2048,
num_encoder_layers: int = 6,
num_decoder_layers: int = 6,
dropout: float = 0.1,
max_len: int = 5000,
pad_idx: int = 0,
):
super().__init__()
self.pad_idx = pad_idx
self.encoder = Encoder(src_vocab_size, d_model, num_heads, d_ff,
num_encoder_layers, dropout, max_len)
self.decoder = Decoder(tgt_vocab_size, d_model, num_heads, d_ff,
num_decoder_layers, dropout, max_len)
self.output_projection = nn.Linear(d_model, tgt_vocab_size)
self._init_weights()
def _init_weights(self):
"""Xavier 初始化所有线性层"""
for p in self.parameters():
if p.dim() > 1:
nn.init.xavier_uniform_(p)
def encode(self, src: torch.Tensor) -> torch.Tensor:
src_mask = make_pad_mask(src, self.pad_idx)
return self.encoder(src, src_mask), src_mask
def decode(
self,
tgt: torch.Tensor,
memory: torch.Tensor,
memory_mask: torch.Tensor,
) -> torch.Tensor:
tgt_len = tgt.size(1)
tgt_pad_mask = make_pad_mask(tgt, self.pad_idx)
tgt_causal_mask = make_causal_mask(tgt_len, tgt.device)
# 合并 padding mask 和 causal mask
tgt_mask = tgt_pad_mask | tgt_causal_mask
return self.decoder(tgt, memory, tgt_mask, memory_mask)
def forward(
self,
src: torch.Tensor, # [batch, src_len]
tgt: torch.Tensor, # [batch, tgt_len]
) -> torch.Tensor:
memory, src_mask = self.encode(src)
dec_out = self.decode(tgt, memory, src_mask)
return self.output_projection(dec_out) # [batch, tgt_len, tgt_vocab_size]
11. 训练技巧与超参数选择
11.1 Warmup 学习率调度
原始论文提出了一个独特的学习率 schedule:
lr=dmodel−0.5⋅min(step−0.5, step⋅warmup_steps−1.5)lr = d_{model}^{-0.5} \cdot \min(\text{step}^{-0.5},\ \text{step} \cdot \text{warmup\_steps}^{-1.5})lr=dmodel−0.5⋅min(step−0.5, step⋅warmup_steps−1.5)
- 前
warmup_steps步线性增大学习率 - 之后按步数的 −0.5-0.5−0.5 次方衰减
class WarmupScheduler:
def __init__(self, optimizer, d_model: int, warmup_steps: int = 4000):
self.optimizer = optimizer
self.d_model = d_model
self.warmup_steps = warmup_steps
self.step_num = 0
def step(self):
self.step_num += 1
lr = self.d_model ** (-0.5) * min(
self.step_num ** (-0.5),
self.step_num * self.warmup_steps ** (-1.5)
)
for param_group in self.optimizer.param_groups:
param_group['lr'] = lr
11.2 标签平滑(Label Smoothing)
原始论文使用 ϵls=0.1\epsilon_{ls} = 0.1ϵls=0.1,将硬标签改为软标签:
ysmooth=(1−ϵ)⋅yone-hot+ϵ∣V∣y_{\text{smooth}} = (1 - \epsilon) \cdot y_{\text{one-hot}} + \frac{\epsilon}{|V|}ysmooth=(1−ϵ)⋅yone-hot+∣V∣ϵ
class LabelSmoothingLoss(nn.Module):
def __init__(self, vocab_size: int, padding_idx: int, smoothing: float = 0.1):
super().__init__()
self.smoothing = smoothing
self.padding_idx = padding_idx
self.vocab_size = vocab_size
def forward(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
# pred: [batch * seq, vocab], target: [batch * seq]
confidence = 1.0 - self.smoothing
smooth_val = self.smoothing / (self.vocab_size - 2)
true_dist = torch.full_like(pred, smooth_val)
true_dist.scatter_(1, target.unsqueeze(1), confidence)
true_dist[:, self.padding_idx] = 0
mask = (target == self.padding_idx)
true_dist[mask] = 0
return F.kl_div(F.log_softmax(pred, dim=-1), true_dist, reduction='sum')
11.3 其他关键技巧
| 技巧 | 说明 |
|---|---|
| Gradient Clipping | clip_grad_norm_(params, max_norm=1.0) 防止梯度爆炸 |
| Eropout | 注意力权重和 FFN 输出都加 dropout=0.1 |
| 权重共享 | Encoder/Decoder embedding 和输出投影层共享权重 |
| 混合精度训练 | torch.cuda.amp 使用 FP16 加速,节省显存 |
| 梯度累积 | 显存不够时,多步累积再更新 |
12. 常见变体对比:BERT / GPT / T5
| 特性 | BERT | GPT 系列 | T5 |
|---|---|---|---|
| 架构 | 仅 Encoder | 仅 Decoder | Encoder-Decoder |
| 注意力方向 | 双向 | 单向(因果) | 混合 |
| 预训练任务 | MLM + NSP | 语言模型(自回归) | Span 预测 |
| 适合任务 | 理解(分类/NER) | 生成(对话/补全) | 生成+理解 |
| 代表模型 | RoBERTa, ALBERT | GPT-4, LLaMA | T5, FLAN-T5 |
BERT 的 MLM(Masked Language Model):
随机遮蔽 15% 的 token,让模型预测被遮蔽的词。
GPT 的自回归生成:
每次根据已有 token 预测下一个 token,使用 Causal Mask 确保只看左侧上下文。
13. 推理优化:KV Cache、Flash Attention、量化
13.1 KV Cache
自回归生成时,每次生成新 token 都需要对所有历史 token 重新计算 K/V,非常浪费。
KV Cache 的思路:缓存历史步骤的 K/V 矩阵,新步骤只计算新 token 的 K/V 并拼接进去:
# 伪代码示意
cache_k, cache_v = [], []
for step in range(max_new_tokens):
q = compute_query(new_token) # 只计算新 token 的 Q
k = compute_key(new_token)
v = compute_value(new_token)
# 拼接历史 K/V
cache_k.append(k)
cache_v.append(v)
K_all = torch.cat(cache_k, dim=-2)
V_all = torch.cat(cache_v, dim=-2)
output = attention(q, K_all, V_all)
时间复杂度从 O(n2)O(n^2)O(n2) 降为 O(n)O(n)O(n)(每步只做一次向量-矩阵乘法)。
13.2 Flash Attention
标准注意力计算瓶颈在 HBM(显存)带宽,而非计算量本身。注意力矩阵 n×nn \times nn×n 需要频繁读写 HBM。
Flash Attention 利用 Tiling(分块计算) 将整个注意力计算放在 SRAM(片上高速缓存)中完成,避免频繁的 HBM 读写:
- 速度提升 2-4x
- 内存占用从 O(n2)O(n^2)O(n2) 降为 O(n)O(n)O(n)
- 数值结果与标准注意力完全一致
PyTorch 2.0+ 已内置:F.scaled_dot_product_attention()(自动使用 Flash Attention)
13.3 模型量化
| 方案 | 精度损失 | 显存节省 | 速度 |
|---|---|---|---|
| FP32 → FP16 | 极小 | 50% | +50% |
| FP32 → INT8 | 小 | 75% | +2x |
| FP32 → INT4 | 中等 | 87.5% | +3x |
| GPTQ / AWQ | 小(后训练量化) | 75-87% | 依硬件 |
14. 总结与学习路线
为什么 Transformer 统治了 AI?
- O(1)O(1)O(1) $��息传递距离:任意两个 token 可以直接交互
- 完全可并行:训练时所有 token 同时计算,充分利用 GPU
- 扩展性强(Scaling Law):模型越大、数据越多、效果越好,无明显天花板
- 通用架构:NLP / CV / 多模态 / 蛋白质结构预测均可使用
├── 线性代数(矩阵乘法、特征值)
├── 概率论(softmax、交叉熵)
└── PyTorch 基础
Transformer 核心
├── 精读《Attention Is All You Need》
├── 手写本文所有代码并跑通
└── 可视化注意力权重(BertViz)
进阶应用
├── BERT 微调(文本分类、NER)
├── GPT-2 文本生成
└── HuggingFace Transformers 库源码阅读
大模型方向
├── LLaMA 架构(RoPE, SwiGLU, RMSNorm)
├── RLHF(InstructGPT)
└── 参数高效微调(LoRA, Prefix-Tuning)
---
**参考文献**
- Vaswani et al., 2017. *Attention Is All You Need*. NeurIPS
- Devlin et al., 2018. *BERT: Pre-training of Deep Bidirectional Transformers*. NAACL
- Brown et al., 2020. *Language Models are Few-Shot Learners*. NeurIPS
- Dao et al., 2022. *FlashAttention: Fast and Memory-Efficient Exact Attention*. NeurIPS
---
## 关于作者
> 如果这篇文章对你有帮助,欢迎关注我的 GitHub 项目,那里有更多实战代码、学习笔记和开源资源,持续更新中!
GitHub: https://github.com/JerryZ01/openclaw-guide
> 如果觉得有用,请给个 Star 支持一下!你的支持是我持续创作的最大动力 🙏
*觉得不错的话,也欢迎点赞、收藏、转发,让更多人看到这篇文章~*
更多推荐


所有评论(0)