【第四章:大模型(LLM)】09.最强开源大模型:Llama3 原理介绍与实现-(4)Grouped Multi-Query Attention
第四章:大模型(LLM)
第九部分:最强开源大模型:Llama3 原理介绍与实现
第四节:Grouped Multi-Query Attention
这一节我们来讲解 Grouped Multi-Query Attention(分组多查询注意力),这是 Llama3 在推理效率和内存优化上的一个重要设计。
1️⃣ 背景:注意力机制的性能瓶颈
在标准 Transformer 的 多头注意力(Multi-Head Attention, MHA) 中:
-
每个 head 都有独立的 Query (Q)、Key (K)、Value (V) 投影矩阵。
-
计算和存储开销很大,尤其在推理时需要缓存所有 Key/Value(KV Cache)。
问题:
-
KV Cache 需要存储所有 head 的 K 和 V,显存消耗 随 head 数量线性增长。
-
推理时,读取大量 KV Cache 数据,导致 内存带宽瓶颈。
2️⃣ Multi-Query Attention (MQA)
思路:减少 KV 的存储量。
-
标准 MHA:每个 head 有独立的 K、V。
-
MQA:所有 head 共享同一份 K 和 V,只有 Q 仍然独立。
公式(假设有 hh 个注意力头):
-
MHA:
-
MQA:
这样 KV Cache 的存储需求从 O(h·d) 降到 O(d),大幅减小内存消耗。
3️⃣ Grouped Multi-Query Attention (GQA)
Llama3 使用的折中方案:
-
不是所有 head 共用同一组 K/V(像 MQA 那样过于极端)。
-
而是把多个 Query head 分成 若干组(Group),每组共享一份 K/V。
例如:
-
总共有 32 个 Query head。
-
分成 8 个 group,每组 4 个 Query head。
-
每组的 head 共享同一份 K/V。
这样:
-
比标准 MHA 少很多 KV Cache(每组只存一份)。
-
又比 MQA 表现更好(因为不同组能学到不同的 K/V 表达)。
4️⃣ 数学形式
设:
-
总 head 数 = h
-
分组数 = g
-
每组包含
个 Query head
公式:
注意力计算时:
-
Query 仍然逐个 head 独立计算。
-
但它们只会使用对应 group 的 K/V。
其中 g(i) 表示 Query head i 所属的组。
5️⃣ PyTorch 伪代码示例
import torch
import torch.nn as nn
class GroupedMQA(nn.Module):
def __init__(self, dim, num_heads, num_groups):
super().__init__()
self.num_heads = num_heads
self.num_groups = num_groups
self.head_dim = dim // num_heads
self.group_dim = dim // num_groups
# Q 仍然独立
self.W_q = nn.Linear(dim, dim, bias=False)
# K/V 按组划分
self.W_k = nn.Linear(dim, self.group_dim, bias=False)
self.W_v = nn.Linear(dim, self.group_dim, bias=False)
def forward(self, x):
B, T, D = x.size()
Q = self.W_q(x).view(B, T, self.num_heads, self.head_dim)
K = self.W_k(x).view(B, T, self.num_groups, self.group_dim // self.num_groups)
V = self.W_v(x).view(B, T, self.num_groups, self.group_dim // self.num_groups)
# 将 head 映射到对应的 group
# 实际实现中需要更复杂的 index/broadcast
return Q, K, V
6️⃣ Llama3 中的应用
Llama3 采用 Grouped MQA,主要特点:
-
减少 KV Cache 占用
-
MHA:KV Cache ∝ head 数
-
GQA:KV Cache ∝ group 数(远小于 head 数)
-
-
加速推理
-
KV Cache 读取更少,内存带宽压力降低。
-
在长序列生成时(8k, 16k tokens),效果更显著。
-
-
平衡性能与精度
-
MQA 虽然更省,但可能损失模型表达能力。
-
GQA 在效率与效果之间做了折中,实验结果表明 Llama3 依然保持了很强的语言建模能力。
-
7️⃣ 总结
-
MHA:每个 head 独立的 K/V,精度高但推理慢、显存大。
-
MQA:所有 head 共享同一份 K/V,速度快但可能损失表达能力。
-
GQA:多个 head 组成 group,每组共享 K/V,Llama3 的选择。
效果:KV Cache 占用减少数倍,推理速度显著提升,性能损失极小。
更多推荐
所有评论(0)