KV Cache 详解:大模型推理的核心优化技术
KVCache是Transformer模型推理优化的关键技术,通过缓存历史token的Key和Value向量避免重复计算,将推理复杂度从O(n²)降到O(n)。其核心原理是在生成每个新token时,只需计算当前token的K/V,历史token的K/V从缓存读取。虽然会占用额外内存(如LLaMA-7B模型约需2.1GB),但显著提升推理速度。主要优化技术包括分页管理(PagedAttention)
一、什么是 KV Cache?
KV Cache(Key-Value 缓存)是 Transformer 模型在自回归推理过程中,为了避免重复计算而存储的中间状态。它是提高大模型推理速度的关键技术。
核心概念
-
KV 指的是 Transformer 注意力机制中的 Key 和 Value 向量
-
Cache 是指将这些向量缓存起来,供后续 token 生成时复用
-
目的:避免对已处理过的 token 重新计算 Key 和 Value
二、为什么需要 KV Cache?
问题:Transformer 的注意力计算
在 Transformer 中,每个 token 在注意力层需要:
-
Query(Q):查询向量
-
Key(K):键向量
-
Value(V):值向量
注意力分数计算公式:
Attention(Q, K, V) = softmax(Q·K^T / √d) · V
推理时的困境
没有缓存的情况:
生成第1个token:计算 Token1 的 Q, K, V 生成第2个token:计算 Token1, Token2 的 Q, K, V ← 重复计算Token1! 生成第3个token:计算 Token1, Token2, Token3 的 Q, K, V ← 重复计算Token1,2!
每次生成新 token 时,都需要为所有历史token重新计算K和V,计算量随序列长度平方增长。
有缓存的情况:
生成第1个token:计算 Token1 的 Q, K, V,并缓存 K1, V1 生成第2个token:计算 Token2 的 Q, K, V,从缓存读取 K1, V1 生成第3个token:计算 Token3 的 Q, K, V,从缓存读取 K1, V1, K2, V2
只需为新 token 计算 K 和 V,历史 token 的 K, V 从缓存读取。
三、KV Cache 的数学表示
Transformer 层中的计算
对于第 l 层,输入 x:
Q^l = x · W_Q^l # [batch, seq_len, d_model] -> [batch, seq_len, d_k] K^l = x · W_K^l # [batch, seq_len, d_model] -> [batch, seq_len, d_k] V^l = x · W_V^l # [batch, seq_len, d_model] -> [batch, seq_len, d_v]
缓存内容
KV Cache 存储的是每个层的:
-
Key 矩阵 K^l:形状 [batch, seq_len, d_k]
-
Value 矩阵 V^l:形状 [batch, seq_len, d_v]
推理过程伪代码
class TransformerDecoderWithKVCache:
def __init__(self, model, max_seq_len):
self.model = model
self.kv_cache = {
'keys': torch.zeros(max_seq_len, num_layers, d_k),
'values': torch.zeros(max_seq_len, num_layers, d_v)
}
self.cache_position = 0
def generate_next_token(self, input_token):
# 1. 前向传播到每一层
for layer_idx, layer in enumerate(self.model.layers):
# 计算当前token的Q, K, V
q, k, v = layer.attention.qkv_projection(input_token)
# 2. 更新缓存:将新token的K,V存入缓存
self.kv_cache['keys'][self.cache_position, layer_idx] = k
self.kv_cache['values'][self.cache_position, layer_idx] = v
# 3. 注意力计算:使用当前Q和所有缓存的K,V
# 从缓存获取到当前位置的所有K,V
cached_keys = self.kv_cache['keys'][:self.cache_position+1, layer_idx]
cached_values = self.kv_cache['values'][:self.cache_position+1, layer_idx]
# 计算注意力
attention_output = self.attention(q, cached_keys, cached_values)
# 4. 继续前向传播
input_token = layer.ffn(attention_output)
self.cache_position += 1
return output_logits
四、实际例子分析
案例:LLaMA-7B 模型
模型参数:
-
层数:32
-
注意力头数:32
-
每个头的维度:128
-
上下文长度:4096
KV Cache 大小计算:
# 每层的K/V矩阵大小 per_layer_kv_size = 2 * (d_model * d_k) # K和V各一个 # 对于每个token,每层的缓存大小 per_token_per_layer = 2 * (num_heads * head_dim) # 假设为 2 * (32 * 128) = 8192 个浮点数 # 所有层的缓存大小(单个token) per_token_all_layers = per_token_per_layer * num_layers # 8192 * 32 = 262,144 浮点数 # 浮点数大小(假设float16) per_token_bytes = 262,144 * 2 # ≈ 524 KB # 完整序列的缓存(4096个token) full_sequence_bytes = 524 KB * 4096 ≈ 2.1 GB
内存占用分析:
-
模型权重:7B 参数,float16 格式 ≈ 14 GB
-
KV Cache(最大长度):≈ 2.1 GB
-
总内存:约 16.1 GB
实际推理步骤示例
假设我们让模型生成 "The quick brown fox"
步骤1:输入 "The"
计算过程: - 嵌入层:将 "The" 转换为向量 - 每一层:计算 "The" 的 Q1, K1, V1 - 保存:K1, V1 到 KV Cache - 输出:预测下一个词的概率分布 - 选择:选择概率最高的词 "quick"
KV Cache 状态:
Layer1: K1, V1 Layer2: K1, V1 ... Layer32: K1, V1
步骤2:输入 "quick"(当前序列:"The quick")
计算过程: - 嵌入层:将 "quick" 转换为向量 - 每一层: * 计算 "quick" 的 Q2, K2, V2 * 从缓存读取:Layer1: K1, V1 * 注意力计算:Attention(Q2, [K1, K2], [V1, V2]) * 保存:K2, V2 到 KV Cache - 输出:预测下一个词 - 选择:"brown"
KV Cache 状态:
Layer1: K1, V1, K2, V2 Layer2: K1, V1, K2, V2 ... Layer32: K1, V1, K2, V2
步骤3:输入 "brown"(当前序列:"The quick brown")
计算过程类似,但缓存中有3个token的K,V 注意力计算:Attention(Q3, [K1, K2, K3], [V1, V2, V3])
五、KV Cache 的优化技术
1. PagedAttention(vLLM)
-
问题:传统KV Cache是连续内存,导致内存碎片
-
解决方案:将KV Cache分页管理
-
效果:提高内存利用率,支持更长的上下文
# 传统KV Cache(连续内存)
kv_cache = torch.zeros(max_len, num_layers, d_kv)
# PagedAttention(分页管理)
class KVCachePage:
def __init__(self, page_size):
self.keys = torch.zeros(page_size, d_kv)
self.values = torch.zeros(page_size, d_kv)
# 管理多个页
kv_cache_pages = [KVCachePage(page_size) for _ in range(num_pages)]
2. Multi-Query Attention(MQA)
-
传统:每个头有自己的K,V → 缓存大
-
MQA:多个头共享K,V → 缓存小
-
内存节省:约减少为 1/num_heads
3. Grouped-Query Attention(GQA)
-
介于MHA和MQA之间
-
将头分组,组内共享K,V
-
平衡效果和内存
4. KV Cache 量化
-
将KV Cache从float16量化为int8
-
内存减半,精度损失小
-
公式:
KV_int8 = round(KV_fp16 / scale)
5. 滑动窗口注意力
-
只缓存最近N个token的K,V
-
适用于长文本,内存恒定
-
缺点:无法处理长距离依赖
def sliding_window_kv_cache(kv_cache, new_k, new_v, window_size):
# 添加新的K,V
kv_cache.append((new_k, new_v))
# 如果超过窗口大小,移除最旧的
if len(kv_cache) > window_size:
kv_cache.pop(0)
return kv_cache
六、代码实现示例
简单KV Cache实现
import torch
import torch.nn as nn
class KVCache:
def __init__(self, max_batch_size, max_seq_len, num_layers, num_heads, head_dim, dtype=torch.float16):
self.max_seq_len = max_seq_len
self.num_layers = num_layers
self.num_heads = num_heads
self.head_dim = head_dim
# 初始化缓存
self.key_cache = torch.zeros(
max_batch_size, num_layers, max_seq_len, num_heads, head_dim,
dtype=dtype
)
self.value_cache = torch.zeros(
max_batch_size, num_layers, max_seq_len, num_heads, head_dim,
dtype=dtype
)
# 当前序列长度
self.seq_len = 0
def update(self, layer_idx, new_key, new_value, batch_idx=0):
"""更新指定层的KV缓存"""
# new_key形状: [batch, num_heads, seq_len=1, head_dim]
# new_value形状: [batch, num_heads, seq_len=1, head_dim]
self.key_cache[batch_idx, layer_idx, self.seq_len] = new_key.squeeze(2)
self.value_cache[batch_idx, layer_idx, self.seq_len] = new_value.squeeze(2)
def get(self, layer_idx, batch_idx=0):
"""获取指定层的KV缓存(到当前seq_len)"""
keys = self.key_cache[batch_idx, layer_idx, :self.seq_len]
values = self.value_cache[batch_idx, layer_idx, :self.seq_len]
return keys, values
def increment_seq_len(self):
"""增加序列长度"""
self.seq_len += 1
if self.seq_len > self.max_seq_len:
raise ValueError(f"序列长度超过最大值 {self.max_seq_len}")
def clear(self):
"""清空缓存"""
self.seq_len = 0
self.key_cache.zero_()
self.value_cache.zero_()
在Transformer中的使用
class TransformerDecoderLayerWithKVCache(nn.Module):
def __init__(self, d_model, num_heads, ff_dim):
super().__init__()
self.self_attn = MultiHeadAttention(d_model, num_heads)
self.ffn = FeedForward(d_model, ff_dim)
def forward(self, x, kv_cache, layer_idx, use_cache=False):
# 自注意力
q = self.self_attn.query_proj(x)
if use_cache and kv_cache.seq_len > 0:
# 有缓存:只计算当前token的K,V
k = self.self_attn.key_proj(x)
v = self.self_attn.value_proj(x)
# 从缓存获取历史K,V
past_keys, past_values = kv_cache.get(layer_idx)
# 合并历史K,V和当前K,V
keys = torch.cat([past_keys, k], dim=1)
values = torch.cat([past_values, v], dim=1)
# 更新缓存
kv_cache.update(layer_idx, k, v)
else:
# 无缓存:计算所有token的K,V
q, k, v = self.self_attn.qkv_proj(x)
keys, values = k, v
if use_cache:
kv_cache.update(layer_idx, k, v)
# 计算注意力
attn_output = self.self_attn(q, keys, values)
# FFN
output = self.ffn(attn_output)
return output
七、KV Cache 的挑战与解决方案
挑战1:内存占用大
解决方案:
-
量化:FP16 → INT8,内存减半
-
压缩:稀疏化、低秩近似
-
选择性缓存:只缓存重要token
挑战2:长序列生成慢
解决方案:
-
FlashAttention:优化注意力计算
-
增量解码:仅计算新token
-
并行采样:一次生成多个候选
挑战3:批处理效率
解决方案:
-
连续批处理:动态调整batch size
-
vLLM的PagedAttention:高效内存管理
八、性能对比
| 方法 | 内存占用 | 推理速度 | 实现复杂度 |
|---|---|---|---|
| 无KV Cache | 低 | 极慢(O(n²)) | 简单 |
| 基础KV Cache | 高(O(n)) | 快(O(n)) | 中等 |
| PagedAttention | 中等 | 很快 | 复杂 |
| MQA/GQA | 低 | 很快 | 中等 |
九、总结
KV Cache 是现代大语言模型推理的核心优化技术:
-
工作原理:缓存历史token的Key和Value向量,避免重复计算
-
核心价值:将推理复杂度从 O(n²) 降到 O(n)
-
内存代价:需要额外存储所有历史token的K和V
-
优化方向:量化、压缩、高效内存管理
-
实际影响:决定了模型的最大上下文长度和推理速度
简单来说:KV Cache 就像是你读书时做的笔记。第一次读时做笔记(计算K,V),第二次需要时直接看笔记(从缓存读取),而不是重新读整本书(重新计算所有K,V)。
这就是为什么在 llama.cpp 等推理框架中,保持进程运行(KV Cache在内存中)比每次重启加载要快得多的根本原因。
更多推荐


所有评论(0)