目录

一、数据增强

1. 核心概念

2. 核心目的

3. 常用方法

4. 实现示例(基于 PyTorch)

5. 自定义数据集加载

二、保存最优模型

1. 核心概念

2. 实现步骤

(1)定义 CNN 模型

(2)定义训练与测试函数

(3)启动训练

3. 模型加载与使用

三、总结


在卷积神经网络(CNN)的训练过程中,数据增强和模型保存是提升性能与实用性的关键环节。以下结合理论与实例,详细解析其原理及实现方式。

一、数据增强
1. 核心概念

数据增强是通过对原始训练数据进行随机变换(如旋转、翻转、调整亮度等),生成新的训练样本的技术。其本质是扩展数据多样性,让模型在训练中接触更多 “变体”,从而提升泛化能力(减少过拟合)。

2. 核心目的
  • 模拟真实场景中的变量(如光照变化、视角差异、遮挡等)。
  • 解决训练数据不足的问题,通过 “人工扩充” 提升模型鲁棒性。
3. 常用方法

4. 实现示例(基于 PyTorch)
import torch
from torchvision import transforms

# 定义训练集和验证集的数据增强策略
data_transforms = {
    'train': transforms.Compose([
        transforms.Resize([300, 300]),  # 缩放图像
        transforms.RandomRotation(45),  # 随机旋转(-45°~45°)
        transforms.CenterCrop(256),     # 中心裁剪至256x256
        transforms.RandomHorizontalFlip(p=0.5),  # 50%概率水平翻转
        transforms.ColorJitter(brightness=0.2, contrast=0.1),  # 颜色调整
        transforms.ToTensor(),  # 转换为Tensor(像素值归一化到[0,1])
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])  # 标准化
    ]),
    'valid': transforms.Compose([
        transforms.Resize([256, 256]),  # 验证集仅缩放,不做随机增强
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])
}
5. 自定义数据集加载

通过继承Dataset类,将增强策略应用于实际数据:

from torch.utils.data import Dataset
from PIL import Image
import numpy as np

class FoodDataset(Dataset):
    def __init__(self, file_path, transform=None):
        self.transform = transform
        self.imgs = []
        self.labels = []
        # 从txt文件读取图像路径和标签(格式:图像路径 标签)
        with open(file_path, 'r') as f:
            for line in f.readlines():
                img_path, label = line.strip().split(' ')
                self.imgs.append(img_path)
                self.labels.append(int(label))
    
    def __len__(self):
        return len(self.imgs)
    
    def __getitem__(self, idx):
        # 加载图像并应用增强
        image = Image.open(self.imgs[idx]).convert('RGB')
        if self.transform:
            image = self.transform(image)
        # 标签转换为Tensor
        label = torch.tensor(self.labels[idx], dtype=torch.long)
        return image, label

# 加载训练集和验证集
train_dataset = FoodDataset('./train.1txt', transform=data_transforms['train'])
valid_dataset = FoodDataset('./test.1txt', transform=data_transforms['valid'])

# 数据加载器(批量处理)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=64, shuffle=True)
valid_loader = torch.utils.data.DataLoader(valid_dataset, batch_size=64, shuffle=False)

其中train.1txt文件内容为:

test.1txt文件内容:

  其中的每个文件地址都有其对应的图片,数据量较大,训练时间会较长,如需使用,可私信发送打包文件。

        整篇文章所有代码连接为一份完整代码。

二、保存最优模型
1. 核心概念

训练过程中,模型性能(如验证集准确率)会随迭代波动。保存最优模型指在训练中跟踪关键指标(如最高准确率),并保存对应状态,以便后续直接使用最佳模型。

2. 实现步骤
(1)定义 CNN 模型
import torch.nn as nn

class SimpleCNN(nn.Module):
    def __init__(self):
        super(SimpleCNN, self).__init__()
        # 卷积层1:3通道输入→16通道输出,5x5卷积核
        self.conv1 = nn.Sequential(
            nn.Conv2d(3, 16, kernel_size=5, stride=1, padding=2),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2)  # 池化后尺寸减半
        )
        # 卷积层2:16通道→32通道
        self.conv2 = nn.Sequential(
            nn.Conv2d(16, 32, kernel_size=5, stride=1, padding=2),
            nn.ReLU(),
            nn.MaxPool2d(2)
        )
        # 卷积层3:32通道→128通道(无池化)
        self.conv3 = nn.Sequential(
            nn.Conv2d(32, 128, kernel_size=5, stride=1, padding=2),
            nn.ReLU()
        )
        # 全连接层:输入为128×64×64(经3次卷积+池化后的尺寸),输出20类
        self.fc = nn.Linear(128 * 64 * 64, 20)
    
    def forward(self, x):
        x = self.conv1(x)  # 输出:16×128×128
        x = self.conv2(x)  # 输出:32×64×64
        x = self.conv3(x)  # 输出:128×64×64
        x = x.view(x.size(0), -1)  # 展平为向量
        x = self.fc(x)
        return x

# 初始化模型并移动到设备(GPU/CPU)
device = "cuda" if torch.cuda.is_available() else "cpu"
model = SimpleCNN().to(device)

运行结果:

(2)定义训练与测试函数
# 损失函数与优化器
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

# 训练函数
def train(dataloader, model, loss_fn, optimizer):
    model.train()  # 开启训练模式(启用 dropout/batchnorm)
    for batch, (X, y) in enumerate(dataloader):
        X, y = X.to(device), y.to(device)
        pred = model(X)
        loss = loss_fn(pred, y)
        
        # 反向传播
        optimizer.zero_grad()  # 清空梯度
        loss.backward()        # 计算梯度
        optimizer.step()       # 更新参数
        
        if batch % 100 == 0:
            print(f"Batch {batch}, Loss: {loss.item():.4f}")

# 测试函数(含最优模型保存)
best_acc = 0.0  # 记录最佳准确率

def test(dataloader, model, loss_fn):
    global best_acc
    model.eval()  # 开启评估模式(固定 dropout/batchnorm)
    size = len(dataloader.dataset)
    num_batches = len(dataloader)
    test_loss, correct = 0, 0
    
    with torch.no_grad():  # 关闭梯度计算,节省内存
        for X, y in dataloader:
            X, y = X.to(device), y.to(device)
            pred = model(X)
            test_loss += loss_fn(pred, y).item()
            correct += (pred.argmax(1) == y).type(torch.float).sum().item()
    
    test_loss /= num_batches
    correct /= size
    print(f"Test: Accuracy: {(100*correct):.1f}%, Avg loss: {test_loss:.4f}")
    
    # 保存最优模型(准确率提升时)
    if correct > best_acc:
        best_acc = correct
        # 保存完整模型(含结构和参数)
        torch.save(model, "best_model.pt")
        # 或仅保存参数(更轻量):torch.save(model.state_dict(), "best_model.pth")
(3)启动训练
epochs = 150  # 训练轮数
for t in range(epochs):
    print(f"\nEpoch {t+1}/{epochs}")
    train(train_loader, model, loss_fn, optimizer)
    test(valid_loader, model, loss_fn)
print("训练完成!最优模型已保存为 best_model.pt")
3. 模型加载与使用

训练结束后,可直接加载最优模型进行预测:

# 加载保存的模型
loaded_model = torch.load("best_model.pt").to(device)
loaded_model.eval()  # 切换至评估模式

# 示例:对单张图像预测
def predict(image_path):
    image = Image.open(image_path).convert('RGB')
    # 应用验证集的预处理
    transform = data_transforms['valid']
    image = transform(image).unsqueeze(0).to(device)  # 增加批次维度
    with torch.no_grad():
        pred = loaded_model(image)
        return pred.argmax(1).item()  # 返回预测类别

# 测试预测
print("预测类别:", predict("test_image.jpg"))

 训练结束得到当前训练的最优模型,其为pt\pth\t7文件,此时该文件即为当前模型,可直接调用该文件使用。

三、总结
  • 数据增强通过模拟真实场景变化,提升模型泛化能力,需注意训练集用随机增强、验证集仅做标准化。
  • 保存最优模型通过跟踪验证集指标(如准确率),保留性能最佳的模型状态,避免训练后期过拟合导致的性能下降。

以上方法可直接应用于图像分类、目标检测等 CNN 任务,实际使用时需根据数据集特点调整增强策略和模型结构。

Logo

有“AI”的1024 = 2048,欢迎大家加入2048 AI社区

更多推荐