这份《分类实战》教程主要围绕食物分类问题展开,通过迁移学习、数据增强和半监督学习等技术,展示了如何构建一个高效的图像分类模型。

以下是为您整理的模块化学习指南:


1. 项目背景与数据集构成

该实战项目的核心是处理一个包含 11类食物 的分类任务 。

  • 数据集划分

    • 有标签训练数据:280 * 11 = 3080 张 。

    • 无标签训练数据:6786 张(用于半监督学习)。

    • 验证集:30 * 11 = 330 张 。

    • 测试集:3347 张 。

  • 代码文件food_classification


2. 核心技术:迁移学习 (Transfer Learning)

教程强调在数据量较少时,迁移学习是最佳选择 。

  • 原理:利用“大佬”们在 ImageNet 等海量数据集(千万级以上)上训练好的模型 。这些模型已经具备了极强的特征提取能力 。

  • 做法

    • 特征提取器 (Features):保留预训练模型的卷积层权重 。

    • 分类器 (Classifier):将原本的 1000 类输出修改为适合本任务的 11 类 。

    • 微调 (Fine-tuning):在自己的小数据集上进行针对性训练 。

  • 优势:相比从零训练,使用预训练模型能显著提高准确率 。


3. 数据增强 (Data Augmentation)

为了提高模型的泛化能力,教程展示了对原始图像进行处理的方法 。

  • 通过对图像进行旋转、缩放、翻转或色彩变换,增加训练样本的多样性,防止模型过拟合 。


4. 半监督学习 (Semi-supervised Learning)

针对那 6786 张无标签数据,教程给出了处理方案 :

  1. 使用初步训练好的模型对无标签数据进行预测 。

  2. 设定阈值:当预测结果的置信度超过一定水平时,将其视为“伪标签”数据 。

  3. 将这些带有伪标签的数据加入训练集,进一步优化模型 。


5. 神经网络训练流程

一个标准的深度学习项目包含以下四个关键模块 :

模块 说明
Data (数据)

输入文件地址,输出数据结构(如 $X, Y$ 对),通常用 DataLoader 加载 。

Model (模型)

定义模型架构,输入 $x$ 得到预测值 $\hat{y}$ 。

HyperPara (超参)

包括学习率 (lr)、优化器(如 Adam/AdamW)、损失函数等 。

Process (训练)

通过 Training Data 计算 Loss 并更新梯度;通过 Val Data 验证效果 。


6. 其他练习数据集

教程最后提到了几个经典的入门数据集,供对比练习 :

  • MNIST / Fashion-MNIST:28x28 的小图,可以直接展开用全连接网络,也可以用卷积网络对比效果 。

  • CIFAR-10:包含飞机、汽车、鸟、猫等 10 类彩色图片 。

结合你上传的代码(特别是 main.pysimple_class.py),我为你详细拆解一下迁移学习半监督学习的具体代码操作:

一、 迁移学习 (Transfer Learning) 的操作

main.py 中,迁移学习主要通过 initialize_model 函数实现。

1. 加载预训练权重

代码中通常会有类似这样的调用:

Python

# main.py 中
model, input_size = initialize_model(model_name, 11, use_pretrained=True)
  • 操作原理:当 use_pretrained=True 时,PyTorch 会从官方服务器下载在 ImageNet(百万级图片)上练好的模型权重(比如 ResNet18)。

  • 为什么这么做:预训练模型已经学会了如何识别直线、圆圈、纹理等通用特征。

2. 修改分类头 (Classifier)

model_utils/model.py(虽然未直接上传此工具类,但逻辑在 main.py 的初始化中体现)中,通常会执行:

Python

# 假设是 ResNet18
num_ftrs = model.fc.in_features
model.fc = nn.Linear(num_ftrs, 11) # 将原有的 1000 类输出改为现在的 11 类食物分类
  • 操作:保留模型前面的“特征提取”部分,只替换最后一层“分类器”。

3. 设置不同的学习率(微调 Fine-tuning)

观察 main.py 的优化器设置:

Python

optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)
  • 操作:通常使用较小的学习率(如 1e-4),以保证在训练时不破坏预训练模型已经学好的特征,只是针对食物数据进行微调。


二、 半监督学习 (Semi-supervised Learning) 的操作

半监督学习的核心在于利用那 6786 张无标签数据。代码主要在 simple_class.py 中的 food_Dataset 类和 get_label 逻辑中。

1. 伪标签 (Pseudo-Labeling) 流程

其操作逻辑通常如下(参考 simple_class.py 中的 semi 模式):

  1. 第一步:初步训练

    先用 3080 张有标签的数据训练出一个初步的 model

  2. 第二步:预测无标签数据

    调用 get_label 函数(在 simple_class.py 中定义):

    Python

    @torch.no_grad()
    def get_label(model, loader, device, thres=0.99):
        model.eval()
        # 对无标签数据进行预测
        logits = model(img)
        probs = torch.softmax(logits, dim=-1)
        max_probs, labels = torch.max(probs, dim=-1)
    
        # 核心操作:筛选高置信度样本
        mask = max_probs > thres 
        # 只有预测概率大于 0.99 的图片才会被赋予“伪标签”
    
  3. 第三步:加入训练集

    simple_class.pyfood_Dataset 类中,当 mode=="semi" 时:

    Python

    if self.mode == "semi":
        # 如果该样本通过了阈值筛选,就返回这张图及其预测的伪标签
        return self.transform(img), self.label[idx]
    
2. 循环迭代
  • 操作:在 train_val 循环中,每过几个 Epoch,就重新用更新后的模型给无标签数据打一次分,筛选出更多可靠的“伪标签”数据加入训练,从而实现“数据越练越多”。


总结:你在学习时该看哪里?

  1. 想看迁移学习:看 main.py 里的 initialize_model 调用,以及它是如何通过 model_name 切换不同的大模型的。

  2. 想看半监督学习:看 simple_class.pyfood_Dataset 是如何处理 unlabeled 文件夹的,以及如何配合 softmax 概率进行筛选。

这两项技术结合,能让你在只有少量(3000多张)标注图片的情况下,达到接近甚至超过 80% 的准确率。

Logo

有“AI”的1024 = 2048,欢迎大家加入2048 AI社区

更多推荐