第四章:大模型(LLM)

第九部分:最强开源大模型:Llama3 原理介绍与实现

第三节:加速推理:KV Cache


1️⃣ 背景:为什么需要 KV Cache?

在 Transformer 的推理阶段(inference),尤其是 自回归生成(autoregressive generation) 中,每一步都会依赖之前的全部 token。
例如生成到第 t 个 token 时,需要重新计算从 1 到 tt 的所有注意力结果。

这样会导致:

  • 重复计算:之前的 Key/Value 每次都重新算一遍。

  • 推理速度慢:序列越长,计算越大。

  • 显存占用大:需要保存越来越多的中间状态。

为了解决这个问题,Transformer 引入了 KV Cache(Key-Value Cache) 技术。


2️⃣ KV Cache 的基本原理

在自注意力(Self-Attention)机制中,每个 token 会被映射为三组向量:

Q = XW^Q, \quad K = XW^K, \quad V = XW^V

注意力计算公式为:

\text{Attention}(Q_t, K_{1:t}, V_{1:t}) = \text{Softmax}\left(\frac{Q_t K_{1:t}^\top}{\sqrt{d_k}}\right)V_{1:t}

  • Q_t:当前 token 的 query

  • K_{1:t}, V_{1:t}:所有已生成 token 的 key 和 value

在推理时:

  • 第一次生成时,计算并存储所有 token 的 K 和 V。

  • 之后生成新 token 时,只需计算该 token 的 Q_t, K_t, V_t,并把 K_t, V_t 追加到缓存(KV Cache) 中。

  • 不再重复计算之前的 Key/Value。

图示过程
  • 没有 KV Cache:每一步都重新计算 K_{1:t}, V_{1:t}

  • 有 KV Cache:只计算新增的 K_t, V_t,再和缓存拼接。

这样,推理的复杂度由 O(T²) 降为 O(T)(T 是序列长度)。


3️⃣ 数学推导

假设序列长度为 T,隐藏维度为 d。

没有 KV Cache:

每次生成第 t 个 token:

  • 需要重新计算前 t 个 Key/Value

  • 总复杂度:

\sum_{t=1}^T O(td) = O(T^2 d)

使用 KV Cache:
  • 第 1 步:计算 K_1, V_1,存入缓存

  • 第 2 步:只计算 K_2, V_2,拼接到缓存

  • ...

  • 第 t 步:只需计算当前 K_t, V_t,再与缓存结合

总复杂度:

O(Td)

推理效率大幅提升。


4️⃣ PyTorch 伪代码示例

import torch

class KVCache:
    def __init__(self, max_len, num_heads, head_dim, device):
        self.K = torch.zeros(max_len, num_heads, head_dim, device=device)
        self.V = torch.zeros(max_len, num_heads, head_dim, device=device)
        self.cur_len = 0

    def update(self, k_new, v_new):
        # 将新 token 的 KV 存入缓存
        self.K[self.cur_len] = k_new
        self.V[self.cur_len] = v_new
        self.cur_len += 1
        return self.K[:self.cur_len], self.V[:self.cur_len]

# 使用示例
cache = KVCache(max_len=2048, num_heads=32, head_dim=128, device="cuda")

for t in range(seq_len):
    q, k, v = project(x_t)  # 当前 token 的 Q,K,V
    K, V = cache.update(k, v)  # 拼接到缓存
    output = attention(q, K, V)  # 用缓存加速注意力

5️⃣ Llama3 中的优化技巧

Llama3 在 KV Cache 机制上做了一些改进以进一步加速:

  1. 高效缓存布局

    • 把 KV Cache 存成连续内存块,减少随机访问开销。

    • 利用张量并行和流水线并行来分布式存储 KV。

  2. 动态长度支持

    • 支持不同 batch 内 token 长度不一致。

    • 使用 padding mask 管理缓存。

  3. 分块 KV Cache

    • 将缓存切分成多个 block,以支持长上下文(如 Llama3 支持 8k、16k 上下文)。

    • 可以随时丢弃部分 block,实现滑动窗口注意力

  4. 显存优化

    • 半精度存储(FP16/BF16)减少显存开销。

    • 部分场景下可以采用量化 KV Cache(INT8/INT4)。


6️⃣ 总结

  • KV Cache 本质:缓存前序 token 的 Key/Value,避免重复计算。

  • 效果:推理复杂度由  O(T^2)\rightarrow O(T),加速明显。

  • 在 Llama3 中:配合高效缓存布局、分块机制、长上下文支持,使得大模型在实际推理中可行。

Logo

有“AI”的1024 = 2048,欢迎大家加入2048 AI社区

更多推荐