一、什么是 KV Cache?

KV Cache(Key-Value 缓存)是 Transformer 模型在自回归推理过程中,为了避免重复计算而存储的中间状态。它是提高大模型推理速度的关键技术。

核心概念

  1. KV 指的是 Transformer 注意力机制中的 Key 和 Value 向量

  2. Cache 是指将这些向量缓存起来,供后续 token 生成时复用

  3. 目的:避免对已处理过的 token 重新计算 Key 和 Value

二、为什么需要 KV Cache?

问题:Transformer 的注意力计算

在 Transformer 中,每个 token 在注意力层需要:

  • Query(Q):查询向量

  • Key(K):键向量

  • Value(V):值向量

注意力分数计算公式:

Attention(Q, K, V) = softmax(Q·K^T / √d) · V

推理时的困境

没有缓存的情况

生成第1个token:计算 Token1 的 Q, K, V
生成第2个token:计算 Token1, Token2 的 Q, K, V  ← 重复计算Token1!
生成第3个token:计算 Token1, Token2, Token3 的 Q, K, V ← 重复计算Token1,2!

每次生成新 token 时,都需要为所有历史token重新计算K和V,计算量随序列长度平方增长。

有缓存的情况

生成第1个token:计算 Token1 的 Q, K, V,并缓存 K1, V1
生成第2个token:计算 Token2 的 Q, K, V,从缓存读取 K1, V1
生成第3个token:计算 Token3 的 Q, K, V,从缓存读取 K1, V1, K2, V2

只需为新 token 计算 K 和 V,历史 token 的 K, V 从缓存读取。

三、KV Cache 的数学表示

Transformer 层中的计算

对于第 l 层,输入 x

Q^l = x · W_Q^l  # [batch, seq_len, d_model] -> [batch, seq_len, d_k]
K^l = x · W_K^l  # [batch, seq_len, d_model] -> [batch, seq_len, d_k]
V^l = x · W_V^l  # [batch, seq_len, d_model] -> [batch, seq_len, d_v]

缓存内容

KV Cache 存储的是每个层的:

  • Key 矩阵 K^l:形状 [batch, seq_len, d_k]

  • Value 矩阵 V^l:形状 [batch, seq_len, d_v]

推理过程伪代码

class TransformerDecoderWithKVCache:
    def __init__(self, model, max_seq_len):
        self.model = model
        self.kv_cache = {
            'keys': torch.zeros(max_seq_len, num_layers, d_k),
            'values': torch.zeros(max_seq_len, num_layers, d_v)
        }
        self.cache_position = 0
    
    def generate_next_token(self, input_token):
        # 1. 前向传播到每一层
        for layer_idx, layer in enumerate(self.model.layers):
            # 计算当前token的Q, K, V
            q, k, v = layer.attention.qkv_projection(input_token)
            
            # 2. 更新缓存:将新token的K,V存入缓存
            self.kv_cache['keys'][self.cache_position, layer_idx] = k
            self.kv_cache['values'][self.cache_position, layer_idx] = v
            
            # 3. 注意力计算:使用当前Q和所有缓存的K,V
            # 从缓存获取到当前位置的所有K,V
            cached_keys = self.kv_cache['keys'][:self.cache_position+1, layer_idx]
            cached_values = self.kv_cache['values'][:self.cache_position+1, layer_idx]
            
            # 计算注意力
            attention_output = self.attention(q, cached_keys, cached_values)
            
            # 4. 继续前向传播
            input_token = layer.ffn(attention_output)
        
        self.cache_position += 1
        return output_logits

四、实际例子分析

案例:LLaMA-7B 模型

模型参数

  • 层数:32

  • 注意力头数:32

  • 每个头的维度:128

  • 上下文长度:4096

KV Cache 大小计算

# 每层的K/V矩阵大小
per_layer_kv_size = 2 * (d_model * d_k)  # K和V各一个

# 对于每个token,每层的缓存大小
per_token_per_layer = 2 * (num_heads * head_dim)  # 假设为 2 * (32 * 128) = 8192 个浮点数

# 所有层的缓存大小(单个token)
per_token_all_layers = per_token_per_layer * num_layers  # 8192 * 32 = 262,144 浮点数

# 浮点数大小(假设float16)
per_token_bytes = 262,144 * 2  # ≈ 524 KB

# 完整序列的缓存(4096个token)
full_sequence_bytes = 524 KB * 4096 ≈ 2.1 GB

内存占用分析

  • 模型权重:7B 参数,float16 格式 ≈ 14 GB

  • KV Cache(最大长度):≈ 2.1 GB

  • 总内存:约 16.1 GB

实际推理步骤示例

假设我们让模型生成 "The quick brown fox"

步骤1:输入 "The"

计算过程:
- 嵌入层:将 "The" 转换为向量
- 每一层:计算 "The" 的 Q1, K1, V1
- 保存:K1, V1 到 KV Cache
- 输出:预测下一个词的概率分布
- 选择:选择概率最高的词 "quick"

KV Cache 状态

Layer1: K1, V1
Layer2: K1, V1
...
Layer32: K1, V1

步骤2:输入 "quick"(当前序列:"The quick")

计算过程:
- 嵌入层:将 "quick" 转换为向量
- 每一层:
  * 计算 "quick" 的 Q2, K2, V2
  * 从缓存读取:Layer1: K1, V1
  * 注意力计算:Attention(Q2, [K1, K2], [V1, V2])
  * 保存:K2, V2 到 KV Cache
- 输出:预测下一个词
- 选择:"brown"

KV Cache 状态

Layer1: K1, V1, K2, V2
Layer2: K1, V1, K2, V2
...
Layer32: K1, V1, K2, V2

步骤3:输入 "brown"(当前序列:"The quick brown")

计算过程类似,但缓存中有3个token的K,V
注意力计算:Attention(Q3, [K1, K2, K3], [V1, V2, V3])

五、KV Cache 的优化技术

1. PagedAttention(vLLM)

  • 问题:传统KV Cache是连续内存,导致内存碎片

  • 解决方案:将KV Cache分页管理

  • 效果:提高内存利用率,支持更长的上下文

# 传统KV Cache(连续内存)
kv_cache = torch.zeros(max_len, num_layers, d_kv)

# PagedAttention(分页管理)
class KVCachePage:
    def __init__(self, page_size):
        self.keys = torch.zeros(page_size, d_kv)
        self.values = torch.zeros(page_size, d_kv)

# 管理多个页
kv_cache_pages = [KVCachePage(page_size) for _ in range(num_pages)]

2. Multi-Query Attention(MQA)

  • 传统:每个头有自己的K,V → 缓存大

  • MQA:多个头共享K,V → 缓存小

  • 内存节省:约减少为 1/num_heads

3. Grouped-Query Attention(GQA)

  • 介于MHA和MQA之间

  • 将头分组,组内共享K,V

  • 平衡效果和内存

4. KV Cache 量化

  • 将KV Cache从float16量化为int8

  • 内存减半,精度损失小

  • 公式KV_int8 = round(KV_fp16 / scale)

5. 滑动窗口注意力

  • 只缓存最近N个token的K,V

  • 适用于长文本,内存恒定

  • 缺点:无法处理长距离依赖

def sliding_window_kv_cache(kv_cache, new_k, new_v, window_size):
    # 添加新的K,V
    kv_cache.append((new_k, new_v))
    
    # 如果超过窗口大小,移除最旧的
    if len(kv_cache) > window_size:
        kv_cache.pop(0)
    
    return kv_cache

六、代码实现示例

简单KV Cache实现

import torch
import torch.nn as nn

class KVCache:
    def __init__(self, max_batch_size, max_seq_len, num_layers, num_heads, head_dim, dtype=torch.float16):
        self.max_seq_len = max_seq_len
        self.num_layers = num_layers
        self.num_heads = num_heads
        self.head_dim = head_dim
        
        # 初始化缓存
        self.key_cache = torch.zeros(
            max_batch_size, num_layers, max_seq_len, num_heads, head_dim,
            dtype=dtype
        )
        self.value_cache = torch.zeros(
            max_batch_size, num_layers, max_seq_len, num_heads, head_dim,
            dtype=dtype
        )
        
        # 当前序列长度
        self.seq_len = 0
    
    def update(self, layer_idx, new_key, new_value, batch_idx=0):
        """更新指定层的KV缓存"""
        # new_key形状: [batch, num_heads, seq_len=1, head_dim]
        # new_value形状: [batch, num_heads, seq_len=1, head_dim]
        
        self.key_cache[batch_idx, layer_idx, self.seq_len] = new_key.squeeze(2)
        self.value_cache[batch_idx, layer_idx, self.seq_len] = new_value.squeeze(2)
    
    def get(self, layer_idx, batch_idx=0):
        """获取指定层的KV缓存(到当前seq_len)"""
        keys = self.key_cache[batch_idx, layer_idx, :self.seq_len]
        values = self.value_cache[batch_idx, layer_idx, :self.seq_len]
        return keys, values
    
    def increment_seq_len(self):
        """增加序列长度"""
        self.seq_len += 1
        if self.seq_len > self.max_seq_len:
            raise ValueError(f"序列长度超过最大值 {self.max_seq_len}")
    
    def clear(self):
        """清空缓存"""
        self.seq_len = 0
        self.key_cache.zero_()
        self.value_cache.zero_()

在Transformer中的使用

class TransformerDecoderLayerWithKVCache(nn.Module):
    def __init__(self, d_model, num_heads, ff_dim):
        super().__init__()
        self.self_attn = MultiHeadAttention(d_model, num_heads)
        self.ffn = FeedForward(d_model, ff_dim)
        
    def forward(self, x, kv_cache, layer_idx, use_cache=False):
        # 自注意力
        q = self.self_attn.query_proj(x)
        
        if use_cache and kv_cache.seq_len > 0:
            # 有缓存:只计算当前token的K,V
            k = self.self_attn.key_proj(x)
            v = self.self_attn.value_proj(x)
            
            # 从缓存获取历史K,V
            past_keys, past_values = kv_cache.get(layer_idx)
            
            # 合并历史K,V和当前K,V
            keys = torch.cat([past_keys, k], dim=1)
            values = torch.cat([past_values, v], dim=1)
            
            # 更新缓存
            kv_cache.update(layer_idx, k, v)
        else:
            # 无缓存:计算所有token的K,V
            q, k, v = self.self_attn.qkv_proj(x)
            keys, values = k, v
            
            if use_cache:
                kv_cache.update(layer_idx, k, v)
        
        # 计算注意力
        attn_output = self.self_attn(q, keys, values)
        
        # FFN
        output = self.ffn(attn_output)
        
        return output

七、KV Cache 的挑战与解决方案

挑战1:内存占用大

解决方案

  1. 量化:FP16 → INT8,内存减半

  2. 压缩:稀疏化、低秩近似

  3. 选择性缓存:只缓存重要token

挑战2:长序列生成慢

解决方案

  1. FlashAttention:优化注意力计算

  2. 增量解码:仅计算新token

  3. 并行采样:一次生成多个候选

挑战3:批处理效率

解决方案

  1. 连续批处理:动态调整batch size

  2. vLLM的PagedAttention:高效内存管理

八、性能对比

方法 内存占用 推理速度 实现复杂度
无KV Cache 极慢(O(n²)) 简单
基础KV Cache 高(O(n)) 快(O(n)) 中等
PagedAttention 中等 很快 复杂
MQA/GQA 很快 中等

九、总结

KV Cache 是现代大语言模型推理的核心优化技术

  1. 工作原理:缓存历史token的Key和Value向量,避免重复计算

  2. 核心价值:将推理复杂度从 O(n²) 降到 O(n)

  3. 内存代价:需要额外存储所有历史token的K和V

  4. 优化方向:量化、压缩、高效内存管理

  5. 实际影响:决定了模型的最大上下文长度和推理速度

简单来说:KV Cache 就像是你读书时做的笔记。第一次读时做笔记(计算K,V),第二次需要时直接看笔记(从缓存读取),而不是重新读整本书(重新计算所有K,V)。

这就是为什么在 llama.cpp 等推理框架中,保持进程运行(KV Cache在内存中)比每次重启加载要快得多的根本原因。

Logo

有“AI”的1024 = 2048,欢迎大家加入2048 AI社区

更多推荐