初始化配置

trainer = AlexNetTrainer(model, device='cuda')
设备选择策略
  • 优先使用GPU(cuda
  • GPU不可用时自动回退到CPU
  • 支持多GPU训练(需手动扩展)

训练设置 (setup_training 方法)

优化器配置(按原始论文)
# 随机梯度下降(SGD)配置
learning_rate = 0.001    # 初始学习率
momentum = 0.9           # 动量因子
weight_decay = 0.0005    # L2正则化系数

数值示例:

参数更新公式:
v_t = momentum * v_{t-1} + learning_rate * gradient
θ_t = θ_{t-1} - v_t - weight_decay * learning_rate * θ_{t-1}

示例计算:
梯度 = 0.1
上一步动量v_{t-1} = 0.05
当前动量v_t = 0.9*0.05 + 0.001*0.1 = 0.0451
参数更新量 = 0.0451 + 0.0005*0.001*θ
学习率调度器选项
调度器类型 配置 适用场景
StepLR step_size=30, gamma=0.1 固定间隔降低学习率
MultiStepLR milestones=[30,60,80,90], gamma=0.1 多阶段降低(接近原始论文)
ReduceLROnPlateau patience=10, factor=0.1 基于验证损失动态调整

学习率变化示例(MultiStepLR):

Epoch 1-29: 0.001
Epoch 30-59: 0.0001
Epoch 60-79: 0.00001
Epoch 80-89: 0.000001
Epoch 90+: 0.0000001

训练流程

单epoch训练 (train_epoch 方法)

数值示例:

批次大小: 128
训练集大小: 45,000(CIFAR-10训练集的90%)
批次数: 45,000/128 ≈ 352

每100批次打印一次进度:
Epoch: 1 [100/352] | Loss: 2.3145 | Acc: 12.34% | LR: 0.001000
Epoch: 1 [200/352] | Loss: 1.8923 | Acc: 28.56% | LR: 0.001000
Epoch: 1 [300/352] | Loss: 1.6543 | Acc: 35.78% | LR: 0.001000
梯度裁剪
torch.nn.utils.clip_grad_norm_(parameters, max_norm=1.0)
  • 防止梯度爆炸
  • 限制梯度范数不超过1.0
  • 保持训练稳定性

训练循环 (train 方法)

完整训练示例
# 训练90个epoch,使用早停机制
history = trainer.train(
    train_loader=train_loader,
    val_loader=val_loader,
    epochs=90,
    early_stopping_patience=20,
    save_path='alexnet_cifar10_best.pth'
)
训练过程输出示例
Epoch 1/90 | 时间: 45.23s
训练损失: 1.8923, 训练准确率: 32.45%
验证损失: 1.6543, 验证准确率: 40.12%
学习率: 0.001000
耐心计数器: 0/20

✅ 保存最佳模型 (准确率: 40.12%)
早停机制
  • 当验证准确率连续20个epoch没有提升时停止训练
  • 避免过拟合和计算资源浪费
  • 自动保存最佳模型

评估功能 (evaluate 方法)

性能指标计算
# 总体指标
测试损失 = Σ(批次损失) / 批次数量
测试准确率 = 正确预测数 / 总样本数 × 100%

# 每个类别指标
类别准确率 = 类别正确预测数 / 类别总样本数 × 100%
评估输出示例
============================================================
测试集评估
============================================================
测试损失: 0.8923
测试准确率: 85.34%
正确分类: 8534/10000

每个类别的准确率:
----------------------------------------
  plane     : 88.20% (882/1000)
  car       : 92.10% (921/1000)
  bird      : 76.30% (763/1000)
  cat       : 65.40% (654/1000)
  deer      : 82.50% (825/1000)
  dog       : 78.90% (789/1000)
  frog      : 91.20% (912/1000)
  horse     : 87.60% (876/1000)
  ship      : 93.80% (938/1000)
  truck     : 90.50% (905/1000)

可视化功能

训练历史图表 (plot_training_history)
  • 损失曲线(训练vs验证)
  • 准确率曲线(训练vs验证)
  • 学习率变化曲线
  • 训练vs验证准确率散点图

典型训练曲线示例:

Epoch 1-30: 训练损失快速下降,验证准确率稳步提升
Epoch 31-60: 学习率降低后,损失缓慢下降,准确率继续提升
Epoch 61-90: 趋于收敛,训练和验证曲线基本稳定
预测可视化 (visualize_predictions)
  • 显示12个测试样本的预测结果
  • 正确预测标记为绿色,错误预测标记为红色
  • 直观展示模型性能

权值初始化策略

原始AlexNet初始化
# 卷积层:正态分布 N(0, 0.01)
nn.init.normal_(weight, mean=0, std=0.01)

# 全连接层:正态分布 N(0, 0.01)
nn.init.normal_(weight, mean=0, std=0.01)

# 偏置:初始化为0(部分层为1)
nn.init.constant_(bias, 0)  # 大多数层
nn.init.constant_(bias, 1)  # 第2、4、5卷积层
简化版AlexNet初始化
# 使用Kaiming初始化(针对ReLU)
nn.init.kaiming_normal_(weight, mode='fan_out', nonlinearity='relu')

实际训练示例

CIFAR-10训练配置

# 1. 加载数据集
data_handler = AlexNetDataHandler(batch_size=128)
train_loader, val_loader, test_loader, classes = data_handler.load_cifar10()

# 2. 创建模型(简化版AlexNet)
model = SimplifiedAlexNet(num_classes=10, use_batchnorm=False)

# 3. 创建训练器
trainer = AlexNetTrainer(model)

# 4. 设置训练配置
trainer.setup_training(
    learning_rate=0.001,
    weight_decay=0.0005,
    momentum=0.9,
    lr_scheduler='multi_step'
)

# 5. 开始训练
history = trainer.train(
    train_loader=train_loader,
    val_loader=val_loader,
    epochs=90,
    early_stopping_patience=20,
    save_path='alexnet_cifar10.pth'
)

# 6. 评估模型
test_loss, test_acc, preds, targets = trainer.evaluate(test_loader, classes)

# 7. 可视化结果
trainer.plot_training_history()
trainer.visualize_predictions(test_loader, classes)

预期性能指标

数据集 模型 Epochs 验证准确率 测试准确率 训练时间
CIFAR-10 原始AlexNet 90 ~75-80% ~70-75% ~2-3小时
CIFAR-10 简化AlexNet 90 ~82-85% ~80-83% ~1-2小时
ImageNet 原始AlexNet 90 ~56-57% ~56.5% 5-7天(8×GPU)

调优建议

学习率策略
# 针对小数据集可调整的调度
milestones = [50, 75, 90]  # 更晚降低学习率
gamma = 0.5  # 更温和的下降
正则化增强
# 增加Dropout率防止过拟合
dropout_rate = 0.5  # 原始论文使用0.5
# 或增加权重衰减
weight_decay = 0.001  # 适度增加
数据增强
# 增强数据多样性
transforms.RandomRotation(15)  # 随机旋转
transforms.RandomPerspective(0.3)  # 随机透视变换

注意事项

内存使用

  • 原始AlexNet:约240MB显存(批次大小=128)
  • 简化AlexNet:约60MB显存(批次大小=128)
  • 建议根据GPU内存调整批次大小

训练时间

  • CIFAR-10:1-3小时(单GPU)
  • ImageNet:5-7天(多GPU)
  • 可使用混合精度训练加速

常见问题

梯度消失/爆炸
  • 使用梯度裁剪(max_norm=1.0
  • 使用合适的初始化策略
  • 添加Batch Normalization(简化版可选)
过拟合
  • 增加Dropout率
  • 增加数据增强
  • 使用更早的早停
学习率选择
  • 太大:损失震荡,不收敛
  • 太小:收敛缓慢
  • 建议:从0.001开始,使用学习率调度器

扩展功能

特征图可视化

# 获取中间层特征图
feature_maps = model.get_feature_maps(input_image, layer_idx=[0, 2, 4])
# 可用于分析网络学习到的特征

模型保存与加载

# 保存完整训练状态
torch.save({
    'epoch': epoch,
    'model_state_dict': model.state_dict(),
    'optimizer_state_dict': optimizer.state_dict(),
    'val_acc': val_acc,
}, 'checkpoint.pth')

# 加载继续训练
checkpoint = torch.load('checkpoint.pth')
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
start_epoch = checkpoint['epoch'] + 1

多GPU训练支持

# 使用DataParallel
model = nn.DataParallel(model)
# 自动处理数据分布和梯度聚合

附代码

AlexNetTrainer 代码

import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt
import time

class AlexNetTrainer:
    """AlexNet训练器"""

    def __init__(self, model, device='cuda' if torch.cuda.is_available() else 'cpu'):
        self.model = model
        self.device = device
        self.model.to(self.device)

        # 训练历史记录
        self.history = {
            'train_loss': [],
            'train_acc': [],
            'val_loss': [],
            'val_acc': [],
            'learning_rates': []
        }

        # 优化器和学习率调度器
        self.optimizer = None
        self.scheduler = None
        self.criterion = None

    def setup_training(self, learning_rate=0.001, weight_decay=0.0005,
                       momentum=0.9, lr_scheduler='step'):
        """设置训练配置"""
        # 定义损失函数(交叉熵损失)
        self.criterion = nn.CrossEntropyLoss()

        # 定义优化器(SGD with momentum,按原始论文)
        self.optimizer = optim.SGD(
            self.model.parameters(),
            lr=learning_rate,
            momentum=momentum,
            weight_decay=weight_decay
        )

        # 学习率调度器
        if lr_scheduler == 'step':
            # 按原始论文:在第30、60、80、90轮次降低学习率
            self.scheduler = optim.lr_scheduler.StepLR(
                self.optimizer, step_size=30, gamma=0.1
            )
        elif lr_scheduler == 'multi_step':
            # 多步长调度(更接近原始论文)
            milestones = [30, 60, 80, 90]
            self.scheduler = optim.lr_scheduler.MultiStepLR(
                self.optimizer, milestones=milestones, gamma=0.1
            )
        elif lr_scheduler == 'reduce_on_plateau':
            # 基于验证集性能降低学习率
            self.scheduler = optim.lr_scheduler.ReduceLROnPlateau(
                self.optimizer, mode='min', factor=0.1, patience=10
            )
        else:
            self.scheduler = None

        print(f"优化器: SGD (lr={learning_rate}, momentum={momentum}, weight_decay={weight_decay})")
        print(f"学习率调度器: {lr_scheduler}")

    def train_epoch(self, train_loader, epoch, print_freq=100):
        """训练一个epoch"""
        self.model.train()
        running_loss = 0.0
        correct = 0
        total = 0

        for batch_idx, (inputs, targets) in enumerate(train_loader):
            # 移动数据到设备
            inputs, targets = inputs.to(self.device), targets.to(self.device)

            # 清零梯度
            self.optimizer.zero_grad()

            # 前向传播
            outputs = self.model(inputs)
            loss = self.criterion(outputs, targets)

            # 反向传播
            loss.backward()

            # 梯度裁剪(防止梯度爆炸)
            torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)

            # 更新参数
            self.optimizer.step()

            # 统计
            running_loss += loss.item()
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()

            # 打印进度
            if (batch_idx + 1) % print_freq == 0:
                batch_loss = running_loss / (batch_idx + 1)
                batch_acc = 100. * correct / total
                current_lr = self.optimizer.param_groups[0]['lr']

                print(f'Epoch: {epoch} [{batch_idx + 1}/{len(train_loader)}] | '
                      f'Loss: {batch_loss:.4f} | Acc: {batch_acc:.2f}% | '
                      f'LR: {current_lr:.6f}')

        # 计算epoch统计
        epoch_loss = running_loss / len(train_loader)
        epoch_acc = 100. * correct / total

        return epoch_loss, epoch_acc

    def validate(self, val_loader):
        """验证模型"""
        self.model.eval()
        running_loss = 0.0
        correct = 0
        total = 0

        with torch.no_grad():
            for inputs, targets in val_loader:
                inputs, targets = inputs.to(self.device), targets.to(self.device)

                outputs = self.model(inputs)
                loss = self.criterion(outputs, targets)

                running_loss += loss.item()
                _, predicted = outputs.max(1)
                total += targets.size(0)
                correct += predicted.eq(targets).sum().item()

        val_loss = running_loss / len(val_loader)
        val_acc = 100. * correct / total

        return val_loss, val_acc

    def train(self, train_loader, val_loader, epochs=90,
              early_stopping_patience=20, save_path='best_model.pth'):
        """完整的训练循环"""
        print(f"开始训练,共 {epochs} 个epoch")
        print(f"设备: {self.device}")
        print("-" * 60)

        best_val_acc = 0.0
        patience_counter = 0

        for epoch in range(1, epochs + 1):
            start_time = time.time()

            # 训练一个epoch
            train_loss, train_acc = self.train_epoch(train_loader, epoch)

            # 验证
            val_loss, val_acc = self.validate(val_loader)

            # 更新学习率
            if self.scheduler is not None:
                if isinstance(self.scheduler, optim.lr_scheduler.ReduceLROnPlateau):
                    self.scheduler.step(val_loss)
                else:
                    self.scheduler.step()

            # 记录历史
            self.history['train_loss'].append(train_loss)
            self.history['train_acc'].append(train_acc)
            self.history['val_loss'].append(val_loss)
            self.history['val_acc'].append(val_acc)
            self.history['learning_rates'].append(self.optimizer.param_groups[0]['lr'])

            # 打印epoch结果
            epoch_time = time.time() - start_time
            current_lr = self.optimizer.param_groups[0]['lr']

            print(f"\nEpoch {epoch}/{epochs} | 时间: {epoch_time:.2f}s")
            print(f"训练损失: {train_loss:.4f}, 训练准确率: {train_acc:.2f}%")
            print(f"验证损失: {val_loss:.4f}, 验证准确率: {val_acc:.2f}%")
            print(f"学习率: {current_lr:.6f}")
            print("-" * 50)

            # 保存最佳模型
            if val_acc > best_val_acc:
                best_val_acc = val_acc
                patience_counter = 0

                # 保存模型
                torch.save({
                    'epoch': epoch,
                    'model_state_dict': self.model.state_dict(),
                    'optimizer_state_dict': self.optimizer.state_dict(),
                    'val_acc': val_acc,
                    'val_loss': val_loss,
                }, save_path)

                print(f"✅ 保存最佳模型 (准确率: {val_acc:.2f}%)")
            else:
                patience_counter += 1
                print(f"耐心计数器: {patience_counter}/{early_stopping_patience}")

            # 早停检查
            if patience_counter >= early_stopping_patience:
                print(f"\n⚠️  早停触发!在 {epoch} 个epoch后停止训练")
                break

        print(f"\n训练完成!最佳验证准确率: {best_val_acc:.2f}%")

        return self.history

    def evaluate(self, test_loader, classes=None):
        """评估模型在测试集上的性能"""
        print("\n" + "=" * 60)
        print("测试集评估")
        print("=" * 60)

        self.model.eval()
        test_loss = 0.0
        correct = 0
        total = 0

        # 用于混淆矩阵
        all_preds = []
        all_targets = []

        with torch.no_grad():
            for batch_idx, (inputs, targets) in enumerate(test_loader):
                inputs, targets = inputs.to(self.device), targets.to(self.device)

                outputs = self.model(inputs)
                loss = self.criterion(outputs, targets)

                test_loss += loss.item()
                _, predicted = outputs.max(1)

                total += targets.size(0)
                correct += predicted.eq(targets).sum().item()

                # 收集预测结果
                all_preds.extend(predicted.cpu().numpy())
                all_targets.extend(targets.cpu().numpy())

        # 计算总体指标
        avg_loss = test_loss / len(test_loader)
        accuracy = 100. * correct / total

        print(f"测试损失: {avg_loss:.4f}")
        print(f"测试准确率: {accuracy:.2f}%")
        print(f"正确分类: {correct}/{total}")

        # 计算每个类别的准确率
        if classes is not None:
            print(f"\n每个类别的准确率:")
            print("-" * 40)

            # 将列表转换为numpy数组
            all_preds = np.array(all_preds)
            all_targets = np.array(all_targets)

            for i, class_name in enumerate(classes):
                class_mask = all_targets == i
                class_total = class_mask.sum()

                if class_total > 0:
                    class_correct = (all_preds[class_mask] == i).sum()
                    class_acc = 100. * class_correct / class_total
                    print(f"  {class_name:10s}: {class_acc:.2f}% ({class_correct}/{class_total})")

        return avg_loss, accuracy, all_preds, all_targets

    def plot_training_history(self):
        """绘制训练历史图表"""
        fig, axes = plt.subplots(2, 2, figsize=(14, 10))

        epochs = range(1, len(self.history['train_loss']) + 1)

        # 损失曲线
        axes[0, 0].plot(epochs, self.history['train_loss'], 'b-', label='训练损失')
        axes[0, 0].plot(epochs, self.history['val_loss'], 'r-', label='验证损失')
        axes[0, 0].set_xlabel('Epoch')
        axes[0, 0].set_ylabel('损失')
        axes[0, 0].set_title('训练和验证损失')
        axes[0, 0].legend()
        axes[0, 0].grid(True, alpha=0.3)

        # 准确率曲线
        axes[0, 1].plot(epochs, self.history['train_acc'], 'b-', label='训练准确率')
        axes[0, 1].plot(epochs, self.history['val_acc'], 'r-', label='验证准确率')
        axes[0, 1].set_xlabel('Epoch')
        axes[0, 1].set_ylabel('准确率 (%)')
        axes[0, 1].set_title('训练和验证准确率')
        axes[0, 1].legend()
        axes[0, 1].grid(True, alpha=0.3)

        # 学习率变化
        axes[1, 0].plot(epochs, self.history['learning_rates'], 'g-')
        axes[1, 0].set_xlabel('Epoch')
        axes[1, 0].set_ylabel('学习率')
        axes[1, 0].set_title('学习率变化')
        axes[1, 0].set_yscale('log')
        axes[1, 0].grid(True, alpha=0.3)

        # 训练vs验证准确率散点图
        axes[1, 1].scatter(self.history['train_acc'], self.history['val_acc'],
                           alpha=0.6, c=epochs, cmap='viridis')
        axes[1, 1].set_xlabel('训练准确率 (%)')
        axes[1, 1].set_ylabel('验证准确率 (%)')
        axes[1, 1].set_title('训练vs验证准确率')

        # 添加对角线
        min_acc = min(min(self.history['train_acc']), min(self.history['val_acc']))
        max_acc = max(max(self.history['train_acc']), max(self.history['val_acc']))
        axes[1, 1].plot([min_acc, max_acc], [min_acc, max_acc], 'r--', alpha=0.5)

        # 添加颜色条
        plt.colorbar(axes[1, 1].collections[0], ax=axes[1, 1], label='Epoch')
        axes[1, 1].grid(True, alpha=0.3)

        plt.tight_layout()
        plt.show()

    def visualize_predictions(self, test_loader, classes, num_samples=12):
        """可视化模型预测结果"""
        # 获取一个批次的测试数据
        data_iter = iter(test_loader)
        images, labels = next(data_iter)

        # 移动到设备并进行预测
        images_gpu = images.to(self.device)
        self.model.eval()

        with torch.no_grad():
            outputs = self.model(images_gpu)
            _, predictions = outputs.max(1)

        # 将预测移回CPU
        predictions = predictions.cpu()

        # 反归一化图像以便显示
        # 注意:这里假设使用CIFAR-10的均值和标准差
        mean = torch.tensor([0.4914, 0.4822, 0.4465]).view(3, 1, 1)
        std = torch.tensor([0.2470, 0.2435, 0.2616]).view(3, 1, 1)
        images = images * std + mean
        images = torch.clamp(images, 0, 1)

        # 创建可视化
        fig, axes = plt.subplots(3, 4, figsize=(15, 10))
        axes = axes.ravel()

        for i in range(min(num_samples, len(images))):
            # 转换图像维度
            img = images[i].permute(1, 2, 0).numpy()

            # 显示图像
            axes[i].imshow(img)

            # 设置标题颜色(正确为绿色,错误为红色)
            pred_class = classes[predictions[i]]
            true_class = classes[labels[i]]
            is_correct = predictions[i] == labels[i]

            title_color = 'green' if is_correct else 'red'
            axes[i].set_title(f'预测: {pred_class}\n实际: {true_class}',
                              color=title_color, fontsize=11)
            axes[i].axis('off')

        plt.suptitle('模型预测结果可视化', fontsize=16)
        plt.tight_layout()
        plt.show()

主函数调用代码

import torch
import torch.nn as nn
from torchvision import models
from learn_alexnet_pytorch.alex_net import SimplifiedAlexNet
from learn_alexnet_pytorch.alex_net_data_handler import AlexNetDataHandler
from learn_alexnet_pytorch.alex_net_trainer import AlexNetTrainer


def main():
    """主训练函数"""
    print("=" * 70)
    print("AlexNet训练脚本")
    print("=" * 70)

    # 设置随机种子以保证可重复性
    torch.manual_seed(42)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(42)

    # 1. 设置参数
    config = {
        'batch_size': 128,
        'learning_rate': 0.01,  # CIFAR-10上可以使用更大的学习率
        'epochs': 100,
        'num_classes': 10,
        'weight_decay': 0.0005,
        'momentum': 0.9,
        'data_dir': './data',
        'model_save_path': 'alexnet_cifar10_best.pth'
    }

    print("\n训练配置:")
    for key, value in config.items():
        print(f"  {key}: {value}")

    # 2. 准备数据
    print("\n" + "-" * 70)
    print("步骤1: 准备数据")
    print("-" * 70)

    data_handler = AlexNetDataHandler(
        data_dir=config['data_dir'],
        batch_size=config['batch_size']
    )

    train_loader, val_loader, test_loader, classes = data_handler.load_cifar10()

    # 可视化一个批次的训练数据
    data_handler.visualize_batch(train_loader, classes)

    # 3. 创建模型
    print("\n" + "-" * 70)
    print("步骤2: 创建模型")
    print("-" * 70)

    # 使用简化版AlexNet(适配CIFAR-10)
    model = SimplifiedAlexNet(num_classes=config['num_classes'])

    # 打印模型结构
    print(f"模型名称: {model.__class__.__name__}")
    print(f"总参数量: {sum(p.numel() for p in model.parameters()):,}")
    print(f"可训练参数: {sum(p.numel() for p in model.parameters() if p.requires_grad):,}")

    # 4. 创建训练器
    print("\n" + "-" * 70)
    print("步骤3: 设置训练器")
    print("-" * 70)

    trainer = AlexNetTrainer(model)
    trainer.setup_training(
        learning_rate=config['learning_rate'],
        weight_decay=config['weight_decay'],
        momentum=config['momentum'],
        lr_scheduler='multi_step'
    )

    # 5. 训练模型
    print("\n" + "-" * 70)
    print("步骤4: 训练模型")
    print("-" * 70)

    history = trainer.train(
        train_loader=train_loader,
        val_loader=val_loader,
        epochs=config['epochs'],
        early_stopping_patience=25,
        save_path=config['model_save_path']
    )

    # 6. 绘制训练历史
    print("\n" + "-" * 70)
    print("步骤5: 分析训练结果")
    print("-" * 70)

    trainer.plot_training_history()

    # 7. 在测试集上评估
    print("\n" + "-" * 70)
    print("步骤6: 测试集评估")
    print("-" * 70)

    # 加载最佳模型
    checkpoint = torch.load(config['model_save_path'])
    model.load_state_dict(checkpoint['model_state_dict'])
    print(f"加载最佳模型 (epoch {checkpoint['epoch']}, 验证准确率: {checkpoint['val_acc']:.2f}%)")

    # 评估
    test_loss, test_acc, all_preds, all_targets = trainer.evaluate(test_loader, classes)

    # 8. 可视化预测结果
    print("\n" + "-" * 70)
    print("步骤7: 可视化预测")
    print("-" * 70)

    trainer.visualize_predictions(test_loader, classes, num_samples=12)

    print("\n" + "=" * 70)
    print("训练完成!")
    print("=" * 70)

    return model, trainer, history


def load_pretrained_alexnet(num_classes=1000):
    """
    加载PyTorch官方预训练的AlexNet

    参数:
        num_classes: 输出类别数,默认1000(ImageNet)
    """
    print("加载PyTorch官方预训练的AlexNet...")

    # 加载预训练模型
    model = models.alexnet(pretrained=True)

    # 如果需要修改输出类别数
    if num_classes != 1000:
        # 修改最后一个全连接层
        in_features = model.classifier[6].in_features
        model.classifier[6] = nn.Linear(in_features, num_classes)

        print(f"修改输出层: 1000 -> {num_classes} 类别")

    # 冻结特征提取层(可选)
    # for param in model.features.parameters():
    #     param.requires_grad = False

    print(f"模型加载完成 (ImageNet预训练)")
    print(f"总参数量: {sum(p.numel() for p in model.parameters()):,}")

    return model


if __name__ == "__main__":
    # 运行主训练函数(训练简化版AlexNet on CIFAR-10)
    model, trainer, history = main()
Logo

有“AI”的1024 = 2048,欢迎大家加入2048 AI社区

更多推荐