人工智能综合项目开发9---手势识别resnetmodel_train
·
import numpy as np
import matplotlib.pyplot as plt
import os
from keras.layers import Input, Add, Dense, Activation, BatchNormalization, Flatten, Conv2D, AveragePooling2D, Dropout
from keras.models import Model, load_model
from keras.initializers import glorot_uniform
from keras.optimizers import SGD
from keras.callbacks import ReduceLROnPlateau, EarlyStopping, ModelCheckpoint
from keras.regularizers import l2
import keras.backend as K
# 配置中文显示
plt.rcParams["font.family"] = ["SimHei", "WenQuanYi Micro Hei", "Heiti TC"] # 设置中文字体
plt.rcParams["axes.unicode_minus"] = False # 解决负号显示问题
K.set_image_data_format('channels_last') # 设置Keras图像数据格式为通道在后
# 轻量化残差块
def basic_block(X, filters, stride=1):
"""
基础残差块实现
参数:
X: 输入张量
filters: 卷积核数量列表 [F1, F2]
stride: 步长,默认为1
返回:
经过残差连接后的激活输出
"""
F1, F2 = filters # 解包滤波器数量
X_shortcut = X # 保存shortcut连接
# 主分支 - 第一个卷积层
X = Conv2D(F1, (3, 3), strides=(stride, stride), padding='same',
kernel_initializer=glorot_uniform(seed=0), kernel_regularizer=l2(1e-6))(X)
X = BatchNormalization(axis=3)(X) # 批归一化,沿通道轴
X = Activation('relu')(X) # ReLU激活函数
X = Dropout(0.05)(X) # 添加dropout防止过拟合
# 主分支 - 第二个卷积层
X = Conv2D(F2, (3, 3), strides=(1, 1), padding='same',
kernel_initializer=glorot_uniform(seed=0), kernel_regularizer=l2(1e-6))(X)
X = BatchNormalization(axis=3)(X) # 批归一化
# 调整shortcut路径以匹配主分支的维度
if stride != 1 or F2 != X_shortcut.shape[3]:
# 当步长不为1或通道数不匹配时,使用1x1卷积调整shortcut
X_shortcut = Conv2D(F2, (1, 1), strides=(stride, stride), padding='valid',
kernel_initializer=glorot_uniform(seed=0))(X_shortcut)
X_shortcut = BatchNormalization(axis=3)(X_shortcut)
# 残差连接:主分支输出 + shortcut
X = Add()([X, X_shortcut])
return Activation('relu')(X) # 最终激活输出
# 构建极简ResNet
def SimpleResNet(input_shape=(64, 64, 3), classes=6):
"""
构建简化的ResNet模型
参数:
input_shape: 输入图像形状,默认(64, 64, 3)
classes: 分类类别数,默认6
返回:
编译好的Keras模型
"""
# 定义输入层
X_input = Input(input_shape)
# 初始卷积层
X = Conv2D(8, (3, 3), padding='same', kernel_initializer=glorot_uniform(seed=0))(X_input)
X = BatchNormalization(axis=3)(X) # 批归一化
X = Activation('relu')(X) # ReLU激活
# 残差块 - 使用基础残差块
X = basic_block(X, [8, 16], stride=2) # 滤波器数量从8到16,步长为2进行下采样
# 输出层
X = AveragePooling2D((8, 8))(X) # 全局平均池化
X = Flatten()(X) # 展平为全连接层输入
# 输出层,使用softmax激活函数进行多分类
outputs = Dense(classes, activation='softmax')(X)
return Model(inputs=X_input, outputs=outputs) # 创建并返回模型
# 训练模型
def train_model(X_train, Y_train, X_val, Y_val, X_test, Y_test, aug_generator, steps):
"""
训练ResNet模型
参数:
X_train, Y_train: 训练数据及标签
X_val, Y_val: 验证数据及标签
X_test, Y_test: 测试数据及标签
aug_generator: 数据增强生成器
steps: 每个epoch的步数
返回:
best_model: 最佳模型
history: 训练历史记录
"""
# 创建模型
model = SimpleResNet()
# 编译模型:使用SGD优化器,分类交叉熵损失函数,准确率作为评估指标
model.compile(optimizer=SGD(0.001, momentum=0.95),
loss='categorical_crossentropy',
metrics=['accuracy'])
# 定义回调函数
callbacks = [
# 早停:监控验证损失, patience=20表示连续20轮无改善则停止
EarlyStopping(monitor='val_loss', patience=20, restore_best_weights=True),
# 学习率调整:当验证损失停止改善时降低学习率
ReduceLROnPlateau(monitor='val_loss', factor=0.9, patience=10, min_lr=1e-6),
# 模型检查点:保存验证准确率最高的模型
ModelCheckpoint('model/best_resnet.h5', monitor='val_accuracy',
save_best_only=True, mode='max')
]
# 训练模型
history = model.fit(aug_generator,
steps_per_epoch=steps, # 每个epoch的批次数
epochs=100, # 最大训练轮数
validation_data=(X_val, Y_val), # 验证数据
callbacks=callbacks, # 回调函数
verbose=1) # 显示训练进度
# 加载最佳模型进行评估
best_model = load_model('model/best_resnet.h5')
# 在测试集上评估模型性能
loss, acc = best_model.evaluate(X_test, Y_test, verbose=0)
print(f"\n测试集准确率: {acc:.4f}, 损失: {loss:.4f}")
# 确保模型保存目录存在
os.makedirs('model', exist_ok=True)
# 保存最终训练好的模型
model.save('model/simple_resnet_trained.h5')
# 绘制训练曲线
plt.figure(figsize=(12, 4))
# 准确率曲线
plt.subplot(121)
plt.plot(history.history['accuracy'], label='训练准确率')
plt.plot(history.history['val_accuracy'], label='验证准确率')
plt.legend()
plt.title('模型准确率')
# 损失曲线
plt.subplot(122)
plt.plot(history.history['loss'], label='训练损失')
plt.plot(history.history['val_loss'], label='验证损失')
plt.legend()
plt.title('模型损失')
plt.tight_layout() # 自动调整子图间距
plt.show()
return best_model, history # 返回最佳模型和训练历史
# 主程序入口
if __name__ == "__main__":
# 从数据预处理模块导入数据
from data_preprocessing import X_train, Y_train, X_val, Y_val, X_test, Y_test, augmented_train_generator, steps_per_epoch
# 开始训练模型
train_model(X_train, Y_train, X_val, Y_val, X_test, Y_test, augmented_train_generator, steps_per_epoch)更多推荐


所有评论(0)