获取ONNX模型的输入/输出信息
在使用 ONNX(Open Neural Network Exchange)模型时,可以通过。或直接解析 ONNX 模型来获取输入和输出的信息(如名称、形状、数据类型等),并以。
·
在使用 ONNX(Open Neural Network Exchange)模型时,可以通过 onnxruntime 或直接解析 ONNX 模型来获取输入和输出的信息(如名称、形状、数据类型等),并以 name 为键构建字典。
方法 1:使用 onnxruntime 获取输入/输出信息
- 适用于实际推理场景,信息更直观。
- 使用
InferenceSession.get_inputs()和InferenceSession.get_outputs()获取输入输出信息。
import onnxruntime as ort
# 加载 ONNX 模型
onnx_model_path = "model.onnx"
sess = ort.InferenceSession(onnx_model_path)
# 获取所有输入信息
inputs_info = sess.get_inputs()
input_dict = {input_info.name: {
"shape": input_info.shape,
"type": input_info.type,
"dtype": input_info.type # 可能需要进行类型转换
} for input_info in inputs_info}
# 获取所有输出信息
outputs_info = sess.get_outputs()
output_dict = {output_info.name: {
"shape": output_info.shape,
"type": output_info.type,
"dtype": output_info.type # 可能需要进行类型转换
} for output_info in outputs_info}
print("Inputs:", input_dict)
print("Outputs:", output_dict)
输出示例
Inputs: {
'input1': {'shape': [1, 3, 224, 224], 'type': 'tensor(float)', 'dtype': 'tensor(float)'},
'input2': {'shape': [1, 10], 'type': 'tensor(int64)', 'dtype': 'tensor(int64)'}
}
Outputs: {
'output1': {'shape': [1, 1000], 'type': 'tensor(float)', 'dtype': 'tensor(float)'},
'output2': {'shape': [1], 'type': 'tensor(int64)', 'dtype': 'tensor(int64)'}
}
方法 2:直接解析 ONNX 模型(不依赖 onnxruntime)
- 适用于模型分析,只静态解析模型的结构,不依赖
onnxruntime。 - 使用
onnx.load()加载模型,然后从graph.input和graph.output提取信息。 - ONNX 使用
tensor(float)等字符串表示类型,而onnx.mapping.TENSOR_TYPE_TO_NP_TYPE可以映射到numpy类型。
import onnx
# 加载 ONNX 模型
onnx_model = onnx.load("model.onnx")
# 获取输入信息
input_dict = {
input.name: {
"shape": [dim.dim_value for dim in input.type.tensor_type.shape.dim],
"dtype": onnx.mapping.TENSOR_TYPE_TO_NP_TYPE[input.type.tensor_type.elem_type]
}
for input in onnx_model.graph.input
}
# 获取输出信息
output_dict = {
output.name: {
"shape": [dim.dim_value for dim in output.type.tensor_type.shape.dim],
"dtype": onnx.mapping.TENSOR_TYPE_TO_NP_TYPE[output.type.tensor_type.elem_type]
}
for output in onnx_model.graph.output
}
print("Inputs:", input_dict)
print("Outputs:", output_dict)
输出示例
Inputs: {
'input1': {'shape': [1, 3, 224, 224], 'dtype': numpy.float32},
'input2': {'shape': [1, 10], 'dtype': numpy.int64}
}
Outputs: {
'output1': {'shape': [1, 1000], 'dtype': numpy.float32},
'output2': {'shape': [1], 'dtype': numpy.int64}
}
总结
| 方法 | 适用场景 | 依赖库 | 获取方式 |
|---|---|---|---|
onnxruntime |
推理时获取输入输出 | onnxruntime |
sess.get_inputs(), sess.get_outputs() |
| 直接解析 ONNX | 静态分析模型结构 | onnx |
onnx_model.graph.input, onnx_model.graph.output |
更多推荐


所有评论(0)