昇腾CANN ATB KV Cache 与 PagedAttention:显存碎片消除的完整方案
摘要:LLM推理面临的主要瓶颈是显存而非计算,特别是长上下文下KV Cache的二次增长导致显存不足。ATB采用PagedAttention和虚拟内存管理技术,将KV Cache分页存储,按需申请页面,有效解决显存碎片问题。标准KV Cache需要连续分配大块显存,而PagedAttention将显存划分为16KB页面,允许非连续分配,显著提高显存利用率。ATB实现包括全局页池管理、逻辑到物理地址
LLM 推理的最大瓶颈不是计算——是显存。长上下文下,KV Cache 的显存占用是二次增长的:seq_len=128K → KV Cache = 128K × 每层 KV 大小 = 128K × (2 × hidden × head_num) = 128K × 2 × 8192 × 32 = 32GB。加上模型参数(70B × 2bytes = 140GB)→ 总共 172GB → Ascend 910 只有 128GB → OOM。
ATB 用 PagedAttention + 虚拟内存管理解决这个问题:把 KV Cache 分页存储(Page Table),不连续分配,按需申请页面。像操作系统管理虚拟内存一样管理 KV Cache。
KV Cache 的显存碎片问题
标准 KV Cache 是连续分配的三维张量 [batch, seq_len, hidden, head_num×2]:
连续分配 KV Cache 的问题
请求 1:seq_len=128K → 需要 32GB 连续块
请求 2:seq_len=512 → 需要 16MB 连续块
请求 3:seq_len=64K → 需要 16GB 连续块
...
32GB + 16MB + 16GB + ... = 180GB > 128GB
即使空闲总量够(很多小请求释放后),但无法分配 32GB 连续块 → OOM
对比 PagedAttention:
PagedAttention 方式
KV Cache 被分成 16KB 的 pages(每 16KB=1 page)
请求 1 的 32GB → 分配 32GB/16KB = 2,097,152 pages
请求 2 的 16MB → 分配 16MB/16KB = 1024 pages
请求 3 的 16GB → 分配 16GB/16KB = 1,048,576 pages
...
page 不需要连续!碎片不再是问题——任何空闲 page 都能分配
ATB 的 PagedAttention 实现
// ascend-transformer-boost/memory/paged_attention.cpp
class PagedAttentionMemory {
private:
// 全局 page 池:所有请求共享
static constexpr int PAGE_SIZE = 16 * 1024; // 16KB
struct Page {
int id; // page 编号(全局唯一)
DevicePtr ptr; // page 在 HBM 上的地址
bool allocated; // 是否已分配
int ref_count; // 引用计数(多请求共享)
};
std::vector<Page> global_page_pool_; // 全局 page 池
int total_pages_; // 总 page 数 = HBM 大小 / PAGE_SIZE
// 每个请求的 page 表
struct PageTable {
std::vector<int> page_ids; // 虚拟地址 → 物理 page 映射
int num_pages; // 已分配 page 数
int seq_len; // 当前序列长度
};
std::unordered_map<int, PageTable> request_page_tables_; // request_id → page 表
public:
// ===== 分配 pages =====
Status AllocatePages(int request_id, int num_pages_needed) {
PageTable& pt = request_page_tables_[request_id];
for (int i = 0; i < num_pages_needed; i++) {
int page_id = FindFreePage();
if (page_id == -1) {
return Status::OUT_OF_MEMORY; // 没有空闲 page
}
// 分配 page
global_page_pool_[page_id].allocated = true;
global_page_pool_[page_id].ref_count = 1;
pt.page_ids.push_back(page_id);
pt.num_pages++;
}
pt.seq_len = num_pages_needed * (PAGE_SIZE / sizeof(float16) / (hidden * 2));
return Status::OK;
}
// ===== 逻辑地址 → 物理地址转换 =====
DevicePtr LogicalToPhysical(int request_id, int logical_offset) {
PageTable& pt = request_page_tables_[request_id];
// 计算逻辑偏移在哪一页和页内偏移
int page_index = logical_offset / PAGE_SIZE;
int offset_in_page = logical_offset % PAGE_SIZE;
// 从 page 表查询物理地址
int physical_page_id = pt.page_ids[page_index];
DevicePtr physical_page = global_page_pool_[physical_page_id].ptr;
return physical_page + offset_in_page;
}
// ===== 释放 pages(请求完成或溢出)=====
void FreePages(int request_id) {
PageTable& pt = request_page_tables_[request_id];
for (int page_id : pt.page_ids) {
global_page_pool_[page_id].ref_count--;
if (global_page_pool_[page_id].ref_count == 0) {
global_page_pool_[page_id].allocated = false; // 真正释放
}
}
pt.page_ids.clear();
pt.num_pages = 0;
}
};
PagedAttention 的 Attention 计算修改
标准注意力计算(完整 KV Cache):
// 标准 Attention
for (int k = 0; k < seq_len; k++) {
float score = dot(Q[token], K[k]); // Q 与 K 的一维点积
softmax_scores[k] = exp(score);
}
PagedAttention 计算(按页计算):
// ascend-transformer-boost/kernels/paged_attention_kernel.cpp
__aicore__ void PagedAttentionKernel(
GlobalTensor<float16>& Q, // [batch, num_heads, d_head]
GlobalTensor<float16>& K_pages, // [total_pages, page_size]
GlobalTensor<float16>& V_pages, // [total_pages, page_size]
GlobalTensor<float16>& output, // [batch, num_heads, d_head]
GlobalTensor<int>& page_table, // [request_id, max_pages]
int num_pages, int head_dim
) {
int request_id = blockIdx.x; // 每个 block 处理一个请求
int head_id = threadIdx.y; // 每个 thread 处理一个注意力头
// 初始化累加器
LocalTensor<float16> O_local(head_dim);
for (int d = 0; d < head_dim; d++) O_local[d] = 0.0f;
float max_val = -65504.0f;
float sum_exp = 0.0f;
// 逐页计算 Attention
for (int p = 0; p < num_pages; p++) {
int physical_page_id = page_table[request_id * MAX_PAGES + p];
// 加载一页的 K 和 V(连续访问——物理地址)
LocalTensor<float16> K_page(page_size);
LocalTensor<float16> V_page(page_size);
DataCopy(K_page, K_pages + physical_page_id * page_size, page_size);
DataCopy(V_page, V_pages + physical_page_id * page_size, page_size);
// 计算 QK^T 在这一页的分数
for (int i = 0; i < page_tokens; i++) {
// K_page[i] 是 K[token_i],计算 Q·K[token_i]
float score = 0.0f;
for (int d = 0; d < head_dim; d++) {
score += float(Q[head_id * head_dim + d]) *
float(K_page[i * head_dim + d]);
}
// Online softmax(逐页更新)
float exp_score = expf(score - max_val);
// 如果这一页有更大的 score → 重新标定累加器
if (score > max_val) {
float old_max = max_val;
max_val = score;
// 重新标定之前累加的 O_local 和 sum_exp
float correction = expf(old_max - max_val);
for (int d = 0; d < head_dim; d++) {
O_local[d] = O_local[d] * correction;
}
sum_exp = sum_exp * correction;
}
// 累加 V * softmax_score
for (int d = 0; d < head_dim; d++) {
O_local[d] += V_page[i * head_dim + d] * exp_score;
}
sum_exp += exp_score;
}
}
// 归一化
for (int d = 0; d < head_dim; d++) {
output[request_id * head_dim + d] = float16(O_local[d] / sum_exp);
}
}
PagedAttention 的关键:page 表中的物理地址是离散的,但每个 page 内部的访问是连续的。分页解决了碎片,不会降低注意力计算的性能(因为每个 page 内部依然是连续访问)。
page 分配的贪心策略
// ascend-transformer-boost/memory/page_allocator.cpp
class PageAllocator {
private:
int FindFreePage() {
// 贪心:找第一个空闲 page
for (int i = 0; i < total_pages_; i++) {
if (!global_page_pool_[i].allocated) {
return i;
}
}
return -1; // 无空闲
}
// 预取策略(预测下一个 page 位置)
int PrefetchNextPage(int current_page_id) {
// 如果当前 page 后一个也是该请求的 → 预取(减少延迟)
int next_page = current_page_id + 1;
if (next_page < total_pages_ &&
!global_page_pool_[next_page].allocated) {
PrefetchToCache(next_page); // 预取到 SRAM
}
}
public:
// 批量预取(所有已分配 page)
void PrefetchAllPages(int request_id) {
PageTable& pt = request_page_tables_[request_id];
for (int page_id : pt.page_ids) {
PrefetchToCache(page_id);
}
}
};
踩坑一:page 表查找的延迟
PagedAttention 需要频繁查 page 表(每次访问 K/V 都要逻辑→物理转换)。page 表本身在 HBM 中——每次查表都是 HBM 访问。
修复:把 page 表拷贝到 L1 缓存
// 加速 page 表查找
__aicore__ void FastPageLookup(
GlobalTensor<int>& page_table_in_hbm, // page 表在 HBM 中
LocalTensor<int>& page_table_in_l1, // 拷贝到 L1
int num_pages
) {
// 拷贝 page 表到 L1(一次性把所有 page 的映射都搬上来)
DataCopy(page_table_in_l1, page_table_in_hbm, num_pages * sizeof(int));
// 之后所有查表都在 L1 中——延迟 < 1 cycle(不是 HBM 的百 cycle)
}
L1 中的 page 表查表延迟:1 cycle。HBM 中查表延迟:~300 cycles。PagedAttention 每页查一次表——page=2MB → 查表延迟节省 = 2MB × (300-1) = ~600M cycles。
踩坑二:page 引用计数泄漏
多个请求可能共享相同的 K/V pages(如共享前缀)。引用计数减到 0 才真正释放。但如果忘记减引用计数——page 永远不释放 → 内存泄漏。
// 引用计数的正确管理
class RefCountManager {
public:
// 分配:ref_count = 1(新请求独占)
void AllocPage(int page_id) {
global_page_pool_[page_id].ref_count = 1;
}
// 共享:ref_count++(其他请求加入)
void SharePage(int page_id, int request_id) {
global_page_pool_[page_id].ref_count++;
// 记录哪几个请求在共享这个 page
shared_requests_[page_id].push_back(request_id);
}
// 释放:ref_count--(只有变成 0 才释放)
void ReleasePage(int page_id, int request_id) {
global_page_pool_[page_id].ref_count--;
if (global_page_pool_[page_id].ref_count == 0) {
// 真正释放:标记为可用
global_page_pool_[page_id].allocated = false;
shared_requests_[page_id].clear();
}
}
// 校验(多请求释放时的安全检查)
void ValidateRefCount(int page_id, int request_id) {
auto& shared = shared_requests_[page_id];
if (std::find(shared.begin(), shared.end(), request_id) == shared.end()) {
// 这个请求没有共享这个 page → 不应该减引用计数
throw RefCountError("request not in shared list");
}
}
};
踩坑三:page 表更新时的时序竞争
推理过程中,Decoder 生成新 token 时,KV Cache 需要扩展(添加新的 K, V)。如果此时上一个请求的 page 正在被 Attention 计算读 → 数据竞争。
方案:Copy-on-Write(CoW)
// Copy-on-Write page 更新
Status ExtendKVPage(int request_id, int new_page_id) {
PageTable& pt = request_page_tables_[request_id];
int old_page_id = pt.page_ids.back();
// 如果只有这个请求在用这个 page → 直接更新
if (global_page_pool_[old_page_id].ref_count == 1) {
// 无竞争:直接覆盖旧 page
global_page_pool_[old_page_id].allocated = true;
return Status::OK;
}
// 多个请求在共享这个 page → Copy-on-Write
// 分配新 page,拷贝旧内容,写入新数据
int new_page = FindFreePage();
if (new_page == -1) return Status::OUT_OF_MEMORY;
// CoW:拷贝旧 page 到新 page
memcpy(global_page_pool_[new_page].ptr,
global_page_pool_[old_page_id].ptr,
PAGE_SIZE);
// 在新 page 上追加 K,V 数据
WriteKV(global_page_pool_[new_page].ptr, new_K, new_V);
// 更新 page 表
pt.page_ids.back() = new_page;
// 释放旧 page 的引用
ReleasePage(old_page_id, request_id);
return Status::OK;
}
KV Cache 是 LLM 推理中最大的显存消耗者——128K 上下文下占 32GB。ATB 的 PagedAttention 把连续分配变成分页分配:page 池全局共享、page 表做逻辑→物理映射、Copy-on-Write 解决共享页的更新冲突。像操作系统管理虚拟内存一样管理 KV Cache——碎片不再导致 OOM。
更多推荐



所有评论(0)