除了使用trtexec工具实现onnx转tensorRT,还可以基于tensorRT的api,在使用api实现过程中会涉及到config的使用

TRT_LOGGER = trt.Logger()
builder = trt.Builder(logger)
config = builder.create_builder_config()

使用config可以实现精度的配置等功能,这里给出一些可以通过config实现的功能或配置

  • 最大工作空间大小:我们之前讨论过,用于确定优化和执行模型时可用的最大内存量。
  • 精度设置:设置模型运行的精度,如 FP32, FP16 或 INT8。
  • 层融合:启用层融合优化。 动态形状输入:支持动态输入形状。
  • 严格类型约束:对于某些算子,强制使用严格的数据类型。
  • 执行策略:设置不同的策略以优化模型执行。
  • 内存池:为特定类型的内存分配设置内存池。
# 创建一个 builder config
config = builder.create_builder_config()

# 最大工作空间大小
config.max_workspace_size = 1 << 30  # 1GB

# 启用 FP16 精度
config.set_flag(trt.BuilderFlag.FP16)

# 启用层融合优化
config.set_flag(trt.BuilderFlag.DISABLE_TIMING_CACHE)

# 支持动态输入形状
# 假设我们有一个动态输入形状的网络,我们需要定义其形状的范围
input_tensor = network.get_input(0)  # 假设第一个输入是我们想要设置动态形状的输入
profile = builder.create_optimization_profile()
profile.set_shape(input_tensor.name, (1, 3, 224, 224), (4, 3, 224, 224), (8, 3, 224, 224))
config.add_optimization_profile(profile)

# 严格类型约束
config.set_flag(trt.BuilderFlag.STRICT_TYPES)

# 设置执行策略
config.set_flag(trt.BuilderFlag.REFIT)

Logo

有“AI”的1024 = 2048,欢迎大家加入2048 AI社区

更多推荐