量化感知训练:如何恢复低精度模型的准确性

深度学习模型在训练后,通常会使用多种压缩技术进行优化,以加速其在边缘设备和数据中心上的部署。其中最常见的方法是后训练量化(Post-Training Quantization, PTQ),它通过数值缩放技术将模型权重近似为低精度数据类型(如 INT8、INT4 等)。PTQ 的优点在于其无需重新训练,简单高效。然而,当模型对量化误差特别敏感时,PTQ 可能会导致显著的精度下降,从而影响模型的实际可用性。

为了解决这个问题,更高级的策略应运而生:量化感知训练(Quantization Aware Training, QAT)和量化感知蒸馏(Quantization Aware Distillation, QAD)。这两种方法通过在训练过程中模拟量化操作,使模型能够适应低精度算术,从而在量化后恢复甚至超越原始模型的精度。

本文将深入探讨 QAT 和 QAD 的原理,并结合 NVIDIA 的技术栈,展示如何通过这些方法实现高效的低精度模型部署。


PTQ、QAT 与 QAD 的决策流程

在模型部署前,我们可以根据精度要求和可用数据,遵循以下决策流程来选择最合适的量化方法。

  • 首先评估 PTQ:如果 PTQ 可以达到可接受的精度,则这是最简单且最快的方法,直接进行部署。
  • 如果 PTQ 精度不足:如果 PTQ 导致精度显著下降,那么下一步就是考虑 QAT。QAT 通过在训练中引入量化模拟,通常可以恢复大部分甚至全部精度。
  • 对于更复杂或精度要求更高的场景:可以考虑使用 QAD。QAD 在 QAT 的基础上,引入一个高精度的教师模型,通过知识蒸馏的方式指导低精度学生模型,进一步提升其精度和泛化能力。

这三种方法的目标都是在牺牲最小精度的同时,尽可能地减小模型大小和提高推理速度,从而为不同的硬件和应用场景提供最合适的部署方案。


量化感知训练(QAT)的工作原理

QAT 的核心思想是在训练过程中模拟量化操作的影响,使得模型的权重和激活能够适应低精度表示。

QAT 通常在模型预训练完成后进行,通过一个额外的微调阶段来实现。这个阶段通常只需要原始训练周期的一小部分。例如,对于大型语言模型(LLM),QAT 微调的时间甚至可以少于原始预训练时间的 1%。

在 QAT 中,前向传播使用“假量化”(fake quantized)的权重和激活值。这意味着低精度值在较高的数据类型中通过量化/反量化(quantize/dequantize, Q/DQ)操作来表示。

这种方法允许 QAT 自然地集成到现有的高精度训练流程中:

  • 前向传播:使用假量化操作模拟低精度计算,将量化误差暴露给损失函数。
  • 反向传播:梯度仍然使用高精度数据类型计算。量化操作被建模为直通估计(Straight-Through Estimator, STE),即在反向传播时将其视为恒等函数,这样梯度可以无阻碍地流过。

通过这种方式,QAT 允许模型学习如何适应舍入和截断误差,从而恢复精度。虽然这可能会带来一些额外的训练开销,但它提供了一个稳定且实用的训练过程,可以产生高精度的量化推理模型。


NVIDIA 技术栈下的 QAT 实现

NVIDIA 在其技术生态中提供了强大的工具来简化 QAT 的实现,其中最核心的是 NVIDIA TensorRT

NVIDIA TensorRT Model Optimizer 是一个关键组件,它提供了对多种量化格式(如 INT8、FP8、NVFP4)的支持。结合 PyTorch 和 Hugging Face 等流行框架,开发者可以轻松地将 QAT 集成到他们的模型训练和部署工作流中。

示例代码:使用 PyTorch 和 TensorRT Model Optimizer 进行 QAT

这个示例展示了如何在一个 PyTorch 模型上应用 QAT。虽然这是一个简化版本,但它演示了核心概念:在模型中插入 QuantStubDeQuantStub,并在微调过程中使用 torch.quantization API。

import torch
import torch.nn as nn
from torch.quantization import QuantStub, DeQuantStub

# 定义一个简单的模型
class SimpleNet(nn.Module):
    def __init__(self):
        super(SimpleNet, self).__init__()
        self.quant = QuantStub() # 量化操作的占位符
        self.conv = nn.Conv2d(3, 16, 3, 1)
        self.relu = nn.ReLU()
        self.pool = nn.MaxPool2d(2)
        self.dequant = DeQuantStub() # 反量化操作的占位符

    def forward(self, x):
        x = self.quant(x)
        x = self.conv(x)
        x = self.relu(x)
        x = self.pool(x)
        x = self.dequant(x)
        return x

# 实例化模型
model = SimpleNet()
model.train() # 切换到训练模式

# 准备 QAT
model.qconfig = torch.quantization.get_default_qconfig('fbgemm') # 使用FBGEMM的默认量化配置
torch.quantization.prepare_qat(model, inplace=True) # 准备进行QAT

# 假设我们有训练数据
# train_data = ...
# for data, target in train_data:
#     output = model(data)
#     loss = ...
#     optimizer.zero_grad()
#     loss.backward()
#     optimizer.step()

# 完成 QAT 微调后,将模型转换为量化版本
# quantized_model = torch.quantization.convert(model.eval())

深入:TensorRT 与 QAT 的协同作用

  • 模型优化:TensorRT 能够识别和融合 Q/DQ 操作,将其转换为高效的量化计算,从而充分利用 NVIDIA GPU 的低精度 Tensor Core。
  • FP8 和 NVFP4:随着 NVIDIA Blackwell 等新架构的推出,FP8 和 NVFP4 等低精度格式变得越来越重要。TensorRT-LLM 和其他框架将原生支持这些格式,结合 QAT 可以实现更高的压缩率和性能,同时保持卓越的精度。

量化感知蒸馏(QAD)的工作原理

**量化感知蒸馏(QAD)是 QAT 的一个增强版本,它将知识蒸馏(Knowledge Distillation)**的概念引入了量化训练过程。

在 QAD 中,一个高精度的教师模型(Teacher Model)和一个低精度的学生模型(Student Model)协同工作。

  • 学生模型:学生模型的计算是“假量化”的。
  • 教师模型:教师模型保持全精度。

QAD 的目标是使学生模型的输出与教师模型的输出对齐。任何由量化引入的失配都会直接暴露给蒸馏损失(distillation loss)。这个损失函数结合了标准的 QAT 损失和蒸馏损失,允许学生模型的低精度权重和激活向教师模型的行为进行调整。

QAD 在大型语言模型(LLM)的量化中尤其有效,因为它能帮助学生模型继承教师模型的复杂知识,从而在低精度下依然保持高质量的生成和理解能力。例如,在 Super Llama Nemotron 推理基准测试中,QAD 可以在保持基准精度的同时,在 Math-500 和 AIME 2024 等任务上超越 PTQ 的表现。


结论

量化感知训练(QAT)和量化感知蒸馏(QAD)是解决后训练量化(PTQ)中精度下降问题的强大技术。通过在训练过程中模拟量化操作,它们使模型能够主动适应低精度计算,从而在保持小模型体积和高推理性能的同时,有效恢复甚至提升模型精度。

借助 NVIDIA TensorRT 和 TensorRT Model Optimizer 等专业工具,开发者可以轻松地将这些先进的量化技术集成到他们的 AI 工作流中。特别是在 LLM 领域,结合 QAT/QAD 与 NVIDIA GPU 的低精度计算能力,可以极大地加速模型的部署,为更广泛的应用场景提供高性能、高效率的解决方案。

Logo

有“AI”的1024 = 2048,欢迎大家加入2048 AI社区

更多推荐