认知篇#15:MMPretrain
MMPretrain是基于PyTorch的开源深度学习预训练工具箱,由OpenMMLab项目开发。它整合了MMClassification和MMSelfSup项目,提供丰富的预训练主干网络(如ResNet、ViT)和训练策略,支持图像分类、检索、描述、视觉问答等多种任务。MMPretrain采用模块化设计,包含models、datasets、apis等核心组件,支持从预训练到微调的完整工作流程。安
一、MMPretrain是什么
MMPreTrain 是一款基于 PyTorch 的开源深度学习预训练工具箱。MMPretrain通过其模块化设计,让构建和训练模型变得清晰灵活。其典型的工作流程是:选择一个主干网络(Backbone) 在大型数据集上进行预训练,学习通用的视觉特征;然后,将这个预训练好的模型,通过微调(Fine-tuning)或直接用于特征提取,应用到各种下游任务中。具体的,可以总结为下表:
表2-1 MMPretrain特性说明表
| 特性分类 | 具体说明 |
| 核心定位 | 基于 PyTorch 的开源深度学习预训练工具箱,OpenMMLab 项目成员 |
| 技术渊源 | 源自 MMClassification 和 MMSelfSup 项目整合升级 |
| 核心功能 | 提供丰富的预训练主干网络(如 ResNet, ViT)和训练策略(有监督、自监督、多模态) |
| 支持任务 | 图像分类、图像描述、视觉问答、视觉定位、图像/文本检索 |
| 主要特点 | 模块化设计、高精度模型、训练技巧丰富、高效率与可扩展性 |
二、MMPretrain模型一般框架及工作流程
1、MMPretrain组件
MMPretrain通过其模块化设计,让构建和训练模型变得清晰灵活。其代码框架主要包含以下组件:
-
models:这是核心,分为backbones(主干网络,如ResNet、ViT,用于特征提取)、necks(连接主干和头部的结构,如特征融合)、heads(负责输出任务结果,如分类得分)、losses(损失函数)以及更高阶的模型定义(如classifier)。 -
datasets:支持多种数据集,并包含数据变换(如数据增强)管道。 -
apis:提供顶层的推理接口,支持多种任务的"开箱即用"的预测。
典型的工作流程是:选择一个主干网络(Backbone) 在大型数据集上进行预训练,学习通用的视觉特征;然后,将这个预训练好的模型,通过微调(Fine-tuning)或直接用于特征提取,应用到各种下游任务中。
MMPretrain支持大量经典的(如ResNet)和现代的(如Vision Transformer (ViT)、ConvNeXt)主干网络,并提供在ImageNet等大型数据集上预训练的高精度模型,可直接使用或微调。
2、MMPretrain功能
支持广泛的视觉任务:
图像分类与检索:基础图像分类,及以图搜图、以文搜图等检索任务。
图像描述(Image Caption):为给定图像生成文字描述。
视觉问答(Visual Question Answering):根据给定图片回答自然语言问题。
视觉定位(Visual Grounding):定位图像中与文本描述对应的区域。
当然,简单的分类、计数、分割等等功能也不在话下。
三、MMPretrain模型安装
1、环境配置
需要 Python 3.7+、CUDA 10.2+ 和 PyTorch 1.8+
由于采用云服务器进行模型训练,且环境配置方法与YOLO相同,故略
2、模型下载安装
可终端输入以下命令进行下载安装:
git clone https://github.com/open-mmlab/mmpretrain.git
cd mmpretrain
pip install -U openmim && mim install -e .
如需安装安装多模态支持,则:
mim install -e ".[multimodal]"
3、验证安装
同YOLO类似,需要通过一个例子来对环境搭建和模型下载安装进行验证,可输入:
python demo/image_demo.py demo/demo.JPEG resnet18_8xb32_in1k --device cpu
如果写成Python文件,则:
from mmpretrain import get_model, inference_model model = get_model('resnet18_8xb32_in1k', device='cpu') # 或者 device='cuda:0' inference_model(model, 'demo/demo.JPEG')
如看到输出一个字典,包含预测的标签、得分及类别名,则成功。
四、MMPretrain模型使用
1、理解配置文件(MMPretrain/Configs/)
整体文件的结构如下:
MMPretrain/
├── configs/
│ ├── _base_/ # primitive configuration folder
│ │ ├── datasets/ # primitive datasets
│ │ ├── models/ # primitive models
│ │ ├── schedules/ # primitive schedules
│ │ └── default_runtime.py # primitive runtime setting
│ ├── beit/ # BEiT Algorithms Folder
│ ├── mae/ # MAE Algorithms Folder
│ ├── mocov2/ # MoCoV2 Algorithms Folder
│ ├── resnet/ # ResNet Algorithms Folder
│ ├── swin_transformer/ # Swin Algorithms Folder
│ ├── vision_transformer/ # ViT Algorithms Folder
│ ├── ...
└── ...
可以通过继承一些基本配置文件轻松构建自己的训练配置文件。我们称这些被继承的配置文件为原始配置文件,如 base 文件夹中的文件一般仅作为原始配置文件。
(1)model
model = dict( type='ImageClassifier', # 主模型类型(对于图像分类任务,使用 `ImageClassifier`) backbone=dict( type='ResNet', # 主干网络类型 # 除了 `type` 之外的所有字段都来自 `ResNet` 类的 __init__ 方法 # 可查阅 https://mmpretrain.readthedocs.io/zh_CN/latest/api/generated/mmpretrain.models.backbones.ResNet.html depth=50, num_stages=4, # 主干网络状态(stages)的数目,这些状态产生的特征图作为后续的 head 的输入。 out_indices=(3, ), # 输出的特征图输出索引。 frozen_stages=-1, # 冻结主干网的层数 style='pytorch'), neck=dict(type='GlobalAveragePooling'), # 颈网络类型 head=dict( type='LinearClsHead', # 分类颈网络类型 # 除了 `type` 之外的所有字段都来自 `LinearClsHead` 类的 __init__ 方法 # 可查阅 https://mmpretrain.readthedocs.io/zh_CN/latest/api/generated/mmpretrain.models.heads.LinearClsHead.html num_classes=1000, in_channels=2048, loss=dict(type='CrossEntropyLoss', loss_weight=1.0), # 损失函数配置信息 topk=(1, 5), # 评估指标,Top-k 准确率, 这里为 top1 与 top5 准确率 ))
-
type:算法类型,支持多种任务:例如第2行type='ImageClassifier'指的是图像分类任务;对于自监督任务,有多种类型的算法,例如MoCoV2,BEiT,MAE等;对图像检索任务,通常为ImageToImageRetriever。通常使用type字段 来指定组件的类,并使用其他字段来传递类的初始化参数。 -
backbone=dict(……):主干网络设置,主干网络为主要的特征提取网络,比如ResNet,Swin Transformer,Vision Transformer等等。 -
neck=dict(……):颈网络设置,颈网络主要是连接主干网和头网络的中间部分,比如GlobalAveragePooling等。 -
head=dict(……):头网络设置,头网络主要是与具体任务关联的部件,如图像分类、自监督训练等。 -
括号里:是type后函数的各个参数设置。
type='ResNet', # 主干网络类型 # 除了 `type` 之外的所有字段都来自 `ResNet` 类的 __init__ 方法 # 可查阅 https://mmpretrain.readthedocs.io/zh_CN/latest/api/generated/mmpretrain.models.backbones.ResNet.html depth=50, num_stages=4, # 主干网络状态(stages)的数目,这些状态产生的特征图作为后续的 head 的输入。 out_indices=(3, ), # 输出的特征图输出索引。 frozen_stages=-1, # 冻结主干网的层数 style='pytorch'
(2)data
dataset_type = 'ImageNet' # 预处理配置 data_preprocessor = dict( # 输入的图片数据通道以 'RGB' 顺序 mean=[123.675, 116.28, 103.53], # 输入图像归一化的 RGB 通道均值 std=[58.395, 57.12, 57.375], # 输入图像归一化的 RGB 通道标准差 to_rgb=True, # 是否将通道翻转,从 BGR 转为 RGB 或者 RGB 转为 BGR ) train_pipeline = [ dict(type='LoadImageFromFile'), # 读取图像 dict(type='RandomResizedCrop', scale=224), # 随机放缩裁剪 dict(type='RandomFlip', prob=0.5, direction='horizontal'), # 随机水平翻转 dict(type='PackInputs'), # 准备图像以及标签 ] test_pipeline = [ dict(type='LoadImageFromFile'), # 读取图像 dict(type='ResizeEdge', scale=256, edge='short'), # 缩放短边尺寸至 256px dict(type='CenterCrop', crop_size=224), # 中心裁剪 dict(type='PackInputs'), # 准备图像以及标签 ] # 构造训练集 dataloader train_dataloader = dict( batch_size=32, # 每张 GPU 的 batchsize num_workers=5, # 每个 GPU 的线程数 dataset=dict( # 训练数据集 type=dataset_type, data_root='data/imagenet', ann_file='meta/train.txt', data_prefix='train', pipeline=train_pipeline), sampler=dict(type='DefaultSampler', shuffle=True), # 默认采样器 persistent_workers=True, # 是否保持进程,可以缩短每个 epoch 的准备时间 ) # 构造验证集 dataloader val_dataloader = dict( batch_size=32, num_workers=5, dataset=dict( type=dataset_type, data_root='data/imagenet', ann_file='meta/val.txt', data_prefix='val', pipeline=test_pipeline), sampler=dict(type='DefaultSampler', shuffle=False), persistent_workers=True, ) # 验证集评估设置,使用准确率为指标, 这里使用 topk1 以及 top5 准确率 val_evaluator = dict(type='Accuracy', topk=(1, 5)) test_dataloader = val_dataloader # test dataloader 配置,这里直接与 val_dataloader 相同 test_evaluator = val_evaluator # 测试集的评估配置,这里直接与 val_evaluator 相同
数据原始配置文件主要包括预处理设置、dataloader 以及评估器等设置:
-
data_preprocessor: 模型输入预处理配置,与model.data_preprocessor相同,但优先级更低。 -
train_evaluator | val_evaluator | test_evaluator: 构建评估器 -
train_dataloader | val_dataloader | test_dataloader: 构建 dataloader-
batch_size: 每个 GPU 的 batch size -
num_workers: 每个 GPU 的线程数 -
sampler: 采样器配置 -
dataset: 数据集配置-
type: 数据集类型, MMPretrain 支持ImageNet、Cifar等数据集 -
pipeline: 数据处理流水线
-
-
数据集结构和YOLO类似,文本文件与图片文件。
(3)schedules策略
optim_wrapper = dict( # 使用 SGD 优化器来优化参数 optimizer=dict(type='SGD', lr=0.1, momentum=0.9, weight_decay=0.0001)) # 学习率参数的调整策略 # 'MultiStepLR' 表示使用多步策略来调度学习率(LR)。 param_scheduler = dict( type='MultiStepLR', by_epoch=True, milestones=[30, 60, 90], gamma=0.1) # 训练的配置,迭代 100 个 epoch,每一个训练 epoch 后都做验证集评估 # 'by_epoch=True' 默认使用 `EpochBaseLoop`, 'by_epoch=False' 默认使用 `IterBaseLoop` train_cfg = dict(by_epoch=True, max_epochs=100, val_interval=1) # 使用默认的验证循环控制器 val_cfg = dict() # 使用默认的测试循环控制器 test_cfg = dict() # 通过默认策略自动缩放学习率,此策略适用于总批次大小 256 # 如果你使用不同的总批量大小,比如 512 并启用自动学习率缩放 # 我们将学习率扩大到 2 倍 auto_scale_lr = dict(base_batch_size=256)
训练策略原始配置文件主要包括预优化器设置(SGD)和训练、验证及测试的循环控制器(LOOP):
-
optim_wrapper: 优化器装饰器配置信息,我们使用优化器装饰配置优化进程。-
optimizer: 支持pytorch所有的优化器。 -
paramwise_cfg: 根据参数的类型或名称设置不同的优化参数。 -
accumulative_counts: 积累几个反向传播后再优化参数,你可以用它通过小批量来模拟大批量。
-
-
param_scheduler: 学习率策略,你可以指定训练期间的学习率和动量曲线。 -
train_cfg | val_cfg | test_cfg: 训练、验证以及测试的循环执行器配置。
(4)default_runtime
本部分主要包括保存权重策略、日志配置、训练参数、断点权重路径和工作目录等等。根据文档,可以直接继承。
2、继承与修改配置文件
与YOLO一样,只需要在原始模型的基础上进行调整即可,当然,部分内容可以忽略,部分内容需要修改。
更多推荐

所有评论(0)