Self-Attention原理和实现代码(Pytorch实现)
本文介绍了Self-Attention机制的原理与实现。Self-Attention通过计算序列元素间的相关性权重,动态关注输入序列的不同部分,是Transformer架构的核心组件。文章详细阐述了其计算流程,包括生成Q/K/V向量、计算注意力分数、权重归一化和加权聚合输出。同时提供了单头注意力和多头注意力的PyTorch实现代码,其中多头注意力通过并行多个注意力头来捕获不同表示子空间的信息。Se
一、Self-Attention原理
Self-Attention(自注意力)机制是Transformer架构中的核心组件,主要用于捕捉序列中不同位置元素之间的依赖关系。其核心思想是通过计算序列中每个元素与其他元素之间的相关性(注意力权重),然后根据这些权重对序列信息进行加权聚合。计算过程主要分为以下几步:
-
输入表示:
对于输入序列 X X X中的每个元素,通过三个不同的权重矩阵 W Q W^Q WQ、 W K W^K WK、 W V W^V WV,分别生成:查询向量(Query): Q = X ⋅ W Q Q = X \cdot W^Q Q=X⋅WQ
键向量(Key): K = X ⋅ W K K = X \cdot W^K K=X⋅WK
值向量(Value): V = X ⋅ W V V = X \cdot W^V V=X⋅WV
其中,Q、K、V 承担不同的角色。
Q (Query): 查询向量,表示当前需要关注的内容。例如在机器翻译中,解码器当前位置的隐状态作为 Query,用于查询源语言的相关信息。
K (Key): 键向量,表示待匹配的索引。Key 与 Value 关联,用于计算与 Query 的相似度。
V (Value): 值向量,存储实际需要提取的信息。相似度计算后,Value 会按权重聚合生成输出。 -
计算注意力分数:
通过计算 Q Q Q与 K K K的点积,得到元素间的相似度分数: Scores = Q ⋅ K ⊤ \text{Scores} = Q \cdot K^{\top} Scores=Q⋅K⊤ 为了稳定梯度,分数会除以缩放因子 d k \sqrt{d_k} dk( d k d_k dk是 K K K的维度)。 -
生成注意力权重:
对分数应用Softmax函数,得到归一化的注意力权重: AttentionWeight = Softmax ( Q ⋅ K ⊤ d k ) \text{AttentionWeight} = \text{Softmax}\left(\frac{Q \cdot K^{\top}}{\sqrt{d_k}}\right) AttentionWeight=Softmax(dkQ⋅K⊤) -
加权聚合输出:
用注意力权重对 V V V加权求和,得到最终输出: Output = AttentionWeight ⋅ V \text{Output} = \text{AttentionWeight} \cdot V Output=AttentionWeight⋅V
综上所述,Self-Attention的完整公式如下:
Attention(Q,K,V) = Softmax ( Q ⋅ K ⊤ d k ) ⋅ V \text{Attention(Q,K,V)} = \text{Softmax}\left(\frac{Q \cdot K^{\top}}{\sqrt{d_k}}\right) \cdot V Attention(Q,K,V)=Softmax(dkQ⋅K⊤)⋅V
核心作用:
Self-Attention 让模型能够动态关注输入序列的不同部分(例如句子中与当前词相关的其他词),从而更好地理解上下文依赖关系。
二、Self-Attention代码实现
1.单头注意力
单头注意力的代码实现
代码如下(示例):
class Attention(nn.Module):
def __init__(self, d_model, head_size, context_length, dropout=0.1):
'''
d_model为输入序列的语义维度
head_size为一个注意力头在语义维度占据的大小,head_size = d_model/num_heads
context_len为输入序列的长度
'''
super().__init__()
self.head_size = head_size
self.Wq = nn.Linear(d_model, head_size, bias=False)
self.Wk = nn.Linear(d_model, head_size, bias=False)
self.Wv = nn.Linear(d_model, head_size, bias=False)
self.register_buffer('mask', torch.tril(torch.ones(context_length, context_length)))
self.dropout = nn.Dropout(dropout)
def forward(self, x):
B, T, C = x.shape
q = self.Wq(x)
k = self.Wk(x)
v = self.Wv(x)
weights = (q @ k.transpose(-2, -1)) / math.sqrt(self.head_size)
weights = weights.masked_fill(self.mask[:T, :T] == 0, float('-inf'))
weights = F.softmax(weights, dim=-1)
return weights @ v
2.多头注意力
多头注意力的核心思想:使用多组( h h h 个头)不同的查询 Q Q Q、键 K K K、值 V V V 投影(线性变换),让模型能够并行地从不同的表示子空间(子空间维度 d k d_k dk, d v d_v dv, d m o d e l / h d_{model}/h dmodel/h)学习信息。每个头学习不同的关注模式。
多头注意力的计算流程
假设:
输入维度: d m o d e l d_{model} dmodel(例如 512)
头数: h h h(例如 8)
每个头的维度: d k = d v = d m o d e l / h d_k = d_v = d_{model} / h dk=dv=dmodel/h(例如 512 / 8 = 64)
步骤:
-
线性投影(生成 h 组 Q, K, V):
对原始的查询 Q Q Q、键 K K K、值 V V V(维度均为 d m o d e l d_{model} dmodel)分别应用 h h h 组不同的线性变换(权重矩阵)。
得到 h h h 组投影后的查询 Q i Q_i Qi、键 K i K_i Ki、值 V i V_i Vi,每组维度为 d k d_k dk, d k d_k dk, d v d_v dv(通常 d k = d v = d m o d e l / h d_k = d_v = d_{model}/h dk=dv=dmodel/h)。
Q i = Q W i Q , K i = K W i K , V i = V W i V for i = 1 , . . . , h Q_i = Q W_i^Q, \quad K_i = K W_i^K, \quad V_i = V W_i^V \quad \text{for } i = 1, ..., h Qi=QWiQ,Ki=KWiK,Vi=VWiVfor i=1,...,h 其中 W i Q , W i K , W i V W_i^Q, W_i^K, W_i^V WiQ,WiK,WiV 是可学习的投影矩阵。 -
并行计算缩放点积注意力:
对每一组投影后的 Q i , K i , V i Q_i, K_i, V_i Qi,Ki,Vi,独立计算缩放点积注意力: head i = Attention ( Q i , K i , V i ) = softmax ( Q i K i T d k ) V i \text{head}_i = \text{Attention}(Q_i, K_i, V_i) = \text{softmax}(\frac{Q_i K_i^T}{\sqrt{d_k}}) V_i headi=Attention(Qi,Ki,Vi)=softmax(dkQiKiT)Vi -
拼接多头输出:
将 h h h 个注意力头计算的结果 head 1 , head 2 , . . . , head h \text{head}_1, \text{head}_2, ..., \text{head}h head1,head2,...,headh(每个维度为 d v d_v dv)拼接起来,得到一个维度为 h × d v = d m o d e l h \times d_v = d{model} h×dv=dmodel 的向量。 MultiHead ( Q , K , V ) = Concat ( head 1 , head 2 , . . . , head h ) \text{MultiHead}(Q, K, V) = \text{Concat}(\text{head}_1, \text{head}_2, ..., \text{head}_h) MultiHead(Q,K,V)=Concat(head1,head2,...,headh) -
最终线性投影(可选,但常用):
将拼接后的结果通过另一个线性变换 W O W^O WO 投影到最终的输出维度(通常是 d m o d e l d_{model} dmodel)。 Output = MultiHead ( Q , K , V ) W O \text{Output} = \text{MultiHead}(Q, K, V) W^O Output=MultiHead(Q,K,V)WO 其中 W O W^O WO 是维度为 d m o d e l × d m o d e l d_{model} \times d_{model} dmodel×dmodel 的可学习权重矩阵。
多头注意力的代码实现
多头注意力的代码实现方法十分简单,只需使用到上面单头注意力的代码即可。
代码如下:
class MultiHeadAttention(nn.Module):
def __init__(self, d_model, num_heads, head_size, context_length, dropout=0.1):
super().__init__()
self.heads = nn.ModuleList([
Attention(d_model, head_size, context_length, dropout)
for _ in range(num_heads)
])
self.projection_layer = nn.Linear(d_model, d_model)
self.dropout = nn.Dropout(dropout)
def forward(self, x):
head_outputs = torch.cat([head(x) for head in self.heads], dim=-1)
return self.dropout(self.projection_layer(head_outputs))
总结
以上就是今天要讲的内容,本文仅仅简单介绍了Self-Attention的原理和简单实现,Self-Attention是大模型的核心机制需要认真学习。
更多推荐


所有评论(0)