【端侧AI 与 C++】6. 使用更通用的推理引擎 ONNX Runtime (ORT) 跑通本地模型加载 - 封装
本文我们先将代码整理一下,将那个“一锅炖”的 main.cpp 拆解成专业的 C++ 类结构,方便我们后续的操作。
上篇文章我们本地跑通了更通用的本地推理框架 ONNX Runtime(ORT),并对代码进行了拆解和解读。
本文我们先将代码整理一下,将那个“一锅炖”的 main.cpp 拆解成专业的 C++ 类结构,方便我们后续的操作。
文章目录
0. 系列文章
1. 调整文件目录
按以下结构整理文件。
注意:我们将代码分为了 include (头文件) 和 src (源文件),这是 C++ 标准工程规范。
Week1_Refactor/
├── CMakeLists.txt # [修改] 构建规则变了
├── model/
│ └── yolov8n.onnx # 模型文件不动
├── third_party/
│ └── onnxruntime/ # 库文件不动
├── include/
│ └── YoloDetector.hpp # [新增] 只有函数的声明(菜单)
└── src/
├── YoloDetector.cpp # [新增] 函数的具体实现(厨房)
└── main.cpp # [修改] 变得很短,只负责点菜
2. 修改 CMakeLists.txt 文件
因为文件分家了,我们需要告诉 CMake 去哪里找 .h 头文件,以及编译哪些 .cpp 源文件。
cmake_minimum_required(VERSION 3.17)
project(YoloRefactor)
set(CMAKE_CXX_STANDARD 17)
# 1. 定义路径变量
set(ORT_HOME ${CMAKE_SOURCE_DIR}/third_party/onnxruntime)
set(INCLUDE_DIR ${CMAKE_SOURCE_DIR}/include)
set(SRC_DIR ${CMAKE_SOURCE_DIR}/src)
# 2. 包含头文件路径
# 既要包含 ORT 的头文件,也要包含我们自己写的 include 文件夹
include_directories(
${ORT_HOME}/include
${INCLUDE_DIR}
)
# 3. 链接库路径 (和 Week 1 一样)
link_directories(${ORT_HOME}/lib)
# 4. 收集 src 目录下所有的 .cpp 文件
file(GLOB SOURCES "${SRC_DIR}/*.cpp")
# 5. 生成可执行文件
add_executable(main ${SOURCES})
# 6. 链接 ORT 库
target_link_libraries(main onnxruntime)
3. 编写头文件 (include/YoloDetector.hpp)
头文件里要写的,就是我们想要暴露给外部调用的接口,也就是 main.cpp 函数中需要调用的接口。
这里,我们只需要在 main.cpp 中加载个模型,执行一下推理就可以了。其它操作 main.cpp 不需要知道。
所以,头文件中我们开放两个接口:
-
void loadModel(const std::string& model_path); // 加载模型
-
void detectFakeInput(); // 推理 (目前还是用假数据)
#pragma once // 防止头文件被重复引用
#include <vector>
#include <string>
#include <onnxruntime_cxx_api.h> // ORT 的头文件放这里
class YoloDetector {
public:
// 构造函数
YoloDetector();
// 1. 加载模型
// 参数:模型文件的路径
void loadModel(const std::string& model_path);
// 2. 推理 (目前还是用假数据)
// 我们把之前 main 函数里后半部分逻辑搬到这里
// 返回值:目前先不返回复杂结构,简单打印即可,所以用 void
void detectFakeInput();
private:
// 成员变量:这些变量需要在整个类的生命周期内存在
// Env 必须比 Session 活得久,所以放在最前面
Ort::Env env{nullptr};
Ort::Session session{nullptr};
// 我们把 SessionOptions 也存起来,方便以后扩展配置
Ort::SessionOptions session_options;
};
4. 编写实现文件 (src/YoloDetector.cpp)
#include "YoloDetector.hpp"
#include <iostream>
// 构造函数:初始化环境
YoloDetector::YoloDetector() : env(ORT_LOGGING_LEVEL_WARNING, "YoloRefactor") {
// 可以在这里做一些基础配置
session_options.SetIntraOpNumThreads(1);
session_options.SetGraphOptimizationLevel(GraphOptimizationLevel::ORT_ENABLE_ALL);
}
// 实现加载模型
void YoloDetector::loadModel(const std::string& model_path) {
std::cout << "📦 Loading model from: " << model_path << "..." << std::endl;
try {
// 创建 Session
// 注意:这里使用类成员变量 session 和 env
session = Ort::Session(env, model_path.c_str(), session_options);
std::cout << "✅ Model loaded successfully!" << std::endl;
} catch (const Ort::Exception& e) {
std::cerr << "❌ ORT Exception: " << e.what() << std::endl;
exit(1); // 加载失败直接退出
}
}
// 实现推理逻辑 (Week 1 的核心逻辑搬运至此)
void YoloDetector::detectFakeInput() {
// --- 以下代码几乎原封不动来自之前的代码 ---
// 1. 准备输入信息
Ort::AllocatorWithDefaultOptions allocator;
// 获取输入名
auto input_name_ptr = session.GetInputNameAllocated(0, allocator);
std::string input_name = input_name_ptr.get();
// 获取输出名
auto output_name_ptr = session.GetOutputNameAllocated(0, allocator);
std::string output_name = output_name_ptr.get();
// 2. 构造假数据 [1, 3, 640, 640]
std::vector<int64_t> input_dims = {1, 3, 640, 640};
size_t input_tensor_size = 1 * 3 * 640 * 640;
std::vector<float> input_tensor_values(input_tensor_size);
// 填充 0.5
for (size_t i = 0; i < input_tensor_size; i++) input_tensor_values[i] = 0.5f;
// 3. 创建 Tensor
auto memory_info = Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeDefault);
std::vector<const char*> input_node_names = { input_name.c_str() };
std::vector<const char*> output_node_names = { output_name.c_str() };
std::vector<Ort::Value> input_tensors;
input_tensors.push_back(Ort::Value::CreateTensor<float>(
memory_info,
input_tensor_values.data(),
input_tensor_size,
input_dims.data(),
input_dims.size()
));
// 4. 运行推理
std::cout << "⚡ Running inference with fake data..." << std::endl;
try {
auto output_tensors = session.Run(
Ort::RunOptions{nullptr},
input_node_names.data(),
input_tensors.data(),
1,
output_node_names.data(),
1
);
// 5. 打印结果维度
auto output_info = output_tensors[0].GetTensorTypeAndShapeInfo();
std::vector<int64_t> output_dims = output_info.GetShape();
std::cout << "✅ Inference finished!" << std::endl;
std::cout << "Output Shape: [";
for (size_t i = 0; i < output_dims.size(); i++) {
std::cout << output_dims[i] << (i < output_dims.size() - 1 ? ", " : "");
}
std::cout << "]" << std::endl;
} catch (const Ort::Exception& e) {
std::cerr << "❌ Runtime Exception: " << e.what() << std::endl;
}
}
5. 清爽的 src/main.cpp
#include <iostream>
#include "YoloDetector.hpp" // 只需要引入头文件
int main() {
std::cout << "--- Week 1 Refactor Demo ---" << std::endl;
// 1. 实例化对象 (相当于买了一台检测器)
YoloDetector detector;
// 2. 加载模型 (插上电源,装载程序)
// 确保 yolov8n.onnx 放在 build/../model/ 下,或者使用绝对路径
detector.loadModel("../model/yolov8n.onnx");
// 3. 运行测试 (按一下开关)
detector.detectFakeInput();
return 0;
}
测试运行与之前的程序输出一致:

6. 总结
本文将原来的代码进行了拆分,封装。这是工程化的关键一步:
- 封装 (Encapsulation):
main.cpp 不需要知道什么是 Ort::Session,也不需要知道输入是 640x640。它只管调用 loadModel 和 detect。
如果以后你要把 ONNX Runtime 换成 TensorRT,你只需要重写 src/YoloDetector.cpp,main.cpp 一行代码都不用改!
- 生命周期管理 (RAII):
在之前的代码中,如果 main 函数很长,env 和 session 变量会混在乱七八糟的逻辑里。
现在,env 是类的成员。只要 detector 对象存在,env 就存在。当 detector 被销毁(比如程序结束),C++ 会自动清理它们。
- 为后续操作铺路:
有了这个结构,后面的实践,我们只需要修改 detect 函数,把里面的“造假数据逻辑”替换成“真实的图像处理逻辑”,再增加一个私有的 preprocess 函数即可。代码结构依然保持清晰。

更多推荐


所有评论(0)