1. 什么是onnx格式

ONNX(Open Neural Network Exchange),开放神经网络交换,是一种模型中间表示,用于在各种深度学习训练和推理框架转换的一个中间表示格式。在实际业务中,可以使用Pytorch或者TensorFlow训练模型,导出成ONNX格式,然后在转换成目标设备上支撑的模型格式,比如TensorRT Engine、NCNN、MNN等格式。ONNX定义了一组和环境,平台均无关的标准格式,来增强各种AI模型的可交互性,开放性较强。一句话来说就是ONNX 是一个用于表示深度学习模型的开放标准格式。它的核心目标是解决 AI 生态系统中一个关键问题:框架碎片化。

2.为什么需要这种格式

在 AI 开发中,程师会使用各种不同的框架来训练模型,例如:

  • PyTorch: 以灵活性和易用性著称,深受研究人员喜爱。

  • TensorFlow: 拥有强大的生产环境部署工具链。

  • Keras: 高级 API,通常以 TensorFlow 为后端。

  • Scikit-learn: 传统的机器学习库。

  • PaddlePaddleMXNet 等等。

假设你在 PyTorch 中训练了一个非常优秀的图像分类模型,但现在你需要将它部署到一个使用 TensorFlow Serving 的生产环境中,或者部署到手机、嵌入式设备等资源受限的环境中。这时候就需要将pytorch的模型转化为onnx格式,然后被目标模型技术框架来加载。

同时ONNX 不仅仅是一个文件格式,它还与 ONNX Runtime 紧密相关。ONNX Runtime 是一个为高性能推理优化的跨平台引擎。即使你的模型来自 PyTorch,通过 ONNX Runtime 在 CPU、GPU 或其它硬件加速器上运行,可能比在原框架中运行得更快,因为它针对推理进行了大量优化。

3.如何实现模型格式转化(以pytorch2onnx为例)

import torch
import torchvision

# 加载预训练模型
model = torchvision.models.resnet18(pretrained=True)
model.eval()  # 设置为评估模式
# 这里的model可以是任意的pytorch框架训练出来的模型

# 创建示例输入
dummy_input = torch.randn(1, 3, 224, 224)

# 导出为ONNX
torch.onnx.export(
    model,                      # 要转换的模型
    dummy_input,                # 示例输入
    "resnet18.onnx",           # 输出文件名
    input_names=["input"],      # 输入节点名称
    output_names=["output"],    # 输出节点名称
    dynamic_axes={              # 动态维度(可选)
        "input": {0: "batch_size"},
        "output": {0: "batch_size"}
    },
    opset_version=11,          # ONNX算子集版本
    verbose=True               # 显示详细信息
)

print("PyTorch模型已成功转换为ONNX格式")

4. 验证转化后onnx是正常的

import onnx
import onnxruntime as ort
import numpy as np

# 加载ONNX模型
onnx_model = onnx.load("resnet18.onnx")
onnx.checker.check_model(onnx_model)  # 检查模型有效性
print("ONNX模型验证成功")

# 使用ONNX Runtime进行推理测试
session = ort.InferenceSession("resnet18.onnx")
input_name = session.get_inputs()[0].name
output_name = session.get_outputs()[0].name

# 创建测试输入
test_input = np.random.randn(1, 3, 224, 224).astype(np.float32)

# 运行推理
results = session.run([output_name], {input_name: test_input})
print("推理测试成功,输出形状:", results[0].shape)

Logo

有“AI”的1024 = 2048,欢迎大家加入2048 AI社区

更多推荐