摘要:在深度学习生态系统中,模型的跨框架互操作性至关重要。ONNX(Open Neural Network Exchange)作为开放的模型中间表示标准,已被 PyTorch、TensorFlow、MXNet 等主流框架广泛支持。其底层采用 Google Protobuf 进行高效序列化。GE(Graph Engine)作为 CANN 开源生态中的图编译器与执行器,提供了对 ONNX/Protobuf 模型的原生解析能力,能够将通用模型无缝转换为内部计算图,进而进行优化与执行。本文将深入剖析 GE 中 ONNX/Protobuf 解析器的设计架构、核心组件与实现细节,并通过一个完整的实战案例——从加载 ONNX 模型到构建内部图表示,逐步演示解析全过程。文中包含清晰的流程图、关键数据结构定义、解析算法伪代码及兼容性表格,帮助开发者理解模型导入机制,并为扩展自定义算子解析提供实践指南。


一、为什么需要 ONNX/Protobuf 解析?

1.1 模型互操作性的挑战

不同深度学习框架使用各自的模型格式:

  • PyTorch.pt / .pth
  • TensorFlow:SavedModel / .pb
  • Keras:HDF5

直接部署需针对每种格式开发专用加载器,维护成本高。

1.2 ONNX 的优势

ONNX 作为中间表示(IR)解决此问题:

  • 统一格式:所有框架可导出为 ONNX;
  • 语言无关:基于 Protobuf,支持多语言解析;
  • 算子标准化:定义了 150+ 标准算子(OpSet)。

✅ “一次导出,处处运行”。


二、GE 模型解析整体架构

GE 的解析器采用分层设计,解耦格式解析与语义构建。

ONNX .onnx 文件

Protobuf 反序列化

ONNX Graph 遍历

算子映射器

内部图构建器

GE 计算图

核心组件

组件 职责
ONNX Parser 加载 .onnx 文件并反序列化
Node Visitor 遍历 ONNX 节点
Op Mapper 将 ONNX 算子映射到 GE 内部算子
Graph Builder 构建 GE 的 ComputeGraph

三、ONNX 模型结构解析

3.1 ONNX Protobuf 定义

ONNX 模型的核心是 ModelProto

message ModelProto {
  string ir_version = 1;
  repeated OperatorSetIdProto opset_import = 8;
  string producer_name = 2;
  GraphProto graph = 7;  // ← 核心计算图
}

message GraphProto {
  string name = 1;
  repeated ValueInfoProto input = 10;
  repeated ValueInfoProto output = 11;
  repeated NodeProto node = 6;  // ← 算子列表
  repeated TensorProto initializer = 5; // ← 常量权重
}

3.2 关键数据结构

NodeProto(算子节点)
message NodeProto {
  repeated string input = 1;   // 输入张量名
  repeated string output = 2;  // 输出张量名
  string op_type = 3;          // 算子类型 (如 "Conv")
  string domain = 4;           // 域 (默认为空)
  repeated AttributeProto attribute = 5; // 属性 (如 kernel_shape)
}
AttributeProto(算子属性)
message AttributeProto {
  string name = 1;
  enum AttributeType {
    FLOAT = 1;
    INT = 2;
    STRING = 3;
    TENSOR = 4;
    GRAPH = 5;
    FLOATS = 6;
    INTS = 7;
    STRINGS = 8;
  }
  AttributeType type = 20;
  // oneof value {
  float f = 2;
  int64 i = 3;
  bytes s = 4;
  TensorProto t = 5;
  GraphProto g = 6;
  repeated float floats = 7;
  repeated int64 ints = 8;
  // }
}

🔑 oneof 确保属性值类型安全。


四、GE 解析器核心实现

4.1 解析入口:加载 ONNX 文件

// onnx_parser.cc
#include "onnx/onnx.pb.h"
#include "ge/graph_builder.h"

Status OnnxParser::Parse(const std::string& model_path, ComputeGraph* graph) {
  // 1. 读取文件
  std::ifstream input(model_path, std::ios::binary);
  std::string model_data((std::istreambuf_iterator<char>(input)),
                         std::istreambuf_iterator<char>());
  
  // 2. Protobuf 反序列化
  onnx::ModelProto onnx_model;
  if (!onnx_model.ParseFromString(model_data)) {
    return Status::ERROR("Failed to parse ONNX model");
  }
  
  // 3. 构建内部图
  return BuildGraph(onnx_model.graph(), graph);
}

⚠️ 需链接 libprotobuf 库。

4.2 图构建:遍历节点

// graph_builder.cc
Status GraphBuilder::BuildGraph(const onnx::GraphProto& onnx_graph, 
                                ComputeGraph* ge_graph) {
  // 1. 注册输入/输出
  for (const auto& input : onnx_graph.input()) {
    ge_graph->AddInput(input.name());
  }
  for (const auto& output : onnx_graph.output()) {
    ge_graph->AddOutput(output.name());
  }
  
  // 2. 处理常量权重 (initializers)
  std::unordered_map<std::string, Tensor> weights;
  for (const auto& tensor : onnx_graph.initializer()) {
    weights[tensor.name()] = ConvertOnnxTensor(tensor);
  }
  
  // 3. 遍历算子节点
  for (const auto& node : onnx_graph.node()) {
    RETURN_IF_ERROR(ProcessNode(node, weights, ge_graph));
  }
  
  return Status::OK();
}

五、算子映射:ONNX → GE

5.1 映射表设计

GE 使用注册表模式管理算子映射:

// op_mapper_registry.h
class OpMapperRegistry {
 public:
  static OpMapperRegistry& Instance() {
    static OpMapperRegistry instance;
    return instance;
  }
  
  void Register(const std::string& onnx_op, OpMapperFunc func) {
    mappers_[onnx_op] = func;
  }
  
  Status Map(const onnx::NodeProto& node, ComputeGraph* graph) {
    auto it = mappers_.find(node.op_type());
    if (it == mappers_.end()) {
      return Status::ERROR("Unsupported op: " + node.op_type());
    }
    return it->second(node, graph);
  }

 private:
  std::unordered_map<std::string, OpMapperFunc> mappers_;
};

5.2 注册标准算子

在初始化时注册所有支持的算子:

// op_mappers.cc
void RegisterStandardOps() {
  auto& registry = OpMapperRegistry::Instance();
  registry.Register("Conv", MapConvOp);
  registry.Register("Relu", MapReluOp);
  registry.Register("MatMul", MapMatMulOp);
  // ... 其他算子
}

Status MapConvOp(const onnx::NodeProto& node, ComputeGraph* graph) {
  // 1. 提取属性
  auto attrs = ParseAttributes(node.attribute());
  int64_t group = GetAttr(attrs, "group", 1);
  std::vector<int64_t> kernel_shape = GetAttr(attrs, "kernel_shape");
  
  // 2. 创建 GE Conv 节点
  auto conv_node = graph->CreateNode("Conv");
  conv_node->SetAttr("kernel_shape", kernel_shape);
  conv_node->SetAttr("group", group);
  
  // 3. 连接输入/输出
  for (const auto& input : node.input()) {
    conv_node->AddInput(input);
  }
  for (const auto& output : node.output()) {
    conv_node->AddOutput(output);
  }
  
  return Status::OK();
}

💡 ParseAttributes 处理 AttributeProtooneof


六、实战:解析一个简单 ONNX 模型

6.1 导出 PyTorch 模型为 ONNX

# export_model.py
import torch
import torch.nn as nn

class SimpleNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv = nn.Conv2d(3, 64, kernel_size=3)
        self.relu = nn.ReLU()
    
    def forward(self, x):
        return self.relu(self.conv(x))

model = SimpleNet()
dummy_input = torch.randn(1, 3, 224, 224)
torch.onnx.export(
    model, dummy_input, "simple_net.onnx",
    input_names=["input"], output_names=["output"],
    opset_version=11
)

生成的 ONNX 包含:

  • 1 个 Conv 节点
  • 1 个 Relu 节点
  • 1 个权重 initializer (conv.weight)

6.2 GE 解析代码

// main.cc
#include "ge/onnx_parser.h"
#include "ge/compute_graph.h"

int main() {
  ComputeGraph graph;
  OnnxParser parser;
  
  // 解析 ONNX 模型
  auto status = parser.Parse("simple_net.onnx", &graph);
  if (!status.ok()) {
    std::cerr << "Parse failed: " << status.message() << std::endl;
    return -1;
  }
  
  // 打印内部图
  std::cout << "Inputs: ";
  for (const auto& input : graph.inputs()) {
    std::cout << input << " ";
  }
  std::cout << "\nNodes:\n";
  for (const auto& node : graph.nodes()) {
    std::cout << "- " << node->type() << " (" 
              << node->inputs().size() << " inputs)\n";
  }
  
  return 0;
}

预期输出:

Inputs: input 
Nodes:
- Conv (2 inputs)
- Relu (1 input)

✅ 成功将 ONNX 转换为 GE 内部图。


七、高级特性:自定义算子与子图

7.1 自定义算子支持

对于非标准算子(如 MyCustomOp),可通过扩展映射表:

// custom_op_mapper.cc
Status MapMyCustomOp(const onnx::NodeProto& node, ComputeGraph* graph) {
  // 自定义解析逻辑
  auto custom_node = graph->CreateNode("MyCustomOp");
  // ... 设置属性 ...
  return Status::OK();
}

// 在初始化时注册
OpMapperRegistry::Instance().Register("MyCustomOp", MapMyCustomOp);

7.2 子图处理(控制流)

ONNX 支持 If/Loop 等控制流算子,其属性包含子图:

message NodeProto {
  // ...
  repeated AttributeProto attribute = 5;
}

message AttributeProto {
  // ...
  GraphProto g = 6;  // ← 子图
}

GE 递归解析子图:

Status MapIfOp(const onnx::NodeProto& node, ComputeGraph* graph) {
  auto if_node = graph->CreateNode("If");
  
  // 解析 then_branch 和 else_branch
  for (const auto& attr : node.attribute()) {
    if (attr.name() == "then_branch") {
      ComputeGraph then_graph;
      BuildGraph(attr.g(), &then_graph);  // ← 递归调用
      if_node->SetSubgraph("then", then_graph);
    }
    // ... else_branch ...
  }
  
  return Status::OK();
}

八、兼容性与 OpSet 管理

8.1 OpSet 版本挑战

ONNX 算子行为随 OpSet 版本演进:

  • BatchNormalization v9: 输出 1 个 tensor
  • BatchNormalization v14: 可选输出 5 个 tensors

GE 通过版本感知映射处理:

// op_mapper_registry.cc
void RegisterOpWithVersion(const std::string& op, int min_version, 
                          int max_version, OpMapperFunc func) {
  versioned_mappers_.push_back({op, min_version, max_version, func});
}

Status Map(const onnx::NodeProto& node, int opset_version, 
           ComputeGraph* graph) {
  for (const auto& mapper : versioned_mappers_) {
    if (mapper.op == node.op_type() && 
        opset_version >= mapper.min_version && 
        opset_version <= mapper.max_version) {
      return mapper.func(node, graph);
    }
  }
  return Status::ERROR("Unsupported op version");
}

8.2 支持的 OpSet 范围

OpSet 版本 支持状态 备注
1-10 基础支持 覆盖 CNN/RNN
11-15 完全支持 推荐使用
16-18 部分支持 新算子需扩展
≥19 实验性 需手动验证

📌 建议导出时指定 opset_version=15


九、性能优化:延迟加载与缓存

9.1 大模型挑战

LLM 的 ONNX 模型可达数十 GB,全量加载内存压力大。

GE 采用按需加载策略:

  • 先解析图结构(轻量级);
  • 权重在执行时才加载。
// lazy_weight_loader.h
class LazyWeightLoader {
 public:
  Tensor LoadWeight(const std::string& name) {
    if (cache_.count(name)) {
      return cache_[name];
    }
    // 从磁盘读取特定 tensor
    auto tensor = ReadTensorFromOnnxFile(model_path_, name);
    cache_[name] = tensor;
    return tensor;
  }
 private:
  std::unordered_map<std::string, Tensor> cache_;
};

9.2 内存映射加速

对于 SSD 存储,使用 mmap 避免拷贝:

// memory_mapped_file.cc
class MemoryMappedFile {
 public:
  MemoryMappedFile(const std::string& path) {
    file_ = open(path.c_str(), O_RDONLY);
    struct stat sb;
    fstat(file_, &sb);
    data_ = mmap(nullptr, sb.st_size, PROT_READ, MAP_PRIVATE, file_, 0);
  }
  
  const char* data() const { return static_cast<const char*>(data_); }
  
 private:
  int file_;
  void* data_;
};

十、调试与验证工具

10.1 模型可视化

GE 提供 ONNX 图转 DOT 工具:

ge-cli --onnx-model simple_net.onnx --export-dot graph.dot
dot -Tpng graph.dot -o graph.png

生成可视化计算图,便于调试。

10.2 数值一致性检查

解析后自动验证:

// 随机生成输入
auto input = GenerateRandomInput();

// 运行 ONNX Runtime 获取参考输出
auto ref_output = RunOnnxRuntime("simple_net.onnx", input);

// 运行 GE 获取测试输出
auto test_output = RunGeGraph(graph, input);

// 检查数值误差
assert(NearEqual(ref_output, test_output, 1e-5));

确保解析无精度损失。


十一、常见问题与解决方案

问题 原因 解决方案
“Unsupported op” 算子未注册 扩展 OpMapper 或升级 GE
权重加载慢 全量加载大模型 启用延迟加载
属性解析错误 OpSet 版本不匹配 指定兼容的 opset_version

十二、未来方向

  1. 动态形状支持:解析带符号的动态维度;
  2. 量化模型解析:原生支持 QLinear 算子;
  3. 多文件 ONNX:支持 external data 格式;
  4. Python 绑定:简化解析 API。

结语

ONNX/Protobuf 解析是连接外部模型与内部执行引擎的“桥梁”。GE 通过模块化、可扩展的解析架构,不仅支持标准算子,还为自定义扩展预留了空间。在 AI 模型日益多样化的今天,掌握模型解析技术,意味着你拥有了打通不同框架壁垒的能力。正如一句开源格言:“Interoperability is the key to innovation.” 而 GE 的解析器,正是那把开启互操作之门的钥匙。


探索 GE 源码与贡献解析器,请访问:

Logo

有“AI”的1024 = 2048,欢迎大家加入2048 AI社区

更多推荐