【大模型基础架构与技术】如何理解MHA、MQA和GQA
逻辑:1对1映射。每个 Query Head 都有其独占对应的 Key 和 Value Head。维度计算:Query:HHH个头。Key/Value:HHH个头。GHG = HGH。计算瓶颈:在推理(Inference)阶段,每一步解码都需要加载所有的 KV Cache,导致内存带宽开销极大,成为大模型推理的主要瓶颈。MHA 是 GQA 的特例(当GHG=HGH),无需广播,计算最慢,显存占用最
·
一、 核心逻辑与维度定义的差异
1. 标准多头注意力机制 (MHA)
- 逻辑:1对1映射。每个 Query Head 都有其独占对应的 Key 和 Value Head。
- 维度计算:
- Query: HHH 个头。
- Key/Value: HHH 个头。
- 映射关系:G=HG = HG=H。
- 计算瓶颈:在推理(Inference)阶段,每一步解码都需要加载所有的 KV Cache,导致内存带宽开销极大,成为大模型推理的主要瓶颈。
2. 多查询注意力机制 (MQA)
- 逻辑:多对1映射。所有的 Query Heads 共享同一个 Key Head 和 同一个 Value Head。
- 维度计算:
- Query: HHH 个头。
- Key/Value: 111 个头。
- 映射关系:G=1G = 1G=1。
- 实现特点:将 KV Cache 的大小压缩了 HHH 倍,显著降低了内存读取量。但在计算注意力分数时,需要将这唯一的 KV Head 进行广播(Broadcast)或复制,以匹配 Query 的数量。
3. 分组查询注意力机制 (GQA)
- 逻辑:多对多(分组)映射。这是 MHA 和 MQA 的插值方案。将 Query Heads 分成 GGG 组,每组内的 Query Heads 共享该组对应的 Key/Value Head。
- 维度计算:
- Query: HHH 个头。
- Key/Value: GGG 个头(1<G<H1 < G < H1<G<H)。
- 每组内的 Query 数量:H/GH / GH/G。
- 实现特点:KV Cache 的大小缩减了 H/GH/GH/G 倍。随着模型尺寸增大,GQA 允许按比例保留带宽和容量,避免了 MQA 过于激进的信息压缩。
二、 从 MHA 转换到 GQA/MQA 的计算步骤(Uptraining Recipe)
论文提出了一种将现有的 MHA 模型转换为 GQA/MQA 的具体计算方法,而非从头随机初始化。这在代码实现中通常涉及检查点转换(Checkpoint Conversion)。
关键算法:均值池化(Mean Pooling)
在代码实现中,不能简单地选择第一个 Head 或随机初始化,论文实验证明“均值池化”效果最佳。
- 输入:原始 MHA 的 Key 投影矩阵 KMHAK_{MHA}KMHA 和 Value 投影矩阵 VMHAV_{MHA}VMHA,形状通常为 [Hidden_Dim, H, Head_Dim]。
- GQA 转换步骤:
- 将 HHH 个 Heads 划分为 GGG 个组。
- 对于第 ggg 个组(g∈{1...G}g \in \{1...G\}g∈{1...G}),取出该组对应的原始 MHA Heads。
- 对这些 Heads 的权重矩阵进行求平均(Mean Pool)操作,生成一个新的 Head。
- 结果:得到新的 Key/Value 投影矩阵,形状为 [Hidden_Dim, G, Head_Dim]。
三、 代码实现层面的逻辑对比(伪代码视角)
在推理阶段的 forward 函数中,三种机制在处理 Attention(Q, K, V) 时的张量操作逻辑如下:
1. 维度准备
假设输入 Batch 为 BBB,序列长度为 SSS。
- MHA:
Q,K,V的形状均为[B, S, H, D]。 - MQA:
Q为[B, S, H, D];K,V为[B, S, 1, D]。 - GQA:
Q为[B, S, H, D];K,V为[B, S, G, D]。
2. 注意力计算流程 (GQA 的核心实现)
GQA 的实现关键在于将 KV Heads 广播(repeat/broadcast)到与 Query Heads 相同的数量,以便进行矩阵乘法。
# 假设我们有 Q, K, V
# Q: [Batch, Seq, H, Head_Dim]
# K, V: [Batch, Seq, G, Head_Dim] (对于 MHA, G=H; 对于 MQA, G=1)
def grouped_query_attention(Q, K, V, H, G):
# 1. 计算每组包含的 Query Heads 数量
group_size = H // G
if G == 1:
# --- MQA 逻辑 ---
# K, V 只有 1 个头,需要广播到 H 个头
# 实际代码中常用 torch.repeat_interleave 或 expand
K_expanded = K.repeat(1, 1, H, 1)
V_expanded = V.repeat(1, 1, H, 1)
elif G == H:
# --- MHA 逻辑 ---
# 是一一对应,无需广播
K_expanded = K
V_expanded = V
else: # 1 < G < H
# --- GQA 逻辑 ---
# K, V 有 G 个头,每个头需要广播 group_size 次
# 维度变换: [B, S, G, D] -> [B, S, G, group_size, D] -> [B, S, H, D]
# 步骤 A: 插入维度以匹配分组
K = K[:, :, :, None, :] # [B, S, G, 1, D]
V = V[:, :, :, None, :]
# 步骤 B: 在组内广播 (Replicate within the group)
K = K.expand(-1, -1, -1, group_size, -1) # [B, S, G, group_size, D]
V = V.expand(-1, -1, -1, group_size, -1)
# 步骤 C: 展平回 H 维度
K_expanded = K.reshape(B, S, H, D)
V_expanded = V.reshape(B, S, H, D)
# 2. 标准 Attention 计算 (所有变体最终都归结为这一步)
# Attention Score: Softmax(Q @ K_expanded.T / sqrt(D))
# Output: Score @ V_expanded
...
总结
- MHA 是 GQA 的特例(当 G=HG=HG=H),无需广播,计算最慢,显存占用最高。
- MQA 是 GQA 的特例(当 G=1G=1G=1),广播倍数最大(HHH倍),计算最快,但质量受损。
- GQA 通过引入分组参数 GGG,在代码实现上仅需增加一步“组内广播”(Expand per group),即可灵活调节推理速度与模型质量的平衡。
更多推荐


所有评论(0)