第四章:大模型(LLM)

第九部分:最强开源大模型:Llama3 原理介绍与实现

第四节:Grouped Multi-Query Attention

这一节我们来讲解 Grouped Multi-Query Attention(分组多查询注意力),这是 Llama3 在推理效率和内存优化上的一个重要设计。


1️⃣ 背景:注意力机制的性能瓶颈

在标准 Transformer 的 多头注意力(Multi-Head Attention, MHA) 中:

  • 每个 head 都有独立的 Query (Q)、Key (K)、Value (V) 投影矩阵。

  • 计算和存储开销很大,尤其在推理时需要缓存所有 Key/Value(KV Cache)。

问题:
  • KV Cache 需要存储所有 head 的 K 和 V,显存消耗 随 head 数量线性增长

  • 推理时,读取大量 KV Cache 数据,导致 内存带宽瓶颈


2️⃣ Multi-Query Attention (MQA)

思路:减少 KV 的存储量。

  • 标准 MHA:每个 head 有独立的 K、V。

  • MQA:所有 head 共享同一份 K 和 V,只有 Q 仍然独立。

公式(假设有 hh 个注意力头):

  • MHA:

Q_i = XW_i^Q, \quad K_i = XW_i^K, \quad V_i = XW_i^V

  • MQA:

Q_i = XW_i^Q, \quad K = XW^K, \quad V = XW^V

这样 KV Cache 的存储需求从 O(h·d) 降到 O(d),大幅减小内存消耗。


3️⃣ Grouped Multi-Query Attention (GQA)

Llama3 使用的折中方案

  • 不是所有 head 共用同一组 K/V(像 MQA 那样过于极端)。

  • 而是把多个 Query head 分成 若干组(Group),每组共享一份 K/V。

例如:

  • 总共有 32 个 Query head。

  • 分成 8 个 group,每组 4 个 Query head。

  • 每组的 head 共享同一份 K/V。

这样:

  • 比标准 MHA 少很多 KV Cache(每组只存一份)。

  • 又比 MQA 表现更好(因为不同组能学到不同的 K/V 表达)。


4️⃣ 数学形式

设:

  • 总 head 数 = h

  • 分组数 = g

  • 每组包含 \frac{h}{g}个 Query head

公式:

Q_i = XW_i^Q, \quad K_j = XW_j^K, \quad V_j = XW_j^V \quad \text{for group } j = 1 \dots g

注意力计算时:

  • Query 仍然逐个 head 独立计算。

  • 但它们只会使用对应 group 的 K/V。

\text{Attention}(Q_i, K_{g(i)}, V_{g(i)})

其中 g(i) 表示 Query head i 所属的组。


5️⃣ PyTorch 伪代码示例

import torch
import torch.nn as nn

class GroupedMQA(nn.Module):
    def __init__(self, dim, num_heads, num_groups):
        super().__init__()
        self.num_heads = num_heads
        self.num_groups = num_groups
        self.head_dim = dim // num_heads
        self.group_dim = dim // num_groups

        # Q 仍然独立
        self.W_q = nn.Linear(dim, dim, bias=False)
        # K/V 按组划分
        self.W_k = nn.Linear(dim, self.group_dim, bias=False)
        self.W_v = nn.Linear(dim, self.group_dim, bias=False)

    def forward(self, x):
        B, T, D = x.size()

        Q = self.W_q(x).view(B, T, self.num_heads, self.head_dim)
        K = self.W_k(x).view(B, T, self.num_groups, self.group_dim // self.num_groups)
        V = self.W_v(x).view(B, T, self.num_groups, self.group_dim // self.num_groups)

        # 将 head 映射到对应的 group
        # 实际实现中需要更复杂的 index/broadcast
        return Q, K, V

6️⃣ Llama3 中的应用

Llama3 采用 Grouped MQA,主要特点:

  1. 减少 KV Cache 占用

    • MHA:KV Cache ∝ head 数

    • GQA:KV Cache ∝ group 数(远小于 head 数)

  2. 加速推理

    • KV Cache 读取更少,内存带宽压力降低。

    • 在长序列生成时(8k, 16k tokens),效果更显著。

  3. 平衡性能与精度

    • MQA 虽然更省,但可能损失模型表达能力。

    • GQA 在效率与效果之间做了折中,实验结果表明 Llama3 依然保持了很强的语言建模能力。


7️⃣ 总结

  • MHA:每个 head 独立的 K/V,精度高但推理慢、显存大。

  • MQA:所有 head 共享同一份 K/V,速度快但可能损失表达能力。

  • GQA:多个 head 组成 group,每组共享 K/V,Llama3 的选择。

效果:KV Cache 占用减少数倍,推理速度显著提升,性能损失极小。

Logo

有“AI”的1024 = 2048,欢迎大家加入2048 AI社区

更多推荐