注意力机制SelfAttention和CrossAttention
自注意力机制和跨注意力机制介绍
✅ 第一章:SelfAttention
1.1 SelfAttention 计算公式
📌 基本定义:
Self-Attention 是 Transformer 架构中的核心机制,它让模型在处理序列数据时能够关注到序列内部不同位置之间的依赖关系。
给定输入张量 X∈R(B,T,D)X \in \mathbb{R}^{(B, T, D)}X∈R(B,T,D),其中:
- BBB: Batch size(批量大小)
- TTT: Sequence length(序列长度)
- DDD: Embedding dimension(每个 token 的嵌入维度)
我们通过线性变换分别生成 Query、Key 和 Value 向量:
Q=XWQ,K=XWK,V=XWV Q = XW_Q,\quad K = XW_K,\quad V = XW_V Q=XWQ,K=XWK,V=XWV
然后计算注意力得分:
Attention(Q,K,V)=softmax(QKTDk)V \text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{D_k}}\right)V Attention(Q,K,V)=softmax(DkQKT)V
其中 $ D_k $ 是 Key 的维度(通常是 $ D / h $,如果使用多头注意力),用于缩放防止梯度消失。
最终输出为:
Y=Attention(Q,K,V) Y = \text{Attention}(Q, K, V) Y=Attention(Q,K,V)
输出形状仍为 Y∈R(B,T,D)Y \in \mathbb{R}^{(B, T, D)}Y∈R(B,T,D)
1.2 手写 SelfAttention 实现
以下是一个完整的 PyTorch 实现的 单头 Self-Attention 模块,并配有注释说明:
import torch
import torch.nn as nn
import torch.nn.functional as F
class SelfAttention(nn.Module):
def __init__(self, embed_size):
super(SelfAttention, self).__init__()
self.embed_size = embed_size
# 定义 Q, K, V 的线性变换层
self.query = nn.Linear(embed_size, embed_size)
self.key = nn.Linear(embed_size, embed_size)
self.value = nn.Linear(embed_size, embed_size)
def forward(self, x):
"""
参数:
x: 输入张量,形状 (B, T, D)
返回:
out: 注意力输出,形状 (B, T, D)
"""
B, T, D = x.size()
# 线性变换得到 Q, K, V
Q = self.query(x) # (B, T, D)
K = self.key(x) # (B, T, D)
V = self.value(x) # (B, T, D)
# 缩放点积注意力分数
scores = torch.bmm(Q, K.transpose(1, 2)) / (D ** 0.5) # (B, T, T)
# softmax 归一化
attn_weights = F.softmax(scores, dim=-1) # (B, T, T)
# 加权求和
out = torch.bmm(attn_weights, V) # (B, T, D)
return out
🔍 示例使用:
batch_size = 32
seq_len = 10
embed_dim = 512
model = SelfAttention(embed_dim)
x = torch.randn(batch_size, seq_len, embed_dim)
y = model(x)
print(y.shape) # 输出:torch.Size([32, 10, 512])
1.3 使用 PyTorch 内置模块:nn.MultiheadAttention
PyTorch 提供了内置的多头自注意力模块:nn.MultiheadAttention。这个模块实现了标准的多头注意力机制,支持并行计算多个注意力头。
📦 导入模块
import torch
import torch.nn as nn
🔧 初始化 MultiheadAttention 模块
embed_dim = 512 # 每个 token 的维度
num_heads = 8 # 多头注意力的头数
multihead_attn = nn.MultiheadAttention(embed_dim=embed_dim, num_heads=num_heads)
⚠️ 注意输入格式:
query,key,value张量的形状必须是(T, B, D),即 序列长度优先
如果你的数据是 (B, T, D) 格式,需要先转置:
x = torch.randn(32, 10, 512) # shape: (B, T, D)
x = x.transpose(0, 1) # 转换为 (T, B, D)
🔁 前向传播示例
# 输入张量 (T, B, D)
query = key = value = torch.randn(10, 32, 512)
# 前向传播
attn_output, attn_weights = multihead_attn(query, key, value)
print("Output shape:", attn_output.shape) # (10, 32, 512)
print("Weights shape:", attn_weights.shape) # (32, 10, 10)
🧠 小贴士
attn_weights: 表示每个查询对所有键的关注程度,形状为(B, T, T)- 如果你只想使用自注意力,那么
query == key == value - 可以传入 mask 来屏蔽某些位置,如 padding 或 future positions(解码器中使用)
📝 总结对比
| 特性 | 手写 SelfAttention | nn.MultiheadAttention |
|---|---|---|
| 是否支持多头 | ❌ 单头 | ✅ 支持多头 |
| 易用性 | 需要自己实现 | 直接调用 |
| 灵活性 | 更灵活 | 固定结构 |
| 性能优化 | 非最优 | 经过优化 |
| 学习价值 | ✅ 有助于理解原理 | ✅ 快速构建模型 |
✅ 第二章:CrossAttention
2.1 CrossAttention 计算公式
📌 基本定义:
Cross-Attention 是注意力机制的一种变体,通常用于解码器中,让解码器能够关注编码器输出的信息。
给定两个输入张量:
- Query 来自解码器的隐藏状态 Q∈R(B,Tq,D)Q \in \mathbb{R}^{(B, T_q, D)}Q∈R(B,Tq,D)
- Key 和 Value 来自编码器的输出 K=V∈R(B,Tk,D)K = V \in \mathbb{R}^{(B, T_k, D)}K=V∈R(B,Tk,D)
计算方式如下:
CrossAttention(Q,K,V)=softmax(QKTDk)V \text{CrossAttention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{D_k}}\right)V CrossAttention(Q,K,V)=softmax(DkQKT)V
最终输出为:
Y=CrossAttention(Q,K,V) Y = \text{CrossAttention}(Q, K, V) Y=CrossAttention(Q,K,V)
输出形状为 Y∈R(B,Tq,D)Y \in \mathbb{R}^{(B, T_q, D)}Y∈R(B,Tq,D)
⚠️ 注意:Query 的序列长度 TqT_qTq 可以与 Key/Value 的序列长度 TkT_kTk 不同。
2.2 手写 CrossAttention 实现
以下是一个完整的 PyTorch 实现的 CrossAttention 模块,并配有注释说明:
import torch
import torch.nn as nn
import torch.nn.functional as F
class CrossAttention(nn.Module):
def __init__(self, embed_size):
super(CrossAttention, self).__init__()
self.embed_size = embed_size
# 定义 Q、K、V 的线性变换层
self.query = nn.Linear(embed_size, embed_size)
self.key = nn.Linear(embed_size, embed_size)
self.value = nn.Linear(embed_size, embed_size)
def forward(self, query_input, key_value_input):
"""
参数:
query_input: 解码器输出,形状 (B, Tq, D)
key_value_input: 编码器输出,形状 (B, Tk, D)
返回:
out: 跨注意力输出,形状 (B, Tq, D)
"""
B, Tq, D = query_input.size()
_, Tk, _ = key_value_input.size()
# 线性变换得到 Q, K, V
Q = self.query(query_input) # (B, Tq, D)
K = self.key(key_value_input) # (B, Tk, D)
V = self.value(key_value_input) # (B, Tk, D)
# 缩放点积注意力分数
scores = torch.bmm(Q, K.transpose(1, 2)) / (D ** 0.5) # (B, Tq, Tk)
# softmax 归一化
attn_weights = F.softmax(scores, dim=-1) # (B, Tq, Tk)
# 加权求和
out = torch.bmm(attn_weights, V) # (B, Tq, D)
return out
🔍 示例使用:
batch_size = 32
seq_len_decoder = 8 # 解码器序列长度
seq_len_encoder = 10 # 编码器序列长度
embed_dim = 512
model = CrossAttention(embed_dim)
query_input = torch.randn(batch_size, seq_len_decoder, embed_dim) # 解码器输出
key_value_input = torch.randn(batch_size, seq_len_encoder, embed_dim) # 编码器输出
y = model(query_input, key_value_input)
print(y.shape) # 输出:torch.Size([32, 8, 512])
2.3 使用 PyTorch 内置模块:nn.MultiheadAttention(用于 CrossAttention)
PyTorch 提供的 nn.MultiheadAttention 同样可以用于实现 多头 CrossAttention,只需将不同的 query, key, value 输入即可。
🔧 初始化 MultiheadAttention 模块
embed_dim = 512
num_heads = 8
multihead_attn = nn.MultiheadAttention(embed_dim=embed_dim, num_heads=num_heads)
⚠️ 注意输入格式:
query,key,value张量的形状必须是(T, B, D),即 序列长度优先
如果你的数据是 (B, T, D) 格式,需要先转置:
query = torch.randn(seq_len_decoder, batch_size, embed_dim)
key = value = torch.randn(seq_len_encoder, batch_size, embed_dim)
🔁 前向传播示例
# 输入张量 (T, B, D)
query = torch.randn(8, 32, 512) # 解码器输出
key = value = torch.randn(10, 32, 512) # 编码器输出
# 前向传播
attn_output, attn_weights = multihead_attn(query, key, value)
print("Output shape:", attn_output.shape) # (8, 32, 512)
print("Weights shape:", attn_weights.shape) # (32, 8, 10)
🧠 小贴士
attn_weights: 表示每个查询对所有键的关注程度,形状为(B, Tq, Tk)- 如果你想实现 CrossAttention,只需让
query != key == value - 可传入 mask 来屏蔽某些位置(如 padding)
📝 总结对比
| 特性 | 手写 CrossAttention | nn.MultiheadAttention |
|---|---|---|
| 是否支持多头 | ❌ 单头 | ✅ 支持多头 |
| 易用性 | 需要自己实现 | 直接调用 |
| 灵活性 | 更灵活 | 固定结构 |
| 性能优化 | 非最优 | 经过优化 |
| 学习价值 | ✅ 有助于理解原理 | ✅ 快速构建模型 |
✅ 第三章:SelfAttention与CrossAttention的区别
在深度学习领域,特别是自然语言处理(NLP)和计算机视觉(CV)中,Self-Attention 和 Cross-Attention 是两种重要的机制,用于捕捉输入数据中的依赖关系。它们都是注意力机制的一部分,但应用于不同的场景。
Self-Attention
Self-Attention,也称为内部注意力(intra-attention),是一种涉及单个序列不同位置的注意力机制,目的是计算序列本身的表示。它允许每个位置都能注意到整个序列的所有位置来计算其表示。这在捕捉句子或序列内的长距离依赖关系时特别有用。
- 应用场景:主要用于捕捉输入序列内部的关系,如文本分类、机器翻译等。
- 工作原理:对于输入序列中的每个词,self-attention 会计算该词与其他所有词之间的相关性得分,然后根据这些得分对其他词进行加权求和,作为该词的新表示。
- 优点:能够并行化处理,相比RNN类模型训练速度更快;可以很好地捕捉长距离依赖关系。
Cross-Attention
Cross-Attention 则是涉及到两个不同序列之间的注意力机制。它通常用于当一个序列需要基于另一个序列的信息来调整自身的表示时。例如,在图像字幕生成任务中,图像特征和文本描述就是两个不同的序列,cross-attention 可以帮助模型基于图像特征来更好地生成相应的文本描述。
- 应用场景:多模态学习(比如图文匹配)、问答系统等,其中一个序列需要参考另一个序列的信息来做出决策。
- 工作原理:给定两个序列,cross-attention 机制会计算一个序列中的元素相对于另一序列中元素的相关性得分,并据此进行信息融合。
- 优点:能够在不同类型的输入之间建立关联,非常适合跨模态的任务。
总结
- Self-Attention 主要用于捕捉单一输入内部的复杂依赖关系,适合于文本理解等任务。
- Cross-Attention 则是在两个或多个输入之间建立联系,适合于跨域或多模态任务。
两者都是现代深度学习架构中非常重要的组成部分,特别是在Transformer架构中扮演着核心角色。了解何时使用哪种类型的注意力机制取决于你的具体应用需求。
更多推荐

所有评论(0)