上篇文章我们本地跑通了更通用的本地推理框架 ONNX Runtime(ORT),并对代码进行了拆解和解读。

本文我们先将代码整理一下,将那个“一锅炖”的 main.cpp 拆解成专业的 C++ 类结构,方便我们后续的操作。

0. 系列文章

1. 调整文件目录

按以下结构整理文件。

注意:我们将代码分为了 include (头文件) 和 src (源文件),这是 C++ 标准工程规范。

Week1_Refactor/
├── CMakeLists.txt          # [修改] 构建规则变了
├── model/
│   └── yolov8n.onnx        # 模型文件不动
├── third_party/
│   └── onnxruntime/        # 库文件不动
├── include/
│   └── YoloDetector.hpp    # [新增] 只有函数的声明(菜单)
└── src/
    ├── YoloDetector.cpp    # [新增] 函数的具体实现(厨房)
    └── main.cpp            # [修改] 变得很短,只负责点菜

2. 修改 CMakeLists.txt 文件

因为文件分家了,我们需要告诉 CMake 去哪里找 .h 头文件,以及编译哪些 .cpp 源文件。

cmake_minimum_required(VERSION 3.17)
project(YoloRefactor)

set(CMAKE_CXX_STANDARD 17)

# 1. 定义路径变量
set(ORT_HOME ${CMAKE_SOURCE_DIR}/third_party/onnxruntime)
set(INCLUDE_DIR ${CMAKE_SOURCE_DIR}/include)
set(SRC_DIR ${CMAKE_SOURCE_DIR}/src)

# 2. 包含头文件路径
# 既要包含 ORT 的头文件,也要包含我们自己写的 include 文件夹
include_directories(
    ${ORT_HOME}/include
    ${INCLUDE_DIR}
)

# 3. 链接库路径 (和 Week 1 一样)
link_directories(${ORT_HOME}/lib)

# 4. 收集 src 目录下所有的 .cpp 文件
file(GLOB SOURCES "${SRC_DIR}/*.cpp")

# 5. 生成可执行文件
add_executable(main ${SOURCES})

# 6. 链接 ORT 库
target_link_libraries(main onnxruntime)

3. 编写头文件 (include/YoloDetector.hpp)

头文件里要写的,就是我们想要暴露给外部调用的接口,也就是 main.cpp 函数中需要调用的接口。

这里,我们只需要在 main.cpp 中加载个模型,执行一下推理就可以了。其它操作 main.cpp 不需要知道。

所以,头文件中我们开放两个接口:

  • void loadModel(const std::string& model_path); // 加载模型

  • void detectFakeInput(); // 推理 (目前还是用假数据)

#pragma once // 防止头文件被重复引用
#include <vector>
#include <string>
#include <onnxruntime_cxx_api.h> // ORT 的头文件放这里

class YoloDetector {
public:
    // 构造函数
    YoloDetector();
    
    // 1. 加载模型
    // 参数:模型文件的路径
    void loadModel(const std::string& model_path);

    // 2. 推理 (目前还是用假数据)
    // 我们把之前 main 函数里后半部分逻辑搬到这里
    // 返回值:目前先不返回复杂结构,简单打印即可,所以用 void
    void detectFakeInput();

private:
    // 成员变量:这些变量需要在整个类的生命周期内存在
    // Env 必须比 Session 活得久,所以放在最前面
    Ort::Env env{nullptr};
    Ort::Session session{nullptr};
    
    // 我们把 SessionOptions 也存起来,方便以后扩展配置
    Ort::SessionOptions session_options;
};

4. 编写实现文件 (src/YoloDetector.cpp)

#include "YoloDetector.hpp"
#include <iostream>

// 构造函数:初始化环境
YoloDetector::YoloDetector() : env(ORT_LOGGING_LEVEL_WARNING, "YoloRefactor") {
    // 可以在这里做一些基础配置
    session_options.SetIntraOpNumThreads(1);
    session_options.SetGraphOptimizationLevel(GraphOptimizationLevel::ORT_ENABLE_ALL);
}

// 实现加载模型
void YoloDetector::loadModel(const std::string& model_path) {
    std::cout << "📦 Loading model from: " << model_path << "..." << std::endl;
    
    try {
        // 创建 Session
        // 注意:这里使用类成员变量 session 和 env
        session = Ort::Session(env, model_path.c_str(), session_options);
        std::cout << "✅ Model loaded successfully!" << std::endl;
    } catch (const Ort::Exception& e) {
        std::cerr << "❌ ORT Exception: " << e.what() << std::endl;
        exit(1); // 加载失败直接退出
    }
}

// 实现推理逻辑 (Week 1 的核心逻辑搬运至此)
void YoloDetector::detectFakeInput() {
    // --- 以下代码几乎原封不动来自之前的代码 ---
    
    // 1. 准备输入信息
    Ort::AllocatorWithDefaultOptions allocator;
    // 获取输入名
    auto input_name_ptr = session.GetInputNameAllocated(0, allocator);
    std::string input_name = input_name_ptr.get();
    // 获取输出名
    auto output_name_ptr = session.GetOutputNameAllocated(0, allocator);
    std::string output_name = output_name_ptr.get();

    // 2. 构造假数据 [1, 3, 640, 640]
    std::vector<int64_t> input_dims = {1, 3, 640, 640};
    size_t input_tensor_size = 1 * 3 * 640 * 640;
    std::vector<float> input_tensor_values(input_tensor_size);
    // 填充 0.5
    for (size_t i = 0; i < input_tensor_size; i++) input_tensor_values[i] = 0.5f;

    // 3. 创建 Tensor
    auto memory_info = Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeDefault);
    std::vector<const char*> input_node_names = { input_name.c_str() };
    std::vector<const char*> output_node_names = { output_name.c_str() };
    
    std::vector<Ort::Value> input_tensors;
    input_tensors.push_back(Ort::Value::CreateTensor<float>(
        memory_info, 
        input_tensor_values.data(), 
        input_tensor_size, 
        input_dims.data(), 
        input_dims.size()
    ));

    // 4. 运行推理
    std::cout << "⚡ Running inference with fake data..." << std::endl;
    try {
        auto output_tensors = session.Run(
            Ort::RunOptions{nullptr}, 
            input_node_names.data(), 
            input_tensors.data(), 
            1, 
            output_node_names.data(), 
            1
        );

        // 5. 打印结果维度
        auto output_info = output_tensors[0].GetTensorTypeAndShapeInfo();
        std::vector<int64_t> output_dims = output_info.GetShape();

        std::cout << "✅ Inference finished!" << std::endl;
        std::cout << "Output Shape: [";
        for (size_t i = 0; i < output_dims.size(); i++) {
            std::cout << output_dims[i] << (i < output_dims.size() - 1 ? ", " : "");
        }
        std::cout << "]" << std::endl;

    } catch (const Ort::Exception& e) {
        std::cerr << "❌ Runtime Exception: " << e.what() << std::endl;
    }
}

5. 清爽的 src/main.cpp

#include <iostream>
#include "YoloDetector.hpp" // 只需要引入头文件

int main() {
    std::cout << "--- Week 1 Refactor Demo ---" << std::endl;

    // 1. 实例化对象 (相当于买了一台检测器)
    YoloDetector detector;

    // 2. 加载模型 (插上电源,装载程序)
    // 确保 yolov8n.onnx 放在 build/../model/ 下,或者使用绝对路径
    detector.loadModel("../model/yolov8n.onnx");

    // 3. 运行测试 (按一下开关)
    detector.detectFakeInput();

    return 0;
}

测试运行与之前的程序输出一致:

在这里插入图片描述

6. 总结

本文将原来的代码进行了拆分,封装。这是工程化的关键一步:

  • 封装 (Encapsulation):

main.cpp 不需要知道什么是 Ort::Session,也不需要知道输入是 640x640。它只管调用 loadModel 和 detect。

如果以后你要把 ONNX Runtime 换成 TensorRT,你只需要重写 src/YoloDetector.cpp,main.cpp 一行代码都不用改!

  • 生命周期管理 (RAII):

在之前的代码中,如果 main 函数很长,env 和 session 变量会混在乱七八糟的逻辑里。

现在,env 是类的成员。只要 detector 对象存在,env 就存在。当 detector 被销毁(比如程序结束),C++ 会自动清理它们。

  • 为后续操作铺路:

有了这个结构,后面的实践,我们只需要修改 detect 函数,把里面的“造假数据逻辑”替换成“真实的图像处理逻辑”,再增加一个私有的 preprocess 函数即可。代码结构依然保持清晰。

在这里插入图片描述

Logo

有“AI”的1024 = 2048,欢迎大家加入2048 AI社区

更多推荐