学习曲线(Learning Curve)和验证曲线(Validation Curve)是用于诊断机器学习模型性能的重要工具。它们可以帮助识别模型是否过拟合(高方差)、欠拟合(高偏差)或者刚刚好。下面分别介绍这两种曲线的概念和如何使用它们。

学习曲线(Learning Curve)

学习曲线是在训练集大小不同时,通过绘制模型训练集和交叉验证集上的准确率来观察模型在新数据上的表现,进而判断模型的方差或偏差是否过高,以及增大训练集是否可以减小过拟合。

如何绘制学习曲线

  1. 准备数据:加载数据并划分训练集和验证集。

  2. 定义模型:选择要使用的机器学习模型。

  3. 计算误差:对于不同的训练集大小,训练模型并记录训练误差和验证误差。

  4. 绘图:将训练误差和验证误差随训练集大小的变化绘制出来。

示例代码

import numpy as np
import matplotlib.pyplot as plt
from sklearn.datasets import load_boston
from sklearn.model_selection import learning_curve
from sklearn.linear_model import LinearRegression
​
# 加载数据
boston = load_boston()
X, y = boston.data, boston.target
​
# 定义模型
model = LinearRegression()
​
# 计算学习曲线
train_sizes, train_scores, validation_scores = learning_curve(
    estimator=model,
    X=X,
    y=y,
    train_sizes=np.linspace(0.1, 1.0, 10),
    cv=5,
    scoring='neg_mean_squared_error'
)
​
# 计算平均值
train_scores_mean = -train_scores.mean(axis=1)
validation_scores_mean = -validation_scores.mean(axis=1)
​
# 绘制学习曲线
plt.figure()
plt.plot(train_sizes, train_scores_mean, label='Training error')
plt.plot(train_sizes, validation_scores_mean, label='Validation error')
plt.xlabel('Training Set Size')
plt.ylabel('Mean Squared Error')
plt.title('Learning Curve')
plt.legend()
plt.show()

验证曲线(Validation Curve)

验证曲线展示了模型在某个超参数的不同取值下的训练误差和验证误差的变化情况。通过观察这些曲线,可以了解模型在不同超参数设置下的性能,并选择最优的超参数。

如何绘制验证曲线

  1. 准备数据:加载数据并划分训练集和验证集。

  2. 定义模型:选择要使用的机器学习模型。

  3. 选择超参数范围:确定要调整的超参数及其取值范围。

  4. 计算误差:对于不同的超参数值,训练模型并记录训练误差和验证误差。

  5. 绘图:将训练误差和验证误差随超参数值的变化绘制出来。

示例代码

from sklearn.model_selection import validation_curve
from sklearn.svm import SVR
​
# 加载数据
boston = load_boston()
X, y = boston.data, boston.target
​
# 定义模型
model = SVR()
​
# 选择超参数范围
param_range = np.logspace(-6, -1, 5)
​
# 计算验证曲线
train_scores, validation_scores = validation_curve(
    estimator=model,
    X=X,
    y=y,
    param_name='C',
    param_range=param_range,
    cv=5,
    scoring='neg_mean_squared_error'
)
​
# 计算平均值
train_scores_mean = -train_scores.mean(axis=1)
validation_scores_mean = -validation_scores.mean(axis=1)
​
# 绘制验证曲线
plt.figure()
plt.plot(param_range, train_scores_mean, label='Training error')
plt.plot(param_range, validation_scores_mean, label='Validation error')
plt.xlabel('C')
plt.ylabel('Mean Squared Error')
plt.title('Validation Curve for C')
plt.xscale('log')  # 对数尺度
plt.legend()
plt.show()

解释曲线

  • 学习曲线

    • 如果训练误差和验证误差都较高且接近,说明模型欠拟合(高偏差)。

    • 如果训练误差较低而验证误差较高,且两者差距较大,说明模型过拟合(高方差)。

    • 如果随着训练集大小的增加,验证误差逐渐下降并趋于稳定,说明模型正在逐渐改善其泛化能力。

  • 验证曲线

    • 观察训练误差和验证误差随超参数变化的趋势,选择使验证误差最小的超参数值。

    • 如果训练误差和验证误差之间的差距较大,说明模型对特定的超参数值非常敏感,可能存在过拟合。

Logo

有“AI”的1024 = 2048,欢迎大家加入2048 AI社区

更多推荐