【CNN算法理解】:二、AlexNet深度学习的AlexNetTrainer 训练器(附代码)
本文介绍了AlexNet模型的训练配置与实现细节。主要内容包括:1) 初始化配置支持GPU优先的设备选择策略;2) 训练设置采用SGD优化器,学习率0.001,配合多阶段学习率调度;3) 训练流程包含梯度裁剪和早停机制;4) 评估功能计算总体及类别准确率;5) 提供训练曲线可视化功能。实验部分展示了在CIFAR-10数据集上的训练示例,简化版AlexNet可获得约80-83%的测试准确率。最后给出
·
初始化配置
trainer = AlexNetTrainer(model, device='cuda')
设备选择策略
- 优先使用GPU(
cuda) - GPU不可用时自动回退到CPU
- 支持多GPU训练(需手动扩展)
训练设置 (setup_training 方法)
优化器配置(按原始论文)
# 随机梯度下降(SGD)配置
learning_rate = 0.001 # 初始学习率
momentum = 0.9 # 动量因子
weight_decay = 0.0005 # L2正则化系数
数值示例:
参数更新公式:
v_t = momentum * v_{t-1} + learning_rate * gradient
θ_t = θ_{t-1} - v_t - weight_decay * learning_rate * θ_{t-1}
示例计算:
梯度 = 0.1
上一步动量v_{t-1} = 0.05
当前动量v_t = 0.9*0.05 + 0.001*0.1 = 0.0451
参数更新量 = 0.0451 + 0.0005*0.001*θ
学习率调度器选项
| 调度器类型 | 配置 | 适用场景 |
|---|---|---|
| StepLR | step_size=30, gamma=0.1 | 固定间隔降低学习率 |
| MultiStepLR | milestones=[30,60,80,90], gamma=0.1 | 多阶段降低(接近原始论文) |
| ReduceLROnPlateau | patience=10, factor=0.1 | 基于验证损失动态调整 |
学习率变化示例(MultiStepLR):
Epoch 1-29: 0.001
Epoch 30-59: 0.0001
Epoch 60-79: 0.00001
Epoch 80-89: 0.000001
Epoch 90+: 0.0000001
训练流程
单epoch训练 (train_epoch 方法)
数值示例:
批次大小: 128
训练集大小: 45,000(CIFAR-10训练集的90%)
批次数: 45,000/128 ≈ 352
每100批次打印一次进度:
Epoch: 1 [100/352] | Loss: 2.3145 | Acc: 12.34% | LR: 0.001000
Epoch: 1 [200/352] | Loss: 1.8923 | Acc: 28.56% | LR: 0.001000
Epoch: 1 [300/352] | Loss: 1.6543 | Acc: 35.78% | LR: 0.001000
梯度裁剪
torch.nn.utils.clip_grad_norm_(parameters, max_norm=1.0)
- 防止梯度爆炸
- 限制梯度范数不超过1.0
- 保持训练稳定性
训练循环 (train 方法)
完整训练示例
# 训练90个epoch,使用早停机制
history = trainer.train(
train_loader=train_loader,
val_loader=val_loader,
epochs=90,
early_stopping_patience=20,
save_path='alexnet_cifar10_best.pth'
)
训练过程输出示例
Epoch 1/90 | 时间: 45.23s
训练损失: 1.8923, 训练准确率: 32.45%
验证损失: 1.6543, 验证准确率: 40.12%
学习率: 0.001000
耐心计数器: 0/20
✅ 保存最佳模型 (准确率: 40.12%)
早停机制
- 当验证准确率连续20个epoch没有提升时停止训练
- 避免过拟合和计算资源浪费
- 自动保存最佳模型
评估功能 (evaluate 方法)
性能指标计算
# 总体指标
测试损失 = Σ(批次损失) / 批次数量
测试准确率 = 正确预测数 / 总样本数 × 100%
# 每个类别指标
类别准确率 = 类别正确预测数 / 类别总样本数 × 100%
评估输出示例
============================================================
测试集评估
============================================================
测试损失: 0.8923
测试准确率: 85.34%
正确分类: 8534/10000
每个类别的准确率:
----------------------------------------
plane : 88.20% (882/1000)
car : 92.10% (921/1000)
bird : 76.30% (763/1000)
cat : 65.40% (654/1000)
deer : 82.50% (825/1000)
dog : 78.90% (789/1000)
frog : 91.20% (912/1000)
horse : 87.60% (876/1000)
ship : 93.80% (938/1000)
truck : 90.50% (905/1000)
可视化功能
训练历史图表 (plot_training_history)
- 损失曲线(训练vs验证)
- 准确率曲线(训练vs验证)
- 学习率变化曲线
- 训练vs验证准确率散点图
典型训练曲线示例:
Epoch 1-30: 训练损失快速下降,验证准确率稳步提升
Epoch 31-60: 学习率降低后,损失缓慢下降,准确率继续提升
Epoch 61-90: 趋于收敛,训练和验证曲线基本稳定
预测可视化 (visualize_predictions)
- 显示12个测试样本的预测结果
- 正确预测标记为绿色,错误预测标记为红色
- 直观展示模型性能
权值初始化策略
原始AlexNet初始化
# 卷积层:正态分布 N(0, 0.01)
nn.init.normal_(weight, mean=0, std=0.01)
# 全连接层:正态分布 N(0, 0.01)
nn.init.normal_(weight, mean=0, std=0.01)
# 偏置:初始化为0(部分层为1)
nn.init.constant_(bias, 0) # 大多数层
nn.init.constant_(bias, 1) # 第2、4、5卷积层
简化版AlexNet初始化
# 使用Kaiming初始化(针对ReLU)
nn.init.kaiming_normal_(weight, mode='fan_out', nonlinearity='relu')
实际训练示例
CIFAR-10训练配置
# 1. 加载数据集
data_handler = AlexNetDataHandler(batch_size=128)
train_loader, val_loader, test_loader, classes = data_handler.load_cifar10()
# 2. 创建模型(简化版AlexNet)
model = SimplifiedAlexNet(num_classes=10, use_batchnorm=False)
# 3. 创建训练器
trainer = AlexNetTrainer(model)
# 4. 设置训练配置
trainer.setup_training(
learning_rate=0.001,
weight_decay=0.0005,
momentum=0.9,
lr_scheduler='multi_step'
)
# 5. 开始训练
history = trainer.train(
train_loader=train_loader,
val_loader=val_loader,
epochs=90,
early_stopping_patience=20,
save_path='alexnet_cifar10.pth'
)
# 6. 评估模型
test_loss, test_acc, preds, targets = trainer.evaluate(test_loader, classes)
# 7. 可视化结果
trainer.plot_training_history()
trainer.visualize_predictions(test_loader, classes)
预期性能指标
| 数据集 | 模型 | Epochs | 验证准确率 | 测试准确率 | 训练时间 |
|---|---|---|---|---|---|
| CIFAR-10 | 原始AlexNet | 90 | ~75-80% | ~70-75% | ~2-3小时 |
| CIFAR-10 | 简化AlexNet | 90 | ~82-85% | ~80-83% | ~1-2小时 |
| ImageNet | 原始AlexNet | 90 | ~56-57% | ~56.5% | 5-7天(8×GPU) |
调优建议
学习率策略
# 针对小数据集可调整的调度
milestones = [50, 75, 90] # 更晚降低学习率
gamma = 0.5 # 更温和的下降
正则化增强
# 增加Dropout率防止过拟合
dropout_rate = 0.5 # 原始论文使用0.5
# 或增加权重衰减
weight_decay = 0.001 # 适度增加
数据增强
# 增强数据多样性
transforms.RandomRotation(15) # 随机旋转
transforms.RandomPerspective(0.3) # 随机透视变换
注意事项
内存使用
- 原始AlexNet:约240MB显存(批次大小=128)
- 简化AlexNet:约60MB显存(批次大小=128)
- 建议根据GPU内存调整批次大小
训练时间
- CIFAR-10:1-3小时(单GPU)
- ImageNet:5-7天(多GPU)
- 可使用混合精度训练加速
常见问题
梯度消失/爆炸
- 使用梯度裁剪(
max_norm=1.0) - 使用合适的初始化策略
- 添加Batch Normalization(简化版可选)
过拟合
- 增加Dropout率
- 增加数据增强
- 使用更早的早停
学习率选择
- 太大:损失震荡,不收敛
- 太小:收敛缓慢
- 建议:从0.001开始,使用学习率调度器
扩展功能
特征图可视化
# 获取中间层特征图
feature_maps = model.get_feature_maps(input_image, layer_idx=[0, 2, 4])
# 可用于分析网络学习到的特征
模型保存与加载
# 保存完整训练状态
torch.save({
'epoch': epoch,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'val_acc': val_acc,
}, 'checkpoint.pth')
# 加载继续训练
checkpoint = torch.load('checkpoint.pth')
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
start_epoch = checkpoint['epoch'] + 1
多GPU训练支持
# 使用DataParallel
model = nn.DataParallel(model)
# 自动处理数据分布和梯度聚合
附代码
AlexNetTrainer 代码
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt
import time
class AlexNetTrainer:
"""AlexNet训练器"""
def __init__(self, model, device='cuda' if torch.cuda.is_available() else 'cpu'):
self.model = model
self.device = device
self.model.to(self.device)
# 训练历史记录
self.history = {
'train_loss': [],
'train_acc': [],
'val_loss': [],
'val_acc': [],
'learning_rates': []
}
# 优化器和学习率调度器
self.optimizer = None
self.scheduler = None
self.criterion = None
def setup_training(self, learning_rate=0.001, weight_decay=0.0005,
momentum=0.9, lr_scheduler='step'):
"""设置训练配置"""
# 定义损失函数(交叉熵损失)
self.criterion = nn.CrossEntropyLoss()
# 定义优化器(SGD with momentum,按原始论文)
self.optimizer = optim.SGD(
self.model.parameters(),
lr=learning_rate,
momentum=momentum,
weight_decay=weight_decay
)
# 学习率调度器
if lr_scheduler == 'step':
# 按原始论文:在第30、60、80、90轮次降低学习率
self.scheduler = optim.lr_scheduler.StepLR(
self.optimizer, step_size=30, gamma=0.1
)
elif lr_scheduler == 'multi_step':
# 多步长调度(更接近原始论文)
milestones = [30, 60, 80, 90]
self.scheduler = optim.lr_scheduler.MultiStepLR(
self.optimizer, milestones=milestones, gamma=0.1
)
elif lr_scheduler == 'reduce_on_plateau':
# 基于验证集性能降低学习率
self.scheduler = optim.lr_scheduler.ReduceLROnPlateau(
self.optimizer, mode='min', factor=0.1, patience=10
)
else:
self.scheduler = None
print(f"优化器: SGD (lr={learning_rate}, momentum={momentum}, weight_decay={weight_decay})")
print(f"学习率调度器: {lr_scheduler}")
def train_epoch(self, train_loader, epoch, print_freq=100):
"""训练一个epoch"""
self.model.train()
running_loss = 0.0
correct = 0
total = 0
for batch_idx, (inputs, targets) in enumerate(train_loader):
# 移动数据到设备
inputs, targets = inputs.to(self.device), targets.to(self.device)
# 清零梯度
self.optimizer.zero_grad()
# 前向传播
outputs = self.model(inputs)
loss = self.criterion(outputs, targets)
# 反向传播
loss.backward()
# 梯度裁剪(防止梯度爆炸)
torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
# 更新参数
self.optimizer.step()
# 统计
running_loss += loss.item()
_, predicted = outputs.max(1)
total += targets.size(0)
correct += predicted.eq(targets).sum().item()
# 打印进度
if (batch_idx + 1) % print_freq == 0:
batch_loss = running_loss / (batch_idx + 1)
batch_acc = 100. * correct / total
current_lr = self.optimizer.param_groups[0]['lr']
print(f'Epoch: {epoch} [{batch_idx + 1}/{len(train_loader)}] | '
f'Loss: {batch_loss:.4f} | Acc: {batch_acc:.2f}% | '
f'LR: {current_lr:.6f}')
# 计算epoch统计
epoch_loss = running_loss / len(train_loader)
epoch_acc = 100. * correct / total
return epoch_loss, epoch_acc
def validate(self, val_loader):
"""验证模型"""
self.model.eval()
running_loss = 0.0
correct = 0
total = 0
with torch.no_grad():
for inputs, targets in val_loader:
inputs, targets = inputs.to(self.device), targets.to(self.device)
outputs = self.model(inputs)
loss = self.criterion(outputs, targets)
running_loss += loss.item()
_, predicted = outputs.max(1)
total += targets.size(0)
correct += predicted.eq(targets).sum().item()
val_loss = running_loss / len(val_loader)
val_acc = 100. * correct / total
return val_loss, val_acc
def train(self, train_loader, val_loader, epochs=90,
early_stopping_patience=20, save_path='best_model.pth'):
"""完整的训练循环"""
print(f"开始训练,共 {epochs} 个epoch")
print(f"设备: {self.device}")
print("-" * 60)
best_val_acc = 0.0
patience_counter = 0
for epoch in range(1, epochs + 1):
start_time = time.time()
# 训练一个epoch
train_loss, train_acc = self.train_epoch(train_loader, epoch)
# 验证
val_loss, val_acc = self.validate(val_loader)
# 更新学习率
if self.scheduler is not None:
if isinstance(self.scheduler, optim.lr_scheduler.ReduceLROnPlateau):
self.scheduler.step(val_loss)
else:
self.scheduler.step()
# 记录历史
self.history['train_loss'].append(train_loss)
self.history['train_acc'].append(train_acc)
self.history['val_loss'].append(val_loss)
self.history['val_acc'].append(val_acc)
self.history['learning_rates'].append(self.optimizer.param_groups[0]['lr'])
# 打印epoch结果
epoch_time = time.time() - start_time
current_lr = self.optimizer.param_groups[0]['lr']
print(f"\nEpoch {epoch}/{epochs} | 时间: {epoch_time:.2f}s")
print(f"训练损失: {train_loss:.4f}, 训练准确率: {train_acc:.2f}%")
print(f"验证损失: {val_loss:.4f}, 验证准确率: {val_acc:.2f}%")
print(f"学习率: {current_lr:.6f}")
print("-" * 50)
# 保存最佳模型
if val_acc > best_val_acc:
best_val_acc = val_acc
patience_counter = 0
# 保存模型
torch.save({
'epoch': epoch,
'model_state_dict': self.model.state_dict(),
'optimizer_state_dict': self.optimizer.state_dict(),
'val_acc': val_acc,
'val_loss': val_loss,
}, save_path)
print(f"✅ 保存最佳模型 (准确率: {val_acc:.2f}%)")
else:
patience_counter += 1
print(f"耐心计数器: {patience_counter}/{early_stopping_patience}")
# 早停检查
if patience_counter >= early_stopping_patience:
print(f"\n⚠️ 早停触发!在 {epoch} 个epoch后停止训练")
break
print(f"\n训练完成!最佳验证准确率: {best_val_acc:.2f}%")
return self.history
def evaluate(self, test_loader, classes=None):
"""评估模型在测试集上的性能"""
print("\n" + "=" * 60)
print("测试集评估")
print("=" * 60)
self.model.eval()
test_loss = 0.0
correct = 0
total = 0
# 用于混淆矩阵
all_preds = []
all_targets = []
with torch.no_grad():
for batch_idx, (inputs, targets) in enumerate(test_loader):
inputs, targets = inputs.to(self.device), targets.to(self.device)
outputs = self.model(inputs)
loss = self.criterion(outputs, targets)
test_loss += loss.item()
_, predicted = outputs.max(1)
total += targets.size(0)
correct += predicted.eq(targets).sum().item()
# 收集预测结果
all_preds.extend(predicted.cpu().numpy())
all_targets.extend(targets.cpu().numpy())
# 计算总体指标
avg_loss = test_loss / len(test_loader)
accuracy = 100. * correct / total
print(f"测试损失: {avg_loss:.4f}")
print(f"测试准确率: {accuracy:.2f}%")
print(f"正确分类: {correct}/{total}")
# 计算每个类别的准确率
if classes is not None:
print(f"\n每个类别的准确率:")
print("-" * 40)
# 将列表转换为numpy数组
all_preds = np.array(all_preds)
all_targets = np.array(all_targets)
for i, class_name in enumerate(classes):
class_mask = all_targets == i
class_total = class_mask.sum()
if class_total > 0:
class_correct = (all_preds[class_mask] == i).sum()
class_acc = 100. * class_correct / class_total
print(f" {class_name:10s}: {class_acc:.2f}% ({class_correct}/{class_total})")
return avg_loss, accuracy, all_preds, all_targets
def plot_training_history(self):
"""绘制训练历史图表"""
fig, axes = plt.subplots(2, 2, figsize=(14, 10))
epochs = range(1, len(self.history['train_loss']) + 1)
# 损失曲线
axes[0, 0].plot(epochs, self.history['train_loss'], 'b-', label='训练损失')
axes[0, 0].plot(epochs, self.history['val_loss'], 'r-', label='验证损失')
axes[0, 0].set_xlabel('Epoch')
axes[0, 0].set_ylabel('损失')
axes[0, 0].set_title('训练和验证损失')
axes[0, 0].legend()
axes[0, 0].grid(True, alpha=0.3)
# 准确率曲线
axes[0, 1].plot(epochs, self.history['train_acc'], 'b-', label='训练准确率')
axes[0, 1].plot(epochs, self.history['val_acc'], 'r-', label='验证准确率')
axes[0, 1].set_xlabel('Epoch')
axes[0, 1].set_ylabel('准确率 (%)')
axes[0, 1].set_title('训练和验证准确率')
axes[0, 1].legend()
axes[0, 1].grid(True, alpha=0.3)
# 学习率变化
axes[1, 0].plot(epochs, self.history['learning_rates'], 'g-')
axes[1, 0].set_xlabel('Epoch')
axes[1, 0].set_ylabel('学习率')
axes[1, 0].set_title('学习率变化')
axes[1, 0].set_yscale('log')
axes[1, 0].grid(True, alpha=0.3)
# 训练vs验证准确率散点图
axes[1, 1].scatter(self.history['train_acc'], self.history['val_acc'],
alpha=0.6, c=epochs, cmap='viridis')
axes[1, 1].set_xlabel('训练准确率 (%)')
axes[1, 1].set_ylabel('验证准确率 (%)')
axes[1, 1].set_title('训练vs验证准确率')
# 添加对角线
min_acc = min(min(self.history['train_acc']), min(self.history['val_acc']))
max_acc = max(max(self.history['train_acc']), max(self.history['val_acc']))
axes[1, 1].plot([min_acc, max_acc], [min_acc, max_acc], 'r--', alpha=0.5)
# 添加颜色条
plt.colorbar(axes[1, 1].collections[0], ax=axes[1, 1], label='Epoch')
axes[1, 1].grid(True, alpha=0.3)
plt.tight_layout()
plt.show()
def visualize_predictions(self, test_loader, classes, num_samples=12):
"""可视化模型预测结果"""
# 获取一个批次的测试数据
data_iter = iter(test_loader)
images, labels = next(data_iter)
# 移动到设备并进行预测
images_gpu = images.to(self.device)
self.model.eval()
with torch.no_grad():
outputs = self.model(images_gpu)
_, predictions = outputs.max(1)
# 将预测移回CPU
predictions = predictions.cpu()
# 反归一化图像以便显示
# 注意:这里假设使用CIFAR-10的均值和标准差
mean = torch.tensor([0.4914, 0.4822, 0.4465]).view(3, 1, 1)
std = torch.tensor([0.2470, 0.2435, 0.2616]).view(3, 1, 1)
images = images * std + mean
images = torch.clamp(images, 0, 1)
# 创建可视化
fig, axes = plt.subplots(3, 4, figsize=(15, 10))
axes = axes.ravel()
for i in range(min(num_samples, len(images))):
# 转换图像维度
img = images[i].permute(1, 2, 0).numpy()
# 显示图像
axes[i].imshow(img)
# 设置标题颜色(正确为绿色,错误为红色)
pred_class = classes[predictions[i]]
true_class = classes[labels[i]]
is_correct = predictions[i] == labels[i]
title_color = 'green' if is_correct else 'red'
axes[i].set_title(f'预测: {pred_class}\n实际: {true_class}',
color=title_color, fontsize=11)
axes[i].axis('off')
plt.suptitle('模型预测结果可视化', fontsize=16)
plt.tight_layout()
plt.show()
主函数调用代码
import torch
import torch.nn as nn
from torchvision import models
from learn_alexnet_pytorch.alex_net import SimplifiedAlexNet
from learn_alexnet_pytorch.alex_net_data_handler import AlexNetDataHandler
from learn_alexnet_pytorch.alex_net_trainer import AlexNetTrainer
def main():
"""主训练函数"""
print("=" * 70)
print("AlexNet训练脚本")
print("=" * 70)
# 设置随机种子以保证可重复性
torch.manual_seed(42)
if torch.cuda.is_available():
torch.cuda.manual_seed(42)
# 1. 设置参数
config = {
'batch_size': 128,
'learning_rate': 0.01, # CIFAR-10上可以使用更大的学习率
'epochs': 100,
'num_classes': 10,
'weight_decay': 0.0005,
'momentum': 0.9,
'data_dir': './data',
'model_save_path': 'alexnet_cifar10_best.pth'
}
print("\n训练配置:")
for key, value in config.items():
print(f" {key}: {value}")
# 2. 准备数据
print("\n" + "-" * 70)
print("步骤1: 准备数据")
print("-" * 70)
data_handler = AlexNetDataHandler(
data_dir=config['data_dir'],
batch_size=config['batch_size']
)
train_loader, val_loader, test_loader, classes = data_handler.load_cifar10()
# 可视化一个批次的训练数据
data_handler.visualize_batch(train_loader, classes)
# 3. 创建模型
print("\n" + "-" * 70)
print("步骤2: 创建模型")
print("-" * 70)
# 使用简化版AlexNet(适配CIFAR-10)
model = SimplifiedAlexNet(num_classes=config['num_classes'])
# 打印模型结构
print(f"模型名称: {model.__class__.__name__}")
print(f"总参数量: {sum(p.numel() for p in model.parameters()):,}")
print(f"可训练参数: {sum(p.numel() for p in model.parameters() if p.requires_grad):,}")
# 4. 创建训练器
print("\n" + "-" * 70)
print("步骤3: 设置训练器")
print("-" * 70)
trainer = AlexNetTrainer(model)
trainer.setup_training(
learning_rate=config['learning_rate'],
weight_decay=config['weight_decay'],
momentum=config['momentum'],
lr_scheduler='multi_step'
)
# 5. 训练模型
print("\n" + "-" * 70)
print("步骤4: 训练模型")
print("-" * 70)
history = trainer.train(
train_loader=train_loader,
val_loader=val_loader,
epochs=config['epochs'],
early_stopping_patience=25,
save_path=config['model_save_path']
)
# 6. 绘制训练历史
print("\n" + "-" * 70)
print("步骤5: 分析训练结果")
print("-" * 70)
trainer.plot_training_history()
# 7. 在测试集上评估
print("\n" + "-" * 70)
print("步骤6: 测试集评估")
print("-" * 70)
# 加载最佳模型
checkpoint = torch.load(config['model_save_path'])
model.load_state_dict(checkpoint['model_state_dict'])
print(f"加载最佳模型 (epoch {checkpoint['epoch']}, 验证准确率: {checkpoint['val_acc']:.2f}%)")
# 评估
test_loss, test_acc, all_preds, all_targets = trainer.evaluate(test_loader, classes)
# 8. 可视化预测结果
print("\n" + "-" * 70)
print("步骤7: 可视化预测")
print("-" * 70)
trainer.visualize_predictions(test_loader, classes, num_samples=12)
print("\n" + "=" * 70)
print("训练完成!")
print("=" * 70)
return model, trainer, history
def load_pretrained_alexnet(num_classes=1000):
"""
加载PyTorch官方预训练的AlexNet
参数:
num_classes: 输出类别数,默认1000(ImageNet)
"""
print("加载PyTorch官方预训练的AlexNet...")
# 加载预训练模型
model = models.alexnet(pretrained=True)
# 如果需要修改输出类别数
if num_classes != 1000:
# 修改最后一个全连接层
in_features = model.classifier[6].in_features
model.classifier[6] = nn.Linear(in_features, num_classes)
print(f"修改输出层: 1000 -> {num_classes} 类别")
# 冻结特征提取层(可选)
# for param in model.features.parameters():
# param.requires_grad = False
print(f"模型加载完成 (ImageNet预训练)")
print(f"总参数量: {sum(p.numel() for p in model.parameters()):,}")
return model
if __name__ == "__main__":
# 运行主训练函数(训练简化版AlexNet on CIFAR-10)
model, trainer, history = main()
更多推荐


所有评论(0)