import numpy as np
import matplotlib.pyplot as plt
import os
from keras.layers import Input, Add, Dense, Activation, BatchNormalization, Flatten, Conv2D, AveragePooling2D, Dropout
from keras.models import Model, load_model
from keras.initializers import glorot_uniform
from keras.optimizers import SGD
from keras.callbacks import ReduceLROnPlateau, EarlyStopping, ModelCheckpoint
from keras.regularizers import l2
import keras.backend as K

# 配置中文显示
plt.rcParams["font.family"] = ["SimHei", "WenQuanYi Micro Hei", "Heiti TC"]  # 设置中文字体
plt.rcParams["axes.unicode_minus"] = False  # 解决负号显示问题
K.set_image_data_format('channels_last')  # 设置Keras图像数据格式为通道在后


# 轻量化残差块
def basic_block(X, filters, stride=1):
    """
    基础残差块实现
    参数:
        X: 输入张量
        filters: 卷积核数量列表 [F1, F2]
        stride: 步长,默认为1
    返回:
        经过残差连接后的激活输出
    """
    F1, F2 = filters  # 解包滤波器数量
    X_shortcut = X  # 保存shortcut连接

    # 主分支 - 第一个卷积层
    X = Conv2D(F1, (3, 3), strides=(stride, stride), padding='same',
               kernel_initializer=glorot_uniform(seed=0), kernel_regularizer=l2(1e-6))(X)
    X = BatchNormalization(axis=3)(X)  # 批归一化,沿通道轴
    X = Activation('relu')(X)  # ReLU激活函数
    X = Dropout(0.05)(X)  # 添加dropout防止过拟合

    # 主分支 - 第二个卷积层
    X = Conv2D(F2, (3, 3), strides=(1, 1), padding='same',
               kernel_initializer=glorot_uniform(seed=0), kernel_regularizer=l2(1e-6))(X)
    X = BatchNormalization(axis=3)(X)  # 批归一化

    # 调整shortcut路径以匹配主分支的维度
    if stride != 1 or F2 != X_shortcut.shape[3]:
        # 当步长不为1或通道数不匹配时,使用1x1卷积调整shortcut
        X_shortcut = Conv2D(F2, (1, 1), strides=(stride, stride), padding='valid',
                            kernel_initializer=glorot_uniform(seed=0))(X_shortcut)
        X_shortcut = BatchNormalization(axis=3)(X_shortcut)

    # 残差连接:主分支输出 + shortcut
    X = Add()([X, X_shortcut])
    return Activation('relu')(X)  # 最终激活输出


# 构建极简ResNet
def SimpleResNet(input_shape=(64, 64, 3), classes=6):
    """
    构建简化的ResNet模型
    参数:
        input_shape: 输入图像形状,默认(64, 64, 3)
        classes: 分类类别数,默认6
    返回:
        编译好的Keras模型
    """
    # 定义输入层
    X_input = Input(input_shape)

    # 初始卷积层
    X = Conv2D(8, (3, 3), padding='same', kernel_initializer=glorot_uniform(seed=0))(X_input)
    X = BatchNormalization(axis=3)(X)  # 批归一化
    X = Activation('relu')(X)  # ReLU激活

    # 残差块 - 使用基础残差块
    X = basic_block(X, [8, 16], stride=2)  # 滤波器数量从8到16,步长为2进行下采样

    # 输出层
    X = AveragePooling2D((8, 8))(X)  # 全局平均池化
    X = Flatten()(X)  # 展平为全连接层输入
    # 输出层,使用softmax激活函数进行多分类
    outputs = Dense(classes, activation='softmax')(X)
    
    return Model(inputs=X_input, outputs=outputs)  # 创建并返回模型


# 训练模型
def train_model(X_train, Y_train, X_val, Y_val, X_test, Y_test, aug_generator, steps):
    """
    训练ResNet模型
    参数:
        X_train, Y_train: 训练数据及标签
        X_val, Y_val: 验证数据及标签  
        X_test, Y_test: 测试数据及标签
        aug_generator: 数据增强生成器
        steps: 每个epoch的步数
    返回:
        best_model: 最佳模型
        history: 训练历史记录
    """
    # 创建模型
    model = SimpleResNet()
    # 编译模型:使用SGD优化器,分类交叉熵损失函数,准确率作为评估指标
    model.compile(optimizer=SGD(0.001, momentum=0.95), 
                  loss='categorical_crossentropy', 
                  metrics=['accuracy'])

    # 定义回调函数
    callbacks = [
        # 早停:监控验证损失, patience=20表示连续20轮无改善则停止
        EarlyStopping(monitor='val_loss', patience=20, restore_best_weights=True),
        # 学习率调整:当验证损失停止改善时降低学习率
        ReduceLROnPlateau(monitor='val_loss', factor=0.9, patience=10, min_lr=1e-6),
        # 模型检查点:保存验证准确率最高的模型
        ModelCheckpoint('model/best_resnet.h5', monitor='val_accuracy', 
                       save_best_only=True, mode='max')
    ]

    # 训练模型
    history = model.fit(aug_generator, 
                        steps_per_epoch=steps,  # 每个epoch的批次数
                        epochs=100,  # 最大训练轮数
                        validation_data=(X_val, Y_val),  # 验证数据
                        callbacks=callbacks,  # 回调函数
                        verbose=1)  # 显示训练进度

    # 加载最佳模型进行评估
    best_model = load_model('model/best_resnet.h5')
    # 在测试集上评估模型性能
    loss, acc = best_model.evaluate(X_test, Y_test, verbose=0)
    print(f"\n测试集准确率: {acc:.4f}, 损失: {loss:.4f}")
    
    # 确保模型保存目录存在
    os.makedirs('model', exist_ok=True)
    # 保存最终训练好的模型
    model.save('model/simple_resnet_trained.h5')

    # 绘制训练曲线
    plt.figure(figsize=(12, 4))
    
    # 准确率曲线
    plt.subplot(121)
    plt.plot(history.history['accuracy'], label='训练准确率')
    plt.plot(history.history['val_accuracy'], label='验证准确率')
    plt.legend()
    plt.title('模型准确率')
    
    # 损失曲线
    plt.subplot(122)
    plt.plot(history.history['loss'], label='训练损失')
    plt.plot(history.history['val_loss'], label='验证损失')
    plt.legend()
    plt.title('模型损失')
    
    plt.tight_layout()  # 自动调整子图间距
    plt.show()

    return best_model, history  # 返回最佳模型和训练历史


# 主程序入口
if __name__ == "__main__":
    # 从数据预处理模块导入数据
    from data_preprocessing import X_train, Y_train, X_val, Y_val, X_test, Y_test, augmented_train_generator, steps_per_epoch

    # 开始训练模型
    train_model(X_train, Y_train, X_val, Y_val, X_test, Y_test, augmented_train_generator, steps_per_epoch)
Logo

有“AI”的1024 = 2048,欢迎大家加入2048 AI社区

更多推荐