《深度学习》卷积神经网络:数据增强与保存最优模型解析及实现
通过继承Dataset# 从txt文件读取图像路径和标签(格式:图像路径 标签)# 加载图像并应用增强# 标签转换为Tensor# 加载训练集和验证集# 数据加载器(批量处理)其中train.1txttest.1txt文件内容:其中的每个文件地址都有其对应的图片,数据量较大,训练时间会较长,如需使用,可私信发送打包文件。整篇文章所有代码连接为一份完整代码。# 卷积层1:3通道输入→16通道输出,5
·
目录
在卷积神经网络(CNN)的训练过程中,数据增强和模型保存是提升性能与实用性的关键环节。以下结合理论与实例,详细解析其原理及实现方式。
一、数据增强
1. 核心概念
数据增强是通过对原始训练数据进行随机变换(如旋转、翻转、调整亮度等),生成新的训练样本的技术。其本质是扩展数据多样性,让模型在训练中接触更多 “变体”,从而提升泛化能力(减少过拟合)。
2. 核心目的
- 模拟真实场景中的变量(如光照变化、视角差异、遮挡等)。
- 解决训练数据不足的问题,通过 “人工扩充” 提升模型鲁棒性。
3. 常用方法
4. 实现示例(基于 PyTorch)
import torch
from torchvision import transforms
# 定义训练集和验证集的数据增强策略
data_transforms = {
'train': transforms.Compose([
transforms.Resize([300, 300]), # 缩放图像
transforms.RandomRotation(45), # 随机旋转(-45°~45°)
transforms.CenterCrop(256), # 中心裁剪至256x256
transforms.RandomHorizontalFlip(p=0.5), # 50%概率水平翻转
transforms.ColorJitter(brightness=0.2, contrast=0.1), # 颜色调整
transforms.ToTensor(), # 转换为Tensor(像素值归一化到[0,1])
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) # 标准化
]),
'valid': transforms.Compose([
transforms.Resize([256, 256]), # 验证集仅缩放,不做随机增强
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
}
5. 自定义数据集加载
通过继承Dataset
类,将增强策略应用于实际数据:
from torch.utils.data import Dataset
from PIL import Image
import numpy as np
class FoodDataset(Dataset):
def __init__(self, file_path, transform=None):
self.transform = transform
self.imgs = []
self.labels = []
# 从txt文件读取图像路径和标签(格式:图像路径 标签)
with open(file_path, 'r') as f:
for line in f.readlines():
img_path, label = line.strip().split(' ')
self.imgs.append(img_path)
self.labels.append(int(label))
def __len__(self):
return len(self.imgs)
def __getitem__(self, idx):
# 加载图像并应用增强
image = Image.open(self.imgs[idx]).convert('RGB')
if self.transform:
image = self.transform(image)
# 标签转换为Tensor
label = torch.tensor(self.labels[idx], dtype=torch.long)
return image, label
# 加载训练集和验证集
train_dataset = FoodDataset('./train.1txt', transform=data_transforms['train'])
valid_dataset = FoodDataset('./test.1txt', transform=data_transforms['valid'])
# 数据加载器(批量处理)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=64, shuffle=True)
valid_loader = torch.utils.data.DataLoader(valid_dataset, batch_size=64, shuffle=False)
其中train.1txt文件内容为:
test.1txt文件内容:
其中的每个文件地址都有其对应的图片,数据量较大,训练时间会较长,如需使用,可私信发送打包文件。
整篇文章所有代码连接为一份完整代码。
二、保存最优模型
1. 核心概念
训练过程中,模型性能(如验证集准确率)会随迭代波动。保存最优模型指在训练中跟踪关键指标(如最高准确率),并保存对应状态,以便后续直接使用最佳模型。
2. 实现步骤
(1)定义 CNN 模型
import torch.nn as nn
class SimpleCNN(nn.Module):
def __init__(self):
super(SimpleCNN, self).__init__()
# 卷积层1:3通道输入→16通道输出,5x5卷积核
self.conv1 = nn.Sequential(
nn.Conv2d(3, 16, kernel_size=5, stride=1, padding=2),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2) # 池化后尺寸减半
)
# 卷积层2:16通道→32通道
self.conv2 = nn.Sequential(
nn.Conv2d(16, 32, kernel_size=5, stride=1, padding=2),
nn.ReLU(),
nn.MaxPool2d(2)
)
# 卷积层3:32通道→128通道(无池化)
self.conv3 = nn.Sequential(
nn.Conv2d(32, 128, kernel_size=5, stride=1, padding=2),
nn.ReLU()
)
# 全连接层:输入为128×64×64(经3次卷积+池化后的尺寸),输出20类
self.fc = nn.Linear(128 * 64 * 64, 20)
def forward(self, x):
x = self.conv1(x) # 输出:16×128×128
x = self.conv2(x) # 输出:32×64×64
x = self.conv3(x) # 输出:128×64×64
x = x.view(x.size(0), -1) # 展平为向量
x = self.fc(x)
return x
# 初始化模型并移动到设备(GPU/CPU)
device = "cuda" if torch.cuda.is_available() else "cpu"
model = SimpleCNN().to(device)
运行结果:
(2)定义训练与测试函数
# 损失函数与优化器
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
# 训练函数
def train(dataloader, model, loss_fn, optimizer):
model.train() # 开启训练模式(启用 dropout/batchnorm)
for batch, (X, y) in enumerate(dataloader):
X, y = X.to(device), y.to(device)
pred = model(X)
loss = loss_fn(pred, y)
# 反向传播
optimizer.zero_grad() # 清空梯度
loss.backward() # 计算梯度
optimizer.step() # 更新参数
if batch % 100 == 0:
print(f"Batch {batch}, Loss: {loss.item():.4f}")
# 测试函数(含最优模型保存)
best_acc = 0.0 # 记录最佳准确率
def test(dataloader, model, loss_fn):
global best_acc
model.eval() # 开启评估模式(固定 dropout/batchnorm)
size = len(dataloader.dataset)
num_batches = len(dataloader)
test_loss, correct = 0, 0
with torch.no_grad(): # 关闭梯度计算,节省内存
for X, y in dataloader:
X, y = X.to(device), y.to(device)
pred = model(X)
test_loss += loss_fn(pred, y).item()
correct += (pred.argmax(1) == y).type(torch.float).sum().item()
test_loss /= num_batches
correct /= size
print(f"Test: Accuracy: {(100*correct):.1f}%, Avg loss: {test_loss:.4f}")
# 保存最优模型(准确率提升时)
if correct > best_acc:
best_acc = correct
# 保存完整模型(含结构和参数)
torch.save(model, "best_model.pt")
# 或仅保存参数(更轻量):torch.save(model.state_dict(), "best_model.pth")
(3)启动训练
epochs = 150 # 训练轮数
for t in range(epochs):
print(f"\nEpoch {t+1}/{epochs}")
train(train_loader, model, loss_fn, optimizer)
test(valid_loader, model, loss_fn)
print("训练完成!最优模型已保存为 best_model.pt")
3. 模型加载与使用
训练结束后,可直接加载最优模型进行预测:
# 加载保存的模型
loaded_model = torch.load("best_model.pt").to(device)
loaded_model.eval() # 切换至评估模式
# 示例:对单张图像预测
def predict(image_path):
image = Image.open(image_path).convert('RGB')
# 应用验证集的预处理
transform = data_transforms['valid']
image = transform(image).unsqueeze(0).to(device) # 增加批次维度
with torch.no_grad():
pred = loaded_model(image)
return pred.argmax(1).item() # 返回预测类别
# 测试预测
print("预测类别:", predict("test_image.jpg"))
训练结束得到当前训练的最优模型,其为pt\pth\t7文件,此时该文件即为当前模型,可直接调用该文件使用。
三、总结
- 数据增强通过模拟真实场景变化,提升模型泛化能力,需注意训练集用随机增强、验证集仅做标准化。
- 保存最优模型通过跟踪验证集指标(如准确率),保留性能最佳的模型状态,避免训练后期过拟合导致的性能下降。
以上方法可直接应用于图像分类、目标检测等 CNN 任务,实际使用时需根据数据集特点调整增强策略和模型结构。
更多推荐
所有评论(0)