在使用 ONNX(Open Neural Network Exchange)模型时,可以通过 onnxruntime 或直接解析 ONNX 模型来获取输入和输出的信息(如名称、形状、数据类型等),并以 name 为键构建字典。


方法 1:使用 onnxruntime 获取输入/输出信息

  • 适用于实际推理场景,信息更直观。
  • 使用 InferenceSession.get_inputs()InferenceSession.get_outputs() 获取输入输出信息。
import onnxruntime as ort

# 加载 ONNX 模型
onnx_model_path = "model.onnx"
sess = ort.InferenceSession(onnx_model_path)

# 获取所有输入信息
inputs_info = sess.get_inputs()
input_dict = {input_info.name: {
    "shape": input_info.shape,
    "type": input_info.type,
    "dtype": input_info.type  # 可能需要进行类型转换
} for input_info in inputs_info}

# 获取所有输出信息
outputs_info = sess.get_outputs()
output_dict = {output_info.name: {
    "shape": output_info.shape,
    "type": output_info.type,
    "dtype": output_info.type  # 可能需要进行类型转换
} for output_info in outputs_info}

print("Inputs:", input_dict)
print("Outputs:", output_dict)

输出示例

Inputs: {
    'input1': {'shape': [1, 3, 224, 224], 'type': 'tensor(float)', 'dtype': 'tensor(float)'},
    'input2': {'shape': [1, 10], 'type': 'tensor(int64)', 'dtype': 'tensor(int64)'}
}

Outputs: {
    'output1': {'shape': [1, 1000], 'type': 'tensor(float)', 'dtype': 'tensor(float)'},
    'output2': {'shape': [1], 'type': 'tensor(int64)', 'dtype': 'tensor(int64)'}
}

方法 2:直接解析 ONNX 模型(不依赖 onnxruntime

  • 适用于模型分析,只静态解析模型的结构,不依赖 onnxruntime
  • 使用 onnx.load() 加载模型,然后从 graph.inputgraph.output 提取信息。
  • ONNX 使用 tensor(float) 等字符串表示类型,而 onnx.mapping.TENSOR_TYPE_TO_NP_TYPE 可以映射到 numpy 类型。
import onnx

# 加载 ONNX 模型
onnx_model = onnx.load("model.onnx")

# 获取输入信息
input_dict = {
    input.name: {
        "shape": [dim.dim_value for dim in input.type.tensor_type.shape.dim],
        "dtype": onnx.mapping.TENSOR_TYPE_TO_NP_TYPE[input.type.tensor_type.elem_type]
    }
    for input in onnx_model.graph.input
}

# 获取输出信息
output_dict = {
    output.name: {
        "shape": [dim.dim_value for dim in output.type.tensor_type.shape.dim],
        "dtype": onnx.mapping.TENSOR_TYPE_TO_NP_TYPE[output.type.tensor_type.elem_type]
    }
    for output in onnx_model.graph.output
}

print("Inputs:", input_dict)
print("Outputs:", output_dict)

输出示例

Inputs: {
    'input1': {'shape': [1, 3, 224, 224], 'dtype': numpy.float32},
    'input2': {'shape': [1, 10], 'dtype': numpy.int64}
}

Outputs: {
    'output1': {'shape': [1, 1000], 'dtype': numpy.float32},
    'output2': {'shape': [1], 'dtype': numpy.int64}
}

总结

方法 适用场景 依赖库 获取方式
onnxruntime 推理时获取输入输出 onnxruntime sess.get_inputs(), sess.get_outputs()
直接解析 ONNX 静态分析模型结构 onnx onnx_model.graph.input, onnx_model.graph.output
Logo

有“AI”的1024 = 2048,欢迎大家加入2048 AI社区

更多推荐