从训练到部署:基于OpenMMLab的草莓病害识别系统实战
在精准农业中,快速准确地识别作物病害对于保障产量至关重要。草莓作为一种高价值经济作物,其病害种类繁多(数据集包含77个类别),且存在严重的样本不平衡问题。这就要求我们的模型既要“认得准”常见病害,也要“抓得住”罕见病,同时还要能在资源受限的边缘设备(如机器狗搭载的Jetson Nano)上实时运行。OpenMMLab提供了一套完整的工具链:MMPretrain用于训练、MMRazor用于压缩、MM
大家好,今天想和大家分享一个完整的计算机视觉项目——基于OpenMMLab生态构建一个高精度、可部署的草莓病害识别系统。本项目涵盖了数据准备、模型训练、模型压缩(剪枝)、模型转换(TensorRT)以及边缘设备(Jetson Nano)部署的全流程。希望这篇博客能给正在学习OpenMMLab或从事农业AI落地的朋友带来一些启发。
前言
在精准农业中,快速准确地识别作物病害对于保障产量至关重要。草莓作为一种高价值经济作物,其病害种类繁多(数据集包含77个类别),且存在严重的样本不平衡问题。这就要求我们的模型既要“认得准”常见病害,也要“抓得住”罕见病,同时还要能在资源受限的边缘设备(如机器狗搭载的Jetson Nano)上实时运行。
OpenMMLab提供了一套完整的工具链:MMPretrain用于训练、MMRazor用于压缩、MMDeploy用于部署。本文将基于这三个核心库,一步步搭建并优化一个草莓病害识别系统。
系统架构设计
整个系统遵循一个清晰的多阶段流程,如下图所示:

图1 系统蓝图
流程说明:
- 数据获取与准备:下载Kaggle上的草莓病害数据集,并按标准格式划分训练/验证/测试集。
- 数据预处理与增强:定义包含MixUp、CutMix、RandAugment等策略的流水线,提升模型泛化能力。
- 模型训练:在本地NVIDIA GeForce RTX 4070上,使用MMPretrain训练ConvNeXt V2基线模型。
- 性能评估:在测试集上计算准确率、宏平均精确率/召回率/F1-score。
- 模型剪枝:利用MMRazor对训练好的模型进行L1-norm通道剪枝,减小体积。
- 模型转换:通过MMDeploy将剪枝后的模型转换为TensorRT引擎,以便在Jetson Nano上高效推理。
数据集选择与分析
我们选用Kaggle上的 “Strawberry Disease Classification” 数据集。该数据集包含超过9000张图像,覆盖77个类别,包括常见的“叶枯病”和“健康”,以及样本极少的“炭疽病”(仅34张)、“花螨”(29张)。严重的类别不平衡是本项目的主要挑战之一,为此我们在训练中引入了标签平滑、MixUp/CutMix等策略。
数据集准备脚本见后文代码部分,它按照7:2:1的比例将原始图像划分为train/val/test,并保持类别子目录结构。
视觉网络选择:ConvNeXt V2
经过综合评估,我们选择了ConvNeXt V2作为骨干网络。它是ConvNeXt的升级版,在纯卷积架构中融入了Transformer的设计思想,并通过自监督学习(FCMAE)和全局响应归一化(GRN)进一步提升了性能。相比同量级的Swin Transformer,ConvNeXt V2推理速度更快,精度更高。

图2 ConvNeXt性能表现与模型架构
图3 模型架构(ConvNeXt V1 vs. ConvNeXt V2)
我们选用convnext-v2-base变体,加载ImageNet-21k预训练权重,将其分类头调整为77类。
模型训练与优化策略
数据预处理与增强
在MMPretrain配置文件中,我们定义了如下流水线:
- 标准化:使用ImageNet的mean/std对图像进行归一化(因为加载的是ImageNet预训练权重)。
- RandomResizedCrop:随机裁剪并缩放到224×224。
- RandomFlip:水平翻转。
- MixUp & CutMix:两者结合,alpha=0.2,有效缓解过拟合和类别不平衡。
- RandAugment:自动增强策略。
- RandomErasing:随机擦除,提高遮挡鲁棒性。
这些增强策略的组合,让模型在样本量少的类别上也能学到鲁棒的特征。
系统配置与超参数优化
配置文件convnext-v2-base_32xb32_sd.py(见后文)包含了以下关键设置:
- 迁移学习:冻结backbone的前两个stage,仅训练后面部分和分类头。
- 优化器:AdamW,初始学习率设为5e-3(比默认稍高,以加速收敛)。
- 学习率调度:前20个epoch线性warmup,之后余弦退火至1e-5。
- 损失函数:标签平滑的交叉熵损失(label_smooth_val=0.1)。
- 训练轮数:100个epoch,每1个epoch验证一次。
- EMA:使用指数移动平均稳定模型。
由于本地GPU显存限制,将batch size从默认的64调整为32,同时为了加快数据加载,将num_workers设为8(CPU核心数足够)。
模型性能比较
我们在测试集上对比了ConvNeXt V2、ConvNeXt和Swin Transformer(三者参数量均在88M左右),结果如下:
表1 模型性能
| 模型 | 参数(M) | Flops(G) | Acc(%) | Precision(%) | Recall(%) | F1(%) | Time(s/iter) |
|---|---|---|---|---|---|---|---|
| ConvNeXt V2 | 88.72 | 15.38 | 92.96 | 87.83 | 85.40 | 84.98 | 0.9308 |
| ConvNeXt | 88.59 | 15.36 | 92.24 | 87.19 | 84.61 | 84.39 | 0.8716 |
| Swin-Transformer | 87.77 | 15.14 | 88.16 | 81.29 | 78.14 | 77.71 | 4.0897 |
可见ConvNeXt V2在各项指标上均领先,且推理速度远快于Swin,证明了选型的正确性。训练过程的准确率、F1分数和损失曲线如下:
图4 准确率变化图
图5 F1分数变化图
图6 训练损失值变化图
最终模型的预测效果示例:

图7 ConvNeXt V2预测结果图
模型压缩与性能评估
为了将模型部署到Jetson Nano上,我们需要进行压缩和转换。这里采用MMRazor的L1-norm通道剪枝,然后通过MMDeploy转换为TensorRT引擎。
L1-Norm剪枝
L1-norm剪枝基于“权重绝对值之和越小,该卷积核越不重要”的启发式思想,将不重要的通道移除。我们在训练好的模型上进行了50%比例的剪枝(即移除50%的通道),然后微调10个epoch。剪枝配置见后文prune_convnext-v2-base_sd.py。
剪枝前后性能对比如下:
表2 剪枝前后模型推理性能变化
| 模型 | Acc(%) | Precision(%) | Recall(%) | F1(%) | Time(s/iter) | 加速比 |
|---|---|---|---|---|---|---|
| ConvNeXt V2 | 92.96 | 87.83 | 85.40 | 84.98 | 0.9308 | 1.0x |
| ConvNeXt V2 pruned | 91.43 | 86.23 | 82.43 | 81.84 | 0.5739 | 1.6x |
精度仅下降1.53个百分点,推理速度却提升了60%,这个trade-off非常值得。
使用MMDeploy转换TensorRT
将剪枝后的PyTorch模型转换为ONNX,再进一步转为TensorRT引擎(.engine)。我们在Jetson Nano上测试了不同后端的性能:
表3 转换前后模型推理性能变化
| 模型 | Acc(%) | Precision(%) | Recall(%) | F1(%) | Time(s/iter) | 加速比 |
|---|---|---|---|---|---|---|
| ConvNeXt V2 (PyTorch) | 92.96 | 87.83 | 85.40 | 84.98 | 0.9308 | 1.0x |
| ConvNeXt V2 pruned | 91.43 | 86.23 | 82.43 | 81.84 | 0.5739 | 1.6x |
| ONNX | 91.43 | 86.27 | 82.42 | 81.87 | 0.0698 | 13.3x |
| TensorRT engine | 91.42 | 86.23 | 82.43 | 81.83 | 0.0249 | 37.3x |
TensorRT引擎相比原始PyTorch模型加速37倍,延迟低至9.33ms/帧,吞吐量达到109 FPS,完全满足实时处理需求。
表4 不同后端模型的延迟与吞吐量
| 模型 | 平均延迟 (ms/帧) | 吞吐量 (FPS) |
|---|---|---|
| ONNX | 22.72 | 44 |
| TensorRT engine | 9.33 | 109 |
实际部署
硬件选择
- 核心计算平台:NVIDIA Jetson Nano Developer Kit(Maxwell架构GPU,128核,4GB内存),功耗低,适合边缘AI。
- 行走机构:四足机器狗,便于在农田中移动。
软件环境
在Jetson Nano上安装NVIDIA JetPack,配置Python环境,编译安装mmcv、mmengine和mmdeploy,并安装TensorRT Python API。具体步骤可参考OpenMMLab官方文档。
挑战与解决方案
- 散热管理:Jetson Nano满载运行时发热严重,会导致降频。我们加装了PWM调速风扇,主动散热,确保性能稳定。
- 供电不足:高功率模式下,若同时给摄像头、风扇等外设供电,可能引发系统重启。需使用5V/4A以上的稳定电源。
- 域漂移:实际农田的光照、背景与训练数据可能有差异,导致模型性能下降。解决方案是定期收集新数据并微调,或采用轻量级模型集成策略。
代码实现
以下是本项目用到的核心脚本和配置文件,供大家参考。
数据准备脚本 data_preparation.sh
#!/bin/bash
pip install kaggle
# 下载数据集
kaggle datasets download -d nizier193/strawberry-disease-classification -p./data --unzip
python split_dataset.py
echo "数据集准备完成。"
数据集分割脚本 split_dataset.py
import os
import shutil
import random
from tqdm import tqdm
import argparse
def split_dataset(base_dir, source_folder='union_dataset', dest_folder='strawberry_dataset', split_ratio=(0.7, 0.2, 0.1)):
source_path = os.path.join(base_dir, source_folder)
dest_path = os.path.join(base_dir, dest_folder)
if not os.path.isdir(source_path):
print(f"错误:源目录 '{source_path}' 不存在。")
return
os.makedirs(dest_path, exist_ok=True)
train_dir = os.path.join(dest_path, 'train')
val_dir = os.path.join(dest_path, 'val')
test_dir = os.path.join(dest_path, 'test')
os.makedirs(train_dir, exist_ok=True)
os.makedirs(val_dir, exist_ok=True)
os.makedirs(test_dir, exist_ok=True)
class_names = [d for d in os.listdir(source_path) if os.path.isdir(os.path.join(source_path, d))]
print(f"在 '{source_path}' 中共找到 {len(class_names)} 个类别。")
for class_name in class_names:
os.makedirs(os.path.join(train_dir, class_name), exist_ok=True)
os.makedirs(os.path.join(val_dir, class_name), exist_ok=True)
os.makedirs(os.path.join(test_dir, class_name), exist_ok=True)
for class_name in tqdm(class_names, desc="处理进度"):
class_source_dir = os.path.join(source_path, class_name)
images = [f for f in os.listdir(class_source_dir) if os.path.isfile(os.path.join(class_source_dir, f))]
random.shuffle(images)
total = len(images)
train_end = int(total * split_ratio[0])
val_end = train_end + int(total * split_ratio[1])
train_images = images[:train_end]
val_images = images[train_end:val_end]
test_images = images[val_end:]
for img in train_images:
shutil.copy2(os.path.join(class_source_dir, img), os.path.join(train_dir, class_name, img))
for img in val_images:
shutil.copy2(os.path.join(class_source_dir, img), os.path.join(val_dir, class_name, img))
for img in test_images:
shutil.copy2(os.path.join(class_source_dir, img), os.path.join(test_dir, class_name, img))
print("数据集划分完成!")
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--base_dir', type=str, default='data')
parser.add_argument('--source', type=str, default='union_dataset')
parser.add_argument('--dest', type=str, default='strawberry_dataset')
args = parser.parse_args()
split_dataset(base_dir=args.base_dir, source_folder=args.source, dest_folder=args.dest)
MMPretrain模型配置文件 convnext-v2-base_32xb32_sd.py
_base_ = [
'../_base_/models/convnext_v2/base.py',
'../_base_/datasets/imagenet_bs64_swin_224.py',
'../_base_/schedules/imagenet_bs1024_adamw_swin.py',
'../_base_/default_runtime.py',
]
model = dict(
data_preprocessor=dict(
num_classes=77,
mean=[123.675, 116.28, 103.53],
std=[58.395, 57.12, 57.375],
to_rgb=True,
),
backbone=dict(
frozen_stages=2,
init_cfg=dict(
type='Pretrained',
checkpoint='checkpoints/convnext_v2/convnext-v2-base_fcmae-in21k-pre_3rdparty_in1k_20230104-c48d16a5.pth',
prefix='backbone',
)),
head=dict(
num_classes=77,
loss=dict(type='LabelSmoothLoss', label_smooth_val=0.1, num_classes=77),
),
train_cfg=dict(augments=[
dict(type='Mixup', alpha=0.2),
dict(type='CutMix', alpha=0.2),
]),
)
data_root = 'data/strawberry_dataset'
train_dataloader = dict(
num_workers=4,
batch_size=32,
dataset=dict(type='CustomDataset', data_root=data_root, data_prefix='train'))
val_dataloader = dict(
num_workers=4,
batch_size=32,
dataset=dict(type='CustomDataset', data_root=data_root, data_prefix='val'))
test_dataloader = dict(
batch_size=32,
num_workers=4,
dataset=dict(type='CustomDataset', data_root=data_root, data_prefix='test'))
val_evaluator = [
dict(type='Accuracy', topk=(1, 5)),
dict(type='SingleLabelMetric', average='macro', items=['precision', 'recall', 'f1-score'])
]
test_evaluator = val_evaluator
optim_wrapper = dict(optimizer=dict(lr=5e-3))
param_scheduler = [
dict(type='LinearLR', start_factor=1e-3, by_epoch=True, end=20, convert_to_iter_based=True),
dict(type='CosineAnnealingLR', eta_min=1e-5, by_epoch=True, begin=20)
]
train_cfg = dict(by_epoch=True, max_epochs=100, val_interval=1)
custom_hooks = [dict(type='EMAHook', momentum=1e-4, priority='ABOVE_NORMAL')]
MMRazor剪枝配置文件 prune_convnext-v2-base_sd.py
_base_ = 'path/to/mmpretrain/configs/convnext_v2/convnext-v2-base_32xb32_sd.py'
custom_imports = dict(imports=['mmpretrain.models'], allow_failed_imports=False)
load_from = 'path/to/mmpretrain/work_dirs/convnext-v2-base_32xb32_sd/best_accuracy_top1_epoch_63.pth'
algorithm = dict(
type='ChannelPruning',
pruner=dict(
type='L1NormPruner',
pruning_schedule=dict(
type='LinearPruningSchedule',
start_ratio=0.0,
end_ratio=0.5,
start_epoch=0,
end_epoch=10
)
),
target_pruning_ratio=0.5,
finetuner=dict(
type='DefaultFinetuner',
epochs=10,
optimizer=dict(type='AdamW', lr=1e-5, weight_decay=0.05),
lr_scheduler=dict(type='CosineAnnealingLR', T_max=10, eta_min=0.0)
),
pruning_begin_epoch=0,
pruning_end_epoch=10,
early_stopping=dict(patience=3, min_delta=0.001)
)
default_hooks = dict(
logger=dict(type='LoggerHook', interval=50),
param_scheduler=dict(type='ParamSchedulerHook'),
checkpoint=dict(type='CheckpointHook', interval=1, save_best='auto'),
sampler_seed=dict(type='DistSamplerSeedHook'),
visualization=dict(type='VisualizationHook', enable=False),
)
MMDeploy转换为TensorRT引擎的脚本
python3 tools/deploy.py \
configs/mmpretrain/classification_tensorrt_dynamic-224x224-224x224.py \
path/to/mmrazor/configs/prune_convnext-v2-base_sd.py \
path/to/mmrazor/work_dirs/pruned_convnext-v2-base_sd/epoch_20.pth \
path/to/mmpretrain/data/strawberry_dataset/val/aphid/aphid_13.jpg \
--work-dir work_dir/tensorrt_dynamic \
--device cuda:0
使用TensorRT引擎测试的脚本
python3 tools/test.py \
configs/mmpretrain/classification_tensorrt_dynamic-224x224-224x224.py \
path/to/mmrazor/configs/prune_convnext-v2-base_sd.py \
--model work_dir/tensorrt_dynamic/end2end.engine \
--show-dir work_dir/tensorrt_dynamic/outpu \
--device cuda:0 \
--speed-test \
--work-dir work_dir/tensorrt_dynamic \
--log2file work_dir/tensorrt_dynamic/output/eval.log
总结与展望
本文详细记录了基于OpenMMLab构建草莓病害识别系统的全过程。通过ConvNeXt V2+先进训练策略,我们在77类数据集上取得了92.96%的准确率;通过L1-norm剪枝,模型速度提升1.6倍而精度损失极小;通过TensorRT转换,最终在Jetson Nano上实现了109 FPS的实时推理。整个流程充分体现了OpenMMLab工具链的强大与便捷。
当然,实际落地仍有许多挑战,比如光照变化导致的域漂移、边缘设备的散热和供电等。未来可以考虑引入在线学习或联邦学习,让模型在部署后持续进化。希望这篇博客能为你的类似项目提供一些参考,也欢迎大家在评论区交流讨论。
参考资料
- MMPretrain Contributors. MMPretrain: OpenMMLab’s Pre-training Toolbox and Benchmark
- MMRazor Contributors. MMRazor: Model Compression Toolbox
- MMDeploy Contributors. MMDeploy: OpenMMLab Model Deployment Framework
- Strawberry Disease Classification Dataset. Kaggle
- Woo S, Debnath S, Hu R, et al. ConvNeXt V2: Co-designing and Scaling ConvNets with Masked Autoencoders. arXiv 2023.
- Liu Z, Mao H, Wu C Y, et al. A ConvNet for the 2020s. arXiv 2022.
- Liu Z, Lin Y, Cao Y, et al. Swin Transformer: Hierarchical Vision Transformer using Shifted Windows. arXiv 2021.
- NVIDIA Jetson Nano. Product Page
作者:surpolo
日期:2026年2月
转载请注明出处
更多推荐

所有评论(0)