作为AI技术专家兼学习规划博主,我每天都会收到读者的类似提问:
“南木,Transformer的自注意力到底怎么算?公式里的Q/K/V看晕了”
“多头注意力为什么要拆分?合并后的维度怎么对齐?”
“明明懂了每个模块的结构,却还是不明白为什么它能撑起GPT、BERT这些大模型”

其实Transformer的门槛不在于“公式复杂”,而在于“跳出RNN的时序思维定式”——它用“自注意力机制”打破了序列处理的“时间依赖”,用“并行计算”重构了模型效率,更用“可堆叠架构”奠定了大模型的 scalability 基础。

这篇文章会用“问题-原理-公式-案例-代码”的五层拆解法,把Transformer讲透:从“为什么需要Transformer”的背景切入,到自注意力、多头注意力、前馈网络的公式推导与通俗解释,再到“Transformer为何能成为大模型基石”的底层逻辑,最后用PyTorch实现核心模块并可视化注意力权重。全程无冗余理论,每个公式都配“数学推导+文字解读”,每个模块都有“案例验证”,看完就能动手复现。

同时需要学习规划、就业指导、技术答疑和系统课程学习的同学 欢迎扫码交流
在这里插入图片描述

一、为什么需要Transformer?——从RNN的“死穴”说起

在Transformer(2017年《Attention Is All You Need》)出现前,时序数据处理的主流是RNN/LSTM/GRU。但这些模型有两个致命缺陷,直接限制了大模型的发展:

1.1 RNN的两大“死穴”

(1)并行性极差,训练效率低

RNN的核心是“时序依赖”——第t步的计算必须等第t-1步完成,就像“流水线必须按顺序走”,无法并行处理整个序列。
比如处理一个1000词的句子,RNN需要依次计算每个词的隐藏状态,而Transformer可以同时计算所有词的注意力,训练速度相差10~100倍(序列越长,差距越大)。

(2)长序列依赖弱,记忆能力有限

LSTM/GRU虽然用门控机制缓解了RNN的梯度消失问题,但仍属于“局部依赖”——每个位置的信息只能通过相邻位置逐步传递,长序列(如1000词以上)的远端信息会严重衰减。
比如处理“小明今天去了北京,他说______会带特产回来”,LSTM可能记不住“他”指的是“小明”,但Transformer通过全局注意力能直接关联“他”和“小明”。

1.2 Transformer的革命性突破

Transformer用“自注意力机制”一次性解决了这两个问题:

  • 并行计算:自注意力对所有位置的处理是独立的,无需等待前一步,可充分利用GPU算力;
  • 全局依赖:每个位置能直接“看到”序列中的所有其他位置,远端信息无衰减;
  • 可扩展性:模块结构简洁,支持多层堆叠(比如GPT-3有96层Transformer),能不断提升模型容量。

1.3 Transformer整体架构:先看“骨架”再拆“细节”

Transformer分为Encoder(编码器)Decoder(解码器) 两部分,核心模块包括:自注意力、多头注意力、前馈网络、残差连接、LayerNorm、位置编码。
先记住整体架构(以机器翻译任务为例):

  • Encoder(N层堆叠):输入源语言序列(如英文),输出“源语言语义表示”;每层包含“多头自注意力”和“前馈网络”;
  • Decoder(N层堆叠):输入目标语言序列(如中文,自回归生成),结合Encoder的输出,输出目标语言下一个词的概率;每层包含“掩码多头自注意力”“Encoder-Decoder注意力”和“前馈网络”;
  • 输出层:线性变换+Softmax,将Decoder的输出转为词表概率分布。

Transformer架构图(源自原论文)
(注:原论文中N=6,即Encoder和Decoder各6层)

接下来,我们从“最核心的自注意力”开始,逐个拆解每个模块。

二、Transformer的“心脏”:自注意力机制(Self-Attention)

自注意力是Transformer的核心,它的本质是“让序列中的每个位置,通过与其他位置的‘注意力交互’,学习到更丰富的语义表示”——通俗说就是“每个词都要‘关注’序列中对自己重要的其他词”。

2.1 先理解“注意力”:从人类到机器

人类的注意力是“选择性关注”——比如读“猫坐在红色的垫子上”,我们会自然关注“猫”(主语)和“垫子”(宾语),而“红色的”(定语)关注度较低。
机器的注意力机制模拟了这一过程,核心是**“根据‘查询’(Query),对‘键’(Key)进行权重分配,再用权重加权‘值’(Value)”**,即:
注意力输出 = 对Value按Key与Query的相似度加权求和

而“自注意力”(Self-Attention)是“Query、Key、Value都来自同一个输入序列”——比如处理句子时,每个词既是Query(要查询其他词),也是Key(被其他词查询),还是Value(提供信息给其他词)。

2.2 自注意力三要素:Q、K、V的定义与作用

首先明确三个核心向量的来源和含义(假设输入序列的嵌入维度为d_model,如512):

  • Query(查询向量,Q):代表“当前位置需要什么信息”,维度d_k(如64);
  • Key(键向量,K):代表“当前位置提供什么信息”,维度d_k
  • Value(值向量,V):代表“当前位置的具体信息”,维度d_v(通常d_v = d_k)。

它们的生成方式很简单:对输入序列的嵌入矩阵X(维度[batch_size, seq_len, d_model])做三次线性变换:
Q=XWQ Q = X W_Q Q=XWQ
K=XWK K = X W_K K=XWK
V=XWV V = X W_V V=XWV

  • W_Q:Query的权重矩阵,维度[d_model, d_k]
  • W_K:Key的权重矩阵,维度[d_model, d_k]
  • W_V:Value的权重矩阵,维度[d_model, d_v]

通俗比喻:把每个词看作“学生”,Q是“学生的问题”(比如“我和其他词是什么关系?”),K是“其他学生的标签”(比如“我是主语”“我是宾语”),V是“其他学生的答案”(比如“我是‘猫’,代表动物”)。

2.3 自注意力完整计算流程:4步推导(附公式+案例)

假设输入序列是“猫坐在垫子上”(seq_len=5),嵌入维度d_model=512d_k=d_v=64,我们一步步计算自注意力输出。

步骤1:生成Q、K、V(线性变换)

对每个词的嵌入向量x_i(维度d_model=512),通过W_QW_KW_V生成q_ik_iv_i(维度d_k=64):
比如“猫”的x_1q_1(“猫的问题”)、k_1(“猫的标签”)、v_1(“猫的信息”);
“垫子”的x_4q_4k_4v_4

最终得到Q、K、V矩阵:

  • Q:[batch_size, 5, 64](5个词,每个词的Query向量);
  • K:[batch_size, 5, 64]
  • V:[batch_size, 5, 64]
步骤2:计算注意力得分(相似度)

注意力得分衡量“Query与每个Key的匹配程度”——得分越高,说明该Key对应的Value对当前Query越重要。
Transformer用点积计算相似度(计算效率高,且当Q/K维度足够大时,效果优于余弦相似度):
Score(Q,K)=QKT \text{Score}(Q, K) = Q K^T Score(Q,K)=QKT

  • K^T:K的转置,维度[batch_size, d_k, seq_len]
  • 得分矩阵维度:[batch_size, seq_len, seq_len](比如[batch, 5, 5],每个元素Score[i][j]是第i个词的Q与第j个词的K的相似度)。

关键优化:缩放点积(Scaled Dot-Product)
点积的结果会随d_k增大而增大(比如d_k=64时,点积结果可能达到64×(0.1)^2=0.64,而d_k=1024时会达到10.24),导致Softmax后梯度消失(数值过大,Softmax输出趋近于0或1)。
解决方案:除以√d_k做缩放:
Scaled Score(Q,K)=QKTdk \text{Scaled Score}(Q, K) = \frac{Q K^T}{\sqrt{d_k}} Scaled Score(Q,K)=dk QKT

案例:“猫”(i=1)的Q与“垫子”(j=4)的K点积后除以√64=8,得到得分Score[1][4] = 3.2(高分,说明“垫子”对“猫”很重要);与“的”(j=3)的得分Score[1][3] = 0.5(低分,说明“的”不重要)。

步骤3:Softmax归一化(得到注意力权重)

对步骤2的得分矩阵每行做Softmax,将得分转为“0~1”的权重,且每行权重和为1:
Attention Weight=Softmax(QKTdk) \text{Attention Weight} = \text{Softmax}\left( \frac{Q K^T}{\sqrt{d_k}} \right) Attention Weight=Softmax(dk QKT)

  • 权重矩阵维度:[batch_size, seq_len, seq_len],每个元素Weight[i][j]是第i个词对第j个词的注意力权重。

案例:“猫”(i=1)的权重行是[0.1, 0.05, 0.05, 0.7, 0.1]——对“垫子”(j=4)的权重0.7(最高),对其他词的权重低,符合语义逻辑。

步骤4:加权求和(得到自注意力输出)

用注意力权重对V矩阵的列进行加权求和,得到每个位置的自注意力输出:
Self-Attention(Q,K,V)=Softmax(QKTdk)V \text{Self-Attention}(Q, K, V) = \text{Softmax}\left( \frac{Q K^T}{\sqrt{d_k}} \right) V Self-Attention(Q,K,V)=Softmax(dk QKT)V

  • 输出维度:[batch_size, seq_len, d_v](比如[batch, 5, 64],每个词的输出向量融合了所有词的信息)。

案例:“猫”的自注意力输出 = 0.1×v_1(猫自身) + 0.05×v_2(坐) + 0.05×v_3(的) + 0.7×v_4(垫子) + 0.1×v_5(上)——核心融合了“垫子”的信息,语义更丰富。

2.4 关键问题:为什么用点积?为什么要缩放?

(1)为什么用点积而不是其他相似度(如余弦、MLP)?
  • 计算效率高:点积可通过矩阵乘法并行实现,而MLP(如v^T W q)需要额外参数,计算量更大;
  • 效果足够好:当d_k足够大时,点积与余弦相似度的效果接近(余弦相似度是归一化后的点积);
  • 原论文验证:作者对比了点积、缩放点积、MLP三种方式,发现缩放点积在效果和效率上最优。
(2)为什么缩放能缓解梯度消失?

Softmax函数的梯度与exp(x_i)相关——当x_i(缩放前的得分)过大时,exp(x_i)会溢出,导致Softmax输出趋近于1(对应位置)和0(其他位置),梯度趋近于0;
除以√d_k后,得分的方差被控制在1附近(原论文证明:当Q/K的元素独立同分布于N(0,1)时,QK^T的元素方差为d_k,缩放后方差为1),避免数值溢出,梯度更稳定。

三、更强的表达能力:多头注意力(Multi-Head Attention)

自注意力能捕捉序列的全局依赖,但“单头”只能关注一种维度的关系(比如语义关系)。多头注意力通过“拆分多个头并行计算,再合并”,让模型同时捕捉多维度的关系(如语义、语法、位置)。

3.1 单头注意力的局限:“一叶障目”

比如处理句子“他在公园散步,那里有很多花”:

  • 单头注意力可能只关注“他-散步”的语义关系,却忽略“公园-那里”的指代关系、“散步-花”的场景关系;
  • 多头注意力可以用不同的头分别捕捉这三种关系,再融合所有头的信息,表达能力更强。

3.2 多头注意力的设计思路:“分而治之,合而用之”

核心逻辑是“将Q、K、V拆分为h个独立的子空间(头),每个头计算自注意力,最后将所有头的输出合并并线性变换”,共5步:

步骤1:拆分Q、K、V(按头划分维度)

假设多头数h=8(原论文设置),d_model=512,则每个头的维度d_k = d_model / h = 64
将Q、K、V沿d_model维度拆分为h个部分:

  • Q_split:[batch_size, h, seq_len, d_k](每个头的Q维度[batch, seq_len, 64]);
  • K_split、V_split同理。

注意:拆分时需调整维度顺序(将“头数h”提前),方便后续并行计算。

步骤2:每个头独立计算自注意力

对每个头的Q_split[i]、K_split[i]、V_split[i],按自注意力流程计算输出:
headi=Self-Attention(Qsplit[i],Ksplit[i],Vsplit[i]) \text{head}_i = \text{Self-Attention}(Q_{\text{split}[i]}, K_{\text{split}[i]}, V_{\text{split}[i]}) headi=Self-Attention(Qsplit[i],Ksplit[i],Vsplit[i])

  • 每个head的输出维度:[batch_size, seq_len, d_k]
  • h个头的总输出:[batch_size, h, seq_len, d_k]
步骤3:合并所有头的输出

将h个头的输出沿d_k维度拼接(concat),恢复到d_model维度:
Concat=cat(head1,head2,...,headh) \text{Concat} = \text{cat}(\text{head}_1, \text{head}_2, ..., \text{head}_h) Concat=cat(head1,head2,...,headh)

  • 拼接后维度:[batch_size, seq_len, h×d_k] = [batch, seq_len, d_model](因为h×d_k = 8×64=512)。
步骤4:线性变换(维度对齐+信息融合)

对拼接后的结果做一次线性变换,让模型学习如何融合不同头的信息:
MultiHead(Q,K,V)=Concat×WO \text{MultiHead}(Q,K,V) = \text{Concat} \times W_O MultiHead(Q,K,V)=Concat×WO

  • W_O:输出权重矩阵,维度[d_model, d_model]
  • 最终输出维度:[batch_size, seq_len, d_model](与输入X的维度一致,方便后续残差连接)。
步骤5:完整公式总结

MultiHead(Q,K,V)=cat(head1,...,headh)WO \text{MultiHead}(Q,K,V) = \text{cat}(\text{head}_1, ..., \text{head}_h) W_O MultiHead(Q,K,V)=cat(head1,...,headh)WO
where headi=Self-Attention(QWQi,KWKi,VWVi) \text{where } \text{head}_i = \text{Self-Attention}(Q W_{Q_i}, K W_{K_i}, V W_{V_i}) where headi=Self-Attention(QWQi,KWKi,VWVi)

  • W_{Q_i}:第i个头的Query权重矩阵(W_Q拆分为h个,每个维度[d_model, d_k]);
  • W_{K_i}W_{V_i}同理。

3.3 多头注意力的核心价值:3个维度的提升

  1. 多关系捕捉:不同头关注不同类型的关系(语义、语法、指代),比如头1关注“主谓”,头2关注“指代”;
  2. 泛化能力增强:多个头的独立计算相当于“模型集成”,减少单一头的偏见;
  3. 维度拆分优化:将高维d_model拆分为低维d_k,降低每个头的计算复杂度(总计算量与单头相当,因为h×(seq_len^2×d_k) ≈ seq_len^2×d_model)。

3.4 Encoder与Decoder中的多头注意力差异

Transformer的Encoder和Decoder用的多头注意力略有不同,核心是“是否需要掩码”:

  • Encoder的多头自注意力:无掩码,所有位置可相互关注(比如“猫”可关注“垫子”,“垫子”也可关注“猫”),适合“理解输入序列”;
  • Decoder的掩码多头自注意力:有“未来掩码”(Future Mask),即第i个位置只能关注第1~i个位置,不能关注i之后的位置(避免“看到未来的词”,符合自回归生成逻辑);
  • Decoder的Encoder-Decoder注意力:Q来自Decoder,K/V来自Encoder,让Decoder“关注输入序列中与当前生成词相关的位置”(比如翻译时,生成“猫”要关注输入的“cat”)。

四、特征强化:前馈网络(Feed-Forward Network, FFN)

多头注意力捕捉了序列的全局依赖,但每个位置的特征仍需进一步“加工”——前馈网络的作用就是“对每个位置的特征进行独立的非线性变换,增强模型的表达能力”。

4.1 前馈网络的结构:简单但高效

前馈网络是一个“两层线性变换+ReLU激活+Dropout”的结构,对每个位置的特征独立处理(即不同位置的计算互不干扰):
FFN(x)=max⁡(0,xW1+b1)W2+b2 \text{FFN}(x) = \max(0, x W_1 + b_1) W_2 + b_2 FFN(x)=max(0,xW1+b1)W2+b2

各部分解析:
  1. 第一层线性变换:将输入维度d_model提升到d_ff(原论文d_ff=2048),增加非线性表达能力:
    x1=xW1+b1 x_1 = x W_1 + b_1 x1=xW1+b1

    • W_1:维度[d_model, d_ff](512→2048);
    • b_1:偏置项,维度[d_ff]
  2. ReLU激活函数:引入非线性,让模型能学习复杂特征(原论文用ReLU,后续模型如GPT用GELU,效果更好):
    x1act=max⁡(0,x1) x_{1_{\text{act}}} = \max(0, x_1) x1act=max(0,x1)

    • ReLU的优势:计算快,缓解梯度消失(相比sigmoid)。
  3. 第二层线性变换:将维度从d_ff降回d_model,与输入维度对齐(方便残差连接):
    x2=x1actW2+b2 x_2 = x_{1_{\text{act}}} W_2 + b_2 x2=x1actW2+b2

    • W_2:维度[d_ff, d_model](2048→512);
    • b_2:偏置项,维度[d_model]
  4. Dropout:训练时随机丢弃部分神经元(原论文 dropout rate=0.1),防止过拟合。

4.2 为什么需要前馈网络?——“全局依赖+局部强化”

多头注意力解决了“全局关系捕捉”,但每个位置的特征仍是“其他位置信息的加权和”,缺乏“局部非线性加工”——比如“猫坐在垫子上”,“坐”的特征需要结合“猫”和“垫子”的信息,但还需要进一步强化“动作”属性,这就是前馈网络的作用。

通俗比喻:多头注意力是“收集所有同学的笔记”,前馈网络是“自己整理笔记,提炼重点”——两者结合,才能得到更优质的学习成果。

4.3 关键设计:为什么中间维度d_ff4×d_model

原论文选择d_ff=2048=4×512,不是随机的:

  1. 非线性表达能力:更高的中间维度能让模型学习更复杂的特征映射(比如从512维的语义特征映射到2048维的细粒度特征);
  2. 计算效率平衡4×d_model是“效果”和“计算量”的平衡点——维度太小,表达能力不足;维度太大,计算量剧增(如8×d_model会让计算量翻倍);
  3. 实践验证:后续大模型(如BERT、GPT)均沿用这一比例,证明其有效性。

五、Transformer的“稳定器”:残差连接与LayerNorm

Transformer通常堆叠6~100层(如GPT-3有96层),但深层模型容易出现“梯度消失”和“内部协变量偏移”——残差连接和LayerNorm就是解决这两个问题的“稳定器”。

5.1 残差连接(Residual Connection):缓解梯度消失

残差连接的核心是“将输入直接加到输出上”,让梯度能“跳过”中间层直接传递,避免梯度在深层中衰减到0。

公式与结构:

Transformer中,每个模块(多头注意力、前馈网络)的输出都加了残差连接:
Output=LayerNorm(Module(X)+X) \text{Output} = \text{LayerNorm}(\text{Module}(X) + X) Output=LayerNorm(Module(X)+X)

  • X:模块的输入;
  • Module(X):模块的输出(如多头注意力输出、前馈网络输出);
  • “+X”:残差连接,输入与输出直接相加(需保证维度一致,Transformer中所有模块输入输出维度都是d_model)。
为什么有效?

假设模块是线性变换Module(X) = WX + b,则输出为WX + b + X = (W+I)X + b——梯度传递时,会额外保留“X的直接梯度”(即1),避免梯度随层数增加而指数衰减。

案例:100层Transformer,若没有残差连接,梯度可能衰减到0.9^100 ≈ 1e-5(几乎为0);有残差连接时,梯度能稳定传递。

5.2 LayerNorm:解决内部协变量偏移

“内部协变量偏移”是指“深层模型的输入分布随训练迭代不断变化”,导致模型需要不断适应新分布,训练不稳定。
LayerNorm通过“对每个样本的每个序列位置做归一化”,让输入分布保持稳定。

LayerNorm的公式推导:

对输入X(维度[batch_size, seq_len, d_model]),按“特征维度”(d_model)做归一化:
μi=1dmodel∑k=1dmodelXi,k \mu_i = \frac{1}{d_model} \sum_{k=1}^{d_model} X_{i,k} μi=dmodel1k=1dmodelXi,k
σi2=1dmodel∑k=1dmodel(Xi,k−μi)2 \sigma_i^2 = \frac{1}{d_model} \sum_{k=1}^{d_model} (X_{i,k} - \mu_i)^2 σi2=dmodel1k=1dmodel(Xi,kμi)2
X^i,k=Xi,k−μiσi2+ϵ \hat{X}_{i,k} = \frac{X_{i,k} - \mu_i}{\sqrt{\sigma_i^2 + \epsilon}} X^i,k=σi2+ϵ Xi,kμi
LayerNorm(X)=γX^i,k+β \text{LayerNorm}(X) = \gamma \hat{X}_{i,k} + \beta LayerNorm(X)=γX^i,k+β

  • μ_i:第i个位置特征的均值;
  • σ_i^2:第i个位置特征的方差;
  • ε:小常数(如1e-6),避免分母为0;
  • γβ:可学习的缩放和偏移参数(维度d_model),让模型能调整归一化后的分布(若γ=√d_modelβ=0,则等价于不调整)。
LayerNorm vs BatchNorm:为什么选LayerNorm?

BatchNorm是“按批次维度归一化”(对每个批次的所有样本同一特征做归一化),适合图像(固定尺寸,批次无关);
LayerNorm是“按样本维度归一化”(对每个样本的所有特征做归一化),适合序列(序列长度可变,每个样本的分布独立)。
比如处理不同长度的句子,BatchNorm会因批次中序列长度不同而失效,LayerNorm则不受影响。

5.3 位置编码(Positional Encoding):给序列“加时间戳”

自注意力本身没有“时序信息”——比如“猫坐在垫子上”和“垫子坐在猫上”,自注意力会计算相同的权重,但语义完全相反。
位置编码的作用是“给每个位置添加独特的位置信息”,让模型能区分不同位置。

正弦余弦位置编码:原论文方案

原论文用正弦和余弦函数生成位置编码,公式如下(pos是位置索引,i是特征维度索引):
PEpos,2i=sin⁡(pos100002i/dmodel) PE_{pos, 2i} = \sin\left( \frac{pos}{10000^{2i/d_model}} \right) PEpos,2i=sin(100002i/dmodelpos)
PEpos,2i+1=cos⁡(pos100002i/dmodel) PE_{pos, 2i+1} = \cos\left( \frac{pos}{10000^{2i/d_model}} \right) PEpos,2i+1=cos(100002i/dmodelpos)

  • PE:位置编码矩阵,维度[seq_len, d_model]
  • 2i2i+1:偶数维度用正弦,奇数维度用余弦(保证相邻位置的编码差异);
  • 10000:控制位置编码的周期(位置越远,编码差异越大)。
为什么用正弦余弦?
  1. 可扩展性:无需训练,可生成任意长度的位置编码(即使测试时序列长度超过训练时);
  2. 相对位置信息PE_{pos+k}可表示为PE_pos的线性组合(比如sin(a+b)=sin a cos b + cos a sin b),让模型能学习相对位置关系;
  3. 计算高效:无需额外参数,直接生成。

直观理解:不同位置的编码向量在高维空间中是正交的,模型能通过学习区分“pos=1”和“pos=5”的差异。

六、Transformer的底层逻辑:为什么能成为大模型基石?

理解了各个模块,我们需要回到核心问题:为什么Transformer能支撑起GPT、BERT、LLaMA这些千亿参数的大模型? 关键在于它的三个底层优势:

6.1 并行计算:大模型训练的“效率引擎”

大模型的训练需要处理海量数据(如万亿tokens),效率是关键。Transformer的并行性体现在两个层面:

  1. 序列内并行:自注意力对所有位置的处理是独立的,可同时计算整个序列的注意力权重;
  2. 层间并行:多层Transformer可通过模型并行(不同层放在不同GPU上)或数据并行(不同批次放在不同GPU上)加速训练。

相比之下,RNN的串行计算无法支撑大模型——即使是10亿参数的RNN,训练速度也会慢到无法接受。

6.2 可扩展性:从百万到千亿参数的“架构基础”

Transformer的模块结构极其简洁,所有模块(多头注意力、前馈网络)都是“可堆叠”的:

  • 横向扩展:增加多头数h(如从8头到16头),提升多关系捕捉能力;
  • 纵向扩展:增加层数N(如从6层到96层),提升特征抽象能力;
  • 维度扩展:增加d_model(如从512到12288),提升每个位置的特征表达能力。

这种“模块化+可堆叠”的设计,让模型能通过增加参数持续提升性能(符合“缩放定律”),而RNN/LSTM因结构限制,参数增加到一定程度后性能会饱和。

6.3 预训练-微调范式:大模型生态的“核心载体”

Transformer的结构天然适合“预训练-微调”范式:

  • 预训练阶段:用海量无标注数据(如互联网文本)训练一个通用Transformer模型(如BERT、GPT),学习语言的通用规律;
  • 微调阶段:用少量标注数据(如分类任务的1000条样本)微调预训练模型,快速适配具体任务(文本分类、机器翻译、问答)。

这种范式的核心是“Transformer能通过预训练学习到通用的特征表示”,而RNN/LSTM因泛化能力弱,预训练效果远不如Transformer。

七、实战:用PyTorch实现自注意力与多头注意力(附代码+可视化)

理论讲完,我们用PyTorch实现核心模块,并用注意力热力图可视化,直观理解注意力的分布。

7.1 环境准备

# 安装依赖(已安装可跳过)
pip install torch torchvision pandas numpy matplotlib seaborn

7.2 实现自注意力模块

import torch
import torch.nn as nn
import torch.nn.functional as F

class SelfAttention(nn.Module):
    def __init__(self, d_model, d_k):
        super().__init__()
        self.d_k = d_k  # 每个头的维度
        # 定义Q、K、V的线性变换权重
        self.W_Q = nn.Linear(d_model, d_k)
        self.W_K = nn.Linear(d_model, d_k)
        self.W_V = nn.Linear(d_model, d_k)
    
    def forward(self, X):
        # X: [batch_size, seq_len, d_model]
        batch_size, seq_len, _ = X.shape
        
        # 步骤1:生成Q、K、V
        Q = self.W_Q(X)  # [batch_size, seq_len, d_k]
        K = self.W_K(X)  # [batch_size, seq_len, d_k]
        V = self.W_V(X)  # [batch_size, seq_len, d_k]
        
        # 步骤2:计算缩放点积得分
        score = torch.matmul(Q, K.transpose(-2, -1))  # [batch_size, seq_len, seq_len]
        score = score / torch.sqrt(torch.tensor(self.d_k, dtype=torch.float32))  # 缩放
        
        # 步骤3:Softmax得到注意力权重
        attn_weight = F.softmax(score, dim=-1)  # [batch_size, seq_len, seq_len]
        
        # 步骤4:加权求和得到输出
        output = torch.matmul(attn_weight, V)  # [batch_size, seq_len, d_k]
        
        return output, attn_weight  # 返回输出和注意力权重(用于可视化)

7.3 实现多头注意力模块

class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, h=8):
        super().__init__()
        self.h = h  # 多头数
        self.d_k = d_model // h  # 每个头的维度(d_model必须能被h整除)
        assert d_model % h == 0, "d_model must be divisible by h"
        
        # 定义Q、K、V的线性变换(共享权重,后续拆分)
        self.W_Q = nn.Linear(d_model, d_model)
        self.W_K = nn.Linear(d_model, d_model)
        self.W_V = nn.Linear(d_model, d_model)
        
        # 输出线性变换
        self.W_O = nn.Linear(d_model, d_model)
        
        # 自注意力模块(每个头共用)
        self.self_attn = SelfAttention(d_model, self.d_k)
    
    def forward(self, X):
        # X: [batch_size, seq_len, d_model]
        batch_size, seq_len, d_model = X.shape
        
        # 步骤1:生成Q、K、V并拆分多头
        Q = self.W_Q(X)  # [batch_size, seq_len, d_model]
        K = self.W_K(X)
        V = self.W_V(X)
        
        # 拆分维度:[batch_size, seq_len, d_model] → [batch_size, h, seq_len, d_k]
        Q_split = Q.view(batch_size, seq_len, self.h, self.d_k).transpose(1, 2)
        K_split = K.view(batch_size, seq_len, self.h, self.d_k).transpose(1, 2)
        V_split = V.view(batch_size, seq_len, self.h, self.d_k).transpose(1, 2)
        
        # 步骤2:每个头计算自注意力
        attn_outputs = []
        attn_weights = []
        for i in range(self.h):
            q = Q_split[:, i, :, :]  # [batch_size, seq_len, d_k]
            k = K_split[:, i, :, :]
            v = V_split[:, i, :, :]
            output, weight = self.self_attn(q)  # 每个头的输出和权重
            attn_outputs.append(output)
            attn_weights.append(weight)
        
        # 步骤3:合并所有头的输出
        attn_outputs = torch.stack(attn_outputs, dim=1)  # [batch_size, h, seq_len, d_k]
        # 转置+拼接:[batch_size, seq_len, h*d_k] = [batch_size, seq_len, d_model]
        attn_outputs = attn_outputs.transpose(1, 2).contiguous().view(batch_size, seq_len, d_model)
        
        # 步骤4:输出线性变换
        output = self.W_O(attn_outputs)  # [batch_size, seq_len, d_model]
        
        # 合并注意力权重(用于可视化,取所有头的平均)
        attn_weights = torch.mean(torch.stack(attn_weights, dim=1), dim=1)  # [batch_size, seq_len, seq_len]
        
        return output, attn_weights

7.4 注意力权重可视化(案例:句子“猫坐在垫子上”)

我们用一个简单案例展示注意力分布,步骤如下:

  1. 构造句子的嵌入向量(用随机嵌入模拟,实际中用预训练词嵌入);
  2. 输入多头注意力模块,得到注意力权重;
  3. 用热力图可视化权重分布。
import matplotlib.pyplot as plt
import seaborn as sns

# 1. 构造输入数据
sentence = ["猫", "坐", "在", "垫子", "上"]  # 输入句子
seq_len = len(sentence)
d_model = 512  # 嵌入维度
batch_size = 1

# 随机生成嵌入向量(模拟词嵌入)
torch.manual_seed(42)  # 固定随机种子,结果可复现
X = torch.randn(batch_size, seq_len, d_model)  # [1, 5, 512]

# 2. 初始化多头注意力模块
multi_head_attn = MultiHeadAttention(d_model=d_model, h=8)
output, attn_weight = multi_head_attn(X)  # attn_weight: [1, 5, 5]

# 3. 提取注意力权重(取第一个batch)
attn_weight = attn_weight[0].detach().cpu().numpy()  # [5, 5]

# 4. 绘制热力图
plt.figure(figsize=(10, 8))
sns.heatmap(
    attn_weight,
    annot=True,  # 显示数值
    cmap="YlOrRd",  # 颜色映射
    xticklabels=sentence,  # x轴标签(被关注的词)
    yticklabels=sentence   # y轴标签(关注其他词的词)
)
plt.title("Multi-Head Attention Weight Heatmap (Sentence: 猫坐在垫子上)")
plt.xlabel("Attended Words (Key)")
plt.ylabel("Query Words")
plt.show()
可视化结果分析:

热力图中,颜色越红、数值越大,说明注意力权重越高。预期结果:

  • “猫”(y=0)对“垫子”(x=3)的权重最高(如0.6),符合“猫坐在垫子上”的语义;
  • “垫子”(y=3)对“猫”(x=0)和“上”(x=4)的权重较高;
  • “在”(y=2)对所有词的权重较低(功能词,语义贡献小)。

7.5 用Hugging Face验证:加载BERT查看注意力

实际应用中,我们常用Hugging Face的transformers库加载预训练模型,直接查看注意力权重:

from transformers import BertTokenizer, BertModel

# 1. 加载BERT模型和Tokenizer
tokenizer = BertTokenizer.from_pretrained('bert-base-chinese')
model = BertModel.from_pretrained('bert-base-chinese', output_attentions=True)  # 输出注意力权重

# 2. 输入句子并编码
sentence = "猫坐在垫子上"
inputs = tokenizer(sentence, return_tensors="pt")  # 编码为token ID

# 3. 前向传播,获取注意力权重
with torch.no_grad():
    outputs = model(**inputs)
    attentions = outputs.attentions  # 所有层的注意力权重,形状:[12层, batch_size, 12头, seq_len, seq_len]

# 4. 提取第1层第1个头的注意力权重(可调整层和头)
layer_idx = 0  # 第1层
head_idx = 0   # 第1个头
attn_weight = attentions[layer_idx][0][head_idx].detach().cpu().numpy()  # [seq_len, seq_len]

# 5. 获取token标签(BERT会添加[CLS]和[SEP])
tokens = tokenizer.convert_ids_to_tokens(inputs["input_ids"][0])  # ['[CLS]', '猫', '坐', '在', '垫', '子', '上', '[SEP]']

# 6. 绘制热力图
plt.figure(figsize=(12, 10))
sns.heatmap(
    attn_weight,
    annot=True,
    cmap="YlOrRd",
    xticklabels=tokens,
    yticklabels=tokens
)
plt.title(f"BERT Attention Weight (Layer {layer_idx+1}, Head {head_idx+1})")
plt.xlabel("Attended Tokens")
plt.ylabel("Query Tokens")
plt.show()
结果分析:

BERT的注意力权重更符合语言规律,比如:

  • “垫”(token=4)和“子”(token=5)的权重极高(因为“垫子”是一个词);
  • “猫”(token=1)对“垫”(token=4)、“子”(token=5)的权重较高;
  • [CLS](token=0)对所有词的权重较均匀(负责整合整个句子的语义)。

八、南木的学习路径建议:从入门到精通Transformer

很多人学Transformer会陷入“公式劝退”或“代码跑不通”的困境,我结合自己的经验,整理了一条“循序渐进”的学习路径:

阶段1:基础铺垫(1~2周)

  • 数学基础:回顾线性代数(矩阵乘法、转置)、概率论(Softmax、交叉熵)、微积分(梯度下降)——不用深钻,能理解公式含义即可;
  • 工具基础:掌握PyTorch核心操作(张量、模型定义、前向传播),熟悉Hugging Face transformers库的基本使用(加载模型、编码数据);
  • 前置知识:了解RNN/LSTM的缺陷(并行性、长序列依赖),明白Transformer的设计动机。

阶段2:核心拆解(2~3周)

  • 第一步:吃透自注意力
    1. 手动计算一个简单案例(如3个词的句子),推导Q/K/V生成、得分计算、Softmax、加权求和的全过程;
    2. 用PyTorch实现自注意力模块,调整d_k观察缩放的影响;
  • 第二步:理解多头注意力
    1. 重点搞懂“拆分-合并”的维度变化(比如d_model=512h=8d_k=64);
    2. 实现多头注意力,对比单头和多头的注意力权重分布;
  • 第三步:掌握其他模块
    1. 实现前馈网络,理解“维度升降”的作用;
    2. 推导LayerNorm的计算过程,对比BatchNorm的差异;
    3. 生成正弦余弦位置编码,可视化不同位置的编码向量。

阶段3:实战落地(3~4周)

  • 入门项目
    1. 用Transformer实现文本分类(如IMDB情感分析),熟悉Encoder的使用;
    2. 用注意力可视化工具(如bertviz)分析预训练模型的注意力分布;
  • 进阶项目
    1. 用Transformer实现机器翻译(Encoder-Decoder结构),理解Decoder的掩码机制;
    2. 微调预训练模型(如BERT微调做问答任务),掌握“预训练-微调”范式;
  • 挑战项目
    1. 实现简化版GPT(Decoder-only结构),理解自回归生成逻辑;
    2. 用ViT(Vision Transformer)做图像分类,理解Transformer在CV领域的应用。

阶段4:进阶扩展(长期)

  • 大模型优化:学习稀疏注意力(如Longformer)、量化(如INT8量化)、蒸馏(模型压缩),解决大模型部署问题;
  • 理论深入:阅读Transformer的改进论文(如GPT-2/3、BERT、T5),理解“自回归”“双向注意力”“编码器-解码器”的差异;
  • 跨领域应用:探索Transformer在CV(ViT)、语音(Wav2Vec 2.0)、多模态(CLIP)的应用,理解“注意力是通用机制”。

九、常见误区与解答

误区1:自注意力一定是全局的?

错。自注意力可以通过“掩码”限制关注范围:

  • Decoder的“未来掩码”:让第i个位置只能关注前i个位置;
  • 稀疏注意力(如Longformer):只关注局部窗口+全局关键位置,降低长序列的计算量。

误区2:多头注意力头数越多越好?

错。头数增加会带来两个问题:

  • 计算量剧增:头数h翻倍,计算量也翻倍(O(h×seq_len²×d_k));
  • 边际效益递减:头数超过16后,模型性能提升不明显(原论文h=8,GPT-3 h=96是因为d_model更大)。

误区3:Transformer不需要时序信息?

错。Transformer本身没有时序信息,必须通过“位置编码”补充——如果去掉位置编码,模型无法区分“猫坐在垫子上”和“垫子坐在猫上”。

误区4:Transformer在CV领域不如CNN?

错。ViT(2020年)证明Transformer在CV领域可超越CNN:

  • ViT将图像拆分为“图像块”(如16×16),视为“序列”输入Transformer;
  • 大尺度预训练(如ImageNet-21K)后,ViT在图像分类、目标检测等任务上的性能超过CNN。

十、总结

Transformer的核心不是“复杂的公式”,而是“简洁高效的注意力机制”,它用自注意力解决了RNN的并行性和长序列依赖问题,用多头注意力提升了多关系捕捉能力,用可堆叠架构支撑了大模型的 scalability。

从技术发展来看,Transformer不仅是NLP领域的革命,更是AI的通用基础架构——它统一了NLP、CV、语音、多模态等领域,成为大模型时代的“基石”。

如果你觉得这篇文章有用,记得点赞+收藏,关注我(南木),后续会分享更多大模型实战干货(如GPT简化版实现、LLaMA微调)。如果在学习过程中遇到问题,欢迎在评论区留言,我会定期回复!
在这里插入图片描述

Logo

有“AI”的1024 = 2048,欢迎大家加入2048 AI社区

更多推荐