大家好,今天想和大家分享一个完整的计算机视觉项目——基于OpenMMLab生态构建一个高精度、可部署的草莓病害识别系统。本项目涵盖了数据准备、模型训练、模型压缩(剪枝)、模型转换(TensorRT)以及边缘设备(Jetson Nano)部署的全流程。希望这篇博客能给正在学习OpenMMLab或从事农业AI落地的朋友带来一些启发。

前言

在精准农业中,快速准确地识别作物病害对于保障产量至关重要。草莓作为一种高价值经济作物,其病害种类繁多(数据集包含77个类别),且存在严重的样本不平衡问题。这就要求我们的模型既要“认得准”常见病害,也要“抓得住”罕见病,同时还要能在资源受限的边缘设备(如机器狗搭载的Jetson Nano)上实时运行。

OpenMMLab提供了一套完整的工具链:MMPretrain用于训练、MMRazor用于压缩、MMDeploy用于部署。本文将基于这三个核心库,一步步搭建并优化一个草莓病害识别系统。

系统架构设计

整个系统遵循一个清晰的多阶段流程,如下图所示:

在这里插入图片描述

图1 系统蓝图

流程说明

  1. 数据获取与准备:下载Kaggle上的草莓病害数据集,并按标准格式划分训练/验证/测试集。
  2. 数据预处理与增强:定义包含MixUp、CutMix、RandAugment等策略的流水线,提升模型泛化能力。
  3. 模型训练:在本地NVIDIA GeForce RTX 4070上,使用MMPretrain训练ConvNeXt V2基线模型。
  4. 性能评估:在测试集上计算准确率、宏平均精确率/召回率/F1-score。
  5. 模型剪枝:利用MMRazor对训练好的模型进行L1-norm通道剪枝,减小体积。
  6. 模型转换:通过MMDeploy将剪枝后的模型转换为TensorRT引擎,以便在Jetson Nano上高效推理。

数据集选择与分析

我们选用Kaggle上的 “Strawberry Disease Classification” 数据集。该数据集包含超过9000张图像,覆盖77个类别,包括常见的“叶枯病”和“健康”,以及样本极少的“炭疽病”(仅34张)、“花螨”(29张)。严重的类别不平衡是本项目的主要挑战之一,为此我们在训练中引入了标签平滑、MixUp/CutMix等策略。

数据集准备脚本见后文代码部分,它按照7:2:1的比例将原始图像划分为train/val/test,并保持类别子目录结构。

视觉网络选择:ConvNeXt V2

经过综合评估,我们选择了ConvNeXt V2作为骨干网络。它是ConvNeXt的升级版,在纯卷积架构中融入了Transformer的设计思想,并通过自监督学习(FCMAE)和全局响应归一化(GRN)进一步提升了性能。相比同量级的Swin Transformer,ConvNeXt V2推理速度更快,精度更高。

在这里插入图片描述
图2 ConvNeXt性能表现与模型架构
在这里插入图片描述

图3 模型架构(ConvNeXt V1 vs. ConvNeXt V2)

我们选用convnext-v2-base变体,加载ImageNet-21k预训练权重,将其分类头调整为77类。

模型训练与优化策略

数据预处理与增强

在MMPretrain配置文件中,我们定义了如下流水线:

  • 标准化:使用ImageNet的mean/std对图像进行归一化(因为加载的是ImageNet预训练权重)。
  • RandomResizedCrop:随机裁剪并缩放到224×224。
  • RandomFlip:水平翻转。
  • MixUp & CutMix:两者结合,alpha=0.2,有效缓解过拟合和类别不平衡。
  • RandAugment:自动增强策略。
  • RandomErasing:随机擦除,提高遮挡鲁棒性。

这些增强策略的组合,让模型在样本量少的类别上也能学到鲁棒的特征。

系统配置与超参数优化

配置文件convnext-v2-base_32xb32_sd.py(见后文)包含了以下关键设置:

  • 迁移学习:冻结backbone的前两个stage,仅训练后面部分和分类头。
  • 优化器:AdamW,初始学习率设为5e-3(比默认稍高,以加速收敛)。
  • 学习率调度:前20个epoch线性warmup,之后余弦退火至1e-5。
  • 损失函数:标签平滑的交叉熵损失(label_smooth_val=0.1)。
  • 训练轮数:100个epoch,每1个epoch验证一次。
  • EMA:使用指数移动平均稳定模型。

由于本地GPU显存限制,将batch size从默认的64调整为32,同时为了加快数据加载,将num_workers设为8(CPU核心数足够)。

模型性能比较

我们在测试集上对比了ConvNeXt V2、ConvNeXt和Swin Transformer(三者参数量均在88M左右),结果如下:

表1 模型性能

模型 参数(M) Flops(G) Acc(%) Precision(%) Recall(%) F1(%) Time(s/iter)
ConvNeXt V2 88.72 15.38 92.96 87.83 85.40 84.98 0.9308
ConvNeXt 88.59 15.36 92.24 87.19 84.61 84.39 0.8716
Swin-Transformer 87.77 15.14 88.16 81.29 78.14 77.71 4.0897

可见ConvNeXt V2在各项指标上均领先,且推理速度远快于Swin,证明了选型的正确性。训练过程的准确率、F1分数和损失曲线如下:
在这里插入图片描述

图4 准确率变化图
在这里插入图片描述

图5 F1分数变化图
在这里插入图片描述
图6 训练损失值变化图

最终模型的预测效果示例:

在这里插入图片描述

图7 ConvNeXt V2预测结果图

模型压缩与性能评估

为了将模型部署到Jetson Nano上,我们需要进行压缩和转换。这里采用MMRazor的L1-norm通道剪枝,然后通过MMDeploy转换为TensorRT引擎。

L1-Norm剪枝

L1-norm剪枝基于“权重绝对值之和越小,该卷积核越不重要”的启发式思想,将不重要的通道移除。我们在训练好的模型上进行了50%比例的剪枝(即移除50%的通道),然后微调10个epoch。剪枝配置见后文prune_convnext-v2-base_sd.py

剪枝前后性能对比如下:

表2 剪枝前后模型推理性能变化

模型 Acc(%) Precision(%) Recall(%) F1(%) Time(s/iter) 加速比
ConvNeXt V2 92.96 87.83 85.40 84.98 0.9308 1.0x
ConvNeXt V2 pruned 91.43 86.23 82.43 81.84 0.5739 1.6x

精度仅下降1.53个百分点,推理速度却提升了60%,这个trade-off非常值得。

使用MMDeploy转换TensorRT

将剪枝后的PyTorch模型转换为ONNX,再进一步转为TensorRT引擎(.engine)。我们在Jetson Nano上测试了不同后端的性能:

表3 转换前后模型推理性能变化

模型 Acc(%) Precision(%) Recall(%) F1(%) Time(s/iter) 加速比
ConvNeXt V2 (PyTorch) 92.96 87.83 85.40 84.98 0.9308 1.0x
ConvNeXt V2 pruned 91.43 86.23 82.43 81.84 0.5739 1.6x
ONNX 91.43 86.27 82.42 81.87 0.0698 13.3x
TensorRT engine 91.42 86.23 82.43 81.83 0.0249 37.3x

TensorRT引擎相比原始PyTorch模型加速37倍,延迟低至9.33ms/帧,吞吐量达到109 FPS,完全满足实时处理需求。

表4 不同后端模型的延迟与吞吐量

模型 平均延迟 (ms/帧) 吞吐量 (FPS)
ONNX 22.72 44
TensorRT engine 9.33 109

实际部署

硬件选择

  • 核心计算平台:NVIDIA Jetson Nano Developer Kit(Maxwell架构GPU,128核,4GB内存),功耗低,适合边缘AI。
  • 行走机构:四足机器狗,便于在农田中移动。

软件环境

在Jetson Nano上安装NVIDIA JetPack,配置Python环境,编译安装mmcv、mmengine和mmdeploy,并安装TensorRT Python API。具体步骤可参考OpenMMLab官方文档。

挑战与解决方案

  1. 散热管理:Jetson Nano满载运行时发热严重,会导致降频。我们加装了PWM调速风扇,主动散热,确保性能稳定。
  2. 供电不足:高功率模式下,若同时给摄像头、风扇等外设供电,可能引发系统重启。需使用5V/4A以上的稳定电源。
  3. 域漂移:实际农田的光照、背景与训练数据可能有差异,导致模型性能下降。解决方案是定期收集新数据并微调,或采用轻量级模型集成策略。

代码实现

以下是本项目用到的核心脚本和配置文件,供大家参考。

数据准备脚本 data_preparation.sh

#!/bin/bash
pip install kaggle
# 下载数据集
kaggle datasets download -d nizier193/strawberry-disease-classification -p./data --unzip
python split_dataset.py
echo "数据集准备完成。"

数据集分割脚本 split_dataset.py

import os
import shutil
import random
from tqdm import tqdm
import argparse

def split_dataset(base_dir, source_folder='union_dataset', dest_folder='strawberry_dataset', split_ratio=(0.7, 0.2, 0.1)):
    source_path = os.path.join(base_dir, source_folder)
    dest_path = os.path.join(base_dir, dest_folder)
    if not os.path.isdir(source_path):
        print(f"错误:源目录 '{source_path}' 不存在。")
        return

    os.makedirs(dest_path, exist_ok=True)
    train_dir = os.path.join(dest_path, 'train')
    val_dir = os.path.join(dest_path, 'val')
    test_dir = os.path.join(dest_path, 'test')
    os.makedirs(train_dir, exist_ok=True)
    os.makedirs(val_dir, exist_ok=True)
    os.makedirs(test_dir, exist_ok=True)

    class_names = [d for d in os.listdir(source_path) if os.path.isdir(os.path.join(source_path, d))]
    print(f"在 '{source_path}' 中共找到 {len(class_names)} 个类别。")

    for class_name in class_names:
        os.makedirs(os.path.join(train_dir, class_name), exist_ok=True)
        os.makedirs(os.path.join(val_dir, class_name), exist_ok=True)
        os.makedirs(os.path.join(test_dir, class_name), exist_ok=True)

    for class_name in tqdm(class_names, desc="处理进度"):
        class_source_dir = os.path.join(source_path, class_name)
        images = [f for f in os.listdir(class_source_dir) if os.path.isfile(os.path.join(class_source_dir, f))]
        random.shuffle(images)
        total = len(images)
        train_end = int(total * split_ratio[0])
        val_end = train_end + int(total * split_ratio[1])
        train_images = images[:train_end]
        val_images = images[train_end:val_end]
        test_images = images[val_end:]

        for img in train_images:
            shutil.copy2(os.path.join(class_source_dir, img), os.path.join(train_dir, class_name, img))
        for img in val_images:
            shutil.copy2(os.path.join(class_source_dir, img), os.path.join(val_dir, class_name, img))
        for img in test_images:
            shutil.copy2(os.path.join(class_source_dir, img), os.path.join(test_dir, class_name, img))

    print("数据集划分完成!")

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--base_dir', type=str, default='data')
    parser.add_argument('--source', type=str, default='union_dataset')
    parser.add_argument('--dest', type=str, default='strawberry_dataset')
    args = parser.parse_args()
    split_dataset(base_dir=args.base_dir, source_folder=args.source, dest_folder=args.dest)

MMPretrain模型配置文件 convnext-v2-base_32xb32_sd.py

_base_ = [
    '../_base_/models/convnext_v2/base.py',
    '../_base_/datasets/imagenet_bs64_swin_224.py',
    '../_base_/schedules/imagenet_bs1024_adamw_swin.py',
    '../_base_/default_runtime.py',
]

model = dict(
    data_preprocessor=dict(
        num_classes=77,
        mean=[123.675, 116.28, 103.53],
        std=[58.395, 57.12, 57.375],
        to_rgb=True,
    ),
    backbone=dict(
        frozen_stages=2,
        init_cfg=dict(
            type='Pretrained',
            checkpoint='checkpoints/convnext_v2/convnext-v2-base_fcmae-in21k-pre_3rdparty_in1k_20230104-c48d16a5.pth',
            prefix='backbone',
        )),
    head=dict(
        num_classes=77,
        loss=dict(type='LabelSmoothLoss', label_smooth_val=0.1, num_classes=77),
    ),
    train_cfg=dict(augments=[
        dict(type='Mixup', alpha=0.2),
        dict(type='CutMix', alpha=0.2),
    ]),
)

data_root = 'data/strawberry_dataset'
train_dataloader = dict(
    num_workers=4,
    batch_size=32,
    dataset=dict(type='CustomDataset', data_root=data_root, data_prefix='train'))
val_dataloader = dict(
    num_workers=4,
    batch_size=32,
    dataset=dict(type='CustomDataset', data_root=data_root, data_prefix='val'))
test_dataloader = dict(
    batch_size=32,
    num_workers=4,
    dataset=dict(type='CustomDataset', data_root=data_root, data_prefix='test'))
val_evaluator = [
    dict(type='Accuracy', topk=(1, 5)),
    dict(type='SingleLabelMetric', average='macro', items=['precision', 'recall', 'f1-score'])
]
test_evaluator = val_evaluator

optim_wrapper = dict(optimizer=dict(lr=5e-3))
param_scheduler = [
    dict(type='LinearLR', start_factor=1e-3, by_epoch=True, end=20, convert_to_iter_based=True),
    dict(type='CosineAnnealingLR', eta_min=1e-5, by_epoch=True, begin=20)
]
train_cfg = dict(by_epoch=True, max_epochs=100, val_interval=1)
custom_hooks = [dict(type='EMAHook', momentum=1e-4, priority='ABOVE_NORMAL')]

MMRazor剪枝配置文件 prune_convnext-v2-base_sd.py

_base_ = 'path/to/mmpretrain/configs/convnext_v2/convnext-v2-base_32xb32_sd.py'
custom_imports = dict(imports=['mmpretrain.models'], allow_failed_imports=False)
load_from = 'path/to/mmpretrain/work_dirs/convnext-v2-base_32xb32_sd/best_accuracy_top1_epoch_63.pth'

algorithm = dict(
    type='ChannelPruning',
    pruner=dict(
        type='L1NormPruner',
        pruning_schedule=dict(
            type='LinearPruningSchedule',
            start_ratio=0.0,
            end_ratio=0.5,
            start_epoch=0,
            end_epoch=10
        )    
    ),
    target_pruning_ratio=0.5,
    finetuner=dict(
        type='DefaultFinetuner',
        epochs=10,
        optimizer=dict(type='AdamW', lr=1e-5, weight_decay=0.05),
        lr_scheduler=dict(type='CosineAnnealingLR', T_max=10, eta_min=0.0)
    ),
    pruning_begin_epoch=0,
    pruning_end_epoch=10,
    early_stopping=dict(patience=3, min_delta=0.001)
)

default_hooks = dict(
    logger=dict(type='LoggerHook', interval=50),
    param_scheduler=dict(type='ParamSchedulerHook'),
    checkpoint=dict(type='CheckpointHook', interval=1, save_best='auto'),
    sampler_seed=dict(type='DistSamplerSeedHook'),
    visualization=dict(type='VisualizationHook', enable=False),
)

MMDeploy转换为TensorRT引擎的脚本

python3 tools/deploy.py \
    configs/mmpretrain/classification_tensorrt_dynamic-224x224-224x224.py \
    path/to/mmrazor/configs/prune_convnext-v2-base_sd.py \
     path/to/mmrazor/work_dirs/pruned_convnext-v2-base_sd/epoch_20.pth \
     path/to/mmpretrain/data/strawberry_dataset/val/aphid/aphid_13.jpg \
    --work-dir work_dir/tensorrt_dynamic \
    --device cuda:0

使用TensorRT引擎测试的脚本

python3 tools/test.py \
    configs/mmpretrain/classification_tensorrt_dynamic-224x224-224x224.py \
     path/to/mmrazor/configs/prune_convnext-v2-base_sd.py \
    --model work_dir/tensorrt_dynamic/end2end.engine \
    --show-dir work_dir/tensorrt_dynamic/outpu \
    --device cuda:0 \
    --speed-test \
    --work-dir work_dir/tensorrt_dynamic \
    --log2file work_dir/tensorrt_dynamic/output/eval.log

总结与展望

本文详细记录了基于OpenMMLab构建草莓病害识别系统的全过程。通过ConvNeXt V2+先进训练策略,我们在77类数据集上取得了92.96%的准确率;通过L1-norm剪枝,模型速度提升1.6倍而精度损失极小;通过TensorRT转换,最终在Jetson Nano上实现了109 FPS的实时推理。整个流程充分体现了OpenMMLab工具链的强大与便捷。

当然,实际落地仍有许多挑战,比如光照变化导致的域漂移、边缘设备的散热和供电等。未来可以考虑引入在线学习或联邦学习,让模型在部署后持续进化。希望这篇博客能为你的类似项目提供一些参考,也欢迎大家在评论区交流讨论。

参考资料

  1. MMPretrain Contributors. MMPretrain: OpenMMLab’s Pre-training Toolbox and Benchmark
  2. MMRazor Contributors. MMRazor: Model Compression Toolbox
  3. MMDeploy Contributors. MMDeploy: OpenMMLab Model Deployment Framework
  4. Strawberry Disease Classification Dataset. Kaggle
  5. Woo S, Debnath S, Hu R, et al. ConvNeXt V2: Co-designing and Scaling ConvNets with Masked Autoencoders. arXiv 2023.
  6. Liu Z, Mao H, Wu C Y, et al. A ConvNet for the 2020s. arXiv 2022.
  7. Liu Z, Lin Y, Cao Y, et al. Swin Transformer: Hierarchical Vision Transformer using Shifted Windows. arXiv 2021.
  8. NVIDIA Jetson Nano. Product Page

作者:surpolo
日期:2026年2月
转载请注明出处

Logo

有“AI”的1024 = 2048,欢迎大家加入2048 AI社区

更多推荐