sklearn.model_selection

一. 认识 sklearn.model_selection

sklearn (Scikit-Learn) 是 Python 中最负盛名、应用最广泛的传统机器学习算法库。而 model_selection 则是这个库中专门用于“模型选择与评估”的核心模块。

无论你使用的是传统的机器学习算法(如随机森林、SVM)还是深度学习框架(如 PyTorch、TensorFlow),在将数据喂给模型之前,几乎都离不开这个模块的帮助。它主要负责三大核心任务:

  1. 数据集划分 (Data Splitting):如切分训练集、测试集、验证集。
  2. 交叉验证 (Cross-Validation):更科学地评估模型的泛化能力。
  3. 超参数调优 (Hyperparameter Tuning):如网格搜索,自动寻找模型的最优参数。

导入时,通常按需导入具体的函数或类,例如:

from sklearn.model_selection import train_test_split

二. 核心功能拆解

1. train_test_split - 划分训练集与测试集

作用:将数组或矩阵随机分割为训练子集和测试子集。这是深度学习与机器学习工程中最基础、最必须的步骤,用于防止模型“死记硬背”(过拟合),并评估模型在未见过的新数据上的表现。

train_test_split(*arrays, test_size=None, train_size=None, random_state=None, shuffle=True, stratify=None)

参数:

  • 待划分数据 (*arrays): 允许同时传入多个长度相同的可迭代对象(如 List、NumPy 数组、Pandas DataFrame/Series)。最常见的是同时传入特征数据 X 和标签数据 y
  • 测试集比例 (test_size): 用于决定测试集占总数据的比例。如果是浮点数(0.0 到 1.0 之间),代表比例(例如 0.2 代表 20% 做测试集);如果是整数,则代表测试集的绝对样本数量 (float / int 或 None)。
  • 随机种子 (random_state): 极度重要的参数! 控制数据打乱的随机数生成器。如果设定为一个固定的整数(如 42),那么无论代码运行多少次,每次切分出来的数据都是一模一样的,这保证了实验的“可复现性” (int 或 None)。
  • 是否打乱 (shuffle): 在切分前是否先打乱数据的顺序(bool,默认为 True)。
  • 分层抽样 (stratify): 传入类别的标签数组。如果数据集分类极不平衡(如正样本 90 个,负样本 10 个),设为 stratify=y 可以保证切分后的训练集和测试集中,正负样本的比例依然保持 9:1,防止某一个类别全被切进测试集 (array-like 或 None)。

返回值:

  • 成功: 返回一个列表,包含了切分后的训练集和测试集。返回的数量是输入 *arrays 数量的两倍(因为每个输入都会被切成 train 和 test 两部分)。

示例:

import numpy as np
from sklearn.model_selection import train_test_split

# 假设 X 是特征矩阵 (10个样本),y 是标签
X = np.arange(20).reshape((10, 2))
y = np.array([0, 0, 0, 0, 0, 1, 1, 1, 1, 1])

# 同时切分 X 和 y,测试集占 20% (即2个样本),固定随机种子42
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

# 如果加入了分层抽样 (保证切分后 0 和 1 的比例一致)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42, stratify=y)

2. KFold - K折交叉验证 (进阶)

作用:在数据量较小或需要极其严谨地评估模型时,我们不仅做一次 train/test 切分,而是将数据集等分为 K 份。每次用其中 1 份做测试集,另外 K-1 份做训练集,循环 K 次。这叫做交叉验证。KFold 类用于生成切分的索引。

sklearn.model_selection.KFold(n_splits=5, shuffle=False, random_state=None)

参数:

  • 折数 (n_splits): 决定将数据集切分为多少份(int,默认值为 5)。通常使用 5 折或 10 折。
  • 是否打乱 (shuffle): 在切分前是否打乱数据 (bool,默认为 False)。建议设为 True。
  • 随机种子 (random_state): 控制打乱的随机性,仅在 shuffle=True 时生效 (int 或 None)。

返回值:

  • 成功: 这是一个生成器类。调用其 .split(X) 方法时,会按轮次产生训练集索引(train_idx)和测试集索引(test_idx)的元组。

示例:

from sklearn.model_selection import KFold
import numpy as np

X = np.array(["a", "b", "c", "d", "e", "f"])

# 实例化 3折 交叉验证对象
kf = KFold(n_splits=3, shuffle=True, random_state=42)

# 循环 3 次,每次产生不同的训练/测试索引
for train_index, test_index in kf.split(X):
    print("训练集索引:", train_index, " 测试集索引:", test_index)
    
# 结果输出:
# 训练集索引: [0 1 2 5]  测试集索引: [3 4]
# 训练集索引: [3 4 5]  测试集索引: [0 1 2]
# ...

(注:对于分类问题,通常使用它的变体 StratifiedKFold,它自带分层抽样功能,保证每折数据的类别比例均匀。)


3. GridSearchCV - 网格搜索调参 (机器学习常用)

作用:模型有很多超参数(比如决策树的深度、学习率大小),手动一个一个试非常麻烦。网格搜索通过提供一个“参数字典”,穷举所有参数的组合,并自动使用交叉验证找出能在验证集上取得最高分的“最优参数组合”。

sklearn.model_selection.GridSearchCV(estimator, param_grid, cv=None, scoring=None)

参数:

  • 模型实例 (estimator): 需要被调优的未训练的模型对象(如 RandomForestClassifier())。
  • 参数网格 (param_grid): 字典或字典列表。键为模型的参数名,值为需要测试的参数值列表 (dict)。
  • 交叉验证折数 (cv): 决定评估时的切分折数 (int,默认为 5)。
  • 评分标准 (scoring): 评估好坏的标准,如分类可用 'accuracy' (准确率),回归可用 'neg_mean_squared_error' (均方误差) (str)。

返回值:

  • 成功: 返回一个封装好的搜索对象。调用 .fit(X, y) 运行后,可以通过 .best_params_ 提取出找出的最牛参数组合。

示例:

from sklearn.model_selection import GridSearchCV
from sklearn.svm import SVC

# 1. 准备要测试的参数池 (一共 2 x 2 = 4 种组合)
param_grid = {
    'C': [0.1, 1, 10], 
    'kernel': ['linear', 'rbf']
}

# 2. 实例化搜索对象
grid_search = GridSearchCV(estimator=SVC(), param_grid=param_grid, cv=5)

# 3. 开始暴力搜索跑数据 (假设已定义好 X_train, y_train)
# grid_search.fit(X_train, y_train)

# 4. 获取得分最高的参数
# print(grid_search.best_params_)  
# 结果类似: {'C': 1, 'kernel': 'rbf'}
Logo

有“AI”的1024 = 2048,欢迎大家加入2048 AI社区

更多推荐