前言

在现代 AI 推理与训练系统中,通用矩阵乘法(GEMM, General Matrix Multiply)作为计算密集型操作的核心,其性能直接决定了整个模型的吞吐与延迟。CANN 开源生态中的 catlass(CANN Templates for Linear Algebra Subroutines)项目,正是为解决这一关键瓶颈而生。它并非一个传统意义上的算子库,而是一个基于模板元编程(Template Metaprogramming)构建的高性能 GEMM 算子生成框架,通过高度抽象与参数化设计,实现了“一次编写、多场景复用、极致优化”的目标。

一、catlass 的定位:不是库,而是“算子工厂”

ops-nnops-math 等提供预编译算子二进制的仓库不同,catlass 本身不直接提供可执行的算子,而是提供一套可组合、可定制的模板代码。开发者通过指定数据类型、矩阵布局、分块策略(Tiling)、融合逻辑等模板参数,即可“实例化”出针对特定 Shape 和场景优化的 GEMM Kernel。

📌 核心价值

  • 白盒化:所有计算逻辑对开发者可见,支持局部修改与调试;
  • 泛化性:一套模板覆盖 FP16/BF16/INT8/INT4 等多种精度;
  • 高性能:在定制 Shape 下可达标杆性能的 0.98~1.2 倍(见 README 性能图)。

二、整体架构:四层抽象模型

  • 用户配置层examples/):提供具体算子样例(如 102_dynamic_optimized_matmul),定义输入输出、精度、融合逻辑;
  • 调度策略层:决定如何将大矩阵划分为小块(Tiling),以及如何映射到硬件计算单元(Core/Stream);
  • 计算模板层include/catlass/gemm/):核心 GEMM 实现,包含主循环、加载/存储、累加等模板;
  • 硬件指令层include/tla/):对底层向量/矩阵指令的封装,屏蔽硬件差异。

三、模板元编程基石:类型系统与编译期计算

catlass 大量使用 C++17/20 特性,尤其是模板元编程,将运行时决策转化为编译期常量,消除分支开销。

3.1 类型特征萃取(Type Traits)

通过 using 和模板特化,定义数据类型的属性:

// include/catlass/gemm/numeric_types.hpp
template<typename T>
struct NumericTraits;

template<>
struct NumericTraits<half> {
    static constexpr int bits = 16;
    static constexpr bool is_float = true;
    using AccumulatorType = float; // 累加器类型
};

template<>
struct NumericTraits<int8_t> {
    static constexpr int bits = 8;
    static constexpr bool is_float = false;
    using AccumulatorType = int32_t;
};

这使得后续模板可根据 ElementAElementB 自动推导出 ElementAccumulator,无需手动指定。

3.2 编译期常量表达式(constexpr)

Tiling 参数(如 BLOCK_M, BLOCK_N, BLOCK_K)全部以 constexpr 形式定义,确保在编译期展开:

// examples/102_dynamic_optimized_matmul/config.hpp
struct MatmulConfig {
    static constexpr int BLOCK_M = 128;
    static constexpr int BLOCK_N = 256;
    static constexpr int BLOCK_K = 64;
    static constexpr int THREADS_PER_BLOCK = 512;
};

这些值直接用于数组大小声明、循环展开等,避免运行时计算。


四、核心 GEMM 模板:分块、流水与融合

GEMM 的核心在于高效利用片上内存(L1/L2 Cache)和计算单元。catlass 通过 分块(Blocking) + 软件流水(Software Pipelining) 实现高吞吐。

4.1 主循环模板结构

// include/catlass/gemm/kernel/gemm_kernel.hpp
template<typename GemmTraits>
__global__ void GemmKernel(
    typename GemmTraits::Params params) {

    // 1. 初始化片上内存(Shared Memory)
    __shared__ typename GemmTraits::SmemLayoutA smem_A;
    __shared__ typename GemmTraits::SmemLayoutB smem_B;

    // 2. 计算当前线程块负责的全局坐标
    int block_m = blockIdx.y * GemmTraits::BLOCK_M;
    int block_n = blockIdx.x * GemmTraits::BLOCK_N;

    // 3. 主循环:沿 K 维迭代
    #pragma unroll
    for (int k = 0; k < params.K; k += GemmTraits::BLOCK_K) {
        // 3.1 异步加载 A、B 到 Shared Memory
        LoadTile(smem_A, params.ptr_A + ..., threadIdx);
        LoadTile(smem_B, params.ptr_B + ..., threadIdx);

        __syncthreads();

        // 3.2 执行计算(调用 TLA 指令)
        ComputeTile(params.accum, smem_A, smem_B, threadIdx);

        __syncthreads();
    }

    // 4. 写回结果 C
    StoreTile(params.ptr_C + ..., params.accum, threadIdx);
}

此模板通过 GemmTraits 注入所有行为,实现高度泛化。

4.2 融合算子支持:随路量化(On-the-fly Quantization)

catlass 支持在 GEMM 计算后立即进行量化或激活函数融合,避免额外 Kernel 启动开销。例如,在 v1.3.0 中新增的 FixPipe 支持随路量化:

// include/catlass/gemm/tile/tile_copy.hpp
template<typename ElementInput, typename ElementOutput>
struct FixPipe {
    __device__ __forceinline__ void operator()(
        ElementOutput& out, const ElementInput& in) {
        // 编译期判断是否启用量化
        if constexpr (std::is_same_v<ElementOutput, int8_t>) {
            out = __float2int_rn(in * scale); // 量化
        } else {
            out = static_cast<ElementOutput>(in); // 直通
        }
    }
};

通过 if constexpr,编译器会自动剔除未使用的分支,生成最优指令流。


五、硬件抽象层:TLA 与指令封装

include/tla/(Tensor Linear Algebra)目录提供了对底层向量/矩阵指令的统一抽象,是 catlass 实现硬件无关性的关键。

5.1 矩阵乘累加指令封装

// include/tla/mma/mma_sm80.hpp (示例)
template<typename LayoutA, typename LayoutB>
__device__ __forceinline__ void mma_sync(
    FragmentC& d, const FragmentA& a, const FragmentB& b, const FragmentC& c) {
    // 映射到具体的 PTX 指令,如 mma.sync.aligned.m16n8k16.f16.f16.f16.f32
    asm volatile(
        "mma.sync.aligned.m16n8k16.f16.f16.f16.f32 "
        "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};"
        : "=f"(d.x), "=f"(d.y), "=f"(d.z), "=f"(d.w)
        : "f"(a.x), "f"(a.y), "f"(a.z), "f"(a.w),
          "f"(b.x), "f"(b.y),
          "f"(c.x), "f"(c.y), "f"(c.z), "f"(c.w)
    );
}

尽管此处以 CUDA PTX 为例示意,但在 CANN 生态中,TLA 会映射到 NPU 专用的向量/矩阵指令集,通过 pto-isa 仓库定义。

5.2 数据布局抽象

支持多种内存布局(RowMajor, ColumnMajor, NZ 等),通过模板特化实现高效访存:

template<typename Element, int M, int N>
struct Layout {
    __device__ __forceinline__ Element* get_ptr(int m, int n) {
        return data_ + m * stride_ + n;
    }
};

六、开发体验:从模板到可执行算子

catlass 提供完整的开发工具链,降低使用门槛:

  1. 快速入门docs/quickstart.md 指导编译 examples/00_basic_matmul
  2. Tiling 自动调优tools/tuner/ 提供自动搜索最优分块参数;
  3. Python 绑定examples/python_extension/ 展示如何通过 PyBind11 暴露算子给 Python。

构建脚本示例

# scripts/build.sh
cd examples/102_dynamic_optimized_matmul
mkdir build && cd build
cmake .. -DCATLASS_ROOT=../../..
make -j8
./matmul_example  # 运行

构建系统自动链接 catlass 模板头文件,无需预编译库。


七、性能与验证:精度与吞吐保障

catlass 强调“正确性优先”,所有算子均通过严格测试:

  • 精度测试tests/ 目录包含 FP16/BF16/INT8 精度对比;
  • 泛化测试:支持随机 Shape 验证;
  • 性能基线:README 中展示 vs cuBLAS/cuDNN 的性能比。

例如,在 Issue #126 修复的文档中明确要求:“所有 PR 必须通过二级冒烟测试与算子泛化测试”。


八、社区共建与未来方向

catlass 由 CANN 社区联合华南理工大学、科大讯飞等团队共同维护,采用标准开源治理流程(TSC/PMC/SIG)。未来重点包括:

  • 稀疏 GEMM 支持:已在 tests/sparse_matmul 中初步实现;
  • 动态 Shape 优化:减少 Padding 开销;
  • 更高级融合:支持 Attention、LayerNorm 等复合算子模板。

结语

CANN catlass 通过精妙的模板元编程设计,将高性能 GEMM 算子的开发从“手写汇编”的黑盒模式,转变为“参数化配置 + 白盒调试”的工程化范式。它不仅是 CANN 生态的性能基石,更是开源社区协同创新的典范——让每一位开发者都能站在巨人的肩膀上,定制属于自己的极致算子。

对于追求极致性能的 AI 系统工程师、编译器开发者或 HPC 研究者而言,深入理解 catlass 的设计哲学与实现细节,无疑是掌握现代 AI 加速器编程的关键一步。


相关链接:

Logo

有“AI”的1024 = 2048,欢迎大家加入2048 AI社区

更多推荐