import numpy as np
import h5py
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
from keras.preprocessing.image import ImageDataGenerator
from keras.utils import to_categorical

# 中文显示配置
plt.rcParams["font.family"] = ["SimHei", "SimSun", "Microsoft YaHei"]
plt.rcParams["axes.unicode_minus"] = False


def load_dataset():
    """
    加载并分割数据集为训练/验证/测试集
    
    从HDF5文件中加载手势识别数据集,并将训练集分割为训练集和验证集
    返回:
        train_x, train_y: 训练集图像和标签
        val_x, val_y: 验证集图像和标签  
        test_x, test_y: 测试集图像和标签
        classes: 类别名称列表
    """
    with h5py.File('./train_signs.h5', "r") as f:
        train_x, train_y = np.array(f["train_set_x"][:]), np.array(f["train_set_y"][:])
    with h5py.File('./test_signs.h5', "r") as f:
        test_x, test_y = np.array(f["test_set_x"][:]), np.array(f["test_set_y"][:])
        classes = np.array(f["list_classes"][:])

    # 分层分割训练集为训练/验证集 (80%训练, 20%验证)
    # stratify参数确保训练集和验证集中各类别比例相同
    train_x, val_x, train_y, val_y = train_test_split(
        train_x, train_y, test_size=0.2, random_state=42, stratify=train_y
    )

    return train_x, train_y, val_x, val_y, test_x, test_y, classes


def preprocess_data(images):
    """
    图像预处理:归一化到[0,1]并确保3通道
    
    参数:
        images: 输入图像数组
    返回:
        预处理后的图像数组
    """
    # 将像素值从[0,255]归一化到[0,1]
    images = images.astype('float32') / 255.0
    
    # 如果图像是单通道(灰度图),复制为3通道(RGB)
    # 这确保与预训练模型兼容,因为大多数CNN期望3通道输入
    if images.shape[-1] == 1:
        images = np.repeat(images, 3, axis=-1)
    return images


def create_data_augmenter():
    """
    创建轻量数据增强生成器
    
    通过数据增强增加训练数据的多样性,提高模型泛化能力
    返回:
        ImageDataGenerator: 配置好的数据增强器
    """
    return ImageDataGenerator(
        horizontal_flip=True,        # 水平翻转
        width_shift_range=0.03,      # 宽度方向随机平移3%
        height_shift_range=0.03,     # 高度方向随机平移3%
        fill_mode='nearest'          # 填充新像素的策略:使用最近的像素
    )


# 加载并预处理数据
X_train_orig, Y_train_orig, X_val_orig, Y_val_orig, X_test_orig, Y_test_orig, classes = load_dataset()

# 预处理图像数据:归一化并确保3通道
X_train = preprocess_data(X_train_orig)
X_val = preprocess_data(X_val_orig)
X_test = preprocess_data(X_test_orig)

# 标签转为one-hot编码 (6个类别的手势识别)
# 例如:标签2变为[0,0,1,0,0,0]
Y_train = to_categorical(Y_train_orig, 6)
Y_val = to_categorical(Y_val_orig, 6)
Y_test = to_categorical(Y_test_orig, 6)

# 验证数据形状
print(f"Y_train形状: {Y_train.shape} (应为 (864,6))")
print(f"Y_val形状: {Y_val.shape} (应为 (216,6))")
print(f"Y_test形状: {Y_test.shape} (应为 (120,6))")

# 数据增强配置
datagen = create_data_augmenter()
# 创建数据增强生成器,批量大小为32,每次迭代时打乱数据
augmented_train_generator = datagen.flow(X_train, Y_train, batch_size=32, shuffle=True)
# 计算每个epoch需要的步数(批次数量)
steps_per_epoch = X_train.shape[0] // 32  # 训练样本数除以批量大小


Logo

有“AI”的1024 = 2048,欢迎大家加入2048 AI社区

更多推荐