目录

一、引言

二、项目背景与数据集介绍

2.1 项目背景

2.2 数据集结构

三、完整代码实现(附逐行注释)

3.1 环境准备与依赖导入

3.2 自动生成train.txt和test.txt文件(自动获取类别列表)

代码解析:

3.3 自定义数据集类(继承Dataset)

代码解析:

3.4 数据预处理与加载(DataLoader)

代码解析:

3.5 模型构建(CNN网络定义)

代码解析:

3.6 模型训练与测试

代码解析:

3.7 输入图片预测功能(核心扩展)

代码解析:

四、训练与测试流程

4.1 训练模型

4.2 测试模型

4.3 输入图片预测

五、常见问题与优化建议

5.1 常见问题

5.2 优化建议

六、总结


一、引言

在计算机视觉领域,图像分类是核心技术之一,而食物分类作为其垂直细分场景,广泛应用于餐饮推荐、健康管理等生活场景。PyTorch凭借其灵活的动态计算图和丰富的生态库,成为实现图像分类任务的首选框架。本文将以​​本地自定义食物图像数据集​​为对象,手把手演示如何从数据准备、模型构建到训练测试,最终实现“输入图片路径→输出食物类别”的全流程实战。


二、项目背景与数据集介绍

2.1 项目背景

食物分类的核心目标是通过模型识别图像中的食物种类(如薯条、八宝粥、骨肉相连等)。与通用图像分类(如ImageNet)不同,食物分类的数据集通常规模较小但类别更聚焦(本案例包含20类食物),因此需要更精细的数据组织和模型调优。

2.2 数据集结构

用户提供的项目数据集结构清晰,完全符合PyTorch的Dataset加载规范。通过图1-5的文件管理器界面,我们可以直观看到以下层级关系:

2、卷积神经网络/
└─ data/
   └─ 食物分类/
      └─ food_dataset/
         ├─ train/          # 训练集(包含20类食物的子文件夹)
         │  ├─ 八宝粥/
         │  │  ├─ img_八宝粥罐_22.jpeg
         │  │  └─ ...(共6张)
         │  ├─ 巴旦木/
         │  └─ ...(共20类)
         └─ test/           # 测试集(结构同train)

​关键特点​​:

  • 训练集(train)和测试集(test)按类别分文件夹存储(如“八宝粥”“骨肉相连”)。
  • 无需额外标注文件(如CSV/JSON),标签可通过“子文件夹名称”自动生成(如“薯条”文件夹的索引为n,则该文件夹下所有图像的标签为n)。

三、完整代码实现(附逐行注释)

3.1 环境准备与依赖导入

首先需要安装必要的库,包括PyTorch、Pillow(图像处理)、matplotlib(可视化)等。本文假设已配置好PyTorch环境(支持CUDA或MPS加速)。

# 导入基础库:用于文件操作、数值计算等
import os  # 操作系统接口,用于路径遍历、文件操作
import numpy as np  # 数值计算库,用于数组操作

# 导入PyTorch核心库及数据加载工具
import torch  # PyTorch深度学习框架
from torch.utils.data import Dataset, DataLoader  # Dataset定义数据集,DataLoader批量加载数据

# 导入图像处理库:PIL用于打开、保存、显示图像
from PIL import Image  # Python Imaging Library,处理图像文件

# 导入PyTorch的图像预处理工具(缩放、转Tensor等)
from torchvision import transforms  # 包含图像变换的工具集
import torch.nn as nn  # 神经网络模块

3.2 自动生成train.txt和test.txt文件(自动获取类别列表)

在PyTorch中,Dataset类通常需要读取一个包含图像路径和标签的文本文件(如train.txt)。本节通过遍历文件夹结构,自动生成这两个文件,并​​动态获取类别名称列表​​(无需手动输入)。

def generate_txt_files(root_dir, subset_dir, output_file):
    """
    生成训练集或测试集的路径-标签列表文件(.txt),自动从目录中获取类别名称
    
    参数:
        root_dir (str): 数据集根目录(如'./食物分类/food_dataset')
        subset_dir (str): 子集名称('train'或'test')
        output_file (str): 输出文件路径(如'./train.txt')
    """
    subset_path = os.path.join(root_dir, subset_dir)  # 拼接子集完整路径(如'./食物分类/food_dataset/train')
    class_names = []  # 存储类别名称(自动从dirs获取)
    img_paths = []    # 存储图像路径
    labels = []       # 存储标签(类别索引)
    
    # 遍历子集目录,获取类别名称和图像路径(核心逻辑)
    for root, dirs, files in os.walk(subset_path):  # os.walk递归遍历目录
        if dirs:  # 当前目录是父级(如train文件夹下有八宝粥、巴旦木等子目录)
            class_names = dirs.copy()  # 捕获当前层级的所有类别名称(如['八宝粥', '巴旦木', ...])
        else:  # 当前目录是类别文件夹(如train/八宝粥)
            current_class = os.path.basename(root)  # 获取当前类别名称(如"八宝粥")
            if current_class not in class_names:
                continue  # 防止跨文件夹的类别干扰(理论上不会出现)
            label = class_names.index(current_class)  # 计算标签(类别索引,如"八宝粥"对应0)
            for file in files:  # 遍历当前类别下的所有图像文件
                img_path = os.path.join(root, file)  # 拼接完整图像路径(如'./train/八宝粥/img_八宝粥罐_22.jpeg')
                img_paths.append(img_path)  # 存储图像路径
                labels.append(label)       # 存储标签
    
    # 写入txt文件(格式:"图像路径 标签")
    with open(output_file, 'w', encoding='utf-8') as f:
        for img_path, label in zip(img_paths, labels):
            f.write(f"{img_path} {label}
")
    
    # 返回类别名称列表(供后续加载数据使用)
    return class_names

# 配置参数(根据用户实际路径调整)
dataset_root = r'.\食物分类\food_dataset'  # 数据集根目录(用户图2的food_dataset路径)
train_txt_path = './train.txt'            # 训练集txt文件输出路径
test_txt_path = './test.txt'              # 测试集txt文件输出路径

# 生成训练集和测试集的txt文件,并获取类别名称列表(自动从目录中提取)
print("正在生成train.txt和test.txt文件...")
train_class_names = generate_txt_files(dataset_root, 'train', train_txt_path)
test_class_names = generate_txt_files(dataset_root, 'test', test_txt_path)
print("txt文件生成完成!")
代码解析:
  • ​动态获取类别​​:通过os.walk遍历父目录(如train文件夹)时,dirs变量会自动包含该目录下的所有子目录名称(即所有食物类别名称)。例如,若train文件夹下有“八宝粥”“巴旦木”等子文件夹,dirs将返回['八宝粥', '巴旦木', ...]
  • ​标签生成​​:通过class_names.index(current_class)计算当前类别名称的索引(如“八宝粥”对应0,“巴旦木”对应1),确保标签与类别名称一一对应。
  • ​一致性验证​​:训练集和测试集的类别顺序必须一致(通过assert验证),否则会导致标签错位。

3.3 自定义数据集类(继承Dataset)

PyTorch的Dataset类是数据加载的核心,需要重写__len__(数据集大小)和__getitem__(按索引获取数据)方法。本节定义FoodDataset类,支持自动使用生成的类别名称。

class FoodDataset(Dataset):
    def __init__(self, txt_path, transform=None, class_names=None):
        """
        初始化食物分类数据集(自动使用生成的类别名称)
        
        参数:
            txt_path (str): 图像路径-标签列表文件的路径(如'train.txt')
            transform (callable): 图像预处理变换(如缩放、转Tensor)
            class_names (list): 类别名称列表(自动从generate_txt_files获取)
        """
        self.txt_path = txt_path  # txt文件路径(如'./train.txt')
        self.transform = transform  # 图像预处理变换(如Resize、ToTensor)
        self.class_names = class_names  # 直接使用生成的类别名称列表(如['八宝粥', '巴旦木', ...])
        self.img_paths = []    # 存储图像路径
        self.labels = []       # 存储标签(类别索引)
        
        # 读取txt文件并解析数据(逐行读取用户生成的train.txt/test.txt)
        with open(self.txt_path, 'r', encoding='utf-8') as f:
            for line in f.readlines():
                parts = line.strip().split(' ')  # 按空格分割路径和标签
                if len(parts) == 2:  # 确保每行有路径和标签两部分
                    self.img_paths.append(parts[0])  # 存储图像路径
                    self.labels.append(int(parts[1]))  # 存储标签(转为整数)
    
    def __len__(self):
        """返回数据集的大小(图像数量)"""
        return len(self.img_paths)
    
    def __getitem__(self, idx):
        """
        按索引获取图像和标签(支持返回类别名称)
        
        参数:
            idx (int): 数据索引(如0代表train.txt中的第一行数据)
            
        返回:
            tuple: (图像Tensor, 标签Tensor, 类别名称)
        """
        # 获取图像路径和标签(如idx=0时,路径为train.txt的第一行路径)
        img_path = self.img_paths[idx]
        label = self.labels[idx]
        
        # 用PIL打开图像(支持JPEG格式),并确保为3通道(RGB)
        image = Image.open(img_path).convert('RGB')
        
        # 应用预处理变换(如缩放、转Tensor,后续data_transforms中定义)
        if self.transform:
            image = self.transform(image)
        
        # 标签转为Tensor(PyTorch的损失函数需要Tensor类型)
        label = torch.tensor(label, dtype=torch.long)
        
        # 返回图像、标签、类别名称(可选,方便后续可视化)
        return image, label, self.class_names[label]
代码解析:
  • ​初始化方法(__init__)​​:读取txt文件,将图像路径和标签分别存储到img_pathslabels列表中。class_names参数直接使用生成的类别名称列表,无需手动输入。
  • __len__方法​​:返回img_paths的长度,即数据集中图像的总数(如用户train文件夹下有1000张图像,则返回1000)。
  • __getitem__方法​​:
    • 根据索引idx获取对应的图像路径和标签。
    • PIL.Image.open打开图像,并通过.convert('RGB')确保为3通道(避免灰度图导致的通道数不一致问题)。
    • 应用预处理变换(如ResizeToTensor),将图像转换为模型需要的格式。
    • 返回图像Tensor、标签Tensor和类别名称(如“八宝粥”),方便后续可视化。

3.4 数据预处理与加载(DataLoader)

本节定义训练集和测试集的预处理变换,并通过DataLoader批量加载数据。

# 定义训练集和测试集的预处理变换(与模型输入要求一致)
data_transforms = {
    'train': transforms.Compose([
        transforms.Resize([256, 256]),  # 调整图像大小为256x256(统一尺寸,避免批量运算错误)
        transforms.ToTensor(),          # 转换为Tensor(自动归一化到[0,1])
    ]),
    'test': transforms.Compose([
        transforms.Resize([256, 256]),  # 测试集无需数据增强,仅调整尺寸和转Tensor
        transforms.ToTensor(),
    ]),
}

# 加载训练集和测试集(使用生成的txt文件和类别名称)
training_data = FoodDataset(
    txt_path=train_txt_path,          # 训练集txt文件路径
    transform=data_transforms['train'],  # 应用训练集预处理(包含数据增强)
    class_names=train_class_names      # 使用生成的类别名称列表
)
test_data = FoodDataset(
    txt_path=test_txt_path,           # 测试集txt文件路径
    transform=data_transforms['test'],   # 应用测试集预处理(无数据增强)
    class_names=test_class_names       # 使用生成的类别名称列表(与训练集一致)
)

# 创建数据加载器(DataLoader,批量加载数据)
train_dataloader = DataLoader(
    dataset=training_data,    # 训练集数据集
    batch_size=64,            # 每批加载64张图像(平衡内存和效率)
    shuffle=True,             # 训练集打乱顺序(防止模型记忆数据顺序)
    num_workers=4             # 多线程加载数据(根据CPU核心数调整)
)
test_dataloader = DataLoader(
    dataset=test_data,        # 测试集数据集
    batch_size=64,            # 每批加载64张图像
    shuffle=False,            # 测试集不打乱顺序(便于结果分析)
    num_workers=4             # 多线程加载数据
)
代码解析:
  • ​预处理变换​​:
    • transforms.Resize([256, 256]):统一图像尺寸为256x256,确保所有图像能组成批量Tensor(若图像尺寸不一致,无法进行批量运算)。
    • transforms.ToTensor():将PIL图像转换为PyTorch的Tensor,并自动将像素值从[0, 255]归一化到[0, 1]
  • ​DataLoader​​:
    • batch_size=64:每批加载64张图像,根据GPU内存调整(若内存不足,可减小至32)。
    • shuffle=True(训练集):打乱数据顺序,避免模型因数据顺序固定而学习到无关模式(如总是先学习“八宝粥”再学习“薯条”)。
    • shuffle=False(测试集):保持数据顺序,便于按顺序评估模型在测试集上的表现。

3.5 模型构建(CNN网络定义)

本节定义一个简单的卷积神经网络(CNN),包含3个卷积层和2个全连接层,适用于食物分类任务。

class CNN(nn.Module):
    def __init__(self, num_classes):
        """
        初始化卷积神经网络(CNN)
        
        参数:
            num_classes (int): 类别数量(自动从类别名称列表获取)
        """
        super(CNN, self).__init__()  # 调用父类构造函数
        
        # 卷积层1:输入3通道(RGB),输出16通道,提取低级特征(如边缘、纹理)
        self.conv1 = nn.Sequential(
            nn.Conv2d(
                in_channels=3,   # 输入通道数(3表示RGB图像)
                out_channels=16, # 输出通道数(卷积核数量,生成16张特征图)
                kernel_size=5,   # 卷积核尺寸(5x5)
                stride=1,        # 步长(每次滑动1个像素)
                padding=2        # 填充(边缘填充2个像素,确保输出尺寸与输入一致)
            ),                  # 输出特征图尺寸:(batch_size, 16, 256, 256)
            nn.ReLU(),          # ReLU激活函数(引入非线性,提取复杂特征)
            nn.MaxPool2d(2)     # 最大池化层(2x2区域池化,输出尺寸减半:(batch_size, 16, 128, 128))
        )
        
        # 卷积层2:输入16通道,输出32通道,提取中级特征(如局部结构)
        self.conv2 = nn.Sequential(
            nn.Conv2d(
                in_channels=16,  # 输入通道数(来自conv1的输出)
                out_channels=32, # 输出通道数(32张特征图)
                kernel_size=5,   # 卷积核尺寸(5x5)
                stride=1,        # 步长
                padding=2        # 填充(保持尺寸)
            ),                  # 输出特征图尺寸:(batch_size, 32, 128, 128)
            nn.ReLU(),          # ReLU激活函数
            nn.Conv2d(
                in_channels=32,  # 输入通道数(来自上一层的输出)
                out_channels=32, # 输出通道数(32张特征图)
                kernel_size=5,   # 卷积核尺寸(5x5)
                stride=1,        # 步长
                padding=2        # 填充(保持尺寸)
            ),                  # 输出特征图尺寸:(batch_size, 32, 128, 128)
            nn.ReLU(),          # ReLU激活函数
            nn.MaxPool2d(2)     # 最大池化层(输出尺寸减半:(batch_size, 32, 64, 64))
        )
        
        # 卷积层3:输入32通道,输出128通道,提取高级语义特征(如整体形状)
        self.conv3 = nn.Sequential(
            nn.Conv2d(
                in_channels=32,  # 输入通道数(来自conv2的输出)
                out_channels=128,# 输出通道数(128张特征图)
                kernel_size=5,   # 卷积核尺寸(5x5)
                stride=1,        # 步长
                padding=2        # 填充(保持尺寸)
            ),                  # 输出特征图尺寸:(batch_size, 128, 64, 64)
            nn.ReLU()           # ReLU激活函数
        )
        
        # 全连接层:输入128 * 64 * 64(卷积后的特征图尺寸),输出类别数(如20)
        self.fc = nn.Linear(128 * 64 * 64, num_classes)
    
    def forward(self, x):
        """
        前向传播(输入图像,输出类别得分)
        
        参数:
            x (Tensor): 输入图像(形状:[batch_size, 3, 256, 256])
            
        返回:
            Tensor: 预测得分(形状:[batch_size, num_classes])
        """
        # 卷积层1 → 激活 → 池化
        x = self.conv1(x)  # 输出形状:[batch_size, 16, 128, 128]
        # 卷积层2 → 激活 → 卷积 → 激活 → 池化
        x = self.conv2(x)  # 输出形状:[batch_size, 32, 64, 64]
        # 卷积层3 → 激活
        x = self.conv3(x)  # 输出形状:[batch_size, 128, 64, 64]
        # 展平特征图(将三维特征图转为一维向量)
        x = x.view(x.size(0), -1)  # 输出形状:[batch_size, 128 * 64 * 64]
        # 全连接层输出预测得分
        x = self.fc(x)  # 输出形状:[batch_size, num_classes]
        return x
代码解析:
  • ​卷积层设计​​:
    • nn.Conv2d:二维卷积层,用于提取图像的空间特征。in_channels为输入通道数(3表示RGB图像),out_channels为输出通道数(即卷积核数量,生成对应数量的特征图),kernel_size为卷积核尺寸(5x5),stride为步长(1表示每次滑动1个像素),padding为填充(2表示在图像边缘填充2个像素,确保输出尺寸与输入一致)。
    • nn.ReLU():修正线性单元激活函数,引入非线性特性,使模型能拟合更复杂的特征(如食物的边缘、纹理)。
    • nn.MaxPool2d:最大池化层,通过取局部区域的最大值降低特征图尺寸(宽、高减半),减少计算量的同时保留主要特征(如食物的整体轮廓)。
  • ​全连接层​​:
    • nn.Linear(128 * 64 * 64, num_classes):将卷积后的特征图展平为一维向量(长度为128 * 64 * 64),然后通过全连接层映射到num_classes维的输出(对应num_classes类食物的分类得分)。

3.6 模型训练与测试

本节定义训练函数和测试函数,完成模型的优化和性能评估。

def train(dataloader, model, loss_fn, optimizer):
    """
    训练模型(单轮次)
    
    参数:
        dataloader (DataLoader): 训练集数据加载器
        model (nn.Module): 待训练的模型
        loss_fn (nn.Module): 损失函数(交叉熵)
        optimizer (optim.Optimizer): 优化器(Adam)
    """
    model.train()  # 开启训练模式(启用Dropout、BatchNorm等)
    batch_count = 0  # 批次计数器(用于打印进度)
    
    # 遍历训练集的每个批次
    for X, y in dataloader:
        # 将数据和标签移动到目标设备(如GPU)
        X, y = X.to(device), y.to(device)
        
        # 前向传播:计算预测值
        pred = model(X)  # 输入图像 → 输出得分
        
        # 计算损失:预测值与真实标签的交叉熵损失
        loss = loss_fn(pred, y)
        
        # 反向传播:优化参数
        optimizer.zero_grad()  # 清空梯度(避免累积)
        loss.backward()        # 计算梯度(从损失值反向传播到各层参数)
        optimizer.step()       # 更新参数(根据梯度调整参数值)
        
        # 统计并打印训练进度
        batch_count += 1
        if batch_count % 10 == 0:  # 每10个批次打印一次
            print(f"批次 {batch_count}, 损失: {loss.item():.4f}")

def test(dataloader, model, loss_fn):
    """
    测试模型性能(单轮次)
    
    参数:
        dataloader (DataLoader): 测试集数据加载器
        model (nn.Module): 待测试的模型
        loss_fn (nn.Module): 损失函数(交叉熵)
    """
    model.eval()  # 开启测试模式(禁用Dropout、BatchNorm等)
    test_loss = 0.0  # 总损失
    correct = 0      # 正确预测数
    total = 0        # 总样本数
    
    # 关闭梯度计算(节省内存)
    with torch.no_grad():
        # 遍历测试集的每个批次
        for X, y in dataloader:
            # 将数据和标签移动到目标设备(如GPU)
            X, y = X.to(device), y.to(device)
            
            # 前向传播:计算预测值
            pred = model(X)
            
            # 累计损失
            test_loss += loss_fn(pred, y).item()
            
            # 计算正确预测数
            _, predicted = torch.max(pred.data, 1)  # 获取预测类别(得分最高的类别)
            total += y.size(0)                      # 累计总样本数
            correct += (predicted == y).sum().item()  # 累计正确预测数
    
    # 计算平均损失和准确率
    avg_loss = test_loss / len(dataloader)  # 平均每批次损失
    accuracy = 100 * correct / total        # 准确率(百分比)
    
    # 打印测试结果
    print(f"测试结果: \n  准确率: {accuracy:.2f}%, \n  平均损失: {avg_loss:.4f}")
代码解析:
  • ​训练函数(train)​​:
    • model.train():开启训练模式,启用Dropout层(随机失活神经元)和BatchNorm层(计算当前批次的均值和方差),增强模型的泛化能力。
    • 遍历dataloader获取每个批次的图像(X)和标签(y),将数据和标签移动到目标设备(CPU/GPU)。
    • 前向传播计算预测值pred,通过损失函数loss_fn计算预测值与真实标签的损失loss
    • 反向传播loss.backward()计算梯度,优化器optimizer.step()更新模型参数。
    • 统计每个批次的损失值,按指定频率打印训练进度(如每10个批次打印一次)。
  • ​测试函数(test)​​:
    • model.eval():开启测试模式,禁用Dropout和BatchNorm的随机操作(使用训练阶段统计的均值和方差),确保测试结果的稳定性。
    • with torch.no_grad():关闭自动梯度计算,减少内存消耗(测试阶段无需更新参数)。
    • 遍历测试集,计算整体损失和准确率,评估模型对未见过数据的泛化能力。

3.7 输入图片预测功能(核心扩展)

本节实现“输入图片路径→读取→预处理→模型预测→输出结果”的全流程功能。

def predict_image(model, img_path, transform, class_names):
    """
    预测单张图片的食物类别(用户输入路径后调用此函数)
    
    参数:
        model (nn.Module): 训练好的CNN模型
        img_path (str): 用户输入的图片路径(如'./食物分类/food_dataset/test/骨肉相连/img_骨肉相连_05.jpg')
        transform (callable): 图像预处理变换(与训练时一致)
        class_names (list): 类别名称列表(自动从generate_txt_files获取)
    """
    try:
        # 1. 读取图片(支持绝对路径/相对路径)
        image = Image.open(img_path).convert('RGB')  # 确保3通道(避免灰度图错误)
    except FileNotFoundError:
        print(f"错误:路径不存在!请检查输入的路径是否正确:{img_path}")
        return
    except Exception as e:
        print(f"错误:无法打开图片 {img_path},原因:{e}")
        return
    
    # 2. 预处理(与训练时完全一致:Resize(256,256) + ToTensor)
    # 注意:模型输入要求为 [batch_size, channels, height, width],因此需要增加batch维度
    image_tensor = transform(image).unsqueeze(0)  # 形状:[1, 3, 256, 256]
    
    # 3. 模型预测(关闭训练模式,避免Dropout干扰)
    model.eval()  # 开启测试模式
    with torch.no_grad():  # 关闭梯度计算(节省内存)
        output = model(image_tensor)  # 前向传播,输出形状:[1, num_classes]
        _, predicted_idx = torch.max(output, 1)  # 获取得分最高的类别索引(形状:[1])
    
    # 4. 映射索引到类别名称(直接使用生成的class_names列表)
    predicted_class = class_names[predicted_idx.item()]  # 转换为具体类别名称(如"骨肉相连")
    
    # 5. 输出结果(用户友好的提示)
    print(f"预测结果:这张图片是 {predicted_class}")

# ---------------------- 主程序(用户交互入口) ----------------------
if __name__ == '__main__':
    # ---------------------- 5.1 配置参数(根据用户实际路径调整) ----------------------
    dataset_root = r'.\食物分类\food_dataset'  # 数据集根目录(用户图2的food_dataset路径)
    
    # ---------------------- 5.2 生成txt文件并获取类别名称(自动从目录中提取) ----------------------
    print("正在生成train.txt和test.txt文件...")
    train_txt_path = './train.txt'
    test_txt_path = './test.txt'
    train_class_names = generate_txt_files(dataset_root, 'train', train_txt_path)
    test_class_names = generate_txt_files(dataset_root, 'test', test_txt_path)
    print("txt文件生成完成!")
    
    # 验证训练集和测试集的类别顺序一致(避免标签错位)
    assert train_class_names == test_class_names, "错误:训练集和测试集的类别顺序不一致!"
    
    # ---------------------- 5.3 加载预处理变换(与训练时一致) ----------------------
    data_transforms = {
        'train': transforms.Compose([
            transforms.Resize([256, 256]),
            transforms.ToTensor()
        ]),
        'test': transforms.Compose([
            transforms.Resize([256, 256]),
            transforms.ToTensor()
        ])
    }
    
    # ---------------------- 5.4 加载模型(假设已训练并保存) ----------------------
    # 注意:用户需要先训练模型并保存权重(如使用torch.save(model.state_dict(), 'food_classifier.pth'))
    num_classes = len(train_class_names)  # 类别数自动从类别名称列表获取
    model = CNN(num_classes=num_classes)
    
    # 加载训练好的模型权重(替换为用户的实际路径)
    # model.load_state_dict(torch.load('food_classifier.pth'))
    # model.to('cuda' if torch.cuda.is_available() else 'cpu')  # 加载到GPU/CPU(可选)
    
    # ---------------------- 5.5 用户输入图片路径并预测 ----------------------
    print("
请输入要预测的图片路径(绝对路径或相对路径均可,例如:./食物分类/food_dataset/test/骨肉相连/img_骨肉相连_05.jpg):")
    user_input_path = input("路径:").strip()  # 获取用户输入的路径
    
    # 调用预测函数(使用测试集的预处理变换,因为测试时不需要数据增强)
    predict_image(
        model=model,
        img_path=user_input_path,
        transform=data_transforms['test'],
        class_names=train_class_names  # 使用训练集的类别名称(与test一致)
    )
代码解析:
  • ​用户输入处理​​:通过input("路径:").strip()获取用户输入的图片路径,自动去除首尾空格,支持绝对路径和相对路径。
  • ​异常处理​​:捕获FileNotFoundError(路径不存在)和通用异常(如非图片文件),提示用户检查路径或文件类型。
  • ​预处理一致性​​:使用与训练时完全相同的transforms.ComposeResize(256,256)ToTensor),确保模型输入的尺寸和数值范围与训练时一致。
  • ​类别映射​​:通过class_names[predicted_idx.item()]将预测索引映射为具体类别名称(如3对应“骨肉相连”),无需手动输入类别列表。

四、训练与测试流程

4.1 训练模型

  1. ​生成txt文件​​:运行generate_txt_files函数,生成train.txttest.txt,文件中包含图像路径和自动生成的标签。
  2. ​初始化模型​​:实例化CNN模型,类别数自动从train_class_names获取(如20类)。
  3. ​定义损失函数和优化器​​:使用交叉熵损失函数(nn.CrossEntropyLoss)和Adam优化器(torch.optim.Adam)。
  4. ​启动训练​​:遍历训练集,通过train函数更新模型参数,直到达到指定轮次(如10轮)。

4.2 测试模型

  1. ​加载训练好的模型​​:通过model.load_state_dict(torch.load('food_classifier.pth'))加载训练好的权重。
  2. ​运行测试函数​​:调用test函数评估模型在测试集上的准确率和平均损失。

4.3 输入图片预测

  1. ​输入路径​​:根据提示输入图片路径(如./食物分类/food_dataset/test/骨肉相连/img_骨肉相连_05.jpg)。
  2. ​输出结果​​:程序会输出预测的类别名称(如“这张图片是 骨肉相连”)。

五、常见问题与优化建议

5.1 常见问题

  • ​问题1:训练损失不下降​

    • 可能原因:学习率过大(模型无法收敛)或过小(收敛过慢)、数据预处理错误(如标签错误)、模型容量不足(网络太浅)。
    • 解决方法:调整学习率(如从0.001降至0.0001)、检查txt文件中的标签是否正确、增加卷积层或全连接层的神经元数量。
  • ​问题2:测试准确率远低于训练准确率​

    • 可能原因:过拟合(模型过度记忆训练数据)。
    • 解决方法:增加数据增强(如RandomHorizontalFlipRandomRotation)、添加Dropout层(如在conv1后加nn.Dropout2d(0.5))、使用早停法(Early Stopping)。
  • ​问题3:输入图片预测错误​

    • 可能原因:图片尺寸不一致(未通过Resize变换)、预处理不一致(如训练时用了Normalize但预测时未用)、类别名称顺序错误。
    • 解决方法:确保预测时使用与训练时相同的预处理变换、检查class_names列表的顺序是否与txt文件一致。

5.2 优化建议

  • ​数据增强​​:在data_transforms中添加数据增强(如RandomHorizontalFlipRandomRotation),提升模型的泛化能力。
  • ​更深的网络​​:尝试使用ResNet、VGG等预训练模型,或自定义更深的网络结构(如增加卷积层)。
  • ​学习率调整​​:使用torch.optim.lr_scheduler动态调整学习率(如余弦退火、阶梯衰减),提升收敛速度。

六、总结

本文通过完整的代码示例和详细解析,演示了如何基于PyTorch框架和本地自定义数据集实现食物分类任务,并扩展了“输入图片路径→输出类别”的预测功能。核心步骤包括:

  1. ​数据准备​​:按类别分文件夹组织数据,自动生成train.txttest.txt(动态获取类别名称)。
  2. ​自定义数据集​​:继承Dataset类,实现__len____getitem__方法(读取txt文件,加载图像和标签)。
  3. ​模型构建​​:设计卷积神经网络(CNN),提取图像特征并映射到类别空间。
  4. ​训练与测试​​:定义损失函数和优化器,通过DataLoader批量加载数据,迭代训练并评估模型性能。
  5. ​输入预测​​:实现用户输入图片路径的读取、预处理和预测功能,输出具体类别名称。

通过本案例,读者可以掌握PyTorch处理自定义数据集的核心流程,并为后续的图像分类、目标检测等任务打下坚实基础。建议在实际项目中尝试不同的网络结构、数据增强方法和超参数调优,以进一步提升模型性能。

Logo

有“AI”的1024 = 2048,欢迎大家加入2048 AI社区

更多推荐