深入理解 Transformer:从自注意力机制到大模型的基石原理详解

摘要:2017 年 Google 论文《Attention Is All You Need》提出了 Transformer 架构,彻底颠覆了 NLP 乃至整个深度学习领域的格局。本文将从动机出发,逐步拆解每一个核心模块,配套完整 PyTorch 实现代码数学推导工程实践经验,力求让你真正"吃透"Transformer,而不只是会背公式。


目录

  1. 为什么需要 Transformer?RNN 的致命缺陷
  2. Transformer 整体架构总览
  3. 输入处理:Token Embedding + Positional Encoding
  4. 核心模块:缩放点积注意力(Scaled Dot-Product Attention)
  5. 多头注意力机制(Multi-Head Attention)
  6. 前馈神经网络(Feed-Forward Network)
  7. 残差连接与层归一化(Residual + LayerNorm)
  8. 编码器(Encoder)完整实现
  9. 解码器(Decoder)与 Mask 机制
  10. 完整 Transformer 端到端实现
  11. 训练技巧与超参数选择
  12. 常见变体对比:BERT / GPT / T5
  13. 推理优化:KV Cache、Flash Attention、量化
  14. 总结与学习路线

— [Token Embedding]
|
[Positional Encoding]
|
┌────────────────────────┐
│ Encoder │
│ ┌──────────────────┐ │
│ │ Multi-Head Attn │ │
│ │ (Self-Attention) │ │
│ └────────┬─────────┘ │
│ │ Add & Norm │
│ ┌────────┴─────────┐ │
│ │ Feed Forward │ │
│ └────────┬─────────┘ │
│ │ Add & Norm │
│ × N layers │
└────────────────────────┘
|
Context (Memory)
|
┌────────────────────────┐
│ Decoder │
│ ┌──────────────────┐ │
│ │ Masked MH Attn │ │ <-- 输出序列 (tgt)
│ └────────┬─────────┘ │
│ │ Add & Norm │
│ ┌────────┴─────────┐ │
│ │ Cross-Attention │ │ <-- Q来自Decoder, K/V来自Encoder
│ └────────┬─────────┘ │
│ │ Add & Norm │
│ ┌────────┴─────────┐ │
│ │ Feed Forward │ │
│ └────────┬─────────┘ │
│ │ Add & Norm │
│ × N layers │
└────────────────────────┘
|
[Linear + Softmax]
|
输出概率分布


**关键超参数(原始论文 Base 版本):**

| 超参数 | 值 | 含义 |
|:---|:---:|:---|
| $d_{model}$ | 512 | 模型维度(嵌入维度) |
| $d_{ff}$ | 2048 | FFN 中间层维度 |
| $h$ | 8 | 注意力头数 |
| $d_k = d_v$ | 64 | 每个头的维度(512/8) |
| $N$ | 6 | Encoder/Decoder 层数 |
| $p_{drop}$ | 0.1 | Dropout 概率 |

---

## 3. 输入处理:Token Embedding + Positional Encoding

### 3.1 Token Embedding

将词表中的每个 token 映射为 $d_{model}$ 维的稠密向量。嵌入矩阵 $W_E \in \mathbb{R}^{|V| \times d_{model}}$,其中 $|V|$ 是词表大小。

原始论文中,**Encoder 和 Decoder 的嵌入矩阵共享权重**,并且与最终的 Linear 投影层共享,减少参数量。

```python
import torch
import torch.nn as nn
import math

class TokenEmbedding(nn.Module):
    def __init__(self, vocab_size: int, d_model: int):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, d_model)
        self.d_model = d_model

    def forward(self, x):
        # 论文中对嵌入乘以 sqrt(d_model) 以匹配位置编码的数值范围
        return self.embedding(x) * math.sqrt(self.d_model)

3.2 位置编码(Positional Encoding)

为什么需要? 自注意力本身是置换不变的(permutation-invariant),不管你怎么打乱输入顺序,注意力矩阵经过重排后结果一样,无法感知位置信息。

原始论文使用固定的正弦/余弦位置编码

PE(pos,2i)=sin⁡(pos100002i/dmodel)PE_{(pos, 2i)} = \sin\left(\frac{pos}{10000^{2i/d_{model}}}\right)PE(pos,2i)=sin(100002i/dmodelpos)

PE(pos,2i+1)=cos⁡(pos100002i/dmodel)PE_{(pos, 2i+1)} = \cos\left(\frac{pos}{10000^{2i/d_{model}}}\right)PE(pos,2i+1)=cos(100002i/dmodelpos)

设计直觉:

  • 不同频率的正弦波组合,可以唯一表示任意位置
  • 对于固定的位置偏移 kkkPEpos+kPE_{pos+k}PEpos+k 可以表示为 PEposPE_{pos}PEpos 的线性变换,有利于模型学习相对位置关系
  • 可以外推到训练时未见过的更长序列
class PositionalEncoding(nn.Module):
    def __init__(self, d_model: int, max_len: int = 5000, dropout: float = 0.1):
        super().__init__()
        self.dropout = nn.Dropout(p=dropout)

        # 预计算位置编码矩阵 [max_len, d_model]
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)  # [max_len, 1]
        
        # 计算分母:10000^(2i/d_model),用 exp(log) 形式数值更稳定
        div_term = torch.exp(
            torch.arange(0, d_model, 2, dtype=torch.float) * 
            (-math.log(10000.0) / d_model)
        )  # [d_model/2]

        pe[:, 0::2] = torch.sin(position * div_term)  # 偶数维
        pe[:, 1::2] = torch.cos(position * div_term)  # 奇数维
        
        pe = pe.unsqueeze(0)  # [1, max_len, d_model],广播用
        self.register_buffer('pe', pe)  # 不参与梯度更新

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # x: [batch_size, seq_len, d_model]
        x = x + self.pe[:, :x.size(1), :]
        return self.dropout(x)

现代改进:BERT 使用可学习的位置编码(learnable positional embedding),RoPE(旋转位置编码,LLaMA 使用)、ALiBi(偏置注意力线性插值)等方案在长序列外推上表现更好。


4. 核心模块:缩放点积注意力

4.1 公式与推导

Attention(Q,K,V)=softmax(QKTdk)V\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right) VAttention(Q,K,V)=softmax(dk QKT)V

逐步拆解:

Step 1:计算相似度分数

scoreij=qi⋅kjdk\text{score}_{ij} = \frac{q_i \cdot k_j}{\sqrt{d_k}}scoreij=dk qikj

qiq_iqi 是第 iii 个 query 向量,kjk_jkj 是第 jjj 个 key 向量,点积衡量它们的相似度。

Step 2:为什么除以 dk\sqrt{d_k}dk

dkd_kdk 较大时,点积的方差会增大(约为 dkd_kdk),导致 softmax 进入梯度极小的饱和区。除以 dk\sqrt{d_k}dk 将方差归一化到 1,保持梯度流动。

证明:若 q,kq, kq,k 各分量独立同分布,均值 0 方差 1,则:
Var(q⋅k)=∑i=1dkVar(qiki)=dk\text{Var}(q \cdot k) = \sum_{i=1}^{d_k} \text{Var}(q_i k_i) = d_kVar(qk)=i=1dkVar(qiki)=dk
Var(q⋅kdk)=1\text{Var}\left(\frac{q \cdot k}{\sqrt{d_k}}\right) = 1Var(dk qk)=1

Step 3:Softmax 归一化

αij=exp⁡(scoreij)∑jexp⁡(scoreij)\alpha_{ij} = \frac{\exp(\text{score}_{ij})}{\sum_j \exp(\text{score}_{ij})}αij=jexp(scoreij)exp(scoreij)

得到注意力权重,∑jαij=1\sum_j \alpha_{ij} = 1jαij=1

Step 4:加权聚合 Value

outputi=∑jαijvj\text{output}_i = \sum_j \alpha_{ij} v_joutputi=jαijvj

直觉上:第 iii 个 token 的输出是所有 token 的 Value 向量按注意力权重的加权平均。

4.2 矩阵形式与复杂度

  • 输入:Q∈Rn×dkQ \in \mathbb{R}^{n \times d_k}QRn×dkK∈Rm×dkK \in \mathbb{R}^{m \times d_k}KRm×dkV∈Rm×dvV \in \mathbb{R}^{m \times d_v}VRm×dv
  • 计算 QKTQK^TQKTO(n⋅m⋅dk)O(n \cdot m \cdot d_k)O(nmdk) 时间,O(n⋅m)O(n \cdot m)O(nm) 空间
  • 对于自注意力(n=mn = mn=m):时间和空间复杂度均为 O(n2)O(n^2)O(n2),这是 Ransformer 意力

4.3 完整实现

  key:    [batch, heads, seq_k, d_k]
    value:  [batch, heads, seq_k, d_v]
    mask:   [batch, 1, seq_q, seq_k] 或 [batch, 1, 1, seq_k],True 表示需要 mask 掉
    dropout_p: attention 权重的 dropout 概率

Returns:
    output: [batch, heads, seq_q, d_v]
    attn_weights: [batch, heads, seq_q, seq_k]
"""
d_k = query.size(-1)

# [batch, heads, seq_q, seq_k]
scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(d_k)

# 应用 mask:将需要屏蔽的位置设为极大负值,softmax 后趋近于 0
if mask is not None:
    scores = scores.masked_fill(mask, float('-inf'))

attn_weights = F.softmax(scores, dim=-1)

# 处理全行为 -inf 导致 softmax 输出 NaN 的情况
attn_weights = torch.nan_to_num(attn_weights, nan=0.0)

if dropout_p > 0.0 and training:
    attn_weights = F.dropout(attn_weights, p=dropout_p)

output = torch.matmul(attn_weights, value)
return output, attn_weights

---

## 5. 多头注意力机制(Multi-Head Attention)

### 5.1 设计动机

单头注意力将所有信息压缩到一个注意力分布上,表达能力有限。**多头注意力**的思路是:

> 与其让一个注意力头学所有东西,不如让多个头各自专注于不同类型的关系。

例如,在句子 "The animal didn't cross the street because **it** was too tired" 中:
- 某个头可能专注于句法依存关系(it → animal)
- 另一个头可能关注语义相似性
- 还有头关注位置邻近性

### 5.2 公式

将 $d_{model}$ 维的 Q/K/V 分别投影到 $h$ 个 $d_k$ 维子空间,独立计算注意力后拼接:

$$\text{head}_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V)$$

$$\text{MultiHead}(Q, K, V) = \text{Concat}(\text{head}_1, ..., \text{head}_h) W^O$$

其中:
- $W_i^Q \in \mathbb{R}^{d_{model} \times d_k}$,$W_i^K \in \mathbb{R}^{d_{model} \times d_k}$,$W_i^V \in \mathbb{R}^{d_{model} \times d_v}$
- $W^O \in \mathbb{R}^{hd_v \times d_{model}}$
- 原始论文中 $d_k = d_v = d_{model} / h = 64$

**参数量**:$3 \times d_{model} \times d_{model} + d_{model} \times d_{model} = 4 d_{model}^2$(与单头相同!)
   self.d_model = d_model
        self.num_heads = num_heads
        self.d_k = d_model // num_heads
        
        # 将 Q/K/V 的投影合并为一个大矩阵,效率更高
        self.W_Q = nn.Linear(d_model, d_model, bias=False)
        self.W_K = nn.Linear(d_model, d_model, bias=False)
        self.W_V = nn.Linear(d_model, d_model, bias=False)
        self.W_O = nn.Linear(d_model, d_model, bias=False)
        
        self.dropout = nn.Dropout(dropout)
        self.attn_weights = None  # 保存供可视化

    def split_heads(self, x: torch.Tensor) -> torch.Tensor:
        """[batch, seq, d_model] -> [batch, heads, seq, d_k]"""
        batch, seq, _ = x.size()
        x = x.view(batch, seq, self.num_heads, self.d_k)
        return x.transpose(1, 2)  # [batch, heads, seq, d_k]

    def forward(
        self,
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
        mask: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:
        batch_size = query.size(0)

        # 线性投影
        Q = self.split_heads(self.W_Q(query))  # [batch, heads, seq_q, d_k]
        K = self.split_heads(self.W_K(key))    # [batch, heads, seq_k, d_k]
        V = self.split_heads(self.W_V(value))  # [batch, heads, seq_k, d_k]

        # 计算注意力
        d_k = Q.size(-1)
        scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(d_k)
        
        if mask is not None:
            scores = scores.masked_fill(mask, float('-inf'))
        
        attn_weights = F.softmax(scores, dim=-1)
        attn_weights = torch.nan_to_num(attn_weights, nan=0.0)
        self.attn_weights = attn_weights  # 保存用于可视化
        
        attn_weights = self.dropout(attn_weights)
        
        out = torch.matmul(attn_weights, V)  # [batch, heads, seq_q, d_k]

        # 合并多头:[batch, heads, seq, d_k] -> [batch, seq, d_model]
        out = out.transpose(1, 2).contiguous()
        out = out.view(batch_size, -1, self.d_model)

        return self.W_O(out)

6. 前馈神经网络(Feed-Forward Network)

每个 Encoder/Decoder 层中,注意力层之后跟一个 位置独立的全连接网络(每个位置单独处理,权重共享):

FFN(x)=max⁡(0,xW1+b1)W2+b2\text{FFN}(x) = \max(0, xW_1 + b_1)W_2 + b_2FFN(x)=max(0,xW1+b1)W2+b2

  • W1∈Rdmodel×dffW_1 \in \mathbb{R}^{d_{model} \times d_{ff}}W1Rdmodel×dffW2∈Rdff×dmodelW_2 \in \mathbb{R}^{d_{ff} \times d_{model}}W2Rdff×dmodel
  • 原始论文:dff=4×dmodel=2048d_{ff} = 4 \times d_{model} = 2048dff=4×dmodel=2048(扩张再压缩)
  • 参数量:2×dmodel×dff=2×512×2048≈2M2 \times d_{model} \times d_{ff} = 2 \times 512 \times 2048 \approx 2M2×dmodel×dff=2×512×20482M

FFN 的作用是什么? 注意力层负责"信息路由"(决定关注谁),FFN 则对每个 token 的特征进行非线性变换,存储和提取"知识"。有研究表明 FFN 层像是一个 key-value 记忆系统。

class FeedForwardNetwork(nn.Module):
    def __init__(self, d_model: int, d_ff: int, dropout: float = 0.1):
        super().__init__()
        self.linear1 = nn.Linear(d_model, d_ff)
        self.linear2 = nn.Linear(d_ff, d_model)
        self.dropout = nn.Dropout(dropout)
        self.activation = nn.ReLU()
        # 现代模型常用 GELU:self.activation = nn.GELU()
        # LLaMA 用 SwiGLU:更复杂但效果更好

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.linear2(self.dropout(self.activation(self.linear1(x))))

7. 残差连接与层归一化

7.1 残差连接(Residual Connection)

每个子层(注意力/FFN)的输出都通过残差连接与输入相加:

output=LayerNorm(x+Sublayer(x))\text{output} = \text{LayerNorm}(x + \text{Sublayer}(x))output=LayerNorm(x+Sublayer(x))

作用:

  • 提供梯度的"高速公路",解决深层网络梯度消失
  • 使模型能轻松学习恒等映射(初始化时子层输出接近 0,整体接近恒等变换)

7.2 层归一化(Layer Normalization)

对每个样本在特征维度上归一化(不同于 BatchNorm 在 batch 维度上归一化):

LayerNorm(x)=γ⊙x−μσ+ϵ+β\text{LayerNorm}(x) = \gamma \odot \frac{x - \mu}{\sigma + \epsilon} + \betaLayerNorm(x)=γσ+ϵxμ+β

其中 μ,σ\mu, \sigmaμ,σ 是对 xxx 在特征维度上计算的均值和标准差,γ,β\gamma, \betaγ,β 是可学习的缩放和偏移参数。

为什么用 LayerNorm 而不是 BatchNorm?

  • 序列长度不固定,batch 维度统计量不稳定
  • 推理时 batch_size=1 时 BatchNorm 效果差
  • LayerNorm 在每个样本内部归一化,不受 batch size 影响

Pre-Norm vs Post-Norm:

  • 原始论文用 Post-Norm(加完再 Norm)
  • 现代大模型(GPT-2+,LLaMA)普遍用 Pre-Norm(先 Norm 再过子层),训练更稳定
class PreNormLayer(nn.Module):
    """Pre-LN 包装器(现代做法)"""
    def __init__(self, d_model: int, sublayer: nn.Module):
        super().__init__()
        self.norm = nn.LayerNorm(d_model)
        self.sublayer = sublayer

    def forward(self, x, **kwargs):
        return x + self.sublayer(self.norm(x), **kwargs)

8. 编码器(Encoder)完整实现

f, d_model: int, num_heads: int, d_ff: int, dropout: float = 0.1):
super().init()
self.self_attn = MultiHeadAttention(d_model, num_heads, dropout)
self.ffn = FeedForwardNetwork(d_model, d_ff, dropout)
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
self.dropout = nn.Dropout(dropout)

def forward(
    self, 
    x: torch.Tensor, 
    src_mask: Optional[torch.Tensor] = None
) -> torch.Tensor:
    # Post-Norm(原始论文)
    # Self-Attention + 残差 + Norm
    attn_out = self.self_attn(x, x, x, mask=src_mask)
    x = self.norm1(x + self.dropout(attn_out))
    
    # FFN + 残差 + Norm
    ffn_out = self.ffn(x)
    x = self.norm2(x + self.dropout(ffn_out))
    return x

class Encoder(nn.Module):
def init(
self,
vocab_size: int,
d_model: int = 512,
num_heads: int = 8,
d_ff: int = 2048,
num_layers: int = 6,
dropout: float = 0.1,
max_len: int = 5000,
):
super().init()
self.embedding = TokenEmbedding(vocab_size, d_model)
self.pos_encoding = PositionalEncoding(d_model, max_len, dropout)
self.layers = nn.ModuleList([
EncoderLayer(d_model, num_heads, d_ff, dropout)
for _ in range(num_layers)
])
self.norm = nn.LayerNorm(d_model)

def forward(
    self, 
    src: torch.Tensor,           # [batch, src_len]
    src_mask: Optional[torch.Tensor] = None  # [batch, 1, 1, src_len]
) -> torch.Tensor:
    x = self.pos_encoding(self.embedding(src))
    for layer in self.layers:
        x = layer(x, src_mask)
    return self.norm(x)

---

## 9. 解码器(Decoder)与 Mask 机制

### 9.1 两种 Mask

**Padding Mask(填充遮罩):**
序列长度不一致时用 `<PAD>` 填充,注意力不应该关注这些填充位置:

```python
def make_pad_mask(seq: torch.Tensor, pad_idx: int = 0) -> torch.Tensor:
    """
    seq: [batch, seq_len]
    返回: [batch, 1, 1, seq_len],True 表示该位置需要被 mask
    """
    return (seq == pad_idx).unsqueeze(1).unsqueeze(2)

Causal Mask / Look-ahead Mask(因果遮罩):
Decoder 在自回归生成时,位置 iii 不能看到 i+1i+1i+1 及之后的位置(防止"作弊"):

def make_causal_mask(seq_len: int, device: torch.device) -> torch.Tensor:
    """
    返回上三角矩阵(True 表示需要被 mask 的未来位置)
    [1, 1, seq_len, seq_len]
    """
    mask = torch.triu(torch.ones(seq_len, seq_len, device=device), diagonal=1)
    return mask.bool().unsqueeze(0).unsqueeze(0)

可视化(seq_len=4):

False True  True  True     位置0只能看自己
False False True  True     位置1能看0,1
False False False True     位置2能看0,1,2
False False False False    位置3能看全部

9.2 Decoder 完整实现

class DecoderLayer(nn.Module):
    def __init__(self, d_model: int, num_heads: int, d_ff: int, dropout: float = 0.1):
        super().__init__()
        # 1. Masked Self-Attention(只看已生成的部分)
        self.self_attn = MultiHeadAttention(d_model, num_heads, dropout)
        # 2. Cross-Attention(Q 来自 Decoder,K/V 来自 Encoder)
        self.cross_attn = MultiHeadAttention(d_model, num_heads, dropout)
        self.ffn = FeedForwardNetwork(d_model, d_ff, dropout)
        
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.norm3 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)
    def forward(
        self,
        tgt: torch.Tensor,
        memory: torch.Tensor,
        tgt_mask: Optional[torch.Tensor] = None,
        memory_mask: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:
        # Masked Self-Attention
        attn1 = self.self_attn(tgt, tgt, tgt, mask=tgt_mask)
        tgt = self.norm1(tgt + self.dropout(attn1))

        # Cross-Attention:Q 来自 Decoder,K/V 来自 Encoder memory
        attn2 = self.cross_attn(tgt, memory, memory, mask=memory_mask)
        tgt = self.norm2(tgt + self.dropout(attn2))

        # FFN
        ffn_out = self.ffn(tgt)
        tgt = self.norm3(tgt + self.dropout(ffn_out))
        return tgt


class Decoder(nn.Module):
    def __init__(
        self,
        vocab_size: int,
        d_model: int = 512,
        num_heads: int = 8,
        d_ff: int = 2048,
        num_layers: int = 6,
        dropout: float = 0.1,
        max_len: int = 5000,
    ):
        super().__init__()
        self.embedding = TokenEmbedding(vocab_size, d_model)
        self.pos_encoding = PositionalEncoding(d_model, max_len, dropout)
        self.layers = nn.ModuleList([
            DecoderLayer(d_model, num_heads, d_ff, dropout)
            for _ in range(num_layers)
        ])
        self.norm = nn.LayerNorm(d_model)

    def forward(
        self,
        tgt: torch.Tensor,
        memory: torch.Tensor,
        tgt_mask: Optional[torch.Tensor] = None,
        memory_mask: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:
        x = self.pos_encoding(self.embedding(tgt))
        for layer in self.layers:
            x = layer(x, memory, tgt_mask, memory_mask)
        return self.norm(x)

---

## 10. 完整 Transformer 端到端

```python
class Transformer(nn.Module):
    def __init__(
        self,
        src_vocab_size: int,
        tgt_vocab_size: int,
        d_model: int = 512,
        num_heads: int = 8,
        d_ff: int = 2048,
        num_encoder_layers: int = 6,
        num_decoder_layers: int = 6,
        dropout: float = 0.1,
        max_len: int = 5000,
        pad_idx: int = 0,
    ):
        super().__init__()
        self.pad_idx = pad_idx
        self.encoder = Encoder(src_vocab_size, d_model, num_heads, d_ff,
                               num_encoder_layers, dropout, max_len)
        self.decoder = Decoder(tgt_vocab_size, d_model, num_heads, d_ff,
                               num_decoder_layers, dropout, max_len)
        self.output_projection = nn.Linear(d_model, tgt_vocab_size)
        
        self._init_weights()

    def _init_weights(self):
        """Xavier 初始化所有线性层"""
        for p in self.parameters():
            if p.dim() > 1:
                nn.init.xavier_uniform_(p)

    def encode(self, src: torch.Tensor) -> torch.Tensor:
        src_mask = make_pad_mask(src, self.pad_idx)
        return self.encoder(src, src_mask), src_mask

    def decode(
        self,
        tgt: torch.Tensor,
        memory: torch.Tensor,
        memory_mask: torch.Tensor,
    ) -> torch.Tensor:
        tgt_len = tgt.size(1)
        tgt_pad_mask = make_pad_mask(tgt, self.pad_idx)
        tgt_causal_mask = make_causal_mask(tgt_len, tgt.device)
        # 合并 padding mask 和 causal mask
        tgt_mask = tgt_pad_mask | tgt_causal_mask
        return self.decoder(tgt, memory, tgt_mask, memory_mask)

    def forward(
        self, 
        src: torch.Tensor,  # [batch, src_len]
        tgt: torch.Tensor,  # [batch, tgt_len]
    ) -> torch.Tensor:
        memory, src_mask = self.encode(src)
        dec_out = self.decode(tgt, memory, src_mask)
        return self.output_projection(dec_out)  # [batch, tgt_len, tgt_vocab_size]

11. 训练技巧与超参数选择

11.1 Warmup 学习率调度

原始论文提出了一个独特的学习率 schedule:

lr=dmodel−0.5⋅min⁡(step−0.5, step⋅warmup_steps−1.5)lr = d_{model}^{-0.5} \cdot \min(\text{step}^{-0.5},\ \text{step} \cdot \text{warmup\_steps}^{-1.5})lr=dmodel0.5min(step0.5, stepwarmup_steps1.5)

  • warmup_steps 步线性增大学习率
  • 之后按步数的 −0.5-0.50.5 次方衰减
class WarmupScheduler:
    def __init__(self, optimizer, d_model: int, warmup_steps: int = 4000):
        self.optimizer = optimizer
        self.d_model = d_model
        self.warmup_steps = warmup_steps
        self.step_num = 0

    def step(self):
        self.step_num += 1
        lr = self.d_model ** (-0.5) * min(
            self.step_num ** (-0.5),
            self.step_num * self.warmup_steps ** (-1.5)
        )
        for param_group in self.optimizer.param_groups:
            param_group['lr'] = lr

11.2 标签平滑(Label Smoothing)

原始论文使用 ϵls=0.1\epsilon_{ls} = 0.1ϵls=0.1,将硬标签改为软标签:

ysmooth=(1−ϵ)⋅yone-hot+ϵ∣V∣y_{\text{smooth}} = (1 - \epsilon) \cdot y_{\text{one-hot}} + \frac{\epsilon}{|V|}ysmooth=(1ϵ)yone-hot+Vϵ

class LabelSmoothingLoss(nn.Module):
    def __init__(self, vocab_size: int, padding_idx: int, smoothing: float = 0.1):
        super().__init__()
        self.smoothing = smoothing
        self.padding_idx = padding_idx
        self.vocab_size = vocab_size

    def forward(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
        # pred: [batch * seq, vocab], target: [batch * seq]
        confidence = 1.0 - self.smoothing
        smooth_val = self.smoothing / (self.vocab_size - 2)
        
        true_dist = torch.full_like(pred, smooth_val)
        true_dist.scatter_(1, target.unsqueeze(1), confidence)
        true_dist[:, self.padding_idx] = 0
        
        mask = (target == self.padding_idx)
        true_dist[mask] = 0
        
        return F.kl_div(F.log_softmax(pred, dim=-1), true_dist, reduction='sum')

11.3 其他关键技巧

技巧 说明
Gradient Clipping clip_grad_norm_(params, max_norm=1.0) 防止梯度爆炸
Eropout 注意力权重和 FFN 输出都加 dropout=0.1
权重共享 Encoder/Decoder embedding 和输出投影层共享权重
混合精度训练 torch.cuda.amp 使用 FP16 加速,节省显存
梯度累积 显存不够时,多步累积再更新

12. 常见变体对比:BERT / GPT / T5

特性 BERT GPT 系列 T5
架构 仅 Encoder 仅 Decoder Encoder-Decoder
注意力方向 双向 单向(因果) 混合
预训练任务 MLM + NSP 语言模型(自回归) Span 预测
适合任务 理解(分类/NER) 生成(对话/补全) 生成+理解
代表模型 RoBERTa, ALBERT GPT-4, LLaMA T5, FLAN-T5

BERT 的 MLM(Masked Language Model):
随机遮蔽 15% 的 token,让模型预测被遮蔽的词。

GPT 的自回归生成:
每次根据已有 token 预测下一个 token,使用 Causal Mask 确保只看左侧上下文。


13. 推理优化:KV Cache、Flash Attention、量化

13.1 KV Cache

自回归生成时,每次生成新 token 都需要对所有历史 token 重新计算 K/V,非常浪费。

KV Cache 的思路:缓存历史步骤的 K/V 矩阵,新步骤只计算新 token 的 K/V 并拼接进去:

# 伪代码示意
cache_k, cache_v = [], []

for step in range(max_new_tokens):
    q = compute_query(new_token)   # 只计算新 token 的 Q
    k = compute_key(new_token)
    v = compute_value(new_token)
    
    # 拼接历史 K/V
    cache_k.append(k)
    cache_v.append(v)
    K_all = torch.cat(cache_k, dim=-2)
    V_all = torch.cat(cache_v, dim=-2)
    
    output = attention(q, K_all, V_all)

时间复杂度从 O(n2)O(n^2)O(n2) 降为 O(n)O(n)O(n)(每步只做一次向量-矩阵乘法)。

13.2 Flash Attention

标准注意力计算瓶颈在 HBM(显存)带宽,而非计算量本身。注意力矩阵 n×nn \times nn×n 需要频繁读写 HBM。

Flash Attention 利用 Tiling(分块计算) 将整个注意力计算放在 SRAM(片上高速缓存)中完成,避免频繁的 HBM 读写:

  • 速度提升 2-4x
  • 内存占用从 O(n2)O(n^2)O(n2) 降为 O(n)O(n)O(n)
  • 数值结果与标准注意力完全一致

PyTorch 2.0+ 已内置:F.scaled_dot_product_attention()(自动使用 Flash Attention)

13.3 模型量化

方案 精度损失 显存节省 速度
FP32 → FP16 极小 50% +50%
FP32 → INT8 75% +2x
FP32 → INT4 中等 87.5% +3x
GPTQ / AWQ 小(后训练量化) 75-87% 依硬件

14. 总结与学习路线

为什么 Transformer 统治了 AI?

  1. O(1)O(1)O(1) $��息传递距离:任意两个 token 可以直接交互
  2. 完全可并行:训练时所有 token 同时计算,充分利用 GPU
  3. 扩展性强(Scaling Law):模型越大、数据越多、效果越好,无明显天花板
  4. 通用架构:NLP / CV / 多模态 / 蛋白质结构预测均可使用

├── 线性代数(矩阵乘法、特征值)
├── 概率论(softmax、交叉熵)
└── PyTorch 基础

Transformer 核心
├── 精读《Attention Is All You Need》
├── 手写本文所有代码并跑通
└── 可视化注意力权重(BertViz)

进阶应用
├── BERT 微调(文本分类、NER)
├── GPT-2 文本生成
└── HuggingFace Transformers 库源码阅读

大模型方向
├── LLaMA 架构(RoPE, SwiGLU, RMSNorm)
├── RLHF(InstructGPT)
└── 参数高效微调(LoRA, Prefix-Tuning)


---

**参考文献**
- Vaswani et al., 2017. *Attention Is All You Need*. NeurIPS
- Devlin et al., 2018. *BERT: Pre-training of Deep Bidirectional Transformers*. NAACL
- Brown et al., 2020. *Language Models are Few-Shot Learners*. NeurIPS
- Dao et al., 2022. *FlashAttention: Fast and Memory-Efficient Exact Attention*. NeurIPS

---

## 关于作者

> 如果这篇文章对你有帮助,欢迎关注我的 GitHub 项目,那里有更多实战代码、学习笔记和开源资源,持续更新中!

GitHub: https://github.com/JerryZ01/openclaw-guide

> 如果觉得有用,请给个 Star 支持一下!你的支持是我持续创作的最大动力 🙏

*觉得不错的话,也欢迎点赞、收藏、转发,让更多人看到这篇文章~*

Logo

有“AI”的1024 = 2048,欢迎大家加入2048 AI社区

更多推荐