分类实战(迁移学习,半监督)
想看迁移学习:看main.py里的调用,以及它是如何通过model_name切换不同的大模型的。想看半监督学习:看里是如何处理unlabeled文件夹的,以及如何配合softmax概率进行筛选。这两项技术结合,能让你在只有少量(3000多张)标注图片的情况下,达到接近甚至超过 80% 的准确率。
这份《分类实战》教程主要围绕食物分类问题展开,通过迁移学习、数据增强和半监督学习等技术,展示了如何构建一个高效的图像分类模型。
以下是为您整理的模块化学习指南:
1. 项目背景与数据集构成
该实战项目的核心是处理一个包含 11类食物 的分类任务 。
-
数据集划分:
-
有标签训练数据:280 * 11 = 3080 张 。
-
无标签训练数据:6786 张(用于半监督学习)。
-
验证集:30 * 11 = 330 张 。
-
测试集:3347 张 。
-
-
代码文件:
food_classification。
2. 核心技术:迁移学习 (Transfer Learning)
教程强调在数据量较少时,迁移学习是最佳选择 。
-
原理:利用“大佬”们在 ImageNet 等海量数据集(千万级以上)上训练好的模型 。这些模型已经具备了极强的特征提取能力 。
-
做法:
-
特征提取器 (Features):保留预训练模型的卷积层权重 。
-
分类器 (Classifier):将原本的 1000 类输出修改为适合本任务的 11 类 。
-
微调 (Fine-tuning):在自己的小数据集上进行针对性训练 。
-
-
优势:相比从零训练,使用预训练模型能显著提高准确率 。
3. 数据增强 (Data Augmentation)
为了提高模型的泛化能力,教程展示了对原始图像进行处理的方法 。
-
通过对图像进行旋转、缩放、翻转或色彩变换,增加训练样本的多样性,防止模型过拟合 。
4. 半监督学习 (Semi-supervised Learning)
针对那 6786 张无标签数据,教程给出了处理方案 :
-
使用初步训练好的模型对无标签数据进行预测 。
-
设定阈值:当预测结果的置信度超过一定水平时,将其视为“伪标签”数据 。
-
将这些带有伪标签的数据加入训练集,进一步优化模型 。
5. 神经网络训练流程
一个标准的深度学习项目包含以下四个关键模块 :
| 模块 | 说明 |
| Data (数据) |
输入文件地址,输出数据结构(如 $X, Y$ 对),通常用 |
| Model (模型) |
定义模型架构,输入 $x$ 得到预测值 $\hat{y}$ 。 |
| HyperPara (超参) |
包括学习率 (lr)、优化器(如 Adam/AdamW)、损失函数等 。 |
| Process (训练) |
通过 Training Data 计算 Loss 并更新梯度;通过 Val Data 验证效果 。 |
6. 其他练习数据集
教程最后提到了几个经典的入门数据集,供对比练习 :
-
MNIST / Fashion-MNIST:28x28 的小图,可以直接展开用全连接网络,也可以用卷积网络对比效果 。
-
CIFAR-10:包含飞机、汽车、鸟、猫等 10 类彩色图片 。
结合你上传的代码(特别是 main.py 和 simple_class.py),我为你详细拆解一下迁移学习和半监督学习的具体代码操作:
一、 迁移学习 (Transfer Learning) 的操作
在 main.py 中,迁移学习主要通过 initialize_model 函数实现。
1. 加载预训练权重
代码中通常会有类似这样的调用:
Python
# main.py 中
model, input_size = initialize_model(model_name, 11, use_pretrained=True)
-
操作原理:当
use_pretrained=True时,PyTorch 会从官方服务器下载在 ImageNet(百万级图片)上练好的模型权重(比如 ResNet18)。 -
为什么这么做:预训练模型已经学会了如何识别直线、圆圈、纹理等通用特征。
2. 修改分类头 (Classifier)
在 model_utils/model.py(虽然未直接上传此工具类,但逻辑在 main.py 的初始化中体现)中,通常会执行:
Python
# 假设是 ResNet18
num_ftrs = model.fc.in_features
model.fc = nn.Linear(num_ftrs, 11) # 将原有的 1000 类输出改为现在的 11 类食物分类
-
操作:保留模型前面的“特征提取”部分,只替换最后一层“分类器”。
3. 设置不同的学习率(微调 Fine-tuning)
观察 main.py 的优化器设置:
Python
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)
-
操作:通常使用较小的学习率(如
1e-4),以保证在训练时不破坏预训练模型已经学好的特征,只是针对食物数据进行微调。
二、 半监督学习 (Semi-supervised Learning) 的操作
半监督学习的核心在于利用那 6786 张无标签数据。代码主要在 simple_class.py 中的 food_Dataset 类和 get_label 逻辑中。
1. 伪标签 (Pseudo-Labeling) 流程
其操作逻辑通常如下(参考 simple_class.py 中的 semi 模式):
-
第一步:初步训练
先用 3080 张有标签的数据训练出一个初步的
model。 -
第二步:预测无标签数据
调用
Pythonget_label函数(在simple_class.py中定义):@torch.no_grad() def get_label(model, loader, device, thres=0.99): model.eval() # 对无标签数据进行预测 logits = model(img) probs = torch.softmax(logits, dim=-1) max_probs, labels = torch.max(probs, dim=-1) # 核心操作:筛选高置信度样本 mask = max_probs > thres # 只有预测概率大于 0.99 的图片才会被赋予“伪标签” -
第三步:加入训练集
在
Pythonsimple_class.py的food_Dataset类中,当mode=="semi"时:if self.mode == "semi": # 如果该样本通过了阈值筛选,就返回这张图及其预测的伪标签 return self.transform(img), self.label[idx]
2. 循环迭代
-
操作:在
train_val循环中,每过几个 Epoch,就重新用更新后的模型给无标签数据打一次分,筛选出更多可靠的“伪标签”数据加入训练,从而实现“数据越练越多”。
总结:你在学习时该看哪里?
-
想看迁移学习:看
main.py里的initialize_model调用,以及它是如何通过model_name切换不同的大模型的。 -
想看半监督学习:看
simple_class.py里food_Dataset是如何处理unlabeled文件夹的,以及如何配合softmax概率进行筛选。
这两项技术结合,能让你在只有少量(3000多张)标注图片的情况下,达到接近甚至超过 80% 的准确率。
更多推荐


所有评论(0)