模型量化详细介绍

这篇文章将深入的探索在AI领域,尤其是大语言模型(LLM)落地实践中,无法绕开的关键技术——模型量化(Model Quantization)


一、为什么需要模型量化

在开始之前,我们不妨先思考,为什么我们需要对这些强大的模型进行量化?

1.1 模型规模爆炸式增长带来的显存/算力挑战

从BERT的几亿参数,到GPT-3的1750亿,再到如今动辄万亿参数的巨无霸模型,我们见证了“大力出奇迹“的时代,更大的模型通常意味着更强的性能,着毋庸置疑。

但硬币的另一面是,它们实在是太贵了!

我们先来看显存方面,这是最直接的瓶颈,以FP32(32位浮点数)精度为例,存储10亿(1B)参数就需要 1B * 32 bits / 8 = 4G的显存。一个70亿参数的模型(比如LLaMA-7B),仅仅是加载weights就需要大约28G的显存,这还不包括推理过程中产生的中间激活值、KV Cache等。一块NVIDIA A100/H100 GPU的80G显存,在万亿参数模型面前也显得捉襟见肘。这也就是为什么很多开源社区都会优先发布量化版本,否则大多数Developer连模型都加载不进去。

其次是算力方面,模型的计算量(FLOPs)与其参数量和输入数据量大致成正比。模型越大,推理一次所需要的时间就越久,对计算芯片的要求也越高。还有带宽方面,在推理的时候,模型权重需要从显存加载到计算单元。模型越大,需要传输的数据就越多,内存带宽就越容易成为瓶颈,尤其是在访存密集型的LLM推理中。

以及能耗方面,大模型推理不仅仅是“慢”,而且“费电”。就比如GPT-3的推理,每输出一个token就需要上百毫焦的能耗,整个长文本生成下来就是几度电。这对企业来说,能耗直接转成电费成本。

除了这些方面,我们还可以往应用层面想想,很多实际应用场景并不在数据中心,而是在边缘设备(手机、车载终端,IoT等等),这些设备的内存、算力、带宽都远低于GPU服务器,如果不量化,几乎不可能部署。举个例子,苹果的CoreML、安卓的NNAPI、甚至微信小程序的AI SDK,都要求模型轻量化。

1.2 全精度(FP32/FP16)推理的局限

在深度学习早期,通常使用FP32(单精度浮点数)进行训练和推理。它可以提供很高的精度和稳定性,但每个参数占用4个字节,显存占用极高,是名副其实的“吞金兽”。那个时候为了缓解压力,业界转向了FP16(半精度浮点数)还有BF16(bfloat16),它们将每个参数的存储从4字节减少到两个字节,显存占用和计算需求直接减半。这个改进让GPT-3这样级别的模型训练成为可能,也奠定了大模型发展的基础。目前,FP16/BF16混合精度训练可以说是已经成为大模型训练的标配。

但就算是这样,随着模型规模突破千亿,FP16的红利也逐渐见顶。比如一个1750亿参数的模型,用FP16存储权重就需要约350G的显存,这远远超出了任何单卡GPU的承受范围。换句话说,FP16已经无法满足未来的大模型部署需求。我们希望能在更便宜的显卡、甚至边缘设备上运行模型,而FP16在这些场景下几乎是不可想象的。

1.3 量化的基本思想

既然16位的浮点数还是太大,我们能不能用更少的Bit来表示模型中的数值?比如8位整数(INT8),甚至4位整数(INT4)?

当然可以,这就是量化的核心思想。

量化,本质上是一种信息压缩技术。它可以把高精度的浮点数映射到一个范围更小、精度更低的整数集合中。这个过程就像我们用“优、良、中、差”去粗略地评价学生成绩,而不是使用精确到小数点后两位的分数。虽然损失了一部分信息,但我们还是可以大致了解每个学生的水平。

在常见的对称/非对称量化中,浮点数x_float和量化后的整数x_int的关系是:
x i n t = r o u n d ( x f l o a t s c a l e ) + z e r o _ p o i n t x_{int}=round(\frac{x_{float}}{scale}) + zero\_point xint=round(scalexfloat)+zero_point

  • Scale(尺度因子):一个浮点数,用于缩放。它决定了量化之后的整数所能表示的浮点数范围
  • Zero_Point(零点):一个整数偏移量,用来保证浮点数的0可以映射到某个整数值

通过这个映射关系,我们可以将模型中大量的FP32或FP16权重和激活值,转成INT8或INT4格式,从而实现模型的压缩。

这里补充个小知识点:为什么需要zero_point?

加入不用zero_point,那x_int = round(x_float / scale),这个时候x_float=0就会被强制映射到整数0。但问题是,有些量化方案要求量化后的整数必须在**非负范围(比如 [0, 255] 的uint8)**,如果数据分布是不对称的(比如[-2.3, 6.7]),那会直接让0对应整数0就会失真,因为整数是最小值不能表示负数。这个时候zero_point就派上用场了,它相当于“平移”映射关系,把浮点的0对应到整数范围中的某个值。

举个例子:

​	加入我们要把浮点数范围 [-2.3, 6.7] 映射到 [0, 255]

​	计算scale ≈ [6.7 - (-2.3)] / 255 ≈ 0.035

​	如果没有zero_point,0会被映射到0,可问题是负数 -2.3 表示不了,因为整数只能 大于等于 0

​	引入zero_point:

​		令zero_point = round (0 - (-2.3) / scale) = 66

​		这样一来 **[-2.3 -> 0] [0-> 66] [6.7 ->255]**,完美的把整个区间对齐了。

1.4 量化的收益

一旦模型被成功量化部署,它将带来立竿见影的好处:

1.显存节省:

这是最直观的收益。从 FP16 到 INT8,模型体积减半;进一步到 INT4,则缩小为原来的 1/4。举个例子:一个 70 亿参数的模型(LLaMA-7B),FP16 需要约 28GB 显存,而 INT4 量化后只需 7GB——这意味着不必依赖昂贵的 A100/H100,消费级显卡也能轻松驾驭。

2.延迟降低:

现代 GPU 和专用 AI 芯片(如 Google TPU、Apple Neural Engine)对整型计算有专门优化,INT8 算力往往远高于 FP16。同样,模型体积变小也意味着内存带宽压力减轻,数据传输更快。这两点叠加,使得量化后模型的推理速度在同等硬件环境下往往能提升 1.5~3 倍

3.能耗降低:

整型运算的能耗远低于浮点运算;数据传输量减少,也进一步降低了电力消耗。在大规模服务场景下,这种优化能为企业节省数百万美元级别的电费开支。同时,它也符合“绿色 AI”的发展趋势,减少了 AI 技术对环境的负担。

4.更广泛的部署:

量化技术让强大的 AI 模型得以落地到资源受限的边缘设备,比如手机、车载系统、IoT 节点。这也是为什么主流的移动端推理框架(CoreML、NNAPI、TensorRT Lite)都内置了量化支持。没有量化,AI 就停留在实验室;有了量化,它才能走向千家万户。

二、模型量化基础知识

刚才,我们理解了“为什么”需要量化,接下来,我们将深入探讨“是什么”。

2.1 量化分类

根据量化操作介入模型生命周期的时机,我们可以将其分为两大类:

  • 训练后量化(Post-Traning Quantization,PTQ)
  • 显化感知训练(Quantization-Aware Traning,QAT)

2.1.1 训练后量化(PTQ)

顾名思义,PTQ是在模型已经训练完成之后进行的一种量化方法。它就像是给一位已经学有所成的“大师”(预训练模型)配上一副“轻量级装备”。

流程如下:

开发者 预训练模型 (FP32) 校准数据集 (Representative) 统计分析器 量化器 量化模型 (INT8) 步骤1: 模型准备阶段 1. 获取预训练好的全精度模型 32位浮点权重 步骤2: 数据准备阶段 2. 准备有代表性的校准数据集 小样本,代表性强 步骤3: 统计分析阶段 3. 开始统计分析 获取校准数据 返回数据样本 前向传播校准数据 返回权重和激活值 记录数值分布范围 Min/Max统计 分布特征分析 步骤4: 参数计算阶段 4. 分析统计信息 计算Scale和Zero-Point 返回最佳量化参数 Scale = (max-min)/255 Zero-Point优化 步骤5: 模型转换阶段 5. 启动量化转换 获取原始权重 返回FP32权重数据 应用量化参数转换 W_int = round(W_fp32/Scale + Zero-Point) 生成INT8量化权重 模型大小压缩75% 推理速度提升 量化完成!从FP32模型成功转换为INT8量化模型 开发者 预训练模型 (FP32) 校准数据集 (Representative) 统计分析器 量化器 量化模型 (INT8)

PTQ不需要重新训练模型,流程非常高效,而且通常只需要少量无标签的校准数据。但缺点是,由于模型在训练时对量化毫不知情,量化过程中的误差可能会被累积放大,导致精度下降,尤其是在进行低比特量化(比如INT4)时,比如:INT8在很多CV任务中几乎无损,但是在LLM上可能略有性能下降。所以,PTQ更适合作为一种快速部署方案。当任务对精度要求很高的场景,比较推荐使用QAT或者是更新的量化算法(比如GPTQ、AWQ等等)。

2.1.2 显化感知训练(QAT)

相比PTQ,QAT是一种更为“主动”的策略。它在模型训练(或微调)的阶段就引入了量化操作的模拟。这就好比是从学徒阶段就开始让它们适应那套“轻量级装备”。

流程如下:

训练器 神经网络模型 (FP32) 伪量化模块 (Fake Quantization) 损失函数 优化器 训练数据 QAT训练循环开始 获取训练批次数据 返回输入数据和标签 前向传播阶段 - 伪量化过程 开始前向传播 传递FP32权重 Step1: FP32 → INT8量化 W_int8 = round(W_fp32/Scale + Zero-Point) Step2: INT8 → FP32反量化 W_fake = (W_int8 - Zero-Point) * Scale 返回伪量化后的FP32权重 使用伪量化权重继续计算 输出预测结果 损失计算阶段 计算预测值与真实标签的损失 返回损失值 损失包含量化误差影响 反向传播阶段 - 梯度感知量化误差 开始反向传播 传播梯度 梯度经过伪量化模块 梯度能感知到量化误差 引导权重向量化友好方向调整 传递调整后的梯度 传递最终梯度 参数更新阶段 更新模型权重 权重朝着量化友好的 分布方向进行调整 训练循环完成 - 获得量化鲁棒模型 重复上述过程直至收敛,权重和激活分布逐渐适应量化 导出最终训练完成的模型 移除伪量化模块,转换为真正的INT8量化模型 获得对量化误差鲁棒的INT8/INT4部署模型 训练器 神经网络模型 (FP32) 伪量化模块 (Fake Quantization) 损失函数 优化器 训练数据

这里重点解释为什么在Step1中量化之后,Step2中又要进行反量化?

要弄明白这个问题,我们首先要回顾一下量化的目的。我们想让模型最终在INT8里面跑,所以第一步肯定是把FP32压缩到INT8(W_int8)。那为什么又需要反量化呢?

问题来了,训练和推理过程中的前向计算、反向传播、梯度下降全都是在FP32(或FP16/BF16)下实现的。如果我们真的只保留INT8权重,那就很难继续训练或者模拟计算图,所以我们就需要把刚才压缩后的INT8再映射回FP32,也即:
W f a k e = ( W i n t 8 − Z e r o P o i n t ) ∗ S c a l e W_{fake}=(W_{int8}-ZeroPoint) * Scale Wfake=(Wint8ZeroPoint)Scale
注意,这里的W_fake看起来是浮点数,但已经带上了量化误差。

我们可以把它们类比成:

  • 量化:把一段音频从无损WAV(FP32)压缩成MP3(INT8)
  • 反量化:解压成WAV格式(FP32),但是音质已经损失了(有失真)

为什么要这样?因为播放器(训练框架)只能播放WAV(FP32),它不认识MP3。


QAT关键技术——伪量化

问题1:什么是伪量化?

听名字我们也大概能猜出来,伪量化是一种模拟量化的操作。

核心过程是在模型的浮点数(FP32)计算图中,插入一个 量化 + 反量化 的模块。但是在计算的时候还是用FP32的浮点运算(包括梯度计算),但数值已经掺杂了量化误差,就像真正量化后那样。这让模型在训练的时候逐渐适应这些误差,提高部署时的量化鲁棒性。

伪量化和真量化有什么区别呢?真量化是直接把权重还有激活值转成INT8(当然,可能是更低的精度),存储还有计算的时候都用整数类型,推理的时候使用整数运算单元,速度更快,内存消耗也更小,但是如果没有经过训练适应,模型的精度可能大幅下降,因为量化误差会累积。而伪量化它是只在训练时模拟,不改变底层计算类型,这样做的好处是梯度计算保持连续和精确(因为还是FP32),模型可以优化最小量化误差。

通俗来说,这是“预防针”,训练时模拟低精度环境,避免部署时“生病”(精度崩盘)。如果不做,部署后模型可能需要重新训练或手动调参,很麻烦。

问题2:这种方案可以解决模型过大导致部署困难吗?

答案是可以,但不是直接解决!

因为伪量化本身发生在训练阶段,训练时计算图还是FP32,所以训练过程的内存还有计算需求都没变(模型还是大的FP32版本)。它没有直接缩小模型大小或加速训练。

它真正的价值是在部署阶段,通过伪量化训练出来的模型,可以安全地转换成INT8版本。模型大小缩小约4倍(FP32 -> INT8),内存占用少而且推理速度快(用整数计算)。这直接解决部署困难,比如在手机、边缘设备上运行大模型时,内存不足或者是速度太慢的问题。

问题3:为什么训练时还在FP32计算图中?

因为**直接用INT8训练有技术难题,也即量化是非连续操作(round函数不可导),梯度无法正常传播,导致优化失败。**使用伪量化可以很巧妙的规避这个问题,用FP32进行计算,但装成带误差的低精度,让训练保持高效。等训练结束后,再把模型导出成INT8格式。这个时候,计算图就变成了INT8,这才真正享受到低精度好处。这里我们解释一下为什么用INT8训练会导致梯度无法正常传播:

我们知道梯度是函数在某点的导数,描述函数值随输入变化的速率,为了梯度存在,函数在该点必须可导。而round函数的性质是它把连续的浮点数砍掉小数部分映射到整数,加入图像是阶梯状的函数,那在每个整数区间内,输出恒定不变。如图所示:

在这里插入图片描述
导数分析:

  • 在阶梯区间内部,函数值恒定 ⇒ 导数 = 0
  • 在阶梯的边界点,函数跳跃 ⇒ 导数不存在

在这里插入图片描述

而神经网络训练依赖链式法则:
φ L φ x = φ L φ y φ y φ x \frac{\varphi L}{\varphi x} = \frac{\varphi L}{\varphi y} \frac{\varphi y}{\varphi x} φxφL=φyφLφxφy
如果y = round(x),那么
φ y φ x = 0 \frac{\varphi y}{\varphi x} = 0 φxφy=0
(在大部分区间)或不存在(跳跃点),导致
φ L φ x \frac{\varphi L}{\varphi x} φxφL
等于0或无法计算。间接的导致优化器收到的梯度信息几乎没用,无法有效训练。

从训练层面来看,我们在前向传播的时候已经执行了量化 + 反量化。在反向传播时,梯度会因为这些操作从而变成0或者是无法正确计算,导致模型无法有效优化(梯度消失)。如果没有解决这个问题,伪量化就无法在训练中使用(深度学习依赖梯度下降来更新参数)。

这个时候就需要用到STE这个方法了!

STE是一种梯度近似方法,最初是在二值化网络中提出来的,后来广泛用于QAT。STE的核心原理就是在前向传播的时候正常应用量化操作(引入误差),但是在反向传播的时候会忽略量化操作的不可微分部分,直接“直通”梯度。就好像量化操作是一个恒等函数,梯度直接从输出直接传到输入,而不被量化“阻挡”。数学上,量化函数Q(x)的梯度近似为:
φ L φ x ≈ φ L φ Q ( x ) ∗ 1 \frac{\varphi L}{\varphi x} \approx \frac{\varphi L}{\varphi Q(x)} * 1 φxφLφQ(x)φL1
(其中1表示直通)

简单来说,STE是伪量化的“桥梁”,伪量化依赖STE来使得整个过程可训练。没有STE,伪量化就无法工作。在框架如PyTorch的FakeQuantize模块中,内部就用到了STE。流程如下:

  • 前向:x_fp32 -> Quantize(x) -> Dequantize -> 输出 (带误差的FP32)
  • 反向:梯度从输出“直通”到x_fp32,忽略Quantize的非微分。

这让模型在FP32计算图中模拟INT8效果,同时保持训练的连续性。

但这种方法也有它的局限性,STE是一种粗糙近似,有时会导致次优优化(因为忽略了量化梯度的真实影响)。为此,我们可以考虑高级变体如LSQ或者是Gumbel-Softmax来改进。


QAT优点是精度更高,通常可以达到与原始FP16模型非常接近甚至无损的性能。但缺点是它需要完整的训练流程、代码和数据,计算成本远远高于PTQ,而且它需要有标签的训练数据集。

2.2 核心概念

2.2.1 权重 & 激活量化

  • 权重(Weight)量化指的是对模型的参数(比如卷积核、全连接层的权重矩阵)进行量化。由于权重在推理过程中是固定不变的,所以量化可以在模型部署前一次性完成,带来显著的存储压缩还有一定程度上的计算加速,而且对精度影响相对可控。
  • 激活(Activation)量化针对推理过程中产生的中间特征图。由于激活值依赖输入数据而动态变化,其量化参数往往需要依赖校准数据集进行统计,或者在运行时动态计算。激活量化可以进一步提升推理速度还有显存效率,但也更容易引入精度损失,因此是量化研究中的难点。

通常情况下,我们会选择对权重和激活值同时量化,以获得最优的加速和压缩收益。

2.2.2 对称 & 非对称量化

这个概念描述了我们如何将一个浮点数范围映射到整数范围。

量化方式 映射方式 公式 零点 (Zero-point) 适用场景 优缺点
对称量化 浮点范围 [-M, M] → 整数范围 [-127, 127] FP32 ≈ Scale × INT8 固定为 0 权重或激活值分布以 0 为中心(如 Tanh) 计算简单;但对偏移分布数据效果差
非对称量化 浮点范围 [min, max] → 整数范围 [0, 255] 或 [-128, 127] FP32 ≈ Scale × (INT8 - Zero_point) 可调,用于对齐浮点 0 分布不以 0 为中心的数据(如 ReLU) 适用性广;计算更复杂

2.2.3 Per-Tensor & Per-Channel 量化

方式 粒度 特点 优缺点 典型应用
Per-Tensor (Per-Layer) 整个权重张量只用一组 Scale 和 Zero-Point 实现简单,计算开销小 - 优点:速度快、实现容易
- 缺点:当不同通道分布差异大时,精度损失严重
边缘设备上的轻量模型,加速场景
Per-Channel (Per-Axis) 沿某个维度(通常是输出通道)独立计算 Scale 和 Zero-Point 每个通道单独适配分布 - 优点:精度更高,能适应分布差异
- 缺点:计算开销略大,实现复杂
大模型(LLM)权重量化的主流方案

2.2.4 量化位宽对比

位宽 压缩率(相对FP32) 精度风险 应用场景
INT8 低,几乎无损 工业界主流,服务器/GPU/TPU 部署
INT4 中,高精度层敏感 大模型推理加速(需 QAT 或混合策略)
INT2 16× 高,非常容易失真 学术研究,极端低比特实验
Binary(±1) 32× 极高,损失严重 特殊应用(边缘设备、定制加速芯片)

2.2.5 混合量化

策略 描述 优点 缺点 典型应用
全局统一量化 整个模型使用相同的位宽(如全 INT8) 实现简单,部署方便 某些关键层精度下降明显 小模型,资源受限场景
分层混合量化 不同层使用不同位宽(如 Embedding/输出层用 INT8,其他层用 INT4) 精度更高,兼顾存储和速度 需要额外分析和调优 LLM 推理部署(GPTQ、AWQ 等)
算子级混合量化 同一层中,不同算子或不同张量用不同位宽(如权重 INT4,激活 INT8) 最细粒度优化,效果最佳 实现复杂,对硬件要求高 前沿研究方向,硬件协同设计

2.2.6 静态量化和动态量化

静态量化是PyTorch量化里面最常见的一种方式,它的特点是:

  • 权重量化
    • 模型的权重参数在离线阶段就转成int8,存储占用大幅下降
  • 激活量化(需要校准)
    • 激活值(输入、中间输出)需要在推理的时候也变成int8,但激活的范围不像权重那样固定,所以需要事先 用一小部分数据跑一遍模型,收集激活的分布范围,得到量化参数(scale和zeropoint),这个过程就i叫做校准
  • 推理时全程int8
    • 前向计算的时候,输入、权重、激活都用int8表示,只有少量操作需要回到fp32(比如softmax、layernorm等),大多数算子可以用int8加速

举个例子:

假设你有一层卷积:

conv = nn.Conv2d(3, 16, kernel_size=3)

权重量化:conv.weight从 FP32 → INT8(离线完成)

激活量化:输入图片经过 校准数据 收集min/max → 算scale/zeropoint → 输入在推理时自动从 FP32 → INT8

卷积计算:用int8进行矩阵乘法,速度块,显存小

对比动态量化

  • 权重量化
    • 权重参数在模型转换时就直接存成int8,这是和静态量化一样的地方
  • 激活量化(动态发生)
    • 激活值在推理时才会动态计算量化参数,不需要提前用校准数据去估计,每次运行都会动态地从fp32转int8,然后再计算
  • 推理时流程
    • 输入张量保持fp32→在算子执行前量化为int8,和int8做矩阵乘法,输出的时候再转回fp32

总结如下:

对比点 静态量化 (Static) 动态量化 (Dynamic)
权重 提前量化成 int8 提前量化成 int8
激活 需要校准数据,提前定死范围 推理时动态量化,不需要校准
精度 较高 略低(因为范围是动态估的)
速度 更快(全程 int8) 较快,但激活转换开销大
适用模型 CNN、Transformer RNN/LSTM、NLP 模型

2.3 量化精度对比

精度 比特数 表示范围 (典型) 优势 劣势
FP32 32 约 ±3.4e38 精度高,范围广,训练稳定 显存/带宽占用大,计算慢
FP16 16 约 ±65504 显存/计算减半,成为训练主流 范围小,可能出现数值溢出
BF16 16 约 ±3.4e38 范围同 FP32,动态范围大,适合训练 精度相对 FP16 较低
INT8 8 [-128, 127] 或 [0, 255] 性能提升显著,硬件支持广泛,精度损失可控 对离群值敏感,PTQ 精度下降可能明显
INT4 4 [-8, 7] 或 [0, 15] 极致压缩,显存占用极低 精度损失较大,需要高级 PTQ 算法
Binary 1 {-1, 1} 或 {0, 1} 压缩比最高,可用位运算加速 精度损失巨大,目前研究为主

三、量化在主流框架中的实现

在前两节已经介绍了“为什么”和“是什么”,这一节将重点围绕当前主流的框架(PyTorch(训练与研究)、TensorRT(高性能推理部署)以及HuggingFace(大模型生态))来重点介绍“怎么做?”

3.1 PyTorch FX Graph Mode Quantization

PyTorch 的 FX 图模式量化是一种自动化的后训练量化 (PTQ) 技术,它将浮点模型 (FP32) 转换为低精度版本(通常为 INT8),以实现更快的推理、更小的模型大小和更低的内存使用。它利用 torch.fx 框架进行符号跟踪,创建模型的图表示。这允许自动处理量化逻辑,包括算子融合、观察者插入和模块替换。与 Eager 模式量化(需要手动注解)不同,FX 模式更自动化,但要求模型可符号跟踪(例如,没有动态控制流)。该过程默认针对静态量化,其中权重和激活值在使用校准数据提前量化。它的核心思想是:先将模型转化为静态计算图,再对图进行修改。这意味着我们既能在图级别做性能分析,也能在图级别做量化替换。

来看一个示例,我们可以用过FX的interpreter机制,在不改动模型源码的基础上,在每个节点上插桩,统计运行耗时:

import torch
import torch.fx
import torchvision.models as models
import statistics, tabulate, time
from typing import Any, Dict, List
from torch.fx import Interpreter

# 以resnet18模型为例
rn18 = models.resnet18()

# 现在我们已经拿到了模型,我们想要更深入地检查它的性能。也就是说,对于接下来的调用,模型的哪些部分耗时最长?
input = torch.randn(5, 3, 224, 224)

"""
接下来,我们将创建一个继承自 torch.fx.Interpreter 的类。虽然 symbolic_trace 生成的 GraphModule 会编译 Python 代码,
并在调用 GraphModule 时运行,但另一种运行 GraphModule 的方式是逐个执行计算图中的节点(Node)。这正是 Interpreter 
提供的功能:它会逐节点地解释执行计算图。

通过继承 Interpreter,我们可以重写其中的部分功能,并加入我们所需的性能分析行为。这样,我们就能得到一个对象:它可以接收模
型作为输入,执行模型一次或多次,然后统计模型整体及各个部分在运行过程中的耗时情况。
"""
class ProfilingInterpreter(Interpreter):
    def __init__(self, model: torch.nn.Module):
        gm = torch.fx.symbolic_trace(model)
        super().__init__(gm)

        # 整个模型每次的耗时
        self.total_runtime_sec: List[float] = []

        # 每个node执行所花费的时间
        self.runtimes_sec: Dict[torch.fx.Node, List[float]] = {}

    # 重写run方法 (整个模型)
    def run(self, *args):
        # 记录模型开始运行的时间
        t_start = time.time()
        # 通过Interpreter.run()运行模型
        return_val = super().run(*args)
        # 记录模型运行完成的时间
        t_end = time.time()
        # 总耗时
        self.total_runtime_sec.append(t_end - t_start)
        return return_val
    
    # 重写run_node方法(单一节点)
    def run_node(self, n: torch.fx.Node) -> Any:
        # 记录开始的时间
        t_start = time.time()
        # 交给上层Interpreter.run_node()运行
        return_val = super().run_node(n)
        # 记录结束的时间
        t_end = time.time()
        # 存储所花费的时间
        self.runtimes_sec.setdefault(n, [])
        self.runtimes_sec[n].append(t_end - t_start)
        return return_val

    # 数据归纳
    def summary(self, should_sort: bool = False) -> str:
        # 收纳每个节点的信息
        node_summaries: List[List[Any]] = []
        # 计算整个模型的平均运行时间
        mean_total_runtime = statistics.mean(self.total_runtime_sec)

        # 统计每个节点的时间信息
        for node, runtimes in self.runtimes_sec.items():
            # 计算每个节点的平均运行时间
            mean_runtime = statistics.mean(runtimes)
            # 转成百分比的形式
            pct_total = mean_runtime / mean_total_runtime * 100
            node_summaries.append(
                [node.op, str(node), mean_runtime, pct_total]
            )

        # 为了方便查看哪个节点耗时最久,我们可以选择进行排序
        if should_sort:
            node_summaries.sort(key=lambda s: s[2], reverse=True)

        # 适用tabulate库创建一个可视化的表格
        headers: List[str] = [
            'Op type', 'Op', 'Average runtime (s)', 'Pct total runtime'
        ]
        
        return tabulate.tabulate(node_summaries, headers=headers)
    
interp = ProfilingInterpreter(rn18)
interp.run(input)
print(interp.summary(True))

运行结果(截取一部分):

Op type        Op                       Average runtime (s)    Pct total runtime
----------------------------------------------------------------------------------
call_module    maxpool                          0.00910187             7.1365
call_module    conv1                            0.00890756             6.98415
call_module    layer1_1_conv2                   0.00760531             5.9631
call_module    layer1_0_conv1                   0.00600076             4.70501
call_module    layer4_0_conv2                   0.00600004             4.70445
call_module    layer1_1_conv1                   0.00582886             4.57023
call_module    layer2_0_conv2                   0.00528502             4.14383
call_module    layer2_1_conv1                   0.00464106             3.63891
call_module    layer3_0_conv2                   0.00403881             3.16671
call_module    layer2_0_conv1                   0.00402904             3.15905
call_module    layer1_0_bn2                     0.0037961              2.97641
call_module    layer2_1_conv2                   0.00358534             2.81116
call_module    layer3_1_conv1                   0.00308013             2.41504
call_module    layer3_0_conv1                   0.00307703             2.41261
call_module    layer4_1_conv2                   0.00305629             2.39634
call_module    layer4_0_conv1                   0.00304508             2.38756
call_module    layer4_1_conv1                   0.00301886             2.36699

[!CAUTION]

我们运行多次,发现每次都是maxpool的耗时最长,这是为什么?

因为Conv在PyTorch里面一般调用的是高度优化过的cuDNN内核(GPU)或者是MKL/DNNL内核(CPU),所以速度很快。但Maxpool不一样,它的运算虽然看起来简单,无非就是取窗口最大值,但是它没有太多可并行的矩阵乘法结构,不想Conv那样可以直接用GEMM/FFT优化。而且Maxpool的内存访问模式是不规则的,意思就是它每次取最大值都要从不同的位置去读取数据,这就导致cache利用率差,很容易受限于内存带宽,最后结果就是GPU算力根本用不上,反而卡在数据搬运还有比较操作上。

3.1.1 QuickStart

我们再来看一个典型的PTQ流程:

import torch
from torch.ao.quantization import get_default_qconfig
from torch.ao.quantization.quantize_fx import prepare_fx, convert_fx
from torchvision.models import mobilenet_v2

# 1. 准备模型和数据
model = mobilenet_v2(pretrained=True).eval()
# calibration_data 是一个用于校准的 DataLoader
calibration_data = ... 

# 2. 配置量化参数
# 使用默认的 'fbgemm' 后端配置,适用于 x86 CPU
qconfig = get_default_qconfig("fbgemm")
qconfig_mapping = {"": qconfig}

# 3. 准备模型:嵌入 Observer
model_to_quantize = prepare_fx(model, qconfig_mapping, example_inputs=(torch.randn(1, 3, 224, 224),))

# 4. 校准模型
# 将校准数据喂给模型,让 Observer 收集信息
with torch.no_grad():
    for images, _ in calibration_data:
        model_to_quantize(images)

# 5. 转换模型:真正的量化
# convert_fx 会利用 Observer 收集到的信息,将模型转换为真正的量化模型
# 比如,将 nn.Linear 替换为 nn.quantized.Linear
quantized_model = convert_fx(model_to_quantize)

接下来,我们把这部分代码拆开来看,它们底层到底做了什么?怎么完成量化的?

第一步、准备模型和数据

模型以评估方式加载(eval()),以禁用训练特定行为,比如dropout或批准化更新。这个时候还没有开始量化,这还只是基线FP32模型。校准数据是一个数据集(比如:验证数据的子集),代表典型输入。从内部来看,PyTorch期望这是一个 DataLoader ,产生批量的张量(比如,对于MobileNetV2, 图像形状是 [batch_size, 3, 224, 224])。这些数据对于静态量化至关重要,因为它能够帮助估计激活范围。

第二步、配置量化参数

get_default_qconfig(“fbgemm”) (默认是x86)返回一个QConfig对象,这个对象内部定义了激活值还有权重的量化方案。对于fbgemm(Facebook通用矩阵乘法后端),底层源码:

torch.ao.quantization.qconfig.py -> def get_default_qconfig(...):
    
if version == 0:
        if backend == 'fbgemm':
            qconfig = QConfig(activation=HistogramObserver.with_args(reduce_range=True),
                              weight=default_per_channel_weight_observer)
                              
default_per_channel_weight_observer = PerChannelMinMaxObserver.with_args(
    dtype=torch.qint8, qscheme=torch.per_channel_symmetric
)

class HistogramObserver(UniformQuantizationObserverBase):
    r"""
    The module records the running histogram of tensor values along with
    min/max values. ``calculate_qparams`` will calculate scale and zero_point.

    Args:
        bins: Number of bins to use for the histogram
        upsample_rate: Factor by which the histograms are upsampled, this is
                       used to interpolate histograms with varying ranges across observations
        dtype: dtype argument to the `quantize` node needed to implement the
               reference model spec
        qscheme: Quantization scheme to be used
        reduce_range: Reduces the range of the quantized data type by 1 bit
        eps: Epsilon value for float32, Defaults to `torch.finfo(torch.float32).eps`.

    The scale and zero point are computed as follows:

    1. Create the histogram of the incoming inputs.
        The histogram is computed continuously, and the ranges per bin change
        with every new tensor observed.
    2. Search the distribution in the histogram for optimal min/max values.
        The search for the min/max values ensures the minimization of the
        quantization error with respect to the floating point model.
    3. Compute the scale and zero point the same way as in the
        :class:`~torch.ao.quantization.MinMaxObserver`
    """

I、权重观察者:PerChannelMinMaxObserver(每通道量化,以在卷积/线性层中获得更好的准确性)

源码解释:

Observer module for computing the quantization parameters based on the running per channel min and max values.

This observer uses the tensor min/max statistics to compute the per channel quantization parameters. The module records the running minimum and maximum of incoming tensors, and uses this statistic to compute the quantization parameters.

​ 这个模块干嘛用的?

​ 核心作用:它像一个”数据监视器“,在模型处理数据时,偷偷记录数据的range(也就是最小值和最大值),然后用这些信息来 计算”量化参数“

​ 为什么叫Observer?因为它不会改变数据,只是在”观察“和”记录“传入的数据(这些数据叫”张量“,你可以想象成多维数组,比 如图像或矩阵)。它就像一个旁观者,边看边记笔记。

​ 展开解释详细流程:

​ 1.记录运行中的最小/最大值:

​ 当模型在运行的时候(比如训练或者是测试的时候),数据(张量)会一波波地进来。

​ 这个模块会实时跟踪这些数据的”运行最小值“和”运行最大值“。”运行“意思是它不是一次性看所有数据,而是边跑边更新 记录(比如用移动平均或简单累积的方式)

​ 重点是”每通道“(per channel):数据往往有多个“通道”,比如图像用rgb三个通道,或者Neural Network层有多个特 征通道。这个模块会单独为每个通道计算最小/最大值,而不是一股脑儿混在一起。这样更精确,因为不同通道的数据 范围可能不一样。

​ 2.用统计信息计算量化参数:

​ 这个模块会用记录下来的min(最小值)和max(最大值)作为”统计信息“

​ 然后,根据这些信息,算出量化参数。具体来说,量化参数包括”缩放比例(scale)和零点偏移(zeropoint)“。关于 (这两个参数上面已经解释过了,这里不做过多的介绍)。计算公式和MinMaxObserver一致。

​ 为什么需要这个模块?

​ 在模型量化时,直接压缩数据可能会丢失精度,导致模型变差。这个Observer通过”观察数据范围“来灵活的调整参数,让压 缩更温和。

II、激活观察者:HistogramObserver

通俗点来解释的话,这个观察者就像一个”数据分布分析师“。它不仅仅记录min和max,而是构建数据的直方图(histogram,统计每个数值区间的出现频率)

为什么用直方图呢?激活值往往分布不均匀(比如很多值集中在0附近,少数极端值),简单用min/max可能会被异常值影响,导致量化参数不准。优点是可以减少量化误差,尤其在激活值动态范围大的时候,可以让量化后模型准确率更高。缺点是计算比较慢(需要维护直方图),但在校准阶段影响不大。直方图可以捕捉整体分布(比如用percentile,如0.1% 和 99.9%的值),计算更精确的scale和zeropoint。它的计算和MinMaxObserver是一致。详细计算步骤:

1.创建传入输入的直方图:

干啥:当数据一批批进来时,这个观察者会构建一个直方图。直方图把数据分成很多“小桶(bin)”,每个桶统计落在某个数值范围内的数据有多少。

关键点:这个直方图不是一次性建好的,而是连续计算的,也就是说边跑边更新。每当有新数据进来,直方图就会动态调整每个桶范围(ranges per bin 会变化),让它更准确地反映整体数据分布。

通俗比喻:想象一下你在数一堆硬币的面值,不是一次性全看完,而是边看边记。开始桶可能是“0-10元” “10-20元”,但如果后面来了很多大面值,桶就会自动调整成“0-50元” “51-100元”,以更好地覆盖所有情况。这样一来,直方图就越来越准。

2.在直方图分布中搜索最优的min/max值:

干啥:用建好的直方图,搜索数据分布,找出最佳的min和max。不是简单取极端值,而是优化选择,确保量化后的误差最小(相对于原浮点模型)

咋搜:它会分析直方图,看哪里数据最集中,忽略少数异常值(比如一个超大或超小的值)。目标是让压缩后的数据尽可能保留原时信息的“精华”,最小化误差

通俗比喻:直方图像一张地图,显示数据人群在哪里密集。你不选地图的最边边角角,而是找一个可以覆盖 99% 人群的最佳边界。这样,压缩的时候就不会因为一两个离群者而浪费资源,让整体模型效果接近原版。

3.用和MinMaxObserver一样的方式计算scale和zeropoint

干啥:一旦找到了min和max,就用这些值去计算scale还有zeropoint。计算公式和MinMaxObserver完全相同:

为什么一样:前两步是HistogramObserver的优化方案,但最后计算参数的部分借用了MinMaxObserver的简单公式。这结合了复杂分析和简单计算

III、MinMaxObserver(每张张量量化)

源码解释:

Observer module for computing the quantization parameters based on the running min and max values.

This observer uses the tensor min/max statistics to compute the quantization parameters. The module records the running minimum and maximum of incoming tensors, and uses this statistic to compute the quantization parameters.

这个模块的的核心原理、用途还有计算方法都和 PerChannelMinMaxObserver 类似。这里重点介绍它们的不同点:

MinMaxObserver:

  • 计算范围:针对”整个张量“作为一个整体。也就是说它不管你的数据有多少维度或通道,它只算一个全局的min和max,然后得出一个共享的scale和zeropoint。

  • 适用场景:适合激活值的量化,因为激活值通常是per-tensor的(所有通道共享参数)。简单而且计算还快,但缺陷是如果数据在不同通道差异范围大,可能会导致量化不准(比如某个通道数据很大,其它的却很小,那压缩之后通道细节就丢的多)

  • 通俗比喻:想象一下我们用一个篮子装水果,我们把各种各样的水果都往一个篮子里面筛,如果水果都一样,那还好一点,但各种各样的水果就不一样了,那其中一些水果被别的类型的水果挤压坏的可能性是要高于前者的。MinMaxObserver的计算范围相当于就是所有水果用同一个”挤压规则“

  • 优点是实现起来简单,内存占用比较少。但缺点是对通道间差异大的数据,准确性可能差。

  • 计算公式:

    • 运行最小/最大值的计算(running min/max)

      x_min(运行最小值):

      如果x_min初始为None:x_min = min(X) (X表示输入的张量)

      否则:x_min = min(当前 x_min, min(X))

      x_max(运行最大值):

      如果x_max初始为None:x_max = max(X) (X表示输入的张量)

      否则:x_max = max(当前 x_max, max(X))

      通俗解释:这是个”累计更新“过程。第一次看到数据的时候,直接取X的min/max。以后每次数据进来,就比较更新,确保running min是所有见过数据的最小,running max是最大。像在记录历史最低/最高温度。

      例子计算:

      ​ X = [-5.0, -3.0, 0.0, 2.0, 4.0]

      ​ 初始 running_min = None,所以 running_min = min(X) = -5.0

      ​ 初始 running_max = None,所以 running_max = max(X) = 4.0

      ​ 如果再来新 X(如 [ -6.0, 3.0 ]),则 running_min = min(-5.0, -6.0) = -6.0;running_max = max(4.0, 3.0) = 4.0

    • scale(s)和 zeropoint (z) 的计算

      对称量化
      s = 2 ∗ m a x ( ∣ x m i n ∣ , x m a x ) Q m a x − Q m i n s =\frac{2 * max(|x_{min}|, x_{max})}{Q_{max} - Q_{min}} s=QmaxQmin2max(xmin,xmax)
      z = 0(如果dtype是qint8),否则z = 128(对于其它类型,比如quint8)

      非对称量化
      s = x m a x − x m i n Q m a x − Q m i n s = \frac{x_{max} - x_{min}}{Q_{max} - Q_{min}} s=QmaxQminxmaxxmin

      z = Q m i n − r o u n d ( x m i n s ) z = Q_{min}-round(\frac{x_{min}}{s}) z=Qminround(sxmin)

      这里的Q_min和Q_max是量化数据类型的最小/最大值(比如qint8: Q_min = -128, Q_max = 127)

      通俗解释:

      对称量化:假设数据对称分布(正负均衡),用范围的最大绝对值计算 s,让零点固定在中间(z=0 或 128)。适合权重,常用于简化计算。

      非对称量化:不假设对称,用全范围计算 s,然后用 round(四舍五入)调整 z 来偏移零点。适合激活值,能更好地处理偏斜数据。

      整体:s 决定压缩比例(大范围数据挤到小整数范围),z 调整“零”的位置,避免负值问题。Q_max - Q_min 通常是 255(对于 8-bit)。

      例子计算(用上面 running_min = -5.0, running_max = 4.0;Q_min = -128, Q_max = 127):

      对称量化:

      • s = 2 * max(|-5.0|, 4.0) / (127 - (-128)) = 2 * 5.0 / 255 ≈ 10 / 255 ≈ 0.0392
      • z = 0 (qint8)

      非对称量化:

      • s = (4.0 - (-5.0)) / 255 = 9.0 / 255 ≈ 0.0353
      • z = -128 - round(-5.0 / 0.0353) = -128 - round(-141.64) = -128 - (-142) = 14 (注意:round(-141.64) ≈ -142,因为 Python round 向偶数靠,但这里近似)

PerChannelMinMaxObserver:

  • 计算范围:针对“每个通道”分开计算。数据张量往往有多个通道(比如图像的 RGB 通道,或神经网络层的多个特征通道),它会为每个通道单独记录 min 和 max,然后每个通道得出自己的 scale 和 zero_point。
  • 适用场景:特别适合权重的量化,尤其是Conv或者是线性层,因为权重通道间差异大(per-channel可以更好的保留细节)。在PyTorch的默认配置中,权重常用这个,激活值常用MinMaxObserver。
  • 通俗比喻:还是举我们刚才装水果的例子,PerChannelMinMaxObserver 就像在篮子内部给每种水果隔离出了一个单独的适合它的存储空间,根据水果的特性去存储。
  • 优点是更精确,尤其在深度模型中,能减少量化误差,提高模型准确率。缺点是计算复杂,参数更多(每个通道一个 scale/zeropoint),内存占用更大。

我们可以用下面的示例代码来更好的感受一下它们的区别:

import torch
from torch.quantization.observer import MinMaxObserver, PerChannelMinMaxObserver

data = torch.tensor([[[1.0, 2.0], [3.0, 4.0]], [[10.0, 20.0], [30.0, 40.0]]])  # 通道1小,通道2大

# MinMaxObserver: 整体 min=1, max=40 → 一个scale
obs1 = MinMaxObserver()
obs1(data)
print(obs1.calculate_qparams())  # 输出一个scale和zero_point

# PerChannelMinMaxObserver: 通道1 min=1 max=4, 通道2 min=10 max=40 → 两个scale
obs2 = PerChannelMinMaxObserver(ch_axis=0)
obs2(data)
print(obs2.calculate_qparams())  # 输出两个scale和zero_point

输出结果:

(tensor([0.1569]), tensor([0], dtype=torch.int32))	# MinMaxObserver
(tensor([0.0157, 0.1569]), tensor([0, 0], dtype=torch.int32)) # PerChannelMinMaxObserver

我们可以看到MinMaxObserver只给了一个参数对,而PerChannelMinMaxObserver给了多个(按通道数)。

在PyTorch文档中建议:权重 per-channel,激活 per-tensor,这是默认最佳实践。

IV、量化方案:激活值为仿射(非对称,适用缩放和零点),权重为对称

V、数据类型:torch.qint8(有符号int8)或torch.quint8(无符号)


有关get_default_qconfig的到这里结束。接着看代码:
在这里插入图片描述
qconfig_mapping是一个字典,将模块名称或类型映射到QConfig。这里 {“”: qconfig} 在全局应用默认配置(空字符串表示所有模块)

底层机制:这个配置将会影响后续的Observer(统计收集器)的嵌入方式。后端(fbgemm)确保与 x86 硬件加速兼容。

prepare_fx中所接收的qconfig_mapping在源码中的体现(这里出于简便考虑,使用的是字典类型的,但在实际中,建议使用QConfigMapping,可读性更高)

第三步、准备模型,嵌入观察者

在这里插入图片描述
源码中关于prepare_fx的解释:

Prepare a model for post training quantization

内部工作流程:

prepare_fx 会基于 torch.fx 的 symbolic_trace,用 example_inputs 跟踪模型的前向计算图,得到一个 GraphModule。

源码中关于 example_inputs的解释:

example_inputs (Tuple[Any, …]): Example inputs for forward function of the model, Tuple of positional args (keyword args can be passed as positional args as well).

在这个图中,它会根据 qconfig_mapping 识别可量化的算子模式(如 Conv、Linear),并在相应位置插入 Observer


我们可以追踪源码,查看底层干了些啥工作(注意:这里所讲解的源码只包含model还有qconfig_mapping这两个参数的底层执行)?

首先,通过追踪内部prepare_fx方法,我们发现它底层是调用了_prepare_fx【Internal helper function for prepare_fx】方法,在这个方法的内部,它先调用了这个函数:

torch.ao.quantization.quantize_fx.py -> _prepare_fx(...):

_swap_ff_with_fxff(model)

def _swap_ff_with_fxff(model: torch.nn.Module) -> None:
    r""" Swap FloatFunctional with FXFloatFunctional
    """
    modules_to_swap = []
    for name, module in model.named_children():
        if isinstance(module, torch.ao.nn.quantized.FloatFunctional):
            modules_to_swap.append(name)
        else:
            _swap_ff_with_fxff(module)

    for name in modules_to_swap:
        del model._modules[name]
        model._modules[name] = torch.ao.nn.quantized.FXFloatFunctional()

在这个函数的内部它会遍历模型的所有子模块,如果当前子模块是FloatFunctional类型,则将它记录下来,反之则递归进入子模块继续搜索,最后,再将找到的FloatFunctional 替换成 FXFloatFunctional.

什么是FloatFunctional?在PyTorch里面,有一些算子没有独立的 nn.Module形式,比如:add、mul、cat,这些通常会被包装到torch.ao.nn.quantized.FloatFunctional里面,用来”寄存“运算,比如:

self.ff = torch.ao.nn.quantized.FloatFunctional()
out = self.ff.add(x, y)

这样做是为了让量化流程可以追踪到这些操作,否则它们只是裸函数调用(没发插Observer)

FxFloatFunctional又是啥?在FX(torch.fx)量化里面,模型会被符号追踪(symbolic tracing),但是普通的FloatFunctional无法追踪,因为它们并不是标准的 nn.Module 算子节点

所以PyTorch提供了一个专门的替代品:FxFloatFunctional,它的行为和API与FloatFunctional 是一样的,但内部实现是FX友好的,能在FX Graph里面正确生成节点。


接着往下,它执行了这步操作:

torch.ao.quantization.quantize_fx.py -> _prepare_fx(...):

preserved_attr_names = prepare_custom_config.preserved_attributes
preserved_attrs = {attr: getattr(model, attr) for attr in preserved_attr_names if hasattr(model, attr)}

看起来,它好像是从模型里面挑一些需要保留的属性,保存到一个字典里面。为什么要做这步操作?因为在 prepare_fx/convert_fx的过程中,模型会被FX trace 成一个新的GraphModule,这个过程中一些原始模块的属性可能会丢失。所以PyTorch用 preserved_attrs 机制,把指定的属性(比如qconfigactivation_post_process等量化相关配置)单独拎出来保存。后面再把这些属性赋回 FX GraphModule,保证量化需要的信息不丢失。


继续往下走:

torch.ao.quantization.quantize_fx.py -> _prepare_fx(...):

graph_module = GraphModule(model, tracer.trace(model))

@compatibility(is_backward_compatible=True)
    def trace(
        self,
        root: Union[torch.nn.Module, Callable[..., Any]],
        concrete_args: Optional[Dict[str, Any]] = None,
    ) -> Graph:
        """
        Trace ``root`` and return the corresponding FX ``Graph`` representation. ``root``
        can either be an ``nn.Module`` instance or a Python callable.

        Note that after this call, ``self.root`` may be different from the ``root`` passed
        in here. For example, when a free function is passed to ``trace()``, we will
        create an ``nn.Module`` instance to use as the root and add embedded constants
        to.


        Args:

            root (Union[Module, Callable]): Either a ``Module`` or a function to be
                traced through. Backwards-compatibility for this parameter is
                guaranteed.
            concrete_args (Optional[Dict[str, any]]): Concrete arguments that should
                not be treated as Proxies. This parameter is experimental and
                its backwards-compatibility is *NOT* guaranteed.

        Returns:

            A ``Graph`` representing the semantics of the passed-in ``root``.
        """

@compatibility(is_backward_compatible=True)
class GraphModule(torch.nn.Module):
    """
    GraphModule is an nn.Module generated from an fx.Graph. Graphmodule has a
    ``graph`` attribute, as well as ``code`` and ``forward`` attributes generated
    from that ``graph``.

    .. warning::

        When ``graph`` is reassigned, ``code`` and ``forward`` will be automatically
        regenerated. However, if you edit the contents of the ``graph`` without reassigning
        the ``graph`` attribute itself, you must call ``recompile()`` to update the generated
        code.
    """

这步操作的意义是什么?为什么要做这步操作?再解答这个问题之前,我们需要先回顾一下PyTorch的FX框架的核心思想,它就是为了能把 nn.Module 里面的计算逻辑“抽取”成一个可操作的计算图,这样后续就可以对图中的节点进行修改、插入、替换,最后再生成一个新的 nn.Module。在torch.ao.quantization.quantize_fx.py这个文件中,它需要做的事:

  • 获取模型的计算图(trace阶段)
  • 在计算图中插入量化/反量化节点
  • 重新生成一个可执行的模型类

有了这些知识后,我们再来拆解这行代码:

torch.ao.quantization.quantize_fx.py -> _prepare_fx(...):

graph_module = GraphModule(model, tracer.trace(model))

我们先来看tracer.trace(model),这里的tracer是一个FX的Tracer对象,它会运行一遍你的模型,但是输入的张量不是普通张量,而是特殊的Proxy对象。Proxy对象会记录下算子调用,比如:

x = torch.relu(x)
y = self.conv(x)

这些操作不会真的执行,而是被记录成Graph节点,结果就是生成了一个Graph对象,描述了模型的前向计算逻辑。

在源码中关于Proxy对象的解释:

``Proxy`` objects are ``Node`` wrappers that flow through the
    program during symbolic tracing and record all the operations
    (``torch`` function calls, method calls, operators) that they touch
    into the growing FX Graph.

    If you're doing graph transforms, you can wrap your own ``Proxy``
    method around a raw ``Node`` so that you can use the overloaded
    operators to add additional things to a ``Graph``.

    ``Proxy`` objects cannot be iterated. In other words, the symbolic
    tracer will throw an error if a ``Proxy`` is used in a loop or as
    an ``*args``/``**kwargs`` function argument.

    There are two main ways around this:
    1. Factor out the untraceable logic into a top-level function and
    use ``fx.wrap`` on it.
    2. If the control flow is static (i.e. the loop trip count is
    based on some hyperparameter), the code can be kept in its original
    position and refactored into something like::

        for i in range(self.some_hyperparameter):
            indexed_item = proxied_value[i]

我们简单的来总结一下:

Proxy对象的作用:

Proxy是Node的封装,在符号化追踪(symbolic tracing)过程中,它会在程序里面“流动”。每当Proxy遇到一个操作(比如torch函数调用、方法调用、运算符操作),它就会把这个操作记录到正在构建的FX Graph里面。

但是要注意的是Proxy对象不能被迭代,也就是说它不能写在for循环里面,也不能当作 *args 或 **kwargs 的参数,如果这么用,符号追踪会报错。

意义:

Proxy实际上就是桥梁:

  • 对外:在Python代码里面看起来“还能跑”,但其实没跑
  • 对内:它偷偷记录所有操作到Graph里面,供后续GraphModule使用

举个例子:

我们写一个简单的forward函数

def forward(self, x):
	return torch.relu(self.linear(x))

在trace的过程中:

输入 x 会变成Proxy(node1)

调用 self.linear(x) 时,不会真的执行卷积,而是:

  • 在Graph里面新增一个 call_module 节点(node2)
  • 返回值是 Proxy(node2)

再调用torch.relu(node2)时:

  • 在Graph里面新增一个 call_function 节点(node3)
  • 返回值是Proxy(node3)

最后Graph会长这样:

x (placeholder)
↓
linear (call_module)
↓
relu (call_function)
↓
output (return)

有关Proxy对象的更多信息,可以去PyTorch的官网,找这篇文档:torch/fx/OVERVIEW.md

继续看这步操作

graph_module = GraphModule(model, graph)

GraphModule是一个继承自nn.Module的特殊类,它接收一个Graph,并且会自动生成:

  • graph属性(计算图本身)
  • forward方法(由graph自动转译成可执行Python代码)
  • code属性(forward方法的源代码字符串)

这样我们就拿到了一个新的可执行模型类,但它的forward不是我们自己手写的,而是从 Graph 自动编译出来的。

forward自动编译?怎么做到的?

我们来看看GraphModule中的源码:

torch.fx.graph_module.py.GraphModule.graph(...):

@graph.setter
def graph(self, g : Graph) -> None:
    """
    Set the underlying ``Graph`` for this ``GraphModule``. This will internally
    recompile the ``GraphModule`` so that the generated ``forward()`` function
    corresponds to ``g``
    """
    assert isinstance(g, Graph), f'Expected a Graph instance, but got {type(g)}'
    self._graph = g
    g.owning_module = self
    self.recompile()

首先,它会检查传进来的g一定是FX定义的Graph对象,如果不是,立刻抛出错误,避免后面用错类型。

接着再把传进来的Graph保存到GraphModule内部属性_graph,GraphModule的执行逻辑完全依赖_graph,所以这里就是正式替换掉原来的计算图。

  • g.owning_module = self

这行代码实际上就是建立双向绑定,因为Graph里面有一个属性owning_module,用来指向“这个图属于哪个GraphModule"。这样在处理节点的时候,可以通过 node.graph.owning_module直接访问到外层的GraphModule,比如量化或剪枝的时候,修改Graph之后,还能知道它属于哪个模块。

重点是recompile()这个函数

torch.fx.graph_module.py.GraphModule.recompile(...):

@compatibility(is_backward_compatible=True)
def recompile(self) -> PythonCode:
    if isinstance(self._graph._codegen, _PyTreeCodeGen):
        self._in_spec = self._graph._codegen.pytree_info.in_spec
        self._out_spec = self._graph._codegen.pytree_info.out_spec
    python_code = self._graph.python_code(root_module='self')
    self._code = python_code.src

    cls = type(self)
    cls.forward = _forward_from_src(self._code, python_code.globals)

    cls_call = cls.__call__ if "__call__" in vars(cls) else None

    if '_wrapped_call' not in vars(cls):
        cls._wrapped_call = _WrappedCall(cls, cls_call)  # type: ignore[attr-defined]

    def call_wrapped(self, *args, **kwargs):
        return self._wrapped_call(self, *args, **kwargs)

    cls.__call__ = call_wrapped

    return python_code

这里就不一行一行代码过了,这里我们直接进入到关键代码 self._graph.python_code()

torch.fx.graph.py.Graph.python_code(...):

@compatibility(is_backward_compatible=True)
def python_code(self, root_module: str, *, verbose: bool = False) -> PythonCode:

我们先来拆解一下核心流程:**”起名 → 临时改repr → 逐节点生成代码 → 打包返回“**这四个步骤,接下来看看这四个步骤在源码中怎么实现的:

第一步:建立一个全新的命名空间 (_Namespace())

torch.fx.graph.py.Graph.python_code(...):
    
	namespace = _Namespace()

因为生成Python源码的时候会出现两类名字:局部变量名(每个Node的结果)和全局变量名(比如torch、operator、某些函数对象)

需要确保:

  • 唯一性:不重复、不遮蔽
  • 一致性:同一个对象每次引用用的都是同一个名字

(如果对这个不了解的,建议学下软件工程)

而且在注释里面也解释了:不能直接复用 node.name,因为它是在旧的命名空间里面生成的,为了让”局部 + 全局“都能统一管理,必须整一个全新命名空间。

第二步:临时覆盖Node的repr,让它产生”合法变量名“

torch.fx.graph.py.Graph.python_code(...):

	def node_repr(n: Node):
            return namespace.create_name(n.name, n)

    @contextmanager
    def override_node_repr(graph: Graph):
        orig_repr_fns = {}
        for node in graph.nodes:
            orig_repr_fns[node] = node._repr_fn
            node._repr_fn = node_repr
        try:
            yield None
        finally:
            # restore the original repr functions
            for node in graph.nodes:
                node._repr_fn = orig_repr_fns[node]

生成源码的时候,FX会频繁用repr(node)来把”某个节点的值“ 插入到代码字符串里面。

这里通过上下文管理器,把每个Node_repr_fn暂时替换成我们定制的 node_repr:

  • node_repr会向刚才的namespace要一个合法且唯一的变量名(比如把x变成x_1, x_2 …)

这样在任何地方调用 repr(node),都会得到一个可直接放进Python源码的变量名。

退出 with 块后,会把原本的 _repr_fn 恢复(不污染外部)

这一步就相当于怎么把Node变成代码里面的变量名这件事,统一交给了namespace管控

第三步:进入代码生成器,按拓扑顺序把Gtaph变成Python源码

	torch.fx.graph.py.Graph.python_code(...):

	with override_node_repr(self):
            return self._python_code(root_module, namespace, verbose=verbose)

torch.fx.graph.py.Graph._python_code(...):        

def _python_code(self, root_module: str, namespace: _Namespace, *, verbose: bool = False) -> PythonCode:
    return self._codegen._gen_python_code(self.nodes, root_module, namespace, verbose=verbose)

真正把Graph→源码字符串的工作,是交给 self.codegen_gen_python_code(…) 【这个函数的源码就不看了,太多了,这里总结一下】完成的,这一步会按图中节点顺序(拓扑序)依次吐出Python代码行,并同时准备好执行这段代码所需的globals字典。

常见的Node类型到源码的映射大致是:

  • placeholder: 生成forward(self, ...)的函数签名(包含位置/关键字参数、默认值等)
  • call_function:生成形如 v = some_func(arg1, arg2, **kw)
  • call_method:生成 v = obj.method(arg1, ...)
  • call_module:生成 v = self.submod(arg1, ...),其中 submod 是通过 root_module(通常是 'self')+ 目标的限定名 找到的子模块
  • get_attr:生成 v = self.some_attr,从根模块上取参数、buffer 或常量
  • output:生成 return ...(若是 pytree 输出,会由 _PyTreeCodeGen 额外处理展开/重构)

还有,codegen 会把源码里面用到的外部对象(比如torch、operator等)放入PythonCode.globals里面,这样后续exec的时候就可以正常解析这些名字

第四步:返回PythonCode(src, globals)

python_code() 返回的是 PythonCode 对象,包含:

  • src:完整的 forward 源码字符串
  • globals:执行这段源码需要的全局符号表(名字 → 对象)

之后在 GraphModule.recompile() 中:

torch.fx.graph_module.py.GraphModule.recompile(...):

python_code = self._graph.python_code(root_module='self')
self._code = python_code.src
cls.forward = _forward_from_src(self._code, python_code.globals)

会用 _forward_from_src(内部本质上是一次受控的 exec)把 src 编译成真正的 forward 函数对象,并挂到类上

简单示例:

import torch
import torch.fx as fx
import torch.nn as nn

class MyModule(nn.Module):
    def __init__(self):
        super().__init__()
        self.linaer = nn.Linear(6, 4)

    def forward(self, x):
        return torch.relu(self.linaer(x))
    
m = MyModule()
traced = fx.symbolic_trace(m)

print(traced.code)
print(traced.graph)

输出:

def forward(self, x):
    linaer = self.linaer(x);  x = None
    relu = torch.relu(linaer);  linaer = None
    return relu
    
graph():
    %x : [#users=1] = placeholder[target=x]
    %linaer : [#users=1] = call_module[target=linaer](args = (%x,), kwargs = {})
    %relu : [#users=1] = call_function[target=torch.relu](args = (%linaer,), kwargs = {})
    return relu
torch.ao.quantization.quantize_fx.py._prepare_fx(...):

graph_module = GraphModule(model, tracer.trace(model))

这一步大部分细节都已经讲清楚了,接着往下走:

 torch.ao.quantization.quantize_fx.py._prepare_fx(...):

_attach_meta_to_node_if_not_exist(graph_module)

def _attach_meta_to_node_if_not_exist(model: GraphModule):
	for node in model.graph.nodes:
        if not hasattr(node, "meta"):
            node.meta = {}

这步操作是在干嘛?

我们先回顾一下,在FX里面,模型会被”拆“成一张计算图(Graph),计算图由很多Node组成,每个Node就是一步运算(比如调用conv2d、relu等),但是这些Node默认只保存:

  • 这个操作是什么(op)
  • 用了哪个函数/模块(target)
  • 输入输出(args、kwargs)
  • 名字(name)

它不保存额外的信息,那这样就有问题了,因为后续我们可能需要往每个节点里面记录额外的信息,比如:shape、dtype、profiling(耗时信息),所以我们需要在每个Node里面开辟一块新的”空间“来存储这些东西,这样我们在分析、优化、量化、裁剪模型的时候,就可以直接从node.meta里面拿到这些辅助信息


接下来的步骤就是算子融合了:

After Fusion
FusedConvBnReLU
Before Fusion
ReLU
Conv2d
BatchNorm2d

为什么要做融合呢?

  • 性能优化
    • 避免多次kernel launch
    • 减少显存/内存带宽消耗
  • 数值稳定性
    • BN 的缩放 (γ) 和偏移 (β) 可直接折叠进 Conv 权重和偏置,避免额外运算误差
  • 图结构更简洁
    • 有利于后续量化 pass(observer 插入更容易),减少冗余节点

我们再来看看底层源码怎么做的:

torch.ao.quantization.quantize_fx.py._prepare_fx(...):

graph_module = _fuse_fx(
        graph_module,
        is_qat,
        fuse_custom_config,
        backend_config)

_fuse_fx这个函数内部调用的实际是fuse_fx函数执行算子融合,所以,我们直接跳转到fuse_fx函数算子融合的部分:

torch.ao.quantization.fx.fuse.py.fuse(...):

named_modules = dict(model.named_modules())

这一步是拿到模型中的所有模块,也就相当于拿到{"layer1.conv1": Conv2d(...), "layer1.bn1": BatchNorm2d(...), ...}的映射

torch.ao.quantization.fx.fuse.py.fuse(...):

if backend_config is None:
        backend_config = get_native_backend_config()

在开始之前,我们需要先了解什么是BackendConfig?

它是量化后端的配置对象,告诉PyTorch:

  • 哪些算子(Conv、Linear、BN、ReLU、Embedding等)会被支持量化
  • 这些算子支持什么量化数据类型(int8、fp16等)
  • 哪些算子组合(pattern)可以被识别并替换成更高效的量化内核

而get_native_backend_config()这个函数做了两件事:

  1. 定义算子支持的量化数据类型(conv 用 int8,linear 可用 int8/fp16,embedding 还能用 4bit)。
  2. 把算子 pattern 注册进 backend_config,例如:
  • conv + relu
  • conv + bn + relu
  • linear + relu
  • embedding 量化

这些pattern会作为后续fusion匹配机制

最终返回一个带有所有pattern → config的BackendConfig("native")对象

不过这还有一个细节点:

get_native_backend_config(…)内部可没有直接告诉我们 pattern 列表,那问题来了,pattern从哪来?

在 PyTorch 量化流程里,pattern 主要通过 BackendPatternConfig 来描述。
get_native_backend_config() 里这些 _get_conv_configs(...)_get_linear_configs(...) 等函数,其实 内部是创建了一堆 BackendPatternConfig,而 BackendPatternConfig 除了 dtype 信息,还包含:

fuser_method

  • 定义了怎么把多个算子 fuse 成一个。
  • 比如 (Conv2d, BatchNorm2d) → fuse 成 ConvBn2d

fused_module

  • 直接指定融合之后用哪个 module。
  • 比如 (Conv2d, ReLU)ConvReLU2d

qat_module / reference_quantized_module

  • 定义了 QAT 和量化推理时对应的模块。

如下:

torch.ao.quantization.backend_config._common_operator_config_utils.py._get_conv_configs:

# (2) Conv + relu
        # -----------------
        # 2.1 conv module + relu fusion configs
        # conv relu fusion, conv module + relu module
        conv_configs.append(
            BackendPatternConfig((convs.root, torch.nn.ReLU))
                .set_dtype_configs(dtype_configs)  # noqa: E131
                .set_fuser_method(_sequential_wrapper2(convs.fused_conv_relu))
                .set_fused_module(convs.fused_conv_relu))
        # conv relu fusion, conv module + functional relu
        conv_configs.append(
            BackendPatternConfig((convs.root, F.relu))
                .set_dtype_configs(dtype_configs)  # noqa: E131
                .set_fuser_method(_sequential_wrapper2(convs.fused_conv_relu))
                .set_fused_module(convs.fused_conv_relu))
        # 2.2 conv module + relu fused module configs
        # conv relu, fused module
        conv_configs.append(
            BackendPatternConfig(convs.fused_conv_relu)
                .set_observation_type(observation_type)  # noqa: E131
                .set_dtype_configs(dtype_configs)
                .set_root_module(convs.root)
                .set_reference_quantized_module(convs.reference)
                .set_qat_module(convs.relu_qat))
        # conv relu, qat fused module
        conv_configs.append(
            BackendPatternConfig(convs.relu_qat)
                .set_observation_type(observation_type)  # noqa: E131
                .set_dtype_configs(dtype_configs)
                .set_root_module(convs.root)
                .set_reference_quantized_module(convs.reference))
        # 2.3 functional conv + relu configs
        # conv relu, functional conv + relu module
        conv_configs.append(
            BackendPatternConfig((convs.func, torch.nn.ReLU))
                .set_observation_type(observation_type)  # noqa: E131
                .set_dtype_configs(dtype_configs))
        # conv relu, functional conv + functional relu
        conv_configs.append(
            BackendPatternConfig((convs.func, F.relu))
                .set_observation_type(observation_type)  # noqa: E131
                .set_dtype_configs(dtype_configs))

        # fused conv relu
        conv_configs.append(
            BackendPatternConfig(convs.fused_conv_relu)
                .set_dtype_configs(dtype_configs)  # noqa: E131
                .set_qat_module(convs.relu_qat))

        conv_configs.append(
            BackendPatternConfig(convs.relu_qat)
                .set_dtype_configs(dtype_configs)  # noqa: E131
                .set_root_module(convs.root)
                .set_reference_quantized_module(convs.reference))

拿到backend_config之后:

torch.ao.quantization.fx.fuse.py.fuse(...):

fusion_pattern_to_fuse_handler_cls = 
_sorted_patterns_dict(_get_fusion_pattern_to_fuse_handler_cls(backend_config))

fusion_pattern_to_fuse_handler_cls定义了可融合的算子模式,为什么要这个东西呢?因为在量化之前,PyTorch FX需要把一些常见的算子组合融合成一个算子,比如:

  • Conv2d → BatchNorm2d → ReLU

如果不融和,会是3个节点,融合后就是一个新的Conv2d(参数里已经合并BN)+ ReLU

那系统怎么知道哪些模式能融合呢?那就是通过 fusion_patter_to_fuse_handler_cls 这张表

我们再来深入一下这张表是怎么拿到的:

torch.ao.quantization.fx.fuse_handler.py._get_fusion_pattern_to_fuse_handler_cls:

def _get_fusion_pattern_to_fuse_handler_cls(
        backend_config: BackendConfig) -> Dict[Pattern, Callable]:
    fusion_pattern_to_fuse_handlers: Dict[Pattern, Callable] = {}
    for pattern, config in backend_config._pattern_complex_format_to_config.items():
        if config.fuser_method is not None:
            # TODO: is this logic right?
            fusion_pattern_to_fuse_handlers[pattern] = DefaultFuseHandler
    return fusion_pattern_to_fuse_handlers

这一步实际上就是把backend_config里面的pattern转换成”pattern → handler类“的映射

整体逻辑:

遍历backend_config里面的所有pattern

如果某个pattern有定义 fuser_method(说明它支持融合,比如Conv + BN + ReLU → FusedConv),就加到结果字典里面

默认绑定DefaultFuseHandler,这个handler知道怎么把多个算子节点合并成一个fused节点

转换完成之后,在外层还做了一个排序的操作,这是为了确保长pattern优先匹配

  • 比如 (ReLU, (BN, Conv2d)) 要优先于 (BN, Conv2d)
  • 否则可能先把 (BN, Conv2d) fuse 掉,就没机会 fuse (ReLU, BN, Conv2d)

用一句话总结这几步流程就是:

这一套逻辑就是从backend_config里面拿到算子融合规则,挑选出能融合的pattern,按照长度排序,等FX Graph执行的时候用handler把这些pattern替换成高效的fused节点

当然,除了获取fusion_pattern_to_fuse_handler_cls以外,这里还有:

fuser_method_mapping = get_fuser_method_mapping(backend_config)
fusion_pattern_to_root_node_getter = get_fusion_pattern_to_root_node_getter(backend_config)
fusion_pattern_to_extra_inputs_getter = get_fusion_pattern_to_extra_inputs_getter(backend_config)

不做过多介绍,接着往下看:

torch.ao.quantization.fx.fuse.py.fuse(...):

# find fusion
fusion_pairs = _find_matches(
    model, model.graph, fusion_pattern_to_fuse_handler_cls)

现在我们手里已经有了fusion_pattern_to_fuse_handler_cls,它是一个字典:

  • key:pattern
  • value:handler class

它定义了哪些pattern可以融合,谁来管融合逻辑

_find_matches(...)为了在模型的FX Graph里面,找出和这些模式匹配的子图,并建立映射表

换句话说:我们现在有一个模型的计算图(FX Graph),它是由一堆节点(Node)组成,比如:

x → Conv2d → BatchNorm2d → ReLU → y

这里每个箭头就是图里面的一个节点

现在我们知道某些pattern可以融合,比如:

  • (Conv2d, BatchNorm2d) → ConvBn2d
  • (Conv2d, ReLU) → ConvReLU2d

这些模式就很像模板

_find_matches(...)在干啥?它会沿着这个计算图去扫描,看哪些地方可以套上这些模板:

它可能发现Conv2d + BatchNorm2d能匹配上模式 (Conv2d, BatchNorm2d)

也可能发现 Conv2d + ReLU 能匹配上 (Conv2d, ReLU)

一旦匹配成功,它就会记下来:

  • 哪些节点被匹配了
  • 属于哪个pattern
  • 用哪个handler来做融合

这就是所谓的建立映射表

就像在一段文字里面找关键字:

我昨天吃了炸鸡,然后喝了可乐

pattern:

”炸鸡+可乐“ → 属于”热量高的食物“

匹配后我们就能标记出:”炸鸡“和”可乐“是一个pattern,后面要合并处理

在_find_matches(…)里面,首先会拿到所有模块,方便后面用node去查实际的nn.Module

torch.ao.quantization.fx.fuse.py._find_matches(...):

modules = dict(root.named_modules())

然后遍历graph的节点(倒序),这里为什么用倒叙呢?

倒叙是因为pattern一般是往前看,比如 (Conv → BN → ReLU),从ReLU往前看可以更好地找到匹配

torch.ao.quantization.fx.fuse.py._find_matches(...):

for node in reversed(graph.nodes):
        if node.name not in match_map:
            for pattern, fuse_handler_cls in pattern_to_fuse_handler_cls.items():
                matched_node_pattern: List[Node] = []
                if _is_match(modules, node, pattern):
                    apply_match(pattern, node, (node, pattern, fuse_handler_cls(node)), 												matched_node_pattern, node_to_subpattern)
                    break

循环内部的_is_match会检查这个node以及它的输入们,是否符合某个pattern的结构

Node的输入:

假设有一个FX Graph:

x -> conv -> bn -> relu -> output

在 FX 里,Node 的输入就是它依赖的上游节点。

  • bn 这个 Node 的输入就是 conv 的输出
  • relu 这个 Node 的输入就是 bn 的输出

在源码里,这些信息存在 node.args 里。
比如:

  • relu_node.args[0] == bn_node
  • bn_node.args[0] == conv_node

所以当我们说**“检查这个 node 以及它的输入们”**时,意思就是:
不仅要检查当前节点是不是符合模式,还要沿着它的输入链条,看前面连着的节点是不是一起构成了某个 pattern

_is_match(…):

核心逻辑

torch.ao.quantization.fx.match_utils.py._is_match(...):

if isinstance(pattern, tuple):
    self_match, *arg_matches = pattern
  • 如果 pattern 是个 tuple(比如 (torch.nn.Conv2d, torch.nn.ReLU)),就拆成:
    • self_match = 最后一个(当前)节点应该匹配的类型
    • arg_matches = 输入节点应该匹配的模式

举个例子:
pattern = (Conv2d, ReLU)

  • self_match = Conv2d
  • arg_matches = [ReLU]

等等!注意这里的顺序其实取决于定义,在 PyTorch 的实现里,通常是 (root, subpattern) 这种结构,所以会反着检查输入链。

匹配当前节点

torch.ao.quantization.fx.match_utils.py._is_match(...):

if isinstance(self_match, type) and issubclass(self_match, torch.nn.Module):
    if node.op != 'call_module': return False
    if not type_before_parametrizations(modules[node.target]) == self_match:
        return False

意思是:
如果 pattern 里要求是 torch.nn.Conv2d,那这个 node 必须是一个 call_module 节点,而且它绑定的 modules[node.target] 的类型要真的是 Conv2d

同理:

  • 如果 pattern 里是个函数(比如 F.relu),那就要匹配 call_function 节点。
  • 如果是字符串(比如 'add'),就要匹配 call_method 节点。

检查输入

torch.ao.quantization.fx.match_utils.py._is_match(...):

if not arg_matches:
    return True
if len(arg_matches) != len(node.args):
    return False
return all(_is_match(modules, node, arg_match, max_uses=1) 
           for node, arg_match in zip(node.args, arg_matches))

这里就是关键了:

  • 如果 pattern 是单个算子,没有输入要求,直接返回 True。
  • 否则,拿 node.args(也就是它的 输入节点们),一一和 arg_matches 去递归匹配。

这就是「检查 node 以及它的输入们」的真正含义:

  • 当前节点必须是 ReLU;
  • 它的输入节点必须是 Conv2d;
  • Conv2d 节点如果还有输入要求,就继续往前递归匹配。

如果匹配成功,就在apply_match(…)函数中递归展开:

  • 如果pattern是tuple,就递归匹配子节点
  • 反之直接把这个node记下来,最终得到一个 matched_node_pattern,里面包含实际匹配到的FX Node列表
torch.ao.quantization.fx.fuse.py._find_matches(...).apply_match(...):

def apply_match(pattern, node, match, matched_node_pattern, node_to_subpattern):
        if isinstance(pattern, tuple):
            ...
        else:
           ...

记录结果:

match_map[node.name] = (
    root_node,      # 模式的起点(比如 Conv)
    pattern,        # 模式本身 (Conv, ReLU)
    matched_node_pattern,  # 实际图里匹配到的 Node 列表
    handler,        # 对应的 FuseHandler
    node_to_subpattern
)

返回结果

最后 _find_matches(…) 返回一个match_map,相当于融合候选表

  • key = 节点名
  • value = (root_node, pattern, matched_node_pattern, handler, node_to_subpattern)

也即:这个节点属于某个 pattern,pattern 的根节点是谁,pattern 涉及哪些 node,要用哪个 handler 来 fuse

这基本上就是_find_matches函数了,接着回到 fuse 函数:

在拿到了融合候选的信息之后,我们需要遍历计算图中的节点:

torch.ao.quantization.fx.fuse.py.fuse(...):

for node in model.graph.nodes:
    maybe_last_node, pattern, matched_node_pattern, obj, node_to_subpattern = \
        fusion_pairs.get(node.name, (None, None, None, None, None))

其中每个node会查询它是不是某个融合pattern的最后一个节点(maby_last_node)

如果是最后一个节点:

拿到pattern细节和对应的fuse handler,开始进行融合:

torch.ao.quantization.fx.fuse.py.fuse(...):

if maybe_last_node is node:
    root_node_getter = fusion_pattern_to_root_node_getter.get(pattern, default_root_node_getter)
    root_node = root_node_getter(matched_node_pattern)
    extra_inputs_getter = fusion_pattern_to_extra_inputs_getter.get(pattern, None)
    extra_inputs = extra_inputs_getter(matched_node_pattern) if extra_inputs_getter else []
    
    env[node.name] = obj.fuse(
        load_arg, named_modules, fused_graph, root_node, extra_inputs, matched_node_pattern,
        fuse_custom_config, fuser_method_mapping, is_qat)

融合的时候需要找到这个融合pattern的root(通常是Conv2d)

obj.fuse(...)调用handler,把原来的多个模块(Conv + BN + ReLU)合并为一个”融合模块“:

  • 比如Conv2d + BN → 新的Conv2d(权重和偏置会提前融合BN参数)

env[node.name]会记录替换后的节点

如果不是最后一个节点:

这种情况就说明这个node不属于可融合pattern,就原样复制就行了

最后再构建新的GraphModule

torch.ao.quantization.fx.fuse.py.fuse(...):

model = GraphModule(model, fused_graph)
return model

最终返回一个新的GraphModule,图结构和部分节点已经被替换成融合后的版本

torch.ao.quantization.quantize_fx.py._prepare_fx(…) 里面的 _fuse_fx(…)完整调用上面已经讲清楚了,再执行完_fuse_fx(…)函数之后,也就是算子融合过程,接下来我们往这个优化过后的图里面插入Observer节点,这些节点会在模型跑一遍数据的时候,收集激活值/权重的分布信息,得到 scale / zeropoint

torch.ao.quantization.quantize_fx.py._prepare_fx(...)

prepared = prepare(
        graph_module,
        qconfig_mapping,
        is_qat,
        tracer.node_name_to_scope,
        example_inputs=example_inputs,
        prepare_custom_config=prepare_custom_config,
        _equalization_config=_equalization_config,
        backend_config=backend_config,
        is_standalone_module=is_standalone_module,
)

接下来就要用到我们前面的qconfig_mapping,还记得这玩意吧,它里面存了:

  • activation observer
  • weight oberver
  • dtype
  • 量化方案

prepare(…)具体流程

调用方 prepare()函数 qconfig_mapping 计算图(Graph) 节点(Node) Observer构造器 第一阶段:遍历节点并设置QConfig 调用prepare(model, qconfig_mapping) 获取下一个Node 返回Node对象 查找Node对应的QConfig 返回QConfig对象 node.meta["qconfig"] = qconfig QConfig信息已保存到meta中 返回None 跳过该节点(如Softmax、Reshape等) alt [找到匹配的QConfig] [未找到QConfig] loop [遍历图中的每个Node] 第二阶段:插入Observer节点 检查node.meta["qconfig"] 跳过该节点,不插入Observer 需要为该节点插入Observer 获取QConfig.activation() 构造activation Observer 返回Observer实例 在Node输出处插入Observer节点 获取QConfig.weight() 构造weight Observer 返回Observer实例 在Node权重输入处插入Observer节点 par [处理激活值量化] [处理权重量化] Observer节点已插入到计算图中 alt [QConfig为空] [QConfig存在] loop [再次遍历图中的每个Node] 返回带有Observer的模型 准备阶段完成,模型可用于校准 调用方 prepare()函数 qconfig_mapping 计算图(Graph) 节点(Node) Observer构造器

源码分析

同样,我们可以来看看prepare(…)内部做了些啥:

我们发现它首先会去检查我们传进去的qconfig_mapping是什么类型,如果是Dict,它这里会触发一个警告,提示我们这种写法已经过时:

if isinstance(qconfig_mapping, Dict):
        warnings.warn(
            "Passing a QConfig dictionary to prepare is deprecated and will not be supported "
            "in a future version. Please pass in a QConfigMapping instead.")
        qconfig_mapping = QConfigMapping.from_dict(qconfig_mapping)

继续追踪,查看from_dict内部做了什么:

@classmethod
    def from_dict(cls, qconfig_dict: Dict[str, Any]) -> QConfigMapping:
        """
        Create a ``QConfigMapping`` from a dictionary with the following keys (all optional):

            "" (for global QConfig)

            "object_type"

            "module_name_regex"

            "module_name"

            "module_name_object_type_order"

        The values of this dictionary are expected to be lists of tuples.
        """
        conf = cls()
        if _GLOBAL_DICT_KEY in qconfig_dict:
            conf.set_global(qconfig_dict[_GLOBAL_DICT_KEY])
        for object_type, qconfig in qconfig_dict.get(_OBJECT_TYPE_DICT_KEY, []):
            conf.set_object_type(object_type, qconfig)
        for module_name_regex, qconfig in qconfig_dict.get(_MODULE_NAME_REGEX_DICT_KEY, []):
            conf.set_module_name_regex(module_name_regex, qconfig)
        for module_name, qconfig in qconfig_dict.get(_MODULE_NAME_DICT_KEY, []):
            conf.set_module_name(module_name, qconfig)
        for module_name, object_type, index, qconfig in qconfig_dict.get(_MODULE_NAME_OBJECT_TYPE_ORDER_DICT_KEY, []):
            conf.set_module_name_object_type_order(module_name, object_type, index, qconfig)
        return conf

我们可以看到,我们虽然传入的是字典,但是它在内部还是调用了方法,帮我们将其转成一个QConfigMapping对象。

在我们拿到这个对象之后还有个问题,也就是用户传入的qconfig_mapping可能还想在别的地方复用,如果prepare_fx在里面直接in-place给它改了,那就会导致用户拿到的原对象被”污染“,比如原来你只设了 set_global(config),结果函数里面偷偷加了别的规则,之后你再用这个qconfig_mapping去quantize另一个模型的时候,就会出现一些莫名其妙的行为。为了解决这个问题,在源码中做了深拷贝的操作,这也就是这步操作的意义:

qconfig_mapping = copy.deepcopy(qconfig_mapping)

为了保证量化配置和fusion后的模型对齐,我们需要更新QConfig以支持融合:

_update_qconfig_for_fusion(model, qconfig_mapping)

遍历 model 的graph,根据fusion pattern 调整对应模块的qconfig

  • 就比如Conv + BN 融合之后,我们只对Conv保留qconfig,BN的qconfig会被清理掉

为了快速查找每个节点该用哪个qconfig,我们需要将qconfig_mapping展平成字典:

flattened_qconfig_dict = _get_flattened_qconfig_dict(qconfig_mapping)

把QConfigMapping转换成key → qconfig的字典,方便查找,key可以是模块路径、模块类型、函数名等

之后为了为每个节点确定量化配置,我们需要做QConfig传播:

propagate_qconfig_(model, flattened_qconfig_dict, prepare_custom_config.to_dict())

内部会遍历model.graph的节点,根据 flattened_qconfig_dict 给每个节点分配合适的qconfig,结果会写进node.meta["qconfig"] 或类似字段里

如果是QAT量化,那这里需要做额外的处理:

if is_qat:
    module_to_qat_module = get_module_to_qat_module(backend_config)
    _qat_swap_modules(model, module_to_qat_module)
    _update_qconfig_for_qat(qconfig_mapping, backend_config)
  • 把model里面的某些模块替换成QAT版本(比如nn.Conv2d → nn.qat.Conv2d
  • 根据 backend_config 更新 qconfig 以适配 QAT

在前面我们已经讲过了,QAT它内部进行了伪量化的操作,所以这里的操作是为了让模型进入带 FakeQuant 的 QAT 状态

在插入Observers之前,想想还要干嘛?还有什么遗漏的?

  1. 每个节点该用什么量化配置(qconfig)?——这是单点层面的规则分发
  2. 哪些节点需要组合在一起处理?——这是pattern层面的归并和分派

所以就有了下面这两步操作:先为每个节点贴上qconfig标签,再把这些标签与匹配到的算子组合(pattern)绑定起来,交给 QuantizeHandler 统一处理

node_name_to_qconfig = _generate_node_name_to_qconfig(...)

这一层它只管单个Node,它的结果是:

{
  "conv1": qconfig_for_conv,
  "relu1": qconfig_for_relu,
  "add":   None,    # 不需要量化
  ...
}

可以理解成“给每个node打个量化规则标签”

matches_without_qconfig = _find_matches(...)
node_name_to_match_result_with_qconfig = {}
for node_name, match_without_qconfig in matches_without_qconfig.items():
    match_with_qconfig = (*match_without_qconfig, node_name_to_qconfig[node_name])

这一层不是单个node,而是算子组合,_find_matches(...)得到的结果最初只是“这些节点组合成了一个pattern,应该交给哪个QuantizeHandler处理”,但handler自己还要知道对应的量化配置,不然它不知道用什么Observer,所以就要把上一层得到的 node_name_to_qconfig信息附加到match结果里面

如果还是不理解的话,我们再来看一个例子:

第一层:每个士兵(Node)都发一个装备清单(QConfig)

第二层:找出哪些士兵组成一个作战小队(Pattern),然后把每个人的装备清单都带上,交给指挥官(Handler)来统一安排

完成这些操作之后,就可以往里面加入Observers了:

result_node = insert_observers_for_model(
        model,
        node_name_to_match_result_with_qconfig,
        node_name_to_qconfig,
        prepare_custom_config,
        equalization_node_name_to_qconfig,
        backend_config,
        observed_node_names,
        is_qat
    )

insert_observers_for_model(…) 具体怎么操作的,我们可以看这张图,这里就不细讲了,太多了:

insert_observers_for_model Model Graph QConfig Processing Observer Operations Module Operations 初始化阶段 dict(model.named_modules()) 初始化input_quantized_idxs, output_quantized_idxs 初始化processed_nodes = set() 初始化所有节点的target_dtype_info node.meta["target_dtype_info"] = _DEFAULT_FP32_QCONFIG loop [遍历model.graph.nodes] 建立placeholder_node_to_input_index映射 建立output_node_to_output_index映射 Step 1: 为匹配的节点模式设置observer构造器 获取(last_node, matched_node_pattern, pattern, qhandler, qconfig) _set_target_dtype_info_for_matched_node_pattern() 设置匹配节点的target_dtype_info loop [遍历node_name_to_match_result_with_qconfig] Step 2.1: 基于节点类型的特殊设置 设置QUINT8_QCONFIG_FOR_TARGET_DTYPE_INFO all_node_args_have_no_tensors()检查 设置input_act和output_act observer为None alt [args_have_no_tensors == True] 设置QUINT8_QCONFIG_FOR_TARGET_DTYPE_INFO alt [node.op == "placeholder" && 在input_quantized_idxs中] [node.op in ("call_module", "call_method", "call_function")] [node.op == "output" && 在output_quantized_idxs中] loop [遍历model.graph.nodes] Step 2.2: 传播已知节点的dtype propagate_dtypes_for_known_nodes() Step 3: 检查backend支持并重置不支持的配置 重置processed_nodes = set() _is_pattern_dtype_config_and_qconfig_supported_by_backend() _set_target_dtype_info_for_matched_node_pattern(fp32_placeholder_qconfig) alt [不支持 && output_act_dtype不是特殊类型] loop [再次遍历node_name_to_match_result_with_qconfig] Step 4: 实际插入observers nodes_before_observation = list(model.graph.nodes) 初始化custom_module_names_already_swapped = set() 重置inputs_seen_counter, outputs_seen_counter pass - 不需要观察 从node_name_to_match_result_with_qconfig获取匹配信息 计算skip_inserting_observers条件 _is_pattern_dtype_config_and_qconfig_supported_by_backend() dict(model.named_modules()) - 重新获取 _add_matched_node_name_to_set(observed_node_names) 检查is_quantized_branch逻辑 get_fusion_pattern_to_root_node_getter() root_node_getter(matched_node_pattern) _maybe_insert_input_observers_for_node() 插入输入observers _maybe_insert_input_equalization_observers_for_node() 插入equalization observers alt [node是root_node (输入节点)] _insert_dequant_stubs_for_custom_module_lstm_output() 插入DeQuantStubs并处理嵌套tuple _swap_custom_module_to_observed() alt [node.target未swap过] _maybe_insert_output_observer_for_node() 插入输出observer 更新所有用户节点: user_node.replace_input_with() _maybe_make_input_output_share_observers() _remove_output_observer() alt [返回False] alt [is_general_tensor_value_op或_is_reuse_input_qconfig] _swap_custom_module_to_observed() alt [qhandler.is_custom_module()且未swap过] alt [maybe_output_obs_node不为None] alt [_is_custom_module_lstm()] [一般情况] alt [node是last_node (最后节点)] _maybe_insert_observers_before_graph_output() 在图输出前插入observers alt [node.op != 'output'] [node.op == "output"] alt [不skip且backend支持] alt [node.op == 'placeholder'] [node.op in ('call_module', 'call_method', 'call_function', 'output')] inputs_seen_counter += 1 outputs_seen_counter += 1 results_node = node alt [node.op == 'placeholder'] [node.op == 'output'] loop [遍历nodes_before_observation] return results_node insert_observers_for_model Model Graph QConfig Processing Observer Operations Module Operations

至此,我们的第三步:准备模型、嵌入观察者;大部分细节都已经交代清楚了

第四步、校准模型

在这里插入图片描述
第三步结束之后,模型里面已经加入了Observers节点,这些Observers节点的作用:

  • 统计激活值的分布信息

这些统计结果之后会用来确定量化的scale、zeropoint、数据类型范围如何映射回浮点数范围

等等,好像少了点东西,**Observers节点不是还能统计权重的分布信息吗?**怎么在这里没出现?

在前面,我们已经提过了,权重和激活值的处理方式不一样

权重的量化参数

权重张量是静态已知的(训练好的参数),它不会像激活一样在运行时不断变化

权重Observer也会参与运算min/max,它们一开始就可以直接根据张量本身去算 min/max → scale/zeropoint,不需要我们喂数据,所以这部分不依赖校准阶段

激活的量化参数

激活值取决于输入数据分布,范围未知且动态

必须靠Observers节点在前向推理过程中统计数据(最小值、最大值、分布直方图)

所以才需要“校准模型”,喂数据进去让Observers收集信息

那为什么要跑一遍数据呢?

因为Observer 节点不会自己凭空知道张量范围,必须要通过真实数据来观察

所以我们需要拿一批 校准数据(calibration_data) —— 通常是训练集或验证集里抽出来的一小部分,送进模型跑一遍(forward)

在 forward 的过程中,Observer 节点会不断更新自己的统计量,直到分布信息比较稳定

而校准的目的正是为了减少误差

  • 如果不用校准,量化范围是随便设的,很可能把绝大多数激活值挤到int8的边界之外,误差就会非常大
  • 有了校准之后,int8的范围会根据真实数据来自适应,这样量化时的数值映射更准确
第五步、转换模型
# 5. 转换模型:真正的量化
quantized_model = convert_fx(model_to_quantize)

convert_fx函数会将已经加入Observers节点和经过校准的GraphModule转换成可部署的量化模型(reference 或 non-reference),核心流程:

第一步:根据Observer生成显示量化/反量化节点

  • 这一步会遍历整个计算图,把激活后处理(activation_post_process)或DeQuantStub 替换成对应的quantize/ dequantize调用(或者是分解形式的quantize API,当 is_decomposed=True)。如果用户指定了要量化输入/输出,也会在图的开头或者是结尾插入或者是删除 dequantize 节点 (dequantize 节点的作用就是把 INT8 的量化张量还原成 FP32 浮点张量,方便后续继续做浮点计算)

  • torch.ao.quantization.fx.convert.py.convert(...):
    
    for node in list(model.graph.nodes):
            if node.op == 'placeholder':
               ...
            elif node.op == "output":
                ...
            elif node.op == "call_module":
                ...
    

第二步:加权算子替换为量化等价模块

  • 对于被观测到的带权算子(比如 nn.Linear、nn.Conv2d、融合单元ConvReLU等),将会根据node对应的qconfig、 backend_config还有observed_node_names,替换成后端参考量化模块(比如nn.quantized.Linear),或者是保留 浮点回退。权 重的打包/按通道处理是由后端(比如fbgemm) 所决定

  • 源码片段:

  • torch.ao.quantization.fx.convert.py.convert(...):
        
    elif node.op == "call_module":
    	...
    	elif type_before_parametrizations(mod) in set(
                        root_module_classes).union(qat_module_classes).union(fused_module_classes):
                # extra check for fused module classes to make sure they are fused module classes
                # of target modules
                if type_before_parametrizations(mod) in fused_module_classes and \
                   type_before_parametrizations(mod[0]) not in root_module_classes:  # type: ignore[index]
                    continue
                convert_weighted_module(
                    node, modules, observed_node_names, node_name_to_qconfig, backend_config, is_decomposed)
    

第三步:处理自定义/独立子模块

  • 在了解这一步之前,我们需要先了解两个概念:
  • "standalone"子模块
    • 我们从前面已经知道了在PyTorch的 FX 量化流程里面,模型会被分解成很多子模块(submodule),比如Conv2d、Linear等等,甚至是某些Sequential组合。其中 standalone 子模块指的是可以被单独拿出来量化、转换、推理的子模块单元,它们内部逻辑相对完整,输入输出清晰,就比如 Conv2d + ReLU 这种模块就可以被打包成一个 standalone 单元,直接替换成量化版本
  • 等化(equalization)
    • 等化在这里指的是 weight equalization (权重等化),是量化前的一种 预处理/校正步骤。我们在第四步的时候提到了校准激活值,接下来我们还需要对权重和激活值进行等化的操作。那为什么需要等化呢?因为不同层的权重分布范围差异很大(比如一个卷积层权重大,另一个卷积层权重小),量化的时候可能会导致某些层信息丢失更严重,所以我们等化的目的就是为了重新缩放各层的权重和激活值,使得它们数值分布更加的均衡,从而减少量化误差

在convert_fx中,支持将用户自定义的 observed module 映射到对应的量化实现,并支持"standalone"子模块单元的独立转换与 静态量化。对需要等化(equalization)的节点,会先调整observer/权重再转换

源码片段:

elif _is_observed_standalone_module(mod):
    convert_standalone_module(
        node, modules, model, is_reference, backend_config)

elif type_before_parametrizations(mod) in custom_module_classes:
    convert_custom_module(
        node, model.graph, modules,
        custom_module_class_mapping,
        statically_quantized_custom_module_nodes)

第四步:融合与图级重写

  • 这里指的是还没有进行算子融合,那可以在准备阶段(或者是convert前的可选fuse步骤)合并算子以避免量化中间体(比如Conv+ReLU作为整体量化单元),转换之后对图进行死码消除、常量折叠等优化以提升推理效率

  • 死码消除

    • 这个是很常见的术语,死码也就是永远不会用到的计算

      • 比如:

      • x = conv(input)
        y = relu(x)
        z = some_unused_op(y)
        return y
        
      • some_unused_op(y)就算死码,因为它的结果没被用到,优化的时候会自动删除它

      • 我们在平时写代码,也要注重这一部分

  • 常量折叠

    • 如果图里面有只依赖常量的计算,可以提前在编译阶段算好,避免推理的时候又算一次

      • 比如:

      • y = x * (2 + 3)
        
      • 常量折叠后直接变成: y = x * 5

      • 在量化场景里面,有些量化参数计算出来就是固定的,就可以常量折叠,这样可以减轻推理时的负担

  • 源码片段:

  • # remove deadcode after converting observers to quant/dequant ops
    model.graph.eliminate_dead_code()
    model = GraphModule(model, model.graph)
    

第五步:后端下沉与优化

这一步主要是对于非reference模式,执行下沉/后端特定变换(比如lower_to_fbgemm)以生成fbgemm/内核友好的形式(打包权重、替换为backend ops),不支持的子图回退为FP32

这句话信息量有点多,我们来一个词一个词的抠:

  • 非reference模式

    • Reference模式
      • 指的是一个通用的量化IR(中间表示),只是把量化算子按规范放进去,还没做具体的后端优化
    • 非Reference模式
      • 这个主要是为了针对特定硬件或者是特定的库去做优化,比如Intel CPU用 fbgemm,ARM CPU/GPU用 qnnpack
  • 下沉/后端特定变换

    • 下沉就是把“通用表示” → 能直接跑在后端内核上的形式
    • 举例:
      • 原来是 quantized::conv2d 这种通用量化算子
      • 下沉后会变成 fbgemm 专用的 conv内核调用,并且把权重提前打包成 fbgemm 格式
  • 打包权重、替换backend ops

    • 打包权重
      • 不同后端对权重排布有特殊要求,比如fbgemm会把 linear.weight 重新排列成自己最擅长计算的内存布局
    • 替换backend ops
      • 把通用的 quantized::linear 换成 quantized::fbgemm_linear
  • 不支持的子图回退成FP32

    • 有些算子后端根本不支持量化版本(比如softmax、layernorm这类非线性操作),那就直接回退回FP32版本计算,这样保证模型可以完整运行,而不会因为量化缺算子而报错
  • 源码片段:

  • if not is_reference:
        model = lower_to_fbgemm(model, node_name_to_qconfig, node_name_to_scope)
    

第六步:清理并返回

这一步就是收尾了,在内部将会清理掉qconfig、observer这些训练/准备时才需要的东西,还要删除没用的子模块,去掉临时信息,最后返回一个干净的量化模型

源码片段:

if _remove_qconfig_flag:
        _remove_qconfig(model)
    model.delete_all_unused_submodules()
    model.meta.pop("_observed_graph_module_attrs", None)
    return model

到这里,整个底层分析就彻底结束了,当然其中还有一些细节还是没有介绍到位

3.2 NIVIDA TensorRT INT8 Calibration

当模型需要在NIVIDA GPU上追求极致的推理性能时,TensorRT是不二之选,这是一个专门用于高性能推理的优化器和运行时,它的INT8量化功能非常强大

TensorRT的量化本质上是PTQ,其核心是 Callibration(校准)过程

部署流程:

  1. **模型转换:**首先,我们需要将模型转换成 ONNX 格式,这是一个通用的模型表示格式
  2. **构建TensroRT引擎:**使用TensorRT的 trtexec 工具或Python API来构建推理引擎,在构建时,需要指定int8模式
  3. **实现IInt8Calibarator接口:**这是关键,我们需要编写一个校准器类,它负责提供校准数据。TensorRT会调用这个类,获取一批批校准数据,送入模型,以确定各层激活值的动态范围
  4. **校准过程:**TensorRT内部会使用获取到的校准数据进行校准。它采用了一种基于信息熵的算法来寻找最佳的阈值,从而最小化原始FP32分布和量化后INT8分布之间的信息损失(KL散度),这通常比简单的Min/Max校准方法精度更高
  5. **生成校准表:**校准完成之后,TensorRT会生成一个“校准表”文件,这个文件缓存了每层的量化参数
  6. **构建并部署INT8引擎:**有了校准表,我们就可以快速地构建一个完全优化的INT8推理引擎,在后续部署时,可以直接加载这个引擎,无需再次校准
用户 原始模型 (PyTorch/TensorFlow) ONNX模型 IInt8Calibrator 校准器 校准数据集 TensorRT引擎 校准表文件 INT8推理引擎 1. 模型转换阶段 准备训练好的模型 转换模型格式 导出为ONNX格式 2. 校准器实现阶段 实现IInt8Calibrator接口 准备代表性校准数据 配置数据读取方式 3. 校准过程阶段 启动校准过程 (--int8模式) 调用getBatchSize() 返回批次大小 调用getBatch() 读取一批数据 返回数据批次 提供校准数据 前向推理收集激活值 计算激活值分布统计 loop [校准数据迭代] 4. 量化参数优化阶段 基于KL散度的熵校准算法 为每层寻找最佳量化阈值 最小化FP32与INT8分布差异 5. 校准表生成阶段 生成校准表文件 调用writeCalibrationCache() 保存量化参数到文件 6. 引擎构建与部署阶段 构建INT8优化引擎 读取缓存的校准参数 加载量化参数 生成优化的INT8推理引擎 7. 生产部署阶段 部署并执行INT8推理 返回推理结果 用户 原始模型 (PyTorch/TensorFlow) ONNX模型 IInt8Calibrator 校准器 校准数据集 TensorRT引擎 校准表文件 INT8推理引擎

有关KL散度(简单介绍):

这个算法的本质上是衡量两个概率分布之间的差异,如果两个分布一模一样,KL散度等于0,差异越大,则KL散度越大(没有上界),而且它是不对称的,即KL(P||Q) ≠ KL(Q||P)

举个例子:

假设P是真实的分布,Q是模型预测的分布

KL散度告诉我们:如果我们用Q来近似P,那损失了多少信息

数学公式:

  • 离散分布

    • D K L ( P ∣ ∣ Q ) = ∑ x P ( x ) l o g P ( x ) Q ( x ) D_{KL}(P||Q) = \sum_{x}P(x)log\frac{P(x)}{Q(x)} DKL(P∣∣Q)=xP(x)logQ(x)P(x)
  • 连续分布

    • D K L ( P ∣ ∣ Q ) = ∫ P ( x ) l o g P ( x ) Q ( x ) d x D_{KL}(P||Q) = \int P(x)log\frac{P(x)}{Q(x)} dx DKL(P∣∣Q)=P(x)logQ(x)P(x)dx

P(x):真实分布

Q(x):近似分布(模型预测)

举个直观的例子:

假设我们有两个分布:

真实分布P(抛硬币概率):P = [0.5, 0.5]

模型预测分布Q:Q = [0.9, 0.1]

计算:
D K L ( P ∣ ∣ Q ) = 0.5 ∗ l o g 0.5 0.9 + 0.5 ∗ l o g 0.5 0.1 D_{KL}(P||Q) = 0.5 * log\frac{0.5}{0.9} + 0.5*log\frac{0.5}{0.1} DKL(P∣∣Q)=0.5log0.90.5+0.5log0.10.5
大概结果是0.51,说明差异比较大

应用场景

  • 量化校准

KL散度常用来找量化的最佳阈值:

比较“量化前的激活分布P”和“量化后的激活分布Q”,找到可以最小化KL散度的量化范围,这样可以减少信息损失

  • 机器学习/深度学习

Loss函数:比如在分类任务里面,交叉熵损失其实就是带权KL散度

GAN:判别器和生成器的优化目标里面也会涉及分布差异度量

3.3 HuggingFace Optium & bitsandbytes

对于广大LLM“玩家”来说,HuggingFace生态是事实上的标准。为了简化量化流程,HuggingFace推出了 Optium 库,它集成了多种硬件后端的优化工具。而bitesandbytes则是目前在HuggingFace生态中最流行的、实现“开箱即用”动态量化的底层库

应用方式:

使用bitssandbytes进行量化极其简单,通常只需要在加载模型的时候添加几个参数:

from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig

# 配置量化信息
quantization_config = BitsAndBytesConfig(
    load_in_4bit=True,  # 将权重转成 4bit 表示
    bnb_4bit_quant_type="nf4",  # 设置 4bit 的量化类型
    bnb_4bit_compute_dtype="float16",   # 指定推理计算时使用的精度
    bnb_4bit_use_double_quant=True  # 开启双重量化
)

# 加载模型
model = AutoModelForCausalLM.from_pretrained(
    "openai-community/gpt2-x1",
    quantization_config=quantization_config,
    device_map="auto"
)

# tokenizer
tokenizer = AutoTokenizer.from_pretrained("openai-community/gpt2-x1")
inputs = tokenizer("Once upon a time, there was a magical forest", return_tensors="pt").to(model.device)
outputs = model.generate(**inputs, max_new_tokens=100)
print(tokenizer.decode(outputs[0], skip_special_tokens=True))

双重量化:权重量化一次,量化参数本身也量化一次

背后发生了什么?

bitsandbytes实现了一种复杂的PTQ技术,特别是针对大模型中的离散值问题。它结合了分块量化和动态量化。对于权重矩阵,它会将其分成小块,为每小块计算独立的量化参数,从而更精确地处理数值分布。这种方式极大地保留了模型的精度

3.4 常见的坑和对策

  • 校准数据集不足或不具代表性
    • 现象:模型在测试集上精度骤降
    • 原因:校准数据未能覆盖模型在真实业务场景中遇到的数值范围,导致量化参数有偏
    • 对策:使用与最终应用场景分布一致的、更多样化的数据进行校准
  • 权重或激活值分布异常
    • 现象:个别层精度损失严重
    • 原因:权重或激活值中存在一些极端的大数值,导致量化范围被拉得很大,大部分数值的精度因此被带偏
    • 对策:使用更高级的量化算法,比如bitsandbytes的分块量化,或者采用混合精度策略,将包含离群值的“敏感层”保持在FP16
  • 推理精度下降
    • 现象:模型整体性能不达标
    • 原因:量化误差累积
    • 对策:
      • 首先尝试Per-Channel量化替代Per-Tensor量化
      • 如果PTQ不行,考虑使用QAT,通过微调来恢复精度
      • 分析是哪些层导致了精度下降,对它们进行混合精度处理

我们在实际中选择量化策略需要考虑的因素也比较多,我将它们总结如下:

明确底线、评估硬件、高低尝试、评测基准、KV Cache

  • 明确底线
    • 你的应用场景对精度的要求有多高?是一个要求答案极其精确的医疗或金融问答机器人,还是一个用于创意写作的助手?前者无法接受低精度带来的任何偏差,而后者却可以
  • 评估硬件
    • 你的目标部署环境是什么?是拥有80G H100的云服务器,还是只有8G VRAM的办公电脑?硬件预算直接决定了你的选择上限
  • 高低尝试
    • 如果条件允许,总是先从高精度作为baseline。如果显存不足,首先尝试INT8,并进行充分的评测。如果INT8仍然i无法满足部署需求,再考虑INT4
  • 评测基准
    • 使用标准的学术benchmark(比如MMLU、Hellaswag等)和你们公司自己的业务评测集来量化精度损失,做出数据驱动的决策
  • KV Cache
    • 对于长文本生成任务,推理时的KV Cache占用显存非常可观,有时甚至超过模型权重本身。即使模型权重被量化了,KV Cache通常仍然以FP16存储。优化KV Cache是另一个重要的显存优化方向

四、低比特量化

在刚才,我们通过bitsandbytes体验了“开箱即用”的量化,这种方法(称之为朴素PTQ)虽然便捷,但在追求极致压缩时,精度下降会变得非常明显

想象一下,最简单直接的量化方法就是四舍五入,这非常直观,对吧

比如,我们有一组原始的、高精度的模型权重:

[ 1.7, 3.2, -0.9, 4.6 ]

如果我们把它们量化成整数,用四舍五入的方法,会得到:

[ 2, 3, -1, 5 ]

现在我们来看看误差在哪:

  • 1.7变成2,产生了+0.3的误差
  • 3.2变成3,产生了-0.2的误差
  • -0.9变成-1,产生了-0.1的误差
  • 4.6变成5,产生了+0.4的误差

问题就出现在这里,在这种方法中,我们处理1.7产生了+0.3的误差后,就拍拍屁股把它忘了,接着独立地去处理3.2,每一个数字的量化都是一个孤立的事件。如果权重很少很少影响还很大,但是现在的模型,哪个不是巨无霸,最后不出现问题才怪呢

为了解决这个问题,学术界和工业界的研究者们提出了许多更先进的PTQ算法,在这里将简单的介绍两种:GPTQ和AWQ,它们有着一个共同的目标:在低比特下,实现更高的模型精度

4.1 GPTQ

GPTQ(Generative Pre-trained Transformer Quantization)是一种先进的PTQ方法,它的核心思想非常巧妙:与其将所有权重独立地进行四舍五入量化,不如以一种更灵活的方式来处理量化误差

我们可以举一个团队合作的例子来直观的了解它的核心思想:

假设我们要把一堆石料(原始权重)运到山顶,但我们每次只能运一小块(量化后的权重)

  • **朴素量化:**我们每次都只管把一块石头运走,不考虑剩下的,结果就是,最后一共运上去的石料总量比原始的总量相差了一大截
  • **GPTQ:**我们每次运走一块石头,都会看看剩下的石头。如果这次运走的这块比它应有的分量少了一点,我们就会把这个“亏欠”的部分,分摊到剩下的那些还没运的那些石头上。这样我们后面的每次运输,都会多运或少运一点点,来弥补之前的亏欠

这个“补偿”的过程,就是GPTQ的精髓。它不是独立地处理每一个权重,而是把量化误差看作是一个需要被整体优化的系统问题

但是我们要清楚,GPTQ并不是简单地把误差加到下一列上。它更像一个智能的优化过程。它每次只量化矩阵中的一列(或者一小块)。在量化这一列的时候,同时会计算并考虑它对后面未量化部分的影响。然后,它会找到一个最佳的量化值,使得当前列和后续所有列的整体误差最小

GPTQ的逐列优化的核心实际上就是近似求解

我们知道要找到一个量化矩阵W_q使得W_q · X 与 原始的 W · X的均方误差最小化,这是一个复杂的数学问题,直接求解较为麻烦

GPTQ怎么做的?它并没有试图一次性解决所有问题,相反,它采用了一种近似解法,也就是所谓的逐列优化

简单来说,GPTQ将一个大问题分解成多个小问题,它会一列一列地处理权重矩阵W,在处理每一列的时候,它会:

  • 逐列量化:GPTQ 从权重矩阵的第一列开始,逐列进行量化
  • 计算误差:在量化每一列时,它会计算当前列量化前后产生的误差
  • Hessian 矩阵的近似:GPTQ 会计算一个近似的 Hessian 矩阵。这个近似的 Hessian 矩阵实际上是通过输入数据 X 和权重矩阵 W 计算出来的。它提供了关于不同权重对模型输出影响的敏感度信息
  • 利用Hessian矩阵进行优化:在量化当前列时,GPTQ 会使用这个近似的 Hessian 矩阵的逆矩阵。这个逆矩阵会告诉我们,为了补偿当前列的量化误差,应该如何最有效地调整后续列的量化值。它确保了调整的方向是正确的,并且能最大程度地减少对全局输出的影响

这就是GPTQ逐列优化的精髓,它利用了OBQ的算法,在量化每一列时,会同时计算一个更新矩阵(或者说更新向量),这个更新矩阵包含了当前列量化误差的信息。然后,它用这个更新矩阵去修改后面所有未量化的列,从而指导它们在量化时做出更好的决策

优点:

  • **高精度:**在3-bit和4-bit量化上,GPTQ通常比朴素PTQ的量化精度要更高,它处理了量化误差
  • **推理速度快:**GPTQ的量化结果是静态的INT权重,可以打包成高效的推理格式(比如Marlin核),在推理时没有额外的反量化开销

缺点:

  • **量化过程慢:**GPTQ的量化过程需要计算Hessian矩阵的逆,虽然有近似,但仍然比朴素的量化方法耗时得多

4.2 AWQ

AWQ的研究者们观察到了一个有趣的现象:在LLM中,并不是所有的权重都同等重要,那些与显著激活值相乘的权重,对模型的性能影响更大

核心思想:

普通量化会一视同仁地压缩所有通道,结果有些幅度很大的通道在量化时损失特别严重,性能下降明显。AWQ通过校准数据,先找出那些幅度异常大的通道,把相关的权重标记为重要权重。在量化时,对这些重要权重给予更高的保真度(比如保留精度,或者特殊缩放处理),从而减少量化误差

做法:

  • 识别重要权重:通过观察一小部分校准数据,AWQ会分析激活值的分布,找到那些数值幅度异常大的“离群特征通道”。与这些通道相关的权重,就是重要权重

    • 这里容易混淆的两个概念:
      • 异常值:通常指的是统计意义上“极端、不符合分布规律“的点,比如一堆值都在-1,1,结果突然冒出一个100,这个一般被认为是噪声或异常
      • 离散特征通道:AWQ里并不是把这些通道当作错误或异常来排除,而是把它们视为对模型影响极大的关键特征
    • 为什么AWQ要找数值幅度异常大的通道
      • 在Neural Network里面,每一层都会输出很多通道,有些通道的激活值分布相对均匀,不会对后续造成特别大的影响,但有些通道的值在某些样本上特别大(比如远大于其它通道的均值范围),这说明这些通道的信息非常强,模型高度依赖它们来做决策
      • 换句话说:
      • 异常值 → 可能是噪声,要去掉
      • 离群通道 → 模型的关键点,要特别对待
  • 激活感知缩放:AWQ不直接量化权重,而是引入一个缩放因子s,它对权重进行 W' = W / s 操作,同时对相邻层的激活值进行X' = X * s操作。这样从计算结果上看,模型的等价性并没有被破坏:

    • W ′ ∗ X ′ = ( W s ) ∗ ( X ∗ s ) = W ∗ X W' * X' = (\frac{W}{s})*(X*s)=W*X WX=(sW)(Xs)=WX
  • 保护重要权重:AWQ会精心选择这个缩放因子s,使得那些重要权重在缩放后,其数值范围变得更小,从而在量化时的精度损失也更小,它相当于牺牲了一部分不重要权重的量化精度,来换取重要权重的高保真度

    • 重点:如何选择s?
      • 如果我们设置s=10,那原来权重 =100 变成10,很容易量化
      • 但同时原来权重0.01变成0.001,可能在量化时就直接丢失精度
    • AWQ的策略:
      • 既然有些权重(重要权重)对模型至关重要,就用缩放因子去压小它们,让量化更精确
      • 不重要的权重即使被量化掉一些,也不会显著影响模型性能

五、量化 + 剪枝 + 蒸馏

在前面四章,我们从底层原理到上层应用详细的探讨了模型量化,但是在追求极致模型效率的道路上,量化并非唯一利器

在最后一章,我们可以试着站在一个更高的维度,审视模型压缩领域的“三剑客”:量化(Quantization)、剪枝(Pruning)和蒸馏(Distillation)

我们先简单回顾这三种技术的核心思想:

量化

  • 核心思想:降低模型中数值的表示精度,用更少的比特数来存储和计算权重和激活值
  • 作用对象:模型中的每一个数值(权重和激活值)
  • 效果:
    • 减小模型体积:直接降低存储空间
    • 加速计算:利用硬件对整型运算的支持
    • 降低功耗:整型运算能耗更低

剪枝
在这里插入图片描述

  • 核心思想:移除模型中冗余或不重要的部分。这些部分可以是单个权重、神经元,甚至是整个网络层
  • 作用对象:模型的结构
  • 效果:
    • 减小模型参数量:直接减少需要存储和计算的参数
    • 降低FLOPs:如果进行结构化剪枝,可以实实在在地减少运算次数
    • 可能加速推理:需要特定的稀疏计算库或硬件支持才能获得显著加速

知识蒸馏
在这里插入图片描述

  • 核心思想:将一个大型、复杂的“教师模型”(Teacher Model)的知识,迁移到一个小型的“学生模型”(Student Model)
  • 作用对象:模型的知识传递
  • 效果:
    • 获得一个更小的模型:学生模型的结构可以远小于教师模型
    • 提升小模型性能:学生模型通过模仿教师模型的输出(logits)或中间表示,可以学到比单独训练时更多的“软知识”,从而达到远超其参数量所对应的性能

为什么要结合使用?

这三种技术并非相互排斥,恰恰相反,它们从不同维度对模型进行优化,存在极强的互补性

1.剪枝/蒸馏先行,量化收尾:

  • 剪枝/蒸馏主要作用于模型的宏观结构和知识容量。我们可以先通过知识蒸馏获得一个性能优越的小型学生模型,或者通过剪枝得到一个稀疏但高效的网络结构。这个阶段的目标是确定一个最佳的、更小的模型骨架
  • 量化则作用于模型的微观数值表示。在确定了模型骨架后,再对这个更小的模型进行量化,可以进一步压缩其体积并加速运算
  • 好处:直接量化一个巨大的模型可能会因为模型本身的冗余度而导致次优的结果。先通过剪枝或蒸馏去除冗余,可以让量化过程更加高效,因为量化的对象本身已经是一个更“精华”的模型

2、量化对剪枝的友好性:

剪枝后的模型权重通常是稀疏的(包含大量的0)。在某些硬件平台上,对稀疏矩阵进行计算可能并不会比稠密矩阵快多少,但如果我们将这个稀疏的矩阵再进行量化,一方面模型体积进一步减小,另一方面,硬件可以利用更高效的稀疏INT8/INT4计算核心,从而实现真正的加速

3、蒸馏指导量化/剪枝

在进行QAT或对剪枝后的模型进行微调时,可以引入知识蒸馏。让教师模型的输出作为额外的监督信号,可以帮助学生模型(即正在被量化或剪枝的模型)更好地恢复精度,缓解因压缩带来的性能损失


至此,模型量化系列就告一段落了!
请记住,模型优化没有“银弹”。它更像一门艺术,需要工程师根据具体的业务场景、硬件限制和性能目标,灵活地组合使用量化、剪枝、蒸馏等多种工具,在模型大小、推理速度和预测精度之间,找到那个最佳的平衡点!

Logo

有“AI”的1024 = 2048,欢迎大家加入2048 AI社区

更多推荐