Transformer模型全拆解:自注意力/多头注意力/前馈网络,看透大模型“基石”的底层逻辑(附公式推导+代码)
首先明确三个核心向量的来源和含义(假设输入序列的嵌入维度为d_modelQuery(查询向量,Q):代表“当前位置需要什么信息”,维度d_k(如64);Key(键向量,K):代表“当前位置提供什么信息”,维度d_k;Value(值向量,V):代表“当前位置的具体信息”,维度d_v(通常d_v = d_k它们的生成方式很简单:对输入序列的嵌入矩阵X(维度QXWQQ = X W_QQXWQKXWKK
作为AI技术专家兼学习规划博主,我每天都会收到读者的类似提问:
“南木,Transformer的自注意力到底怎么算?公式里的Q/K/V看晕了”
“多头注意力为什么要拆分?合并后的维度怎么对齐?”
“明明懂了每个模块的结构,却还是不明白为什么它能撑起GPT、BERT这些大模型”
其实Transformer的门槛不在于“公式复杂”,而在于“跳出RNN的时序思维定式”——它用“自注意力机制”打破了序列处理的“时间依赖”,用“并行计算”重构了模型效率,更用“可堆叠架构”奠定了大模型的 scalability 基础。
这篇文章会用“问题-原理-公式-案例-代码”的五层拆解法,把Transformer讲透:从“为什么需要Transformer”的背景切入,到自注意力、多头注意力、前馈网络的公式推导与通俗解释,再到“Transformer为何能成为大模型基石”的底层逻辑,最后用PyTorch实现核心模块并可视化注意力权重。全程无冗余理论,每个公式都配“数学推导+文字解读”,每个模块都有“案例验证”,看完就能动手复现。
同时需要学习规划、就业指导、技术答疑和系统课程学习的同学 欢迎扫码交流
一、为什么需要Transformer?——从RNN的“死穴”说起
在Transformer(2017年《Attention Is All You Need》)出现前,时序数据处理的主流是RNN/LSTM/GRU。但这些模型有两个致命缺陷,直接限制了大模型的发展:
1.1 RNN的两大“死穴”
(1)并行性极差,训练效率低
RNN的核心是“时序依赖”——第t步的计算必须等第t-1步完成,就像“流水线必须按顺序走”,无法并行处理整个序列。
比如处理一个1000词的句子,RNN需要依次计算每个词的隐藏状态,而Transformer可以同时计算所有词的注意力,训练速度相差10~100倍(序列越长,差距越大)。
(2)长序列依赖弱,记忆能力有限
LSTM/GRU虽然用门控机制缓解了RNN的梯度消失问题,但仍属于“局部依赖”——每个位置的信息只能通过相邻位置逐步传递,长序列(如1000词以上)的远端信息会严重衰减。
比如处理“小明今天去了北京,他说______会带特产回来”,LSTM可能记不住“他”指的是“小明”,但Transformer通过全局注意力能直接关联“他”和“小明”。
1.2 Transformer的革命性突破
Transformer用“自注意力机制”一次性解决了这两个问题:
- 并行计算:自注意力对所有位置的处理是独立的,无需等待前一步,可充分利用GPU算力;
- 全局依赖:每个位置能直接“看到”序列中的所有其他位置,远端信息无衰减;
- 可扩展性:模块结构简洁,支持多层堆叠(比如GPT-3有96层Transformer),能不断提升模型容量。
1.3 Transformer整体架构:先看“骨架”再拆“细节”
Transformer分为Encoder(编码器) 和Decoder(解码器) 两部分,核心模块包括:自注意力、多头注意力、前馈网络、残差连接、LayerNorm、位置编码。
先记住整体架构(以机器翻译任务为例):
- Encoder(N层堆叠):输入源语言序列(如英文),输出“源语言语义表示”;每层包含“多头自注意力”和“前馈网络”;
- Decoder(N层堆叠):输入目标语言序列(如中文,自回归生成),结合Encoder的输出,输出目标语言下一个词的概率;每层包含“掩码多头自注意力”“Encoder-Decoder注意力”和“前馈网络”;
- 输出层:线性变换+Softmax,将Decoder的输出转为词表概率分布。
(注:原论文中N=6,即Encoder和Decoder各6层)
接下来,我们从“最核心的自注意力”开始,逐个拆解每个模块。
二、Transformer的“心脏”:自注意力机制(Self-Attention)
自注意力是Transformer的核心,它的本质是“让序列中的每个位置,通过与其他位置的‘注意力交互’,学习到更丰富的语义表示”——通俗说就是“每个词都要‘关注’序列中对自己重要的其他词”。
2.1 先理解“注意力”:从人类到机器
人类的注意力是“选择性关注”——比如读“猫坐在红色的垫子上”,我们会自然关注“猫”(主语)和“垫子”(宾语),而“红色的”(定语)关注度较低。
机器的注意力机制模拟了这一过程,核心是**“根据‘查询’(Query),对‘键’(Key)进行权重分配,再用权重加权‘值’(Value)”**,即:注意力输出 = 对Value按Key与Query的相似度加权求和
而“自注意力”(Self-Attention)是“Query、Key、Value都来自同一个输入序列”——比如处理句子时,每个词既是Query(要查询其他词),也是Key(被其他词查询),还是Value(提供信息给其他词)。
2.2 自注意力三要素:Q、K、V的定义与作用
首先明确三个核心向量的来源和含义(假设输入序列的嵌入维度为d_model
,如512):
- Query(查询向量,Q):代表“当前位置需要什么信息”,维度
d_k
(如64); - Key(键向量,K):代表“当前位置提供什么信息”,维度
d_k
; - Value(值向量,V):代表“当前位置的具体信息”,维度
d_v
(通常d_v = d_k
)。
它们的生成方式很简单:对输入序列的嵌入矩阵X
(维度[batch_size, seq_len, d_model]
)做三次线性变换:
Q=XWQ Q = X W_Q Q=XWQ
K=XWK K = X W_K K=XWK
V=XWV V = X W_V V=XWV
W_Q
:Query的权重矩阵,维度[d_model, d_k]
;W_K
:Key的权重矩阵,维度[d_model, d_k]
;W_V
:Value的权重矩阵,维度[d_model, d_v]
。
通俗比喻:把每个词看作“学生”,Q
是“学生的问题”(比如“我和其他词是什么关系?”),K
是“其他学生的标签”(比如“我是主语”“我是宾语”),V
是“其他学生的答案”(比如“我是‘猫’,代表动物”)。
2.3 自注意力完整计算流程:4步推导(附公式+案例)
假设输入序列是“猫坐在垫子上”(seq_len=5),嵌入维度d_model=512
,d_k=d_v=64
,我们一步步计算自注意力输出。
步骤1:生成Q、K、V(线性变换)
对每个词的嵌入向量x_i
(维度d_model=512
),通过W_Q
、W_K
、W_V
生成q_i
、k_i
、v_i
(维度d_k=64
):
比如“猫”的x_1
→ q_1
(“猫的问题”)、k_1
(“猫的标签”)、v_1
(“猫的信息”);
“垫子”的x_4
→ q_4
、k_4
、v_4
。
最终得到Q、K、V矩阵:
- Q:
[batch_size, 5, 64]
(5个词,每个词的Query向量); - K:
[batch_size, 5, 64]
; - V:
[batch_size, 5, 64]
。
步骤2:计算注意力得分(相似度)
注意力得分衡量“Query与每个Key的匹配程度”——得分越高,说明该Key对应的Value对当前Query越重要。
Transformer用点积计算相似度(计算效率高,且当Q/K维度足够大时,效果优于余弦相似度):
Score(Q,K)=QKT \text{Score}(Q, K) = Q K^T Score(Q,K)=QKT
K^T
:K的转置,维度[batch_size, d_k, seq_len]
;- 得分矩阵维度:
[batch_size, seq_len, seq_len]
(比如[batch, 5, 5]
,每个元素Score[i][j]
是第i个词的Q与第j个词的K的相似度)。
关键优化:缩放点积(Scaled Dot-Product)
点积的结果会随d_k
增大而增大(比如d_k=64
时,点积结果可能达到64×(0.1)^2=0.64,而d_k=1024
时会达到10.24),导致Softmax后梯度消失(数值过大,Softmax输出趋近于0或1)。
解决方案:除以√d_k
做缩放:
Scaled Score(Q,K)=QKTdk \text{Scaled Score}(Q, K) = \frac{Q K^T}{\sqrt{d_k}} Scaled Score(Q,K)=dkQKT
案例:“猫”(i=1)的Q与“垫子”(j=4)的K点积后除以√64=8
,得到得分Score[1][4] = 3.2
(高分,说明“垫子”对“猫”很重要);与“的”(j=3)的得分Score[1][3] = 0.5
(低分,说明“的”不重要)。
步骤3:Softmax归一化(得到注意力权重)
对步骤2的得分矩阵每行做Softmax,将得分转为“0~1”的权重,且每行权重和为1:
Attention Weight=Softmax(QKTdk) \text{Attention Weight} = \text{Softmax}\left( \frac{Q K^T}{\sqrt{d_k}} \right) Attention Weight=Softmax(dkQKT)
- 权重矩阵维度:
[batch_size, seq_len, seq_len]
,每个元素Weight[i][j]
是第i个词对第j个词的注意力权重。
案例:“猫”(i=1)的权重行是[0.1, 0.05, 0.05, 0.7, 0.1]
——对“垫子”(j=4)的权重0.7(最高),对其他词的权重低,符合语义逻辑。
步骤4:加权求和(得到自注意力输出)
用注意力权重对V矩阵的列进行加权求和,得到每个位置的自注意力输出:
Self-Attention(Q,K,V)=Softmax(QKTdk)V \text{Self-Attention}(Q, K, V) = \text{Softmax}\left( \frac{Q K^T}{\sqrt{d_k}} \right) V Self-Attention(Q,K,V)=Softmax(dkQKT)V
- 输出维度:
[batch_size, seq_len, d_v]
(比如[batch, 5, 64]
,每个词的输出向量融合了所有词的信息)。
案例:“猫”的自注意力输出 = 0.1×v_1
(猫自身) + 0.05×v_2
(坐) + 0.05×v_3
(的) + 0.7×v_4
(垫子) + 0.1×v_5
(上)——核心融合了“垫子”的信息,语义更丰富。
2.4 关键问题:为什么用点积?为什么要缩放?
(1)为什么用点积而不是其他相似度(如余弦、MLP)?
- 计算效率高:点积可通过矩阵乘法并行实现,而MLP(如
v^T W q
)需要额外参数,计算量更大; - 效果足够好:当
d_k
足够大时,点积与余弦相似度的效果接近(余弦相似度是归一化后的点积); - 原论文验证:作者对比了点积、缩放点积、MLP三种方式,发现缩放点积在效果和效率上最优。
(2)为什么缩放能缓解梯度消失?
Softmax函数的梯度与exp(x_i)
相关——当x_i
(缩放前的得分)过大时,exp(x_i)
会溢出,导致Softmax输出趋近于1(对应位置)和0(其他位置),梯度趋近于0;
除以√d_k
后,得分的方差被控制在1附近(原论文证明:当Q/K的元素独立同分布于N(0,1)时,QK^T
的元素方差为d_k
,缩放后方差为1),避免数值溢出,梯度更稳定。
三、更强的表达能力:多头注意力(Multi-Head Attention)
自注意力能捕捉序列的全局依赖,但“单头”只能关注一种维度的关系(比如语义关系)。多头注意力通过“拆分多个头并行计算,再合并”,让模型同时捕捉多维度的关系(如语义、语法、位置)。
3.1 单头注意力的局限:“一叶障目”
比如处理句子“他在公园散步,那里有很多花”:
- 单头注意力可能只关注“他-散步”的语义关系,却忽略“公园-那里”的指代关系、“散步-花”的场景关系;
- 多头注意力可以用不同的头分别捕捉这三种关系,再融合所有头的信息,表达能力更强。
3.2 多头注意力的设计思路:“分而治之,合而用之”
核心逻辑是“将Q、K、V拆分为h个独立的子空间(头),每个头计算自注意力,最后将所有头的输出合并并线性变换”,共5步:
步骤1:拆分Q、K、V(按头划分维度)
假设多头数h=8
(原论文设置),d_model=512
,则每个头的维度d_k = d_model / h = 64
。
将Q、K、V沿d_model
维度拆分为h个部分:
- Q_split:
[batch_size, h, seq_len, d_k]
(每个头的Q维度[batch, seq_len, 64]
); - K_split、V_split同理。
注意:拆分时需调整维度顺序(将“头数h”提前),方便后续并行计算。
步骤2:每个头独立计算自注意力
对每个头的Q_split[i]、K_split[i]、V_split[i],按自注意力流程计算输出:
headi=Self-Attention(Qsplit[i],Ksplit[i],Vsplit[i]) \text{head}_i = \text{Self-Attention}(Q_{\text{split}[i]}, K_{\text{split}[i]}, V_{\text{split}[i]}) headi=Self-Attention(Qsplit[i],Ksplit[i],Vsplit[i])
- 每个head的输出维度:
[batch_size, seq_len, d_k]
; - h个头的总输出:
[batch_size, h, seq_len, d_k]
。
步骤3:合并所有头的输出
将h个头的输出沿d_k
维度拼接(concat),恢复到d_model
维度:
Concat=cat(head1,head2,...,headh) \text{Concat} = \text{cat}(\text{head}_1, \text{head}_2, ..., \text{head}_h) Concat=cat(head1,head2,...,headh)
- 拼接后维度:
[batch_size, seq_len, h×d_k] = [batch, seq_len, d_model]
(因为h×d_k = 8×64=512
)。
步骤4:线性变换(维度对齐+信息融合)
对拼接后的结果做一次线性变换,让模型学习如何融合不同头的信息:
MultiHead(Q,K,V)=Concat×WO \text{MultiHead}(Q,K,V) = \text{Concat} \times W_O MultiHead(Q,K,V)=Concat×WO
W_O
:输出权重矩阵,维度[d_model, d_model]
;- 最终输出维度:
[batch_size, seq_len, d_model]
(与输入X的维度一致,方便后续残差连接)。
步骤5:完整公式总结
MultiHead(Q,K,V)=cat(head1,...,headh)WO \text{MultiHead}(Q,K,V) = \text{cat}(\text{head}_1, ..., \text{head}_h) W_O MultiHead(Q,K,V)=cat(head1,...,headh)WO
where headi=Self-Attention(QWQi,KWKi,VWVi) \text{where } \text{head}_i = \text{Self-Attention}(Q W_{Q_i}, K W_{K_i}, V W_{V_i}) where headi=Self-Attention(QWQi,KWKi,VWVi)
W_{Q_i}
:第i个头的Query权重矩阵(W_Q
拆分为h个,每个维度[d_model, d_k]
);W_{K_i}
、W_{V_i}
同理。
3.3 多头注意力的核心价值:3个维度的提升
- 多关系捕捉:不同头关注不同类型的关系(语义、语法、指代),比如头1关注“主谓”,头2关注“指代”;
- 泛化能力增强:多个头的独立计算相当于“模型集成”,减少单一头的偏见;
- 维度拆分优化:将高维
d_model
拆分为低维d_k
,降低每个头的计算复杂度(总计算量与单头相当,因为h×(seq_len^2×d_k) ≈ seq_len^2×d_model
)。
3.4 Encoder与Decoder中的多头注意力差异
Transformer的Encoder和Decoder用的多头注意力略有不同,核心是“是否需要掩码”:
- Encoder的多头自注意力:无掩码,所有位置可相互关注(比如“猫”可关注“垫子”,“垫子”也可关注“猫”),适合“理解输入序列”;
- Decoder的掩码多头自注意力:有“未来掩码”(Future Mask),即第i个位置只能关注第1~i个位置,不能关注i之后的位置(避免“看到未来的词”,符合自回归生成逻辑);
- Decoder的Encoder-Decoder注意力:Q来自Decoder,K/V来自Encoder,让Decoder“关注输入序列中与当前生成词相关的位置”(比如翻译时,生成“猫”要关注输入的“cat”)。
四、特征强化:前馈网络(Feed-Forward Network, FFN)
多头注意力捕捉了序列的全局依赖,但每个位置的特征仍需进一步“加工”——前馈网络的作用就是“对每个位置的特征进行独立的非线性变换,增强模型的表达能力”。
4.1 前馈网络的结构:简单但高效
前馈网络是一个“两层线性变换+ReLU激活+Dropout”的结构,对每个位置的特征独立处理(即不同位置的计算互不干扰):
FFN(x)=max(0,xW1+b1)W2+b2 \text{FFN}(x) = \max(0, x W_1 + b_1) W_2 + b_2 FFN(x)=max(0,xW1+b1)W2+b2
各部分解析:
-
第一层线性变换:将输入维度
d_model
提升到d_ff
(原论文d_ff=2048
),增加非线性表达能力:
x1=xW1+b1 x_1 = x W_1 + b_1 x1=xW1+b1W_1
:维度[d_model, d_ff]
(512→2048);b_1
:偏置项,维度[d_ff]
。
-
ReLU激活函数:引入非线性,让模型能学习复杂特征(原论文用ReLU,后续模型如GPT用GELU,效果更好):
x1act=max(0,x1) x_{1_{\text{act}}} = \max(0, x_1) x1act=max(0,x1)- ReLU的优势:计算快,缓解梯度消失(相比sigmoid)。
-
第二层线性变换:将维度从
d_ff
降回d_model
,与输入维度对齐(方便残差连接):
x2=x1actW2+b2 x_2 = x_{1_{\text{act}}} W_2 + b_2 x2=x1actW2+b2W_2
:维度[d_ff, d_model]
(2048→512);b_2
:偏置项,维度[d_model]
。
-
Dropout:训练时随机丢弃部分神经元(原论文 dropout rate=0.1),防止过拟合。
4.2 为什么需要前馈网络?——“全局依赖+局部强化”
多头注意力解决了“全局关系捕捉”,但每个位置的特征仍是“其他位置信息的加权和”,缺乏“局部非线性加工”——比如“猫坐在垫子上”,“坐”的特征需要结合“猫”和“垫子”的信息,但还需要进一步强化“动作”属性,这就是前馈网络的作用。
通俗比喻:多头注意力是“收集所有同学的笔记”,前馈网络是“自己整理笔记,提炼重点”——两者结合,才能得到更优质的学习成果。
4.3 关键设计:为什么中间维度d_ff
是4×d_model
?
原论文选择d_ff=2048=4×512
,不是随机的:
- 非线性表达能力:更高的中间维度能让模型学习更复杂的特征映射(比如从512维的语义特征映射到2048维的细粒度特征);
- 计算效率平衡:
4×d_model
是“效果”和“计算量”的平衡点——维度太小,表达能力不足;维度太大,计算量剧增(如8×d_model
会让计算量翻倍); - 实践验证:后续大模型(如BERT、GPT)均沿用这一比例,证明其有效性。
五、Transformer的“稳定器”:残差连接与LayerNorm
Transformer通常堆叠6~100层(如GPT-3有96层),但深层模型容易出现“梯度消失”和“内部协变量偏移”——残差连接和LayerNorm就是解决这两个问题的“稳定器”。
5.1 残差连接(Residual Connection):缓解梯度消失
残差连接的核心是“将输入直接加到输出上”,让梯度能“跳过”中间层直接传递,避免梯度在深层中衰减到0。
公式与结构:
Transformer中,每个模块(多头注意力、前馈网络)的输出都加了残差连接:
Output=LayerNorm(Module(X)+X) \text{Output} = \text{LayerNorm}(\text{Module}(X) + X) Output=LayerNorm(Module(X)+X)
X
:模块的输入;Module(X)
:模块的输出(如多头注意力输出、前馈网络输出);- “+X”:残差连接,输入与输出直接相加(需保证维度一致,Transformer中所有模块输入输出维度都是
d_model
)。
为什么有效?
假设模块是线性变换Module(X) = WX + b
,则输出为WX + b + X = (W+I)X + b
——梯度传递时,会额外保留“X的直接梯度”(即1),避免梯度随层数增加而指数衰减。
案例:100层Transformer,若没有残差连接,梯度可能衰减到0.9^100 ≈ 1e-5
(几乎为0);有残差连接时,梯度能稳定传递。
5.2 LayerNorm:解决内部协变量偏移
“内部协变量偏移”是指“深层模型的输入分布随训练迭代不断变化”,导致模型需要不断适应新分布,训练不稳定。
LayerNorm通过“对每个样本的每个序列位置做归一化”,让输入分布保持稳定。
LayerNorm的公式推导:
对输入X
(维度[batch_size, seq_len, d_model]
),按“特征维度”(d_model
)做归一化:
μi=1dmodel∑k=1dmodelXi,k \mu_i = \frac{1}{d_model} \sum_{k=1}^{d_model} X_{i,k} μi=dmodel1k=1∑dmodelXi,k
σi2=1dmodel∑k=1dmodel(Xi,k−μi)2 \sigma_i^2 = \frac{1}{d_model} \sum_{k=1}^{d_model} (X_{i,k} - \mu_i)^2 σi2=dmodel1k=1∑dmodel(Xi,k−μi)2
X^i,k=Xi,k−μiσi2+ϵ \hat{X}_{i,k} = \frac{X_{i,k} - \mu_i}{\sqrt{\sigma_i^2 + \epsilon}} X^i,k=σi2+ϵXi,k−μi
LayerNorm(X)=γX^i,k+β \text{LayerNorm}(X) = \gamma \hat{X}_{i,k} + \beta LayerNorm(X)=γX^i,k+β
μ_i
:第i个位置特征的均值;σ_i^2
:第i个位置特征的方差;ε
:小常数(如1e-6),避免分母为0;γ
、β
:可学习的缩放和偏移参数(维度d_model
),让模型能调整归一化后的分布(若γ=√d_model
、β=0
,则等价于不调整)。
LayerNorm vs BatchNorm:为什么选LayerNorm?
BatchNorm是“按批次维度归一化”(对每个批次的所有样本同一特征做归一化),适合图像(固定尺寸,批次无关);
LayerNorm是“按样本维度归一化”(对每个样本的所有特征做归一化),适合序列(序列长度可变,每个样本的分布独立)。
比如处理不同长度的句子,BatchNorm会因批次中序列长度不同而失效,LayerNorm则不受影响。
5.3 位置编码(Positional Encoding):给序列“加时间戳”
自注意力本身没有“时序信息”——比如“猫坐在垫子上”和“垫子坐在猫上”,自注意力会计算相同的权重,但语义完全相反。
位置编码的作用是“给每个位置添加独特的位置信息”,让模型能区分不同位置。
正弦余弦位置编码:原论文方案
原论文用正弦和余弦函数生成位置编码,公式如下(pos
是位置索引,i
是特征维度索引):
PEpos,2i=sin(pos100002i/dmodel) PE_{pos, 2i} = \sin\left( \frac{pos}{10000^{2i/d_model}} \right) PEpos,2i=sin(100002i/dmodelpos)
PEpos,2i+1=cos(pos100002i/dmodel) PE_{pos, 2i+1} = \cos\left( \frac{pos}{10000^{2i/d_model}} \right) PEpos,2i+1=cos(100002i/dmodelpos)
PE
:位置编码矩阵,维度[seq_len, d_model]
;2i
、2i+1
:偶数维度用正弦,奇数维度用余弦(保证相邻位置的编码差异);10000
:控制位置编码的周期(位置越远,编码差异越大)。
为什么用正弦余弦?
- 可扩展性:无需训练,可生成任意长度的位置编码(即使测试时序列长度超过训练时);
- 相对位置信息:
PE_{pos+k}
可表示为PE_pos
的线性组合(比如sin(a+b)=sin a cos b + cos a sin b
),让模型能学习相对位置关系; - 计算高效:无需额外参数,直接生成。
直观理解:不同位置的编码向量在高维空间中是正交的,模型能通过学习区分“pos=1”和“pos=5”的差异。
六、Transformer的底层逻辑:为什么能成为大模型基石?
理解了各个模块,我们需要回到核心问题:为什么Transformer能支撑起GPT、BERT、LLaMA这些千亿参数的大模型? 关键在于它的三个底层优势:
6.1 并行计算:大模型训练的“效率引擎”
大模型的训练需要处理海量数据(如万亿tokens),效率是关键。Transformer的并行性体现在两个层面:
- 序列内并行:自注意力对所有位置的处理是独立的,可同时计算整个序列的注意力权重;
- 层间并行:多层Transformer可通过模型并行(不同层放在不同GPU上)或数据并行(不同批次放在不同GPU上)加速训练。
相比之下,RNN的串行计算无法支撑大模型——即使是10亿参数的RNN,训练速度也会慢到无法接受。
6.2 可扩展性:从百万到千亿参数的“架构基础”
Transformer的模块结构极其简洁,所有模块(多头注意力、前馈网络)都是“可堆叠”的:
- 横向扩展:增加多头数
h
(如从8头到16头),提升多关系捕捉能力; - 纵向扩展:增加层数
N
(如从6层到96层),提升特征抽象能力; - 维度扩展:增加
d_model
(如从512到12288),提升每个位置的特征表达能力。
这种“模块化+可堆叠”的设计,让模型能通过增加参数持续提升性能(符合“缩放定律”),而RNN/LSTM因结构限制,参数增加到一定程度后性能会饱和。
6.3 预训练-微调范式:大模型生态的“核心载体”
Transformer的结构天然适合“预训练-微调”范式:
- 预训练阶段:用海量无标注数据(如互联网文本)训练一个通用Transformer模型(如BERT、GPT),学习语言的通用规律;
- 微调阶段:用少量标注数据(如分类任务的1000条样本)微调预训练模型,快速适配具体任务(文本分类、机器翻译、问答)。
这种范式的核心是“Transformer能通过预训练学习到通用的特征表示”,而RNN/LSTM因泛化能力弱,预训练效果远不如Transformer。
七、实战:用PyTorch实现自注意力与多头注意力(附代码+可视化)
理论讲完,我们用PyTorch实现核心模块,并用注意力热力图可视化,直观理解注意力的分布。
7.1 环境准备
# 安装依赖(已安装可跳过)
pip install torch torchvision pandas numpy matplotlib seaborn
7.2 实现自注意力模块
import torch
import torch.nn as nn
import torch.nn.functional as F
class SelfAttention(nn.Module):
def __init__(self, d_model, d_k):
super().__init__()
self.d_k = d_k # 每个头的维度
# 定义Q、K、V的线性变换权重
self.W_Q = nn.Linear(d_model, d_k)
self.W_K = nn.Linear(d_model, d_k)
self.W_V = nn.Linear(d_model, d_k)
def forward(self, X):
# X: [batch_size, seq_len, d_model]
batch_size, seq_len, _ = X.shape
# 步骤1:生成Q、K、V
Q = self.W_Q(X) # [batch_size, seq_len, d_k]
K = self.W_K(X) # [batch_size, seq_len, d_k]
V = self.W_V(X) # [batch_size, seq_len, d_k]
# 步骤2:计算缩放点积得分
score = torch.matmul(Q, K.transpose(-2, -1)) # [batch_size, seq_len, seq_len]
score = score / torch.sqrt(torch.tensor(self.d_k, dtype=torch.float32)) # 缩放
# 步骤3:Softmax得到注意力权重
attn_weight = F.softmax(score, dim=-1) # [batch_size, seq_len, seq_len]
# 步骤4:加权求和得到输出
output = torch.matmul(attn_weight, V) # [batch_size, seq_len, d_k]
return output, attn_weight # 返回输出和注意力权重(用于可视化)
7.3 实现多头注意力模块
class MultiHeadAttention(nn.Module):
def __init__(self, d_model, h=8):
super().__init__()
self.h = h # 多头数
self.d_k = d_model // h # 每个头的维度(d_model必须能被h整除)
assert d_model % h == 0, "d_model must be divisible by h"
# 定义Q、K、V的线性变换(共享权重,后续拆分)
self.W_Q = nn.Linear(d_model, d_model)
self.W_K = nn.Linear(d_model, d_model)
self.W_V = nn.Linear(d_model, d_model)
# 输出线性变换
self.W_O = nn.Linear(d_model, d_model)
# 自注意力模块(每个头共用)
self.self_attn = SelfAttention(d_model, self.d_k)
def forward(self, X):
# X: [batch_size, seq_len, d_model]
batch_size, seq_len, d_model = X.shape
# 步骤1:生成Q、K、V并拆分多头
Q = self.W_Q(X) # [batch_size, seq_len, d_model]
K = self.W_K(X)
V = self.W_V(X)
# 拆分维度:[batch_size, seq_len, d_model] → [batch_size, h, seq_len, d_k]
Q_split = Q.view(batch_size, seq_len, self.h, self.d_k).transpose(1, 2)
K_split = K.view(batch_size, seq_len, self.h, self.d_k).transpose(1, 2)
V_split = V.view(batch_size, seq_len, self.h, self.d_k).transpose(1, 2)
# 步骤2:每个头计算自注意力
attn_outputs = []
attn_weights = []
for i in range(self.h):
q = Q_split[:, i, :, :] # [batch_size, seq_len, d_k]
k = K_split[:, i, :, :]
v = V_split[:, i, :, :]
output, weight = self.self_attn(q) # 每个头的输出和权重
attn_outputs.append(output)
attn_weights.append(weight)
# 步骤3:合并所有头的输出
attn_outputs = torch.stack(attn_outputs, dim=1) # [batch_size, h, seq_len, d_k]
# 转置+拼接:[batch_size, seq_len, h*d_k] = [batch_size, seq_len, d_model]
attn_outputs = attn_outputs.transpose(1, 2).contiguous().view(batch_size, seq_len, d_model)
# 步骤4:输出线性变换
output = self.W_O(attn_outputs) # [batch_size, seq_len, d_model]
# 合并注意力权重(用于可视化,取所有头的平均)
attn_weights = torch.mean(torch.stack(attn_weights, dim=1), dim=1) # [batch_size, seq_len, seq_len]
return output, attn_weights
7.4 注意力权重可视化(案例:句子“猫坐在垫子上”)
我们用一个简单案例展示注意力分布,步骤如下:
- 构造句子的嵌入向量(用随机嵌入模拟,实际中用预训练词嵌入);
- 输入多头注意力模块,得到注意力权重;
- 用热力图可视化权重分布。
import matplotlib.pyplot as plt
import seaborn as sns
# 1. 构造输入数据
sentence = ["猫", "坐", "在", "垫子", "上"] # 输入句子
seq_len = len(sentence)
d_model = 512 # 嵌入维度
batch_size = 1
# 随机生成嵌入向量(模拟词嵌入)
torch.manual_seed(42) # 固定随机种子,结果可复现
X = torch.randn(batch_size, seq_len, d_model) # [1, 5, 512]
# 2. 初始化多头注意力模块
multi_head_attn = MultiHeadAttention(d_model=d_model, h=8)
output, attn_weight = multi_head_attn(X) # attn_weight: [1, 5, 5]
# 3. 提取注意力权重(取第一个batch)
attn_weight = attn_weight[0].detach().cpu().numpy() # [5, 5]
# 4. 绘制热力图
plt.figure(figsize=(10, 8))
sns.heatmap(
attn_weight,
annot=True, # 显示数值
cmap="YlOrRd", # 颜色映射
xticklabels=sentence, # x轴标签(被关注的词)
yticklabels=sentence # y轴标签(关注其他词的词)
)
plt.title("Multi-Head Attention Weight Heatmap (Sentence: 猫坐在垫子上)")
plt.xlabel("Attended Words (Key)")
plt.ylabel("Query Words")
plt.show()
可视化结果分析:
热力图中,颜色越红、数值越大,说明注意力权重越高。预期结果:
- “猫”(y=0)对“垫子”(x=3)的权重最高(如0.6),符合“猫坐在垫子上”的语义;
- “垫子”(y=3)对“猫”(x=0)和“上”(x=4)的权重较高;
- “在”(y=2)对所有词的权重较低(功能词,语义贡献小)。
7.5 用Hugging Face验证:加载BERT查看注意力
实际应用中,我们常用Hugging Face的transformers
库加载预训练模型,直接查看注意力权重:
from transformers import BertTokenizer, BertModel
# 1. 加载BERT模型和Tokenizer
tokenizer = BertTokenizer.from_pretrained('bert-base-chinese')
model = BertModel.from_pretrained('bert-base-chinese', output_attentions=True) # 输出注意力权重
# 2. 输入句子并编码
sentence = "猫坐在垫子上"
inputs = tokenizer(sentence, return_tensors="pt") # 编码为token ID
# 3. 前向传播,获取注意力权重
with torch.no_grad():
outputs = model(**inputs)
attentions = outputs.attentions # 所有层的注意力权重,形状:[12层, batch_size, 12头, seq_len, seq_len]
# 4. 提取第1层第1个头的注意力权重(可调整层和头)
layer_idx = 0 # 第1层
head_idx = 0 # 第1个头
attn_weight = attentions[layer_idx][0][head_idx].detach().cpu().numpy() # [seq_len, seq_len]
# 5. 获取token标签(BERT会添加[CLS]和[SEP])
tokens = tokenizer.convert_ids_to_tokens(inputs["input_ids"][0]) # ['[CLS]', '猫', '坐', '在', '垫', '子', '上', '[SEP]']
# 6. 绘制热力图
plt.figure(figsize=(12, 10))
sns.heatmap(
attn_weight,
annot=True,
cmap="YlOrRd",
xticklabels=tokens,
yticklabels=tokens
)
plt.title(f"BERT Attention Weight (Layer {layer_idx+1}, Head {head_idx+1})")
plt.xlabel("Attended Tokens")
plt.ylabel("Query Tokens")
plt.show()
结果分析:
BERT的注意力权重更符合语言规律,比如:
- “垫”(token=4)和“子”(token=5)的权重极高(因为“垫子”是一个词);
- “猫”(token=1)对“垫”(token=4)、“子”(token=5)的权重较高;
[CLS]
(token=0)对所有词的权重较均匀(负责整合整个句子的语义)。
八、南木的学习路径建议:从入门到精通Transformer
很多人学Transformer会陷入“公式劝退”或“代码跑不通”的困境,我结合自己的经验,整理了一条“循序渐进”的学习路径:
阶段1:基础铺垫(1~2周)
- 数学基础:回顾线性代数(矩阵乘法、转置)、概率论(Softmax、交叉熵)、微积分(梯度下降)——不用深钻,能理解公式含义即可;
- 工具基础:掌握PyTorch核心操作(张量、模型定义、前向传播),熟悉Hugging Face
transformers
库的基本使用(加载模型、编码数据); - 前置知识:了解RNN/LSTM的缺陷(并行性、长序列依赖),明白Transformer的设计动机。
阶段2:核心拆解(2~3周)
- 第一步:吃透自注意力:
- 手动计算一个简单案例(如3个词的句子),推导Q/K/V生成、得分计算、Softmax、加权求和的全过程;
- 用PyTorch实现自注意力模块,调整
d_k
观察缩放的影响;
- 第二步:理解多头注意力:
- 重点搞懂“拆分-合并”的维度变化(比如
d_model=512
→h=8
→d_k=64
); - 实现多头注意力,对比单头和多头的注意力权重分布;
- 重点搞懂“拆分-合并”的维度变化(比如
- 第三步:掌握其他模块:
- 实现前馈网络,理解“维度升降”的作用;
- 推导LayerNorm的计算过程,对比BatchNorm的差异;
- 生成正弦余弦位置编码,可视化不同位置的编码向量。
阶段3:实战落地(3~4周)
- 入门项目:
- 用Transformer实现文本分类(如IMDB情感分析),熟悉Encoder的使用;
- 用注意力可视化工具(如
bertviz
)分析预训练模型的注意力分布;
- 进阶项目:
- 用Transformer实现机器翻译(Encoder-Decoder结构),理解Decoder的掩码机制;
- 微调预训练模型(如BERT微调做问答任务),掌握“预训练-微调”范式;
- 挑战项目:
- 实现简化版GPT(Decoder-only结构),理解自回归生成逻辑;
- 用ViT(Vision Transformer)做图像分类,理解Transformer在CV领域的应用。
阶段4:进阶扩展(长期)
- 大模型优化:学习稀疏注意力(如Longformer)、量化(如INT8量化)、蒸馏(模型压缩),解决大模型部署问题;
- 理论深入:阅读Transformer的改进论文(如GPT-2/3、BERT、T5),理解“自回归”“双向注意力”“编码器-解码器”的差异;
- 跨领域应用:探索Transformer在CV(ViT)、语音(Wav2Vec 2.0)、多模态(CLIP)的应用,理解“注意力是通用机制”。
九、常见误区与解答
误区1:自注意力一定是全局的?
错。自注意力可以通过“掩码”限制关注范围:
- Decoder的“未来掩码”:让第i个位置只能关注前i个位置;
- 稀疏注意力(如Longformer):只关注局部窗口+全局关键位置,降低长序列的计算量。
误区2:多头注意力头数越多越好?
错。头数增加会带来两个问题:
- 计算量剧增:头数h翻倍,计算量也翻倍(
O(h×seq_len²×d_k)
); - 边际效益递减:头数超过16后,模型性能提升不明显(原论文h=8,GPT-3 h=96是因为
d_model
更大)。
误区3:Transformer不需要时序信息?
错。Transformer本身没有时序信息,必须通过“位置编码”补充——如果去掉位置编码,模型无法区分“猫坐在垫子上”和“垫子坐在猫上”。
误区4:Transformer在CV领域不如CNN?
错。ViT(2020年)证明Transformer在CV领域可超越CNN:
- ViT将图像拆分为“图像块”(如16×16),视为“序列”输入Transformer;
- 大尺度预训练(如ImageNet-21K)后,ViT在图像分类、目标检测等任务上的性能超过CNN。
十、总结
Transformer的核心不是“复杂的公式”,而是“简洁高效的注意力机制”,它用自注意力解决了RNN的并行性和长序列依赖问题,用多头注意力提升了多关系捕捉能力,用可堆叠架构支撑了大模型的 scalability。
从技术发展来看,Transformer不仅是NLP领域的革命,更是AI的通用基础架构——它统一了NLP、CV、语音、多模态等领域,成为大模型时代的“基石”。
如果你觉得这篇文章有用,记得点赞+收藏,关注我(南木),后续会分享更多大模型实战干货(如GPT简化版实现、LLaMA微调)。如果在学习过程中遇到问题,欢迎在评论区留言,我会定期回复!
更多推荐
所有评论(0)