每天一个Python小技巧:混淆矩阵的高效生成+画图(七)
大家好,今天继续我们的"每天一个Python小技巧"系列。今天给大家分享Python中生成混淆矩阵的几种实用方法,帮助大家更好地评估分类模型性能。混淆矩阵是机器学习中非常重要的评估工具,它能直观展示模型的分类效果。
·
引言
大家好,今天继续我们的"每天一个Python小技巧"系列。今天给大家分享Python中生成混淆矩阵的几种实用方法,帮助大家更好地评估分类模型性能。混淆矩阵是机器学习中非常重要的评估工具,它能直观展示模型的分类效果。
生成
使用scikit-learn基础方法
# pip install scikit-learn -i https://pypi.tuna.tsinghua.edu.cn/simple
from sklearn.metrics import confusion_matrix
import numpy as np
# 模拟真实标签和预测标签
y_true = np.array([0, 1, 0, 1, 1, 0, 1, 0, 0, 1])
y_pred = np.array([1, 1, 0, 1, 0, 0, 1, 0, 1, 1])
# 生成基础混淆矩阵
cm = confusion_matrix(y_true, y_pred)
print("基础混淆矩阵:")
print(cm)
"""
[[3 2]
[1 4]]
"""
# 三分类问题示例
y_true_multi = [0, 1, 2, 0, 1, 2, 0, 1, 2]
y_pred_multi = [0, 1, 1, 0, 2, 2, 0, 1, 2]
cm_multi = confusion_matrix(y_true_multi, y_pred_multi)
print("多分类混淆矩阵:")
print(cm_multi)
"""
[[3 0 0]
[0 2 1]
[0 1 2]]
"""
# 归一化
cm_normalized = confusion_matrix(y_true_multi, y_pred_multi, normalize='true')
print("\n归一化混淆矩阵(按行):")
print(np.round(cm_normalized, 2))
"""
[[1. 0. 0. ]
[0. 0.67 0.33]
[0. 0.33 0.67]]
"""
使用pandas的交叉表(数据探索版)
import pandas as pd
df = pd.DataFrame({'Actual': y_true, 'Predicted': y_pred})
cross_tab = pd.crosstab(df['Actual'], df['Predicted'],
rownames=['Actual'],
colnames=['Predicted'],
margins=True)
print(cross_tab)
"""
Predicted 0 1 All
Actual
0 3 2 5
1 1 4 5
All 4 6 10
"""
可视化
使用seaborn可视化混淆矩阵
# pip install seaborn matplotlib -i https://pypi.tuna.tsinghua.edu.cn/simple
import seaborn as sns
import matplotlib.pyplot as plt
plt.figure(figsize=(10,6))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
xticklabels=['predict -', 'predict +'],
yticklabels=['true -', 'true +'])
plt.title('cm')
plt.xlabel('predict')
plt.ylabel('true')
plt.show()
plt.savefig('aa.png')
使用yellowbrick的ConfusionMatrix(交互式版)
# pip install yellowbrick -i https://pypi.tuna.tsinghua.edu.cn/simple
from yellowbrick.classifier import ConfusionMatrix
from sklearn.ensemble import RandomForestClassifier
from sklearn.datasets import load_iris
# 加载数据
data = load_iris()
X, y = data.data, data.target
# 创建可视化器
model = RandomForestClassifier()
cm = ConfusionMatrix(model, classes=data.target_names)
cm.fit(X, y)
cm.score(X, y)
cm.show()
cm.poof('cm.png')
结论
建议日常使用scikit-learn+seaborn可视化的方法
更多推荐



所有评论(0)