注意力机制实现代码


class Attention(nn.Module):
    def __init__(self, args: LMConfig):
        super().__init__()
        # 1. 配置 kv 头和每个头的维度
        self.n_kv_heads = args.n_heads if args.n_kv_heads is None else args.n_kv_heads # KV的头数
        self.n_local_heads = args.n_heads # Q头数
        self.n_local_kv_heads = self.n_kv_heads # kv头数
        assert self.n_local_heads % self.n_local_kv_heads == 0 # Q头数必须是KV头数的整数倍
        self.n_rep = self.n_local_heads // self.n_local_kv_heads # 每个kv头重复次数(带几个Q头)
        self.head_dim = args.dim // args.n_heads # 每个头的维度

        # 2.定义线性投影层
        self.wq = nn.Linear(args.dim, args.n_heads * self.head_dim, bias=False) # Wq:(dim, n_heads * head_dim)
        self.wk = nn.Linear(args.dim, self.n_kv_heads * self.head_dim, bias=False) # Wk / Wv:(dim, kv_heads * head_dim)
        self.wv = nn.Linear(args.dim, self.n_kv_heads * self.head_dim, bias=False) # Wk / Wv:(dim, kv_heads * head_dim)
        self.wo = nn.Linear(args.n_heads * self.head_dim, args.dim, bias = False) # Wo:合并多头输出 (n_heads * head_dim, dim)

        # 3. Dropout 层
        self.attn_dropout = nn.Dropout(args.dropout) # 默认无元素丢弃
        self.resid_dropout = nn.Dropout(args.dropout)
        self.dropout = args.dropout

        # 4. 是否使用 FlashAttention (高效注意力)
        self.flash = hasattr(F, "scaled_dot_product_attention") and args.flash_attn
        
        # 5. 生成因果掩码矩阵
        mask = torch.full((1, 1, args.max_seq_len, args.max_seq_len), float("-inf"))
        mask = torch.triu(mask, diagonal=1) # 上三角进行掩码
        self.register_buffer("mask", mask, persistent = False) # 将掩码矩阵注册到缓存中,不参与训练
        
    def forward(self, x: torch.Tensor, 
                pos_cis: torch.Tensor, 
                past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
                use_cache = False
                ):
        bsz, seq_len, _ = x.shape
        xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
        xq = xq.view(bsz, seq_len, self.n_local_heads, self.head_dim)
        xk = xk.view(bsz, seq_len, self.n_kv_heads, self.head_dim)
        xq, xk = apply_rotary_emb(xq, xk, pos_cis)
        
        if past_key_value is not None:
            xk = torch.cat([past_key_value[0],xk], dim=-1)
            xv = torch.cat([past_key_value[1],xv], dim=-1)
        past_kv = (xk, xv) if use_cache else None
        
        xq, xk, xv = (
            xq.transpose(1 ,2), # (batch, seq, heads, dim)->(batch, heads, seq, dim)
            repeat_kv(xk, self.n_rep).transpose(1,2),
            repeat_kv(xv, self.n_rep).transpose(1,2)            
        )
        if self.flash and seq_len != 1:
            dropout_p = self.dropout if self.training else 0.0
            F.scaled_dot_product_attention(
                xq, xk, xv,
                attn_mask=None,
                dropout_p=dropout_p,
                is_causal=True,
            )
        else:
            scores = (xq @ xk.transpose(-2, -1)) / math.sqrt(self.head_dim)
            scores += self.mask[:, :, :seq_len, :seq_len]
            scores = F.softmax(scores.float(), dim = -1).type_as(xq)
            output = scores @ xv
        
        output = output.transpose(1,2).reshape(bsz, seq_len, -1)
        output = self.resid_dropout(self.wo(output))
        return output, past_kv
        
     

📌 代码解释

Attention 模块实现了:

  • Grouped-Query Attention (GQA):减少 KV 头数以节省显存
  • Rotary Position Embedding (RoPE):注入相对位置信息
  • KV Cache:支持自回归生成时的高效推理
  • FlashAttention 回退机制:训练/推理时自动选择最高效路径
  • 因果掩码(Causal Masking):确保自回归性质(不能看未来)

🧱 1. 初始化 __init__

def __init__(self, args: LMConfig):
    super().__init__()

(1) 头数配置

self.n_kv_heads = args.n_heads if args.n_kv_heads is None else args.n_kv_heads
  • 若未指定 n_kv_heads,默认等于 n_heads(即标准 Multi-Head Attention)
  • 否则使用指定值(如 LLaMA-2 7B 中 n_heads=32, n_kv_heads=8
self.n_local_heads = args.n_heads
self.n_local_kv_heads = self.n_kv_heads
  • 为清晰起见,定义本地变量(分布式训练中可能不同,此处单机等价)
assert self.n_local_heads % self.n_local_kv_heads == 0
  • GQA 要求:Q 头数必须是 KV 头数的整数倍,以便分组共享
self.n_rep = self.n_local_heads // self.n_local_kv_heads
  • 每个 KV 头被多少个 Q 头共享(如 32 / 8 = 4)
self.head_dim = args.dim // args.n_heads
  • 每个注意力头的维度(如 dim=4096, n_heads=32head_dim=128

✅ 注意:即使使用 GQA,head_dim 仍由 Q 头数 决定,K/V 使用相同 head_dim


(2) 线性投影层

self.wq = nn.Linear(args.dim, args.n_heads * self.head_dim, bias=False)
  • 投影到 Q 空间:[dim] → [n_heads × head_dim]
  • 无偏置(bias=False):主流大模型惯例,减少参数
self.wk = nn.Linear(args.dim, self.n_kv_heads * self.head_dim, bias=False)
self.wv = nn.Linear(args.dim, self.n_kv_heads * self.head_dim, bias=False)
  • 投影到 KV 空间:维度为 n_kv_heads × head_dim可能小于 Q 的维度
self.wo = nn.Linear(args.n_heads * self.head_dim, args.dim, bias=False)
  • 多头融合投影:将拼接后的多头输出映射回原始维度
  • 输入:n_heads × head_dim = dim,输出:dim,但是可学习的线性变换

(3) Dropout 配置

self.attn_dropout = nn.Dropout(args.dropout)
self.resid_dropout = nn.Dropout(args.dropout)
self.dropout = args.dropout
  • attn_dropout:作用于注意力权重(softmax 后或前,取决于路径)
  • resid_dropout:作用于最终输出(残差连接前)
  • 两者目的均为正则化,防止过拟合

(4) FlashAttention 支持检测

self.flash = hasattr(F, "scaled_dot_product_attention") and args.flash_attn
  • F = torch.nn.functional
  • 检查 PyTorch ≥ 2.0 且用户启用 flash_attn
  • 若满足,则优先使用高效内核

(5) 因果掩码(Causal Mask)

mask = torch.full((1, 1, args.max_seq_len, args.max_seq_len), float("-inf"))
mask = torch.triu(mask, diagonal=1)
self.register_buffer("mask", mask, persistent=False)
  • 创建 [1, 1, L, L]-inf 张量
  • triu(diagonal=1):保留上三角(不含主对角线)为 -inf,其余为 0
  • register_buffer(..., persistent=False)
    • 自动随模型移动设备(GPU/CPU)
    • 不保存到 state_dict(因可重建,节省磁盘)

✅ 示例(L=4):

[[0, -inf, -inf, -inf],
 [0,   0,  -inf, -inf],
 [0,   0,    0,  -inf],
 [0,   0,    0,    0]]

🔄 2. 前向传播 forward

def forward(self, x: torch.Tensor, 
            pos_cis: torch.Tensor, 
            past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
            use_cache = False):

(1) 输入解析

bsz, seq_len, _ = x.shape
  • x: [B, S, dim],如 [2, 512, 4096]

(2) 线性投影 + 多头拆分

xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
  • xq: [B, S, n_heads × head_dim]
  • xk/xv: [B, S, n_kv_heads × head_dim]
xq = xq.view(bsz, seq_len, self.n_local_heads, self.head_dim)
xk = xk.view(bsz, seq_len, self.n_kv_heads, self.head_dim)
xv = xv.view(bsz, seq_len, self.n_kv_heads, self.head_dim)
  • 拆分为多头格式:[B, S, H, D]

✅ 注意:xv 也应被 view,原文漏写(应补充)


(3) 应用 RoPE(旋转位置编码)

xq, xk = apply_rotary_emb(xq, xk, pos_cis)
  • apply_rotary_emb 是外部函数,对 Q/K 应用复数旋转
  • 保留相对位置信息,优于绝对位置编码
  • 不改变张量形状

(4) KV Cache 处理(推理加速)

if past_key_value is not None:
    xk = torch.cat([past_key_value[0], xk], dim=1)
    xv = torch.cat([past_key_value[1], xv], dim=1)
  • past_key_value[0]: [B, past_len, n_kv_heads, head_dim]
  • xk: [B, current_len, n_kv_heads, head_dim]
  • 拼接后:[B, total_len, n_kv_heads, head_dim]
  • ⚠️ 注意:应为 dim=1(序列维度),原文 dim=-1 错误!

❌ 原文 bug:

xk = torch.cat([past_key_value[0], xk], dim=-1)  # 错!

应改为:

xk = torch.cat([past_key_value[0], xk], dim=1)   # 正确!
xv = torch.cat([past_key_value[1], xv], dim=1)
past_kv = (xk, xv) if use_cache else None
  • use_cache=True,返回更新后的 KV 缓存供下一次使用

(5) 调整维度 + GQA 扩展

xq, xk, xv = (
    xq.transpose(1, 2),  # [B, S, H_q, D] → [B, H_q, S, D]
    repeat_kv(xk, self.n_rep).transpose(1, 2),  # [B, S, H_kv, D] → [B, H_q, S, D]
    repeat_kv(xv, self.n_rep).transpose(1, 2)
)
  • transpose(1, 2):调整为标准注意力输入格式 [B, H, S, D]
  • repeat_kv(xk, n_rep):将 KV 头从 H_kv 扩展到 H_q(每个 KV 头复制 n_rep 次)

repeat_kv 典型实现:

def repeat_kv(x, n_rep):
    if n_rep == 1: return x
    B, S, H, D = x.shape
    x = x.unsqueeze(2).expand(B, S, n_rep, H, D)
    return x.reshape(B, S, H * n_rep, D)

(6) 注意力计算(双路径)

🔸 路径 A:FlashAttention(高效)
if self.flash and seq_len != 1:
    dropout_p = self.dropout if self.training else 0.0
    output = F.scaled_dot_product_attention(
        xq, xk, xv,
        attn_mask=None,
        dropout_p=dropout_p,
        is_causal=True,
    )
  • is_causal=True自动处理因果掩码,无需传 mask
  • dropout_p:训练时启用,推理时为 0
  • ⚠️ seq_len != 1:可能为兼容性限制(实际可移除)
🔸 路径 B:标准注意力(显式计算)
else:
    scores = (xq @ xk.transpose(-2, -1)) / math.sqrt(self.head_dim)
    scores += self.mask[:, :, :seq_len, :seq_len]
    scores = F.softmax(scores.float(), dim=-1).type_as(xq)
    scores = self.attn_dropout(scores)  # 原文漏写!应加此行
    output = scores @ xv
  • 计算 QK^T / √d
  • 加因果掩码(适配当前序列长度)
  • Softmax + Dropout + 加权求和 V
  • ⚠️ 原文漏掉 attn_dropout,应补充

(7) 输出整合

output = output.transpose(1, 2).reshape(bsz, seq_len, -1)
  • output: [B, H, S, D]transpose[B, S, H, D]reshape[B, S, dim]
output = self.resid_dropout(self.wo(output))
  • wo:融合多头信息(可学习)
  • resid_dropout:残差连接前正则化

(8) 返回

return output, past_kv
  • output: [B, S, dim]
  • past_kv: 更新后的 KV 缓存(若 use_cache=True

⚠️ 原文中的潜在 Bug 总结

位置 问题 修正
KV Cache 拼接 dim=-1 应为 dim=1
标准注意力路径 缺少 attn_dropout 应加 scores = self.attn_dropout(scores)
xv 拆分 view 应加 xv = xv.view(...)

✅ 维度变化全流程示例(B=2, S=4, dim=4096, H=32, H_kv=8, D=128)

步骤 张量 shape
输入 x x [2, 4, 4096]
wq/wk/wv xq/xk/xv [2,4,4096], [2,4,1024], [2,4,1024]
view xq/xk/xv [2,4,32,128], [2,4,8,128], [2,4,8,128]
RoPE shape 不变
KV Cache (若有) xk/xv [2, total_len, 8, 128]
repeat_kv + transpose xq/xk/xv [2,32,4,128], [2,32,total_len,128], …
Attention 输出 output [2,32,4,128]
整合 output [2,4,4096]

🏁 总结

这段代码是一个高度优化、功能完整的现代注意力实现,融合了:

  • GQA(内存效率)
  • RoPE(位置感知)
  • KV Cache(推理加速)
  • FlashAttention(计算效率)
  • 因果掩码(自回归约束)

尽管存在几处小 bug(已在上文指出),但整体结构清晰、符合工业级大模型设计范式,是理解 LLaMA、Gemma 等模型注意力机制的优秀参考。

Logo

有“AI”的1024 = 2048,欢迎大家加入2048 AI社区

更多推荐