ONNX模型格式详解
ONNX(开放神经网络交换)是一种跨平台的深度学习模型中间表示格式,用于解决不同AI框架间的兼容性问题。它允许将PyTorch、TensorFlow等框架训练的模型转换为统一格式,便于部署到不同环境(如TensorRT、移动设备等)。通过torch.onnx.export可轻松将PyTorch模型转为ONNX格式,并使用ONNXRuntime进行高效推理。ONNX不仅实现框架间互操作性,还能通过优
1. 什么是onnx格式
ONNX(Open Neural Network Exchange),开放神经网络交换,是一种模型中间表示,用于在各种深度学习训练和推理框架转换的一个中间表示格式。在实际业务中,可以使用Pytorch或者TensorFlow训练模型,导出成ONNX格式,然后在转换成目标设备上支撑的模型格式,比如TensorRT Engine、NCNN、MNN等格式。ONNX定义了一组和环境,平台均无关的标准格式,来增强各种AI模型的可交互性,开放性较强。一句话来说就是ONNX 是一个用于表示深度学习模型的开放标准格式。它的核心目标是解决 AI 生态系统中一个关键问题:框架碎片化。
2.为什么需要这种格式
在 AI 开发中,程师会使用各种不同的框架来训练模型,例如:
-
PyTorch: 以灵活性和易用性著称,深受研究人员喜爱。
-
TensorFlow: 拥有强大的生产环境部署工具链。
-
Keras: 高级 API,通常以 TensorFlow 为后端。
-
Scikit-learn: 传统的机器学习库。
-
PaddlePaddle、MXNet 等等。
假设你在 PyTorch 中训练了一个非常优秀的图像分类模型,但现在你需要将它部署到一个使用 TensorFlow Serving 的生产环境中,或者部署到手机、嵌入式设备等资源受限的环境中。这时候就需要将pytorch的模型转化为onnx格式,然后被目标模型技术框架来加载。
同时ONNX 不仅仅是一个文件格式,它还与 ONNX Runtime 紧密相关。ONNX Runtime 是一个为高性能推理优化的跨平台引擎。即使你的模型来自 PyTorch,通过 ONNX Runtime 在 CPU、GPU 或其它硬件加速器上运行,可能比在原框架中运行得更快,因为它针对推理进行了大量优化。
3.如何实现模型格式转化(以pytorch2onnx为例)
import torch
import torchvision
# 加载预训练模型
model = torchvision.models.resnet18(pretrained=True)
model.eval() # 设置为评估模式
# 这里的model可以是任意的pytorch框架训练出来的模型
# 创建示例输入
dummy_input = torch.randn(1, 3, 224, 224)
# 导出为ONNX
torch.onnx.export(
model, # 要转换的模型
dummy_input, # 示例输入
"resnet18.onnx", # 输出文件名
input_names=["input"], # 输入节点名称
output_names=["output"], # 输出节点名称
dynamic_axes={ # 动态维度(可选)
"input": {0: "batch_size"},
"output": {0: "batch_size"}
},
opset_version=11, # ONNX算子集版本
verbose=True # 显示详细信息
)
print("PyTorch模型已成功转换为ONNX格式")
4. 验证转化后onnx是正常的
import onnx
import onnxruntime as ort
import numpy as np
# 加载ONNX模型
onnx_model = onnx.load("resnet18.onnx")
onnx.checker.check_model(onnx_model) # 检查模型有效性
print("ONNX模型验证成功")
# 使用ONNX Runtime进行推理测试
session = ort.InferenceSession("resnet18.onnx")
input_name = session.get_inputs()[0].name
output_name = session.get_outputs()[0].name
# 创建测试输入
test_input = np.random.randn(1, 3, 224, 224).astype(np.float32)
# 运行推理
results = session.run([output_name], {input_name: test_input})
print("推理测试成功,输出形状:", results[0].shape)
更多推荐
所有评论(0)