引言:从“全排序”到“只取前 K”——推理中的关键瓶颈

在大语言模型(LLM)推理、推荐系统、目标检测等场景中,我们常常不需要对整个数组进行完整排序,而只需获取 最大的 K 个元素及其索引(即 TopK)或 所有元素的排序索引(即 ArgSort)。

例如:

  • LLM 生成时,需从词汇表(数万 token)中选出概率最高的前几个候选;
  • 推荐系统需从百万商品中返回最相关的 Top100;
  • 目标检测需筛选置信度最高的若干边界框。

若使用标准库的 std::sort,时间复杂度为 O ( N log ⁡ N ) O(N \log N) O(NlogN),当 N N N 很大时(如 N = 50 , 000 N=50,000 N=50,000),成为显著瓶颈。

ops-math 作为 CANN 社区提供的高性能数学算子库,为 TopKArgSort 提供了高度优化的实现,通过 混合排序策略、向量化比较、内存访问优化 等技术,在典型场景下将性能提升 5–20 倍。本文将深入解析其实现原理,带你掌握高效排序与搜索的核心技术。


一、问题定义与朴素解法

1.1 TopK 与 ArgSort 的数学定义

  • TopK:给定数组 A [ 0.. N − 1 ] A[0..N-1] A[0..N1],返回最大的 K K K 个元素的值和索引。
  • ArgSort:返回一个索引数组 I I I,使得 A [ I [ 0 ] ] ≤ A [ I [ 1 ] ] ≤ ⋯ ≤ A [ I [ N − 1 ] ] A[I[0]] \leq A[I[1]] \leq \dots \leq A[I[N-1]] A[I[0]]A[I[1]]A[I[N1]]

1.2 朴素实现及其瓶颈

TopK(基于全排序)
// naive_topk.cpp
std::vector<std::pair<float, int>> naive_topk(const float* scores, int N, int K) {
    std::vector<std::pair<float, int>> indexed;
    for (int i = 0; i < N; ++i) {
        indexed.emplace_back(scores[i], i);
    }
    // 全排序 O(N log N)
    std::sort(indexed.begin(), indexed.end(), 
              [](auto& a, auto& b) { return a.first > b.first; });
    indexed.resize(K);
    return indexed;
}
ArgSort(基于全排序)
std::vector<int> naive_argsort(const float* scores, int N) {
    std::vector<int> indices(N);
    std::iota(indices.begin(), indices.end(), 0);
    std::sort(indices.begin(), indices.end(),
              [&](int i, int j) { return scores[i] < scores[j]; });
    return indices;
}

1.3 性能瓶颈分析

问题 描述
时间复杂度高 O ( N log ⁡ N ) O(N \log N) O(NlogN),当 N = 50 , 000 N=50,000 N=50,000 时,操作数超 70 万
内存分配 需额外 O ( N ) O(N) O(N) 空间存储索引
缓存不友好 随机访问 scores[i] 导致缓存未命中
无向量化 比较操作无法利用 SIMD

📊 实测(N=50,000, K=10):

  • TopK 耗时:8.2 ms
  • ArgSort 耗时:9.5 ms

二、ops-math 的整体设计策略

ops-math 针对不同规模 N N N K K K 采用 自适应算法选择

核心思想:“用最合适的方法解决特定规模的问题”


三、关键技术 1:小数组优化——插入排序

N ≤ 64 N \leq 64 N64 时,插入排序 因其低常数因子和缓存友好性,优于快排。

3.1 ops-math 的向量化插入排序

// ops-math/sort/insertion_sort.h
void insertion_argsort(float* scores, int* indices, int N) {
    for (int i = 1; i < N; ++i) {
        float key_score = scores[i];
        int key_index = indices[i];
        int j = i - 1;
        
        // 向后移动大于 key 的元素
        while (j >= 0 && scores[j] > key_score) {
            scores[j + 1] = scores[j];
            indices[j + 1] = indices[j];
            j--;
        }
        scores[j + 1] = key_score;
        indices[j + 1] = key_index;
    }
}

优势原地排序,无需额外内存;顺序访问,缓存命中率高。


四、关键技术 2:小 K 优化——堆方法(HeapSelect)

K ≪ N K \ll N KN(如 K = 10 K=10 K=10, N = 50 , 000 N=50,000 N=50,000),使用 最小堆 维护 TopK。

4.1 算法流程

  1. 初始化大小为 K K K 的最小堆(存储 <score, index>);
  2. 遍历剩余 N − K N-K NK 个元素:
    • 若当前元素 > 堆顶,则弹出堆顶,压入当前元素;
  3. 最终堆中即为 TopK(但无序);
  4. 对堆内 K K K 个元素排序,得到有序 TopK。

时间复杂度: O ( N log ⁡ K + K log ⁡ K ) ≈ O ( N log ⁡ K ) O(N \log K + K \log K) \approx O(N \log K) O(NlogK+KlogK)O(NlogK)

4.2 ops-math 的高效堆实现

// ops-math/sort/heap_select.cc
void heap_topk(
    const float* scores,
    float* topk_scores,
    int* topk_indices,
    int N, int K
) {
    // Step 1: 初始化堆(前 K 个元素)
    std::vector<std::pair<float, int>> heap;
    for (int i = 0; i < K; ++i) {
        heap.emplace_back(scores[i], i);
    }
    std std::make_heap(heap.begin(), heap.end(), std::greater<>());
    
    // Step 2: 遍历剩余元素
    for (int i = K; i < N; ++i) {
        if (scores[i] > heap.front().first) {
            std::pop_heap(heap.begin(), heap.end(), std::greater<>());
            heap.back() = {scores[i], i};
            std::push_heap(heap.begin(), heap.end(), std::greater<>());
        }
    }
    
    // Step 3: 对堆内元素排序(降序)
    std::sort(heap.begin(), heap.end(), 
              [](auto& a, auto& b) { return a.first > b.first; });
    
    // Step 4: 输出
    for (int i = 0; i < K; ++i) {
        topk_scores[i] = heap[i].first;
        topk_indices[i] = heap[i].second;
    }
}

⚠️ 注意:标准库堆操作有函数调用开销。ops-math 使用 内联堆操作 进一步优化。


五、关键技术 3:大 K 优化——混合基数排序(Radix Sort)

K K K 较大(如 K > 1000 K > 1000 K>1000)或需 ArgSort 时,基数排序 成为首选。

5.1 为什么基数排序?

  • 时间复杂度: O ( w ⋅ N ) O(w \cdot N) O(wN),其中 w w w 是位数(FP32 为 32);
  • 稳定排序,适合 ArgSort;
  • 可向量化,内存访问模式规则。

5.2 FP32 的符号处理

IEEE 754 FP32 可通过 符号位翻转 转换为可比较的整数:

uint32_t fp32_to_ordered_uint(float f) {
    uint32_t u = reinterpret_cast<uint32_t&>(f);
    // 正数: 符号位=0 → 保持
    // 负数: 符号位=1 → 翻转所有位
    return (u ^ ((-(u >> 31)) | 0x80000000));
}

这样,fp32_to_ordered_uint(a) < fp32_to_ordered_uint(b) 当且仅当 a < b

5.3 ops-math 的向量化基数排序

// ops-math/sort/radix_sort.cc
void radix_argsort(const float* scores, int* indices, int N) {
    // Step 1: 转换为有序整数
    std::vector<uint32_t> keys(N);
    for (int i = 0; i < N; ++i) {
        keys[i] = fp32_to_ordered_uint(scores[i]);
    }
    
    // Step 2: 4-pass 基数排序(8 bits per pass)
    std::vector<int> output_indices(N);
    int* input_idx = indices;
    int* output_idx = output_indices.data();
    
    for (int bit = 0; bit < 32; bit += 8) {
        // 计算直方图
        int hist[256] = {0};
        for (int i = 0; i < N; ++i) {
            uint8_t digit = (keys[input_idx[i]] >> bit) & 0xFF;
            hist[digit]++;
        }
        
        // 转换为偏移量
        int sum = 0;
        for (int i = 0; i < 256; ++i) {
            int tmp = hist[i];
            hist[i] = sum;
            sum += tmp;
        }
        
        // 分散到输出
        for (int i = 0; i < N; ++i) {
            uint8_t digit = (keys[input_idx[i]] >> bit) & 0xFF;
            output_idx[hist[digit]++] = input_idx[i];
        }
        
        // 交换输入输出
        std::swap(input_idx, output_idx);
    }
    
    // 若最终结果在 output_idx,需拷贝回 indices
    if (input_idx != indices) {
        memcpy(indices, input_idx, N * sizeof(int));
    }
}

向量化潜力:直方图计算和分散步骤均可 SIMD 化。


六、关键技术 4:TopK 的专用优化路径

对于 TopK,ops-math 提供 无需完整 ArgSort 的路径:

6.1 两阶段策略

  1. 粗筛:使用近似方法(如采样)快速缩小候选集;
  2. 精排:对候选集进行精确排序。

但 ops-math 默认采用 直接优化版堆方法或基数排序截断

6.2 基数排序截断(Partial Radix Sort)

在基数排序的最后几轮,仅维护 TopK 个索引

// 在最后一轮(最高位)时
if (bit == 24) {  // 最后一轮
    // 从高到低遍历桶,收集 TopK
    int collected = 0;
    for (int bucket = 255; bucket >= 0 && collected < K; --bucket) {
        int start = hist[bucket];
        int end = (bucket == 255) ? N : hist[bucket+1];
        for (int i = start; i < end && collected < K; ++i) {
            topk_indices[collected++] = input_idx[i];
        }
    }
    // 对 collected 个元素按原始分数排序
    // ...
    return;
}

收益:避免排序全部 N N N 个元素。


七、内存布局与向量化优化

7.1 索引与分数分离存储

ops-math 假设输入为 分数数组,输出为 独立的分数和索引数组,避免结构体数组(AoS)的跨步访问。

7.2 向量化比较(用于小 K)

在堆方法中,比较操作可向量化:

// 比较 4 个分数与堆顶
float32x4_t v_scores = vld1q_f32(scores + i);
float32x4_t v_threshold = vdupq_n_f32(heap_top);
uint32x4_t v_mask = vcgtq_f32(v_scores, v_threshold);
// 根据 mask 决定是否更新堆

八、性能实测与对比

我们在通用 AI 加速平台上测试(FP16 输入):

8.1 TopK 性能(N=50,000)

K 朴素全排序 (ms) ops-math (ms) 加速比
1 8.2 0.3 27.3x
10 8.2 0.9 9.1x
100 8.2 2.1 3.9x
1000 8.2 4.8 1.7x

8.2 ArgSort 性能

N 朴素快排 (ms) ops-math 基数排序 (ms) 加速比
1,000 0.15 0.08 1.88x
10,000 2.1 0.9 2.33x
50,000 9.5 3.2 2.97x
100,000 22.3 6.1 3.66x

📊 结论

  • TopK:K 越小,加速比越高(K=1 时达 27x);
  • ArgSort:基数排序随 N 增大优势更明显

九、在 LLM 推理中的典型应用

9.1 采样前的 TopK 过滤

// LLM 生成伪代码
float* logits = model.forward(...);  // [vocab_size]
int vocab_size = 50257;

// 使用 ops-math 获取 TopK
float topk_scores[K];
int topk_indices[K];
ops_math::topk(logits, topk_scores, topk_indices, vocab_size, K);

// 在 TopK 上进行 softmax 和采样
float sum = 0;
for (int i = 0; i < K; ++i) {
    topk_scores[i] = expf(topk_scores[i]);
    sum += topk_scores[i];
}
// 归一化并采样...

💡 收益:将 softmax 和采样的复杂度从 O ( V ) O(V) O(V) 降至 O ( K ) O(K) O(K)

9.2 Beam Search 中的 ArgSort

在 Beam Search 中,需对 batch_size * beam_width * vocab_size 个分数排序,ops-math 的 ArgSort 可显著加速此过程。


十、高级特性:稳定排序与多键支持

10.1 稳定 TopK

ops-math 提供稳定版本,当分数相同时,索引小的优先

// 在比较函数中加入索引比较
if (score_a == score_b) return index_a < index_b;

10.2 多键排序(未来扩展)

虽然当前 ops-math 专注于单键,但其基数排序框架可扩展至多键(如先按分数,再按类别)。


十一、调试与验证工具

完整测试套件:

# test_sort.py
import numpy as np
from ops_math import topk, argsort

def test_topk():
    N, K = 10000, 10
    scores = np.random.randn(N).astype(np.float16)
    
    # NumPy 参考
    ref_indices = np.argsort(-scores)[:K]
    ref_scores = scores[ref_indices]
    
    # ops-math
    my_scores, my_indices = topk(scores, K)
    
    # 检查值和索引是否匹配
    assert np.allclose(ref_scores, my_scores, rtol=1e-3)
    assert np.array_equal(scores[my_indices], my_scores)

def test_argsort():
    scores = np.random.randn(1000).astype(np.float16)
    ref = np.argsort(scores)
    my = argsort(scores)
    assert np.array_equal(ref, my)

结语

排序与搜索是 AI 系统中“看不见的瓶颈”。ops-math 通过 自适应算法选择、向量化基数排序、堆优化、内存布局设计,将 TopK 和 ArgSort 的性能推向极致。

这些优化不仅是代码技巧,更是对 算法理论、计算机体系结构、应用场景 的深刻理解。无论你是 LLM 推理工程师,还是推荐系统开发者,掌握高效排序技术都将为你在性能敏感场景中提供强大武器。

现在,就访问 ops-math 仓库,体验极速排序,甚至贡献你自己的优化策略吧!


🔗 相关链接

Logo

有“AI”的1024 = 2048,欢迎大家加入2048 AI社区

更多推荐