【AI课程领学】第二课 · 机器学习基础(课时 2)模型评估与选择(含混淆矩阵、ROC、交叉验证、深度学习早停)

【AI课程领学】第二课 · 机器学习基础(课时 2)模型评估与选择(含混淆矩阵、ROC、交叉验证、深度学习早停)



欢迎铁子们点赞、关注、收藏!
祝大家逢考必过!逢投必中!上岸上岸上岸!upupup

大多数高校硕博生毕业要求需要参加学术会议,发表EI或者SCI检索的学术论文会议论文。详细信息可扫描博文下方二维码 “学术会议小灵通”或参考学术信息专栏:https://ais.cn/u/mmmiUz
详细免费的AI课程可在这里获取→www.lab4ai.cn


1. 为什么模型评估是机器学习最难的一步?

一个模型训练得好 ≠ 能泛化。

真正衡量模型性能的是:

  • 在从未见过的数据上的表现(泛化性能)。

评估还有一个更重要的问题:

  • 不同任务需要不同指标,没有万能指标。

2. 分类任务的模型评估指标

2.1 混淆矩阵(Confusion Matrix)

实际\预测 Positive Negative
Positive TP FN
Negative FP TN
  • Python 绘制混淆矩阵:
from sklearn.metrics import confusion_matrix
import seaborn as sns
import matplotlib.pyplot as plt

y_true = [1,0,1,1,0]
y_pred = [1,0,0,1,0]

cm = confusion_matrix(y_true, y_pred)

sns.heatmap(cm, annot=True, fmt="d", cmap="Blues")
plt.xlabel("Predicted")
plt.ylabel("Actual")
plt.show()

2.2 Accuracy / Precision / Recall / F1

  • 准确率 Accuracy:整体正确率
  • 精确率 Precision:预测为正类中有多少是真的
  • 召回率 Recall:真正类中有多少被预测出来
  • F1:精确率与召回率的调和平均

2.3 ROC 曲线与 AUC

  • 衡量模型排序能力。
from sklearn.metrics import roc_curve, auc
import matplotlib.pyplot as plt

fpr, tpr, _ = roc_curve(y_true, y_pred)
roc_auc = auc(fpr, tpr)

plt.plot(fpr, tpr, label=f"AUC={roc_auc:.2f}")
plt.xlabel("FPR")
plt.ylabel("TPR")
plt.legend()
plt.show()

3. 回归任务的评估指标

常见指标:

  • MSE(均方误差)
  • RMSE(平方根误差)
  • MAE(平均绝对误差)
  • R²(拟合优度)

代码:

from sklearn.metrics import mean_squared_error, r2_score

mse = mean_squared_error(y_true, y_pred)
r2 = r2_score(y_true, y_pred)

4. 交叉验证(Cross Validation)

  • 用于模型选择 / 调参。
from sklearn.model_selection import cross_val_score
from sklearn.linear_model import Ridge

model = Ridge(alpha=1.0)
scores = cross_val_score(model, X, y, cv=5, scoring='neg_mean_squared_error')
print(scores.mean())

深度学习中,由于训练成本高,不常用 k-fold,但常用:

  • Hold-out(验证集)
  • Early stopping(早停)
  • Learning rate schedule

5. 模型选择:偏差-方差权衡

  • 模型太简单 → 高偏差(欠拟合)
  • 模型太复杂 → 高方差(过拟合)

深度学习通过:

  • Dropout
  • BatchNorm
  • 数据增强
  • 正则化
  • Early Stopping

来解决方差过高的问题。

6. 小结

本课的核心目标:

  • 选择正确的指标评价模型
  • 使用验证集或交叉验证进行调参
  • 控制偏差与方差,避免过拟合

下一课进入核心算法:线性模型(机器学习和深度学习的底座)。

Logo

有“AI”的1024 = 2048,欢迎大家加入2048 AI社区

更多推荐