MNIST 手写数字分类
转自我的个人博客:https://shar-pen.github.io/2025/05/04/torch-distributed-series/1.MNIST/
转自我的个人博客: https://shar-pen.github.io/2025/05/04/torch-distributed-series/1.MNIST/
基础的单卡训练
本笔记本演示了训练一个卷积神经网络(CNN)来对 MNIST 数据集中的手写数字进行分类的过程。工作流程包括:
- 数据准备:加载和预处理 MNIST 数据集。
- 模型定义:使用 PyTorch 构建 CNN 模型。
- 模型训练:在 MNIST 训练数据集上训练模型。
- 模型评估:在 MNIST 测试数据集上测试模型并评估其性能。
- 可视化:展示样本图像及其对应的标签。
参考 pytorch 官方示例 https://github.com/pytorch/examples/blob/main/mnist/main.py 。
至于为什么选择 MNIST 分类任务, 因为它就是深度学习里的 Hello World.
import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt
import torch.nn.functional as F
from torchvision import datasets, transforms
from time import time
在深度学习里,真正必要的超参数,大致是下面这些:
-
学习率(learning rate)
- 最最核心的超参数。
- 决定每次参数更新的步幅大小。
- 学习率不合适,训练几乎一定失败。
-
优化器(optimizer)
- 比如
SGD、Adam、AdamW等。 - 不同优化器,收敛速度、最终效果差异很大。
- 有时也需要设置优化器内部超参(比如 Adam 的 β1,β2\beta_1, \beta_2β1,β2)。
- 比如
-
批大小(batch size)
- 多少样本合成一批送进模型训练。
- 影响训练稳定性、收敛速度、硬件占用。
-
训练轮次(epoch) 或 最大步数(max steps)
- 总共训练多久。
- 如果训练不够长,模型欠拟合;太久则过拟合或资源浪费。
-
损失函数(loss function)
- 明确训练目标,比如分类用
CrossEntropyLoss,回归用MSELoss。 - 不同任务必须选对损失。
- 明确训练目标,比如分类用
超参设置
我们设置些最基础的超参: epoch, batch size, device, lr
EPOCHS = 5
BATCH_SIZE = 512
LR = 0.001
LR_DECAY_STEP_NUM = 1
LR_DECAY_FACTOR = 0.5
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
数据构建
直接用库函数生成 dataset 和 dataloader, 前者其实只是拿来生成 dataloader
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])
train_data = datasets.MNIST(
root = './mnist',
train=True, # 设置True为训练数据,False为测试数据
transform = transform,
# download=True # 设置True后就自动下载,下载完成后改为False即可
)
train_loader = torch.utils.data.DataLoader(dataset=train_data, batch_size=BATCH_SIZE, shuffle=True)
test_data = datasets.MNIST(
root = './mnist',
train=False, # 设置True为训练数据,False为测试数据
transform = transform,
)
test_loader = torch.utils.data.DataLoader(dataset=test_data, batch_size=BATCH_SIZE, shuffle=True)
# plot one example
print(f'dataset: input shape: {train_data.data.size()}, label shape: {train_data.targets.size()}')
print(f'dataloader iter: input shape: {next(iter(train_loader))[0].size()}, label shape: {next(iter(train_loader))[1].size()}')
plt.imshow(train_data.data[0].numpy(), cmap='gray')
plt.title(f'Label: {train_data.targets[0]}')
plt.show()
dataset: input shape: torch.Size([60000, 28, 28]), label shape: torch.Size([60000])
dataloader iter: input shape: torch.Size([512, 1, 28, 28]), label shape: torch.Size([512])

网络
设计简单的 ConvNet, 几层 CNN + MLP。初始化新模型后,先将其放到 DEVICE 上
class ConvNet(nn.Module):
"""
A neural network model for MNIST digit classification.
This model is designed to classify images from the MNIST dataset, which
consists of grayscale images of handwritten digits (0-9). The network
architecture includes convolutional layers for feature extraction,
followed by fully connected layers for classification.
Attributes:
features (nn.Sequential): A sequential container of convolutional
layers, activation functions, pooling, and dropout for feature
extraction.
classifier (nn.Sequential): A sequential container of fully connected
layers, activation functions, and dropout for classification.
Methods:
forward(x):
Defines the forward pass of the network. Takes an input tensor
`x`, processes it through the feature extractor and classifier,
and returns the log-softmax probabilities for each class.
"""
def __init__(self):
super(ConvNet, self).__init__()
self.features = nn.Sequential(
nn.Conv2d(1, 32, 3, 1),
nn.ReLU(),
nn.Conv2d(32, 64, 3, 1),
nn.ReLU(),
nn.MaxPool2d(2),
nn.Dropout(0.25)
)
self.classifier = nn.Sequential(
nn.Linear(9216, 128),
nn.ReLU(),
nn.Dropout(0.5),
nn.Linear(128, 10)
)
def forward(self, x):
x = self.features(x)
x = torch.flatten(x, 1)
x = self.classifier(x)
output = F.log_softmax(x, dim=1)
return output
训练和评估函数
将训练和评估函数分别封装为函数,使主循环更简洁
def train(model, device, train_loader, optimizer):
model.train()
for batch_idx, (data, target) in enumerate(train_loader):
data, target = data.to(device), target.to(device)
optimizer.zero_grad()
output = model(data)
loss = F.nll_loss(output, target)
loss.backward()
optimizer.step()
if (batch_idx + 1) % 30 == 0:
print('Train: [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
batch_idx * len(data), len(train_loader.dataset),
100. * batch_idx / len(train_loader), loss.item()))
def test(model, device, test_loader):
model.eval()
test_loss = 0
correct = 0
with torch.no_grad():
for data, target in test_loader:
data, target = data.to(device), target.to(device)
output = model(data)
test_loss += F.nll_loss(output, target, reduction='sum').item() # 将一批的损失相加
pred = output.max(1, keepdim=True)[1] # 找到概率最大的下标
correct += pred.eq(target.view_as(pred)).sum().item()
test_loss /= len(test_loader.dataset)
print('Test set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)'.format(
test_loss, correct, len(test_loader.dataset),
100. * correct / len(test_loader.dataset)))
主训练循环
model = ConvNet().to(DEVICE)
optimizer = optim.Adam(model.parameters(), lr=LR)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=LR_DECAY_STEP_NUM, gamma=LR_DECAY_FACTOR)
start_time = time() # Record the start time
for epoch in range(EPOCHS):
epoch_start_time = time() # Record the start time of the current epoch
print(f'Epoch {epoch}/{EPOCHS}')
print(f'Learning Rate: {scheduler.get_last_lr()[0]}')
train(model, DEVICE, train_loader, optimizer)
test(model, DEVICE, test_loader)
scheduler.step()
epoch_end_time = time() # Record the end time of the current epoch
print(f"Time for epoch {epoch}: {epoch_end_time - epoch_start_time:.2f} seconds")
end_time = time() # Record the end time
print(f"Total training time: {end_time - start_time:.2f} seconds")
+-----------------------------------------------------------------------------------------+
| Processes: |
| GPU GI CI PID Type Process name GPU Memory |
| ID ID Usage |
|=========================================================================================|
| 0 N/A N/A 1795609 C ...st/anaconda3/envs/xprepo/bin/python 448MiB |
| 0 N/A N/A 1814253 C ...st/anaconda3/envs/xprepo/bin/python 1036MiB |
| 7 N/A N/A 4167010 C ...guest/anaconda3/envs/QDM/bin/python 19416MiB |
+-----------------------------------------------------------------------------------------+
0 卡的占用 1484 MB
完整代码
import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt
import torch.nn.functional as F
from torchvision import datasets, transforms
from time import time
import argparse
class ConvNet(nn.Module):
def __init__(self):
super(ConvNet, self).__init__()
self.features = nn.Sequential(
nn.Conv2d(1, 32, 3, 1),
nn.ReLU(),
nn.Conv2d(32, 64, 3, 1),
nn.ReLU(),
nn.MaxPool2d(2),
nn.Dropout(0.25)
)
self.classifier = nn.Sequential(
nn.Linear(9216, 128),
nn.ReLU(),
nn.Dropout(0.5),
nn.Linear(128, 10)
)
def forward(self, x):
x = self.features(x)
x = torch.flatten(x, 1)
x = self.classifier(x)
output = F.log_softmax(x, dim=1)
return output
def arg_parser():
parser = argparse.ArgumentParser(description="MNIST Training Script")
parser.add_argument("--epochs", type=int, default=5, help="Number of training epochs")
parser.add_argument("--batch_size", type=int, default=512, help="Batch size for training")
parser.add_argument("--lr", type=float, default=0.0005, help="Learning rate")
parser.add_argument("--lr_decay_step_num", type=int, default=1, help="Step size for learning rate decay")
parser.add_argument("--lr_decay_factor", type=float, default=0.5, help="Factor by which learning rate is decayed")
parser.add_argument("--cuda_id", type=int, default=0, help="CUDA device ID to use")
return parser.parse_args()
def prepare_data(batch_size):
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])
train_data = datasets.MNIST(
root = './mnist',
train=True, # 设置True为训练数据,False为测试数据
transform = transform,
# download=True # 设置True后就自动下载,下载完成后改为False即可
)
train_loader = torch.utils.data.DataLoader(dataset=train_data, batch_size=batch_size, shuffle=True)
test_data = datasets.MNIST(
root = './mnist',
train=False, # 设置True为训练数据,False为测试数据
transform = transform,
)
test_loader = torch.utils.data.DataLoader(dataset=test_data, batch_size=batch_size, shuffle=True)
return train_loader, test_loader
def train(model, device, train_loader, optimizer):
model.train()
for batch_idx, (data, target) in enumerate(train_loader):
data, target = data.to(device), target.to(device)
optimizer.zero_grad()
output = model(data)
loss = F.nll_loss(output, target)
loss.backward()
optimizer.step()
if (batch_idx + 1) % 30 == 0:
print('Train: [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
batch_idx * len(data), len(train_loader.dataset),
100. * batch_idx / len(train_loader), loss.item()))
def test(model, device, test_loader):
model.eval()
test_loss = 0
correct = 0
with torch.no_grad():
for data, target in test_loader:
data, target = data.to(device), target.to(device)
output = model(data)
test_loss += F.nll_loss(output, target, reduction='sum').item() # 将一批的损失相加
pred = output.max(1, keepdim=True)[1] # 找到概率最大的下标
correct += pred.eq(target.view_as(pred)).sum().item()
test_loss /= len(test_loader.dataset)
print('Test set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)'.format(
test_loss, correct, len(test_loader.dataset),
100. * correct / len(test_loader.dataset)))
def train_mnist_classification():
args = arg_parser()
print(args)
EPOCHS = args.epochs
BATCH_SIZE = args.batch_size
LR = args.lr
LR_DECAY_STEP_NUM = args.lr_decay_step_num
LR_DECAY_FACTOR = args.lr_decay_factor
CUDA_ID = args.cuda_id
DEVICE = torch.device(f"cuda:{CUDA_ID}")
train_loader, test_loader = prepare_data(BATCH_SIZE)
model = ConvNet().to(DEVICE)
optimizer = optim.Adam(model.parameters(), lr=LR)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=LR_DECAY_STEP_NUM, gamma=LR_DECAY_FACTOR)
start_time = time() # Record the start time
for epoch in range(EPOCHS):
epoch_start_time = time() # Record the start time of the current epoch
print(f'Epoch {epoch}/{EPOCHS}')
print(f'Learning Rate: {scheduler.get_last_lr()[0]}')
train(model, DEVICE, train_loader, optimizer)
test(model, DEVICE, test_loader)
scheduler.step()
epoch_end_time = time() # Record the end time of the current epoch
print(f"Time for epoch {epoch}: {epoch_end_time - epoch_start_time:.2f} seconds")
end_time = time() # Record the end time
print(f"Total training time: {end_time - start_time:.2f} seconds")
if __name__ == "__main__":
train_mnist_classification()
更多推荐

所有评论(0)