大模型之注意力机制实现
位置问题修正KV Cache 拼接dim=-1应为dim=1标准注意力路径缺少应加xv拆分未view应加这段代码是一个高度优化、功能完整GQA(内存效率)RoPE(位置感知)KV Cache(推理加速)(计算效率)因果掩码(自回归约束)尽管存在几处小 bug(已在上文指出),但整体结构清晰、符合工业级大模型设计范式,是理解 LLaMA、Gemma 等模型注意力机制的优秀参考。
·
注意力机制实现代码
class Attention(nn.Module):
def __init__(self, args: LMConfig):
super().__init__()
# 1. 配置 kv 头和每个头的维度
self.n_kv_heads = args.n_heads if args.n_kv_heads is None else args.n_kv_heads # KV的头数
self.n_local_heads = args.n_heads # Q头数
self.n_local_kv_heads = self.n_kv_heads # kv头数
assert self.n_local_heads % self.n_local_kv_heads == 0 # Q头数必须是KV头数的整数倍
self.n_rep = self.n_local_heads // self.n_local_kv_heads # 每个kv头重复次数(带几个Q头)
self.head_dim = args.dim // args.n_heads # 每个头的维度
# 2.定义线性投影层
self.wq = nn.Linear(args.dim, args.n_heads * self.head_dim, bias=False) # Wq:(dim, n_heads * head_dim)
self.wk = nn.Linear(args.dim, self.n_kv_heads * self.head_dim, bias=False) # Wk / Wv:(dim, kv_heads * head_dim)
self.wv = nn.Linear(args.dim, self.n_kv_heads * self.head_dim, bias=False) # Wk / Wv:(dim, kv_heads * head_dim)
self.wo = nn.Linear(args.n_heads * self.head_dim, args.dim, bias = False) # Wo:合并多头输出 (n_heads * head_dim, dim)
# 3. Dropout 层
self.attn_dropout = nn.Dropout(args.dropout) # 默认无元素丢弃
self.resid_dropout = nn.Dropout(args.dropout)
self.dropout = args.dropout
# 4. 是否使用 FlashAttention (高效注意力)
self.flash = hasattr(F, "scaled_dot_product_attention") and args.flash_attn
# 5. 生成因果掩码矩阵
mask = torch.full((1, 1, args.max_seq_len, args.max_seq_len), float("-inf"))
mask = torch.triu(mask, diagonal=1) # 上三角进行掩码
self.register_buffer("mask", mask, persistent = False) # 将掩码矩阵注册到缓存中,不参与训练
def forward(self, x: torch.Tensor,
pos_cis: torch.Tensor,
past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
use_cache = False
):
bsz, seq_len, _ = x.shape
xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
xq = xq.view(bsz, seq_len, self.n_local_heads, self.head_dim)
xk = xk.view(bsz, seq_len, self.n_kv_heads, self.head_dim)
xq, xk = apply_rotary_emb(xq, xk, pos_cis)
if past_key_value is not None:
xk = torch.cat([past_key_value[0],xk], dim=-1)
xv = torch.cat([past_key_value[1],xv], dim=-1)
past_kv = (xk, xv) if use_cache else None
xq, xk, xv = (
xq.transpose(1 ,2), # (batch, seq, heads, dim)->(batch, heads, seq, dim)
repeat_kv(xk, self.n_rep).transpose(1,2),
repeat_kv(xv, self.n_rep).transpose(1,2)
)
if self.flash and seq_len != 1:
dropout_p = self.dropout if self.training else 0.0
F.scaled_dot_product_attention(
xq, xk, xv,
attn_mask=None,
dropout_p=dropout_p,
is_causal=True,
)
else:
scores = (xq @ xk.transpose(-2, -1)) / math.sqrt(self.head_dim)
scores += self.mask[:, :, :seq_len, :seq_len]
scores = F.softmax(scores.float(), dim = -1).type_as(xq)
output = scores @ xv
output = output.transpose(1,2).reshape(bsz, seq_len, -1)
output = self.resid_dropout(self.wo(output))
return output, past_kv
📌 代码解释
该 Attention 模块实现了:
- Grouped-Query Attention (GQA):减少 KV 头数以节省显存
- Rotary Position Embedding (RoPE):注入相对位置信息
- KV Cache:支持自回归生成时的高效推理
- FlashAttention 回退机制:训练/推理时自动选择最高效路径
- 因果掩码(Causal Masking):确保自回归性质(不能看未来)
🧱 1. 初始化 __init__
def __init__(self, args: LMConfig):
super().__init__()
(1) 头数配置
self.n_kv_heads = args.n_heads if args.n_kv_heads is None else args.n_kv_heads
- 若未指定
n_kv_heads,默认等于n_heads(即标准 Multi-Head Attention) - 否则使用指定值(如 LLaMA-2 7B 中
n_heads=32,n_kv_heads=8)
self.n_local_heads = args.n_heads
self.n_local_kv_heads = self.n_kv_heads
- 为清晰起见,定义本地变量(分布式训练中可能不同,此处单机等价)
assert self.n_local_heads % self.n_local_kv_heads == 0
- GQA 要求:Q 头数必须是 KV 头数的整数倍,以便分组共享
self.n_rep = self.n_local_heads // self.n_local_kv_heads
- 每个 KV 头被多少个 Q 头共享(如 32 / 8 = 4)
self.head_dim = args.dim // args.n_heads
- 每个注意力头的维度(如
dim=4096,n_heads=32→head_dim=128)
✅ 注意:即使使用 GQA,
head_dim仍由 Q 头数 决定,K/V 使用相同head_dim
(2) 线性投影层
self.wq = nn.Linear(args.dim, args.n_heads * self.head_dim, bias=False)
- 投影到 Q 空间:
[dim] → [n_heads × head_dim] - 无偏置(bias=False):主流大模型惯例,减少参数
self.wk = nn.Linear(args.dim, self.n_kv_heads * self.head_dim, bias=False)
self.wv = nn.Linear(args.dim, self.n_kv_heads * self.head_dim, bias=False)
- 投影到 KV 空间:维度为
n_kv_heads × head_dim(可能小于 Q 的维度)
self.wo = nn.Linear(args.n_heads * self.head_dim, args.dim, bias=False)
- 多头融合投影:将拼接后的多头输出映射回原始维度
- 输入:
n_heads × head_dim = dim,输出:dim,但是可学习的线性变换
(3) Dropout 配置
self.attn_dropout = nn.Dropout(args.dropout)
self.resid_dropout = nn.Dropout(args.dropout)
self.dropout = args.dropout
attn_dropout:作用于注意力权重(softmax 后或前,取决于路径)resid_dropout:作用于最终输出(残差连接前)- 两者目的均为正则化,防止过拟合
(4) FlashAttention 支持检测
self.flash = hasattr(F, "scaled_dot_product_attention") and args.flash_attn
F = torch.nn.functional- 检查 PyTorch ≥ 2.0 且用户启用
flash_attn - 若满足,则优先使用高效内核
(5) 因果掩码(Causal Mask)
mask = torch.full((1, 1, args.max_seq_len, args.max_seq_len), float("-inf"))
mask = torch.triu(mask, diagonal=1)
self.register_buffer("mask", mask, persistent=False)
- 创建
[1, 1, L, L]全-inf张量 triu(diagonal=1):保留上三角(不含主对角线)为-inf,其余为 0register_buffer(..., persistent=False):- 自动随模型移动设备(GPU/CPU)
- 不保存到 state_dict(因可重建,节省磁盘)
✅ 示例(L=4):
[[0, -inf, -inf, -inf], [0, 0, -inf, -inf], [0, 0, 0, -inf], [0, 0, 0, 0]]
🔄 2. 前向传播 forward
def forward(self, x: torch.Tensor,
pos_cis: torch.Tensor,
past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
use_cache = False):
(1) 输入解析
bsz, seq_len, _ = x.shape
x:[B, S, dim],如[2, 512, 4096]
(2) 线性投影 + 多头拆分
xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
xq:[B, S, n_heads × head_dim]xk/xv:[B, S, n_kv_heads × head_dim]
xq = xq.view(bsz, seq_len, self.n_local_heads, self.head_dim)
xk = xk.view(bsz, seq_len, self.n_kv_heads, self.head_dim)
xv = xv.view(bsz, seq_len, self.n_kv_heads, self.head_dim)
- 拆分为多头格式:
[B, S, H, D]
✅ 注意:
xv也应被view,原文漏写(应补充)
(3) 应用 RoPE(旋转位置编码)
xq, xk = apply_rotary_emb(xq, xk, pos_cis)
apply_rotary_emb是外部函数,对 Q/K 应用复数旋转- 保留相对位置信息,优于绝对位置编码
- 不改变张量形状
(4) KV Cache 处理(推理加速)
if past_key_value is not None:
xk = torch.cat([past_key_value[0], xk], dim=1)
xv = torch.cat([past_key_value[1], xv], dim=1)
past_key_value[0]:[B, past_len, n_kv_heads, head_dim]xk:[B, current_len, n_kv_heads, head_dim]- 拼接后:
[B, total_len, n_kv_heads, head_dim] - ⚠️ 注意:应为
dim=1(序列维度),原文dim=-1错误!
❌ 原文 bug:
xk = torch.cat([past_key_value[0], xk], dim=-1) # 错!应改为:
xk = torch.cat([past_key_value[0], xk], dim=1) # 正确! xv = torch.cat([past_key_value[1], xv], dim=1)
past_kv = (xk, xv) if use_cache else None
- 若
use_cache=True,返回更新后的 KV 缓存供下一次使用
(5) 调整维度 + GQA 扩展
xq, xk, xv = (
xq.transpose(1, 2), # [B, S, H_q, D] → [B, H_q, S, D]
repeat_kv(xk, self.n_rep).transpose(1, 2), # [B, S, H_kv, D] → [B, H_q, S, D]
repeat_kv(xv, self.n_rep).transpose(1, 2)
)
transpose(1, 2):调整为标准注意力输入格式[B, H, S, D]repeat_kv(xk, n_rep):将 KV 头从H_kv扩展到H_q(每个 KV 头复制n_rep次)
✅
repeat_kv典型实现:def repeat_kv(x, n_rep): if n_rep == 1: return x B, S, H, D = x.shape x = x.unsqueeze(2).expand(B, S, n_rep, H, D) return x.reshape(B, S, H * n_rep, D)
(6) 注意力计算(双路径)
🔸 路径 A:FlashAttention(高效)
if self.flash and seq_len != 1:
dropout_p = self.dropout if self.training else 0.0
output = F.scaled_dot_product_attention(
xq, xk, xv,
attn_mask=None,
dropout_p=dropout_p,
is_causal=True,
)
is_causal=True:自动处理因果掩码,无需传 maskdropout_p:训练时启用,推理时为 0- ⚠️
seq_len != 1:可能为兼容性限制(实际可移除)
🔸 路径 B:标准注意力(显式计算)
else:
scores = (xq @ xk.transpose(-2, -1)) / math.sqrt(self.head_dim)
scores += self.mask[:, :, :seq_len, :seq_len]
scores = F.softmax(scores.float(), dim=-1).type_as(xq)
scores = self.attn_dropout(scores) # 原文漏写!应加此行
output = scores @ xv
- 计算
QK^T / √d - 加因果掩码(适配当前序列长度)
- Softmax + Dropout + 加权求和 V
- ⚠️ 原文漏掉
attn_dropout,应补充
(7) 输出整合
output = output.transpose(1, 2).reshape(bsz, seq_len, -1)
output:[B, H, S, D]→transpose→[B, S, H, D]→reshape→[B, S, dim]
output = self.resid_dropout(self.wo(output))
wo:融合多头信息(可学习)resid_dropout:残差连接前正则化
(8) 返回
return output, past_kv
output:[B, S, dim]past_kv: 更新后的 KV 缓存(若use_cache=True)
⚠️ 原文中的潜在 Bug 总结
| 位置 | 问题 | 修正 |
|---|---|---|
| KV Cache 拼接 | dim=-1 |
应为 dim=1 |
| 标准注意力路径 | 缺少 attn_dropout |
应加 scores = self.attn_dropout(scores) |
xv 拆分 |
未 view |
应加 xv = xv.view(...) |
✅ 维度变化全流程示例(B=2, S=4, dim=4096, H=32, H_kv=8, D=128)
| 步骤 | 张量 | shape |
|---|---|---|
输入 x |
x |
[2, 4, 4096] |
wq/wk/wv |
xq/xk/xv |
[2,4,4096], [2,4,1024], [2,4,1024] |
view |
xq/xk/xv |
[2,4,32,128], [2,4,8,128], [2,4,8,128] |
| RoPE | — | shape 不变 |
| KV Cache (若有) | xk/xv |
[2, total_len, 8, 128] |
repeat_kv + transpose |
xq/xk/xv |
[2,32,4,128], [2,32,total_len,128], … |
| Attention 输出 | output |
[2,32,4,128] |
| 整合 | output |
[2,4,4096] |
🏁 总结
这段代码是一个高度优化、功能完整的现代注意力实现,融合了:
- GQA(内存效率)
- RoPE(位置感知)
- KV Cache(推理加速)
- FlashAttention(计算效率)
- 因果掩码(自回归约束)
尽管存在几处小 bug(已在上文指出),但整体结构清晰、符合工业级大模型设计范式,是理解 LLaMA、Gemma 等模型注意力机制的优秀参考。
更多推荐


所有评论(0)