LLM 推理的最大瓶颈不是计算——是显存。长上下文下,KV Cache 的显存占用是二次增长的:seq_len=128K → KV Cache = 128K × 每层 KV 大小 = 128K × (2 × hidden × head_num) = 128K × 2 × 8192 × 32 = 32GB。加上模型参数(70B × 2bytes = 140GB)→ 总共 172GB → Ascend 910 只有 128GB → OOM。

ATB 用 PagedAttention + 虚拟内存管理解决这个问题:把 KV Cache 分页存储(Page Table),不连续分配,按需申请页面。像操作系统管理虚拟内存一样管理 KV Cache。

KV Cache 的显存碎片问题

标准 KV Cache 是连续分配的三维张量 [batch, seq_len, hidden, head_num×2]:

连续分配 KV Cache 的问题

请求 1:seq_len=128K → 需要 32GB 连续块
请求 2:seq_len=512 → 需要 16MB 连续块
请求 3:seq_len=64K → 需要 16GB 连续块
...

32GB + 16MB + 16GB + ... = 180GB > 128GB
即使空闲总量够(很多小请求释放后),但无法分配 32GB 连续块 → OOM

对比 PagedAttention:

PagedAttention 方式

KV Cache 被分成 16KB 的 pages(每 16KB=1 page)
请求 1 的 32GB → 分配 32GB/16KB = 2,097,152 pages
请求 2 的 16MB → 分配 16MB/16KB = 1024 pages
请求 3 的 16GB → 分配 16GB/16KB = 1,048,576 pages
...

page 不需要连续!碎片不再是问题——任何空闲 page 都能分配

ATB 的 PagedAttention 实现

// ascend-transformer-boost/memory/paged_attention.cpp

class PagedAttentionMemory {
private:
    // 全局 page 池:所有请求共享
    static constexpr int PAGE_SIZE = 16 * 1024;  // 16KB

    struct Page {
        int id;               // page 编号(全局唯一)
        DevicePtr ptr;         // page 在 HBM 上的地址
        bool allocated;        // 是否已分配
        int ref_count;         // 引用计数(多请求共享)
    };

    std::vector<Page> global_page_pool_;  // 全局 page 池
    int total_pages_;                    // 总 page 数 = HBM 大小 / PAGE_SIZE

    // 每个请求的 page 表
    struct PageTable {
        std::vector<int> page_ids;     // 虚拟地址 → 物理 page 映射
        int num_pages;                  // 已分配 page 数
        int seq_len;                    // 当前序列长度
    };

    std::unordered_map<int, PageTable> request_page_tables_;  // request_id → page 表

public:
    // ===== 分配 pages =====
    Status AllocatePages(int request_id, int num_pages_needed) {
        PageTable& pt = request_page_tables_[request_id];

        for (int i = 0; i < num_pages_needed; i++) {
            int page_id = FindFreePage();
            if (page_id == -1) {
                return Status::OUT_OF_MEMORY;  // 没有空闲 page
            }

            // 分配 page
            global_page_pool_[page_id].allocated = true;
            global_page_pool_[page_id].ref_count = 1;
            pt.page_ids.push_back(page_id);
            pt.num_pages++;
        }

        pt.seq_len = num_pages_needed * (PAGE_SIZE / sizeof(float16) / (hidden * 2));
        return Status::OK;
    }

    // ===== 逻辑地址 → 物理地址转换 =====
    DevicePtr LogicalToPhysical(int request_id, int logical_offset) {
        PageTable& pt = request_page_tables_[request_id];

        // 计算逻辑偏移在哪一页和页内偏移
        int page_index = logical_offset / PAGE_SIZE;
        int offset_in_page = logical_offset % PAGE_SIZE;

        // 从 page 表查询物理地址
        int physical_page_id = pt.page_ids[page_index];
        DevicePtr physical_page = global_page_pool_[physical_page_id].ptr;

        return physical_page + offset_in_page;
    }

    // ===== 释放 pages(请求完成或溢出)=====
    void FreePages(int request_id) {
        PageTable& pt = request_page_tables_[request_id];

        for (int page_id : pt.page_ids) {
            global_page_pool_[page_id].ref_count--;
            if (global_page_pool_[page_id].ref_count == 0) {
                global_page_pool_[page_id].allocated = false;  // 真正释放
            }
        }

        pt.page_ids.clear();
        pt.num_pages = 0;
    }
};

PagedAttention 的 Attention 计算修改

标准注意力计算(完整 KV Cache):

// 标准 Attention
for (int k = 0; k < seq_len; k++) {
    float score = dot(Q[token], K[k]);  // Q 与 K 的一维点积
    softmax_scores[k] = exp(score);
}

PagedAttention 计算(按页计算):

// ascend-transformer-boost/kernels/paged_attention_kernel.cpp

__aicore__ void PagedAttentionKernel(
    GlobalTensor<float16>& Q,          // [batch, num_heads, d_head]
    GlobalTensor<float16>& K_pages,    // [total_pages, page_size]
    GlobalTensor<float16>& V_pages,    // [total_pages, page_size]
    GlobalTensor<float16>& output,     // [batch, num_heads, d_head]
    GlobalTensor<int>& page_table,     // [request_id, max_pages]
    int num_pages, int head_dim
) {
    int request_id = blockIdx.x;  // 每个 block 处理一个请求
    int head_id = threadIdx.y;    // 每个 thread 处理一个注意力头

    // 初始化累加器
    LocalTensor<float16> O_local(head_dim);
    for (int d = 0; d < head_dim; d++) O_local[d] = 0.0f;

    float max_val = -65504.0f;
    float sum_exp = 0.0f;

    // 逐页计算 Attention
    for (int p = 0; p < num_pages; p++) {
        int physical_page_id = page_table[request_id * MAX_PAGES + p];

        // 加载一页的 K 和 V(连续访问——物理地址)
        LocalTensor<float16> K_page(page_size);
        LocalTensor<float16> V_page(page_size);

        DataCopy(K_page, K_pages + physical_page_id * page_size, page_size);
        DataCopy(V_page, V_pages + physical_page_id * page_size, page_size);

        // 计算 QK^T 在这一页的分数
        for (int i = 0; i < page_tokens; i++) {
            // K_page[i] 是 K[token_i],计算 Q·K[token_i]
            float score = 0.0f;
            for (int d = 0; d < head_dim; d++) {
                score += float(Q[head_id * head_dim + d]) *
                         float(K_page[i * head_dim + d]);
            }

            // Online softmax(逐页更新)
            float exp_score = expf(score - max_val);

            // 如果这一页有更大的 score → 重新标定累加器
            if (score > max_val) {
                float old_max = max_val;
                max_val = score;

                // 重新标定之前累加的 O_local 和 sum_exp
                float correction = expf(old_max - max_val);
                for (int d = 0; d < head_dim; d++) {
                    O_local[d] = O_local[d] * correction;
                }
                sum_exp = sum_exp * correction;
            }

            // 累加 V * softmax_score
            for (int d = 0; d < head_dim; d++) {
                O_local[d] += V_page[i * head_dim + d] * exp_score;
            }
            sum_exp += exp_score;
        }
    }

    // 归一化
    for (int d = 0; d < head_dim; d++) {
        output[request_id * head_dim + d] = float16(O_local[d] / sum_exp);
    }
}

PagedAttention 的关键:page 表中的物理地址是离散的,但每个 page 内部的访问是连续的。分页解决了碎片,不会降低注意力计算的性能(因为每个 page 内部依然是连续访问)。

page 分配的贪心策略

// ascend-transformer-boost/memory/page_allocator.cpp

class PageAllocator {
private:
    int FindFreePage() {
        // 贪心:找第一个空闲 page
        for (int i = 0; i < total_pages_; i++) {
            if (!global_page_pool_[i].allocated) {
                return i;
            }
        }
        return -1;  // 无空闲
    }

    // 预取策略(预测下一个 page 位置)
    int PrefetchNextPage(int current_page_id) {
        // 如果当前 page 后一个也是该请求的 → 预取(减少延迟)
        int next_page = current_page_id + 1;
        if (next_page < total_pages_ &&
            !global_page_pool_[next_page].allocated) {
            PrefetchToCache(next_page);  // 预取到 SRAM
        }
    }

public:
    // 批量预取(所有已分配 page)
    void PrefetchAllPages(int request_id) {
        PageTable& pt = request_page_tables_[request_id];
        for (int page_id : pt.page_ids) {
            PrefetchToCache(page_id);
        }
    }
};

踩坑一:page 表查找的延迟

PagedAttention 需要频繁查 page 表(每次访问 K/V 都要逻辑→物理转换)。page 表本身在 HBM 中——每次查表都是 HBM 访问。

修复:把 page 表拷贝到 L1 缓存

// 加速 page 表查找
__aicore__ void FastPageLookup(
    GlobalTensor<int>& page_table_in_hbm,  // page 表在 HBM 中
    LocalTensor<int>& page_table_in_l1,    // 拷贝到 L1
    int num_pages
) {
    // 拷贝 page 表到 L1(一次性把所有 page 的映射都搬上来)
    DataCopy(page_table_in_l1, page_table_in_hbm, num_pages * sizeof(int));

    // 之后所有查表都在 L1 中——延迟 < 1 cycle(不是 HBM 的百 cycle)
}

L1 中的 page 表查表延迟:1 cycle。HBM 中查表延迟:~300 cycles。PagedAttention 每页查一次表——page=2MB → 查表延迟节省 = 2MB × (300-1) = ~600M cycles。

踩坑二:page 引用计数泄漏

多个请求可能共享相同的 K/V pages(如共享前缀)。引用计数减到 0 才真正释放。但如果忘记减引用计数——page 永远不释放 → 内存泄漏。

// 引用计数的正确管理
class RefCountManager {
public:
    // 分配:ref_count = 1(新请求独占)
    void AllocPage(int page_id) {
        global_page_pool_[page_id].ref_count = 1;
    }

    // 共享:ref_count++(其他请求加入)
    void SharePage(int page_id, int request_id) {
        global_page_pool_[page_id].ref_count++;
        // 记录哪几个请求在共享这个 page
        shared_requests_[page_id].push_back(request_id);
    }

    // 释放:ref_count--(只有变成 0 才释放)
    void ReleasePage(int page_id, int request_id) {
        global_page_pool_[page_id].ref_count--;

        if (global_page_pool_[page_id].ref_count == 0) {
            // 真正释放:标记为可用
            global_page_pool_[page_id].allocated = false;
            shared_requests_[page_id].clear();
        }
    }

    // 校验(多请求释放时的安全检查)
    void ValidateRefCount(int page_id, int request_id) {
        auto& shared = shared_requests_[page_id];
        if (std::find(shared.begin(), shared.end(), request_id) == shared.end()) {
            // 这个请求没有共享这个 page → 不应该减引用计数
            throw RefCountError("request not in shared list");
        }
    }
};

踩坑三:page 表更新时的时序竞争

推理过程中,Decoder 生成新 token 时,KV Cache 需要扩展(添加新的 K, V)。如果此时上一个请求的 page 正在被 Attention 计算读 → 数据竞争。

方案:Copy-on-Write(CoW)

// Copy-on-Write page 更新
Status ExtendKVPage(int request_id, int new_page_id) {
    PageTable& pt = request_page_tables_[request_id];
    int old_page_id = pt.page_ids.back();

    // 如果只有这个请求在用这个 page → 直接更新
    if (global_page_pool_[old_page_id].ref_count == 1) {
        // 无竞争:直接覆盖旧 page
        global_page_pool_[old_page_id].allocated = true;
        return Status::OK;
    }

    // 多个请求在共享这个 page → Copy-on-Write
    // 分配新 page,拷贝旧内容,写入新数据
    int new_page = FindFreePage();
    if (new_page == -1) return Status::OUT_OF_MEMORY;

    // CoW:拷贝旧 page 到新 page
    memcpy(global_page_pool_[new_page].ptr,
           global_page_pool_[old_page_id].ptr,
           PAGE_SIZE);

    // 在新 page 上追加 K,V 数据
    WriteKV(global_page_pool_[new_page].ptr, new_K, new_V);

    // 更新 page 表
    pt.page_ids.back() = new_page;

    // 释放旧 page 的引用
    ReleasePage(old_page_id, request_id);

    return Status::OK;
}

KV Cache 是 LLM 推理中最大的显存消耗者——128K 上下文下占 32GB。ATB 的 PagedAttention 把连续分配变成分页分配:page 池全局共享、page 表做逻辑→物理映射、Copy-on-Write 解决共享页的更新冲突。像操作系统管理虚拟内存一样管理 KV Cache——碎片不再导致 OOM。

Logo

有“AI”的1024 = 2048,欢迎大家加入2048 AI社区

更多推荐