一、 核心逻辑与维度定义的差异

1. 标准多头注意力机制 (MHA)
  • 逻辑:1对1映射。每个 Query Head 都有其独占对应的 Key 和 Value Head。
  • 维度计算:
    • Query: HHH 个头。
    • Key/Value: HHH 个头。
    • 映射关系:G=HG = HG=H
  • 计算瓶颈:在推理(Inference)阶段,每一步解码都需要加载所有的 KV Cache,导致内存带宽开销极大,成为大模型推理的主要瓶颈。
2. 多查询注意力机制 (MQA)
  • 逻辑:多对1映射。所有的 Query Heads 共享同一个 Key Head 和 同一个 Value Head。
  • 维度计算:
    • Query: HHH 个头。
    • Key/Value: 111 个头。
    • 映射关系:G=1G = 1G=1
  • 实现特点:将 KV Cache 的大小压缩了 HHH 倍,显著降低了内存读取量。但在计算注意力分数时,需要将这唯一的 KV Head 进行广播(Broadcast)或复制,以匹配 Query 的数量。
3. 分组查询注意力机制 (GQA)
  • 逻辑:多对多(分组)映射。这是 MHA 和 MQA 的插值方案。将 Query Heads 分成 GGG 组,每组内的 Query Heads 共享该组对应的 Key/Value Head。
  • 维度计算:
    • Query: HHH 个头。
    • Key/Value: GGG 个头(1<G<H1 < G < H1<G<H)。
    • 每组内的 Query 数量:H/GH / GH/G
  • 实现特点:KV Cache 的大小缩减了 H/GH/GH/G 倍。随着模型尺寸增大,GQA 允许按比例保留带宽和容量,避免了 MQA 过于激进的信息压缩。

二、 从 MHA 转换到 GQA/MQA 的计算步骤(Uptraining Recipe)

论文提出了一种将现有的 MHA 模型转换为 GQA/MQA 的具体计算方法,而非从头随机初始化。这在代码实现中通常涉及检查点转换(Checkpoint Conversion)。

关键算法:均值池化(Mean Pooling)
在代码实现中,不能简单地选择第一个 Head 或随机初始化,论文实验证明“均值池化”效果最佳。

  • 输入:原始 MHA 的 Key 投影矩阵 KMHAK_{MHA}KMHA 和 Value 投影矩阵 VMHAV_{MHA}VMHA,形状通常为 [Hidden_Dim, H, Head_Dim]。
  • GQA 转换步骤:
    1. HHH 个 Heads 划分为 GGG 个组。
    2. 对于第 ggg 个组(g∈{1...G}g \in \{1...G\}g{1...G}),取出该组对应的原始 MHA Heads。
    3. 对这些 Heads 的权重矩阵进行求平均(Mean Pool)操作,生成一个新的 Head。
    4. 结果:得到新的 Key/Value 投影矩阵,形状为 [Hidden_Dim, G, Head_Dim]。

三、 代码实现层面的逻辑对比(伪代码视角)

在推理阶段的 forward 函数中,三种机制在处理 Attention(Q, K, V) 时的张量操作逻辑如下:

1. 维度准备

假设输入 Batch 为 BBB,序列长度为 SSS

  • MHA: Q, K, V 的形状均为 [B, S, H, D]
  • MQA: Q[B, S, H, D]K, V[B, S, 1, D]
  • GQA: Q[B, S, H, D]K, V[B, S, G, D]
2. 注意力计算流程 (GQA 的核心实现)

GQA 的实现关键在于将 KV Heads 广播(repeat/broadcast)到与 Query Heads 相同的数量,以便进行矩阵乘法。

# 假设我们有 Q, K, V
# Q: [Batch, Seq, H, Head_Dim]
# K, V: [Batch, Seq, G, Head_Dim] (对于 MHA, G=H; 对于 MQA, G=1)

def grouped_query_attention(Q, K, V, H, G):
    # 1. 计算每组包含的 Query Heads 数量
    group_size = H // G 
    
    if G == 1:
        # --- MQA 逻辑 ---
        # K, V 只有 1 个头,需要广播到 H 个头
        # 实际代码中常用 torch.repeat_interleave 或 expand
        K_expanded = K.repeat(1, 1, H, 1) 
        V_expanded = V.repeat(1, 1, H, 1)
        
    elif G == H:
        # --- MHA 逻辑 ---
        # 是一一对应,无需广播
        K_expanded = K
        V_expanded = V
        
    else: # 1 < G < H
        # --- GQA 逻辑 ---
        # K, V 有 G 个头,每个头需要广播 group_size 次
        # 维度变换: [B, S, G, D] -> [B, S, G, group_size, D] -> [B, S, H, D]
        
        # 步骤 A: 插入维度以匹配分组
        K = K[:, :, :, None, :]  # [B, S, G, 1, D]
        V = V[:, :, :, None, :]
        
        # 步骤 B: 在组内广播 (Replicate within the group)
        K = K.expand(-1, -1, -1, group_size, -1) # [B, S, G, group_size, D]
        V = V.expand(-1, -1, -1, group_size, -1)
        
        # 步骤 C: 展平回 H 维度
        K_expanded = K.reshape(B, S, H, D)
        V_expanded = V.reshape(B, S, H, D)

    # 2. 标准 Attention 计算 (所有变体最终都归结为这一步)
    # Attention Score: Softmax(Q @ K_expanded.T / sqrt(D))
    # Output: Score @ V_expanded
    ...

总结

  • MHA 是 GQA 的特例(当 G=HG=HG=H),无需广播,计算最慢,显存占用最高。
  • MQA 是 GQA 的特例(当 G=1G=1G=1),广播倍数最大(HHH倍),计算最快,但质量受损。
  • GQA 通过引入分组参数 GGG,在代码实现上仅需增加一步“组内广播”(Expand per group),即可灵活调节推理速度与模型质量的平衡。
Logo

有“AI”的1024 = 2048,欢迎大家加入2048 AI社区

更多推荐