引言

大家好,今天继续我们的"每天一个Python小技巧"系列。今天给大家分享Python中生成混淆矩阵的几种实用方法,帮助大家更好地评估分类模型性能。混淆矩阵是机器学习中非常重要的评估工具,它能直观展示模型的分类效果。

生成

使用scikit-learn基础方法

# pip install scikit-learn  -i https://pypi.tuna.tsinghua.edu.cn/simple
from sklearn.metrics import confusion_matrix
import numpy as np

# 模拟真实标签和预测标签
y_true = np.array([0, 1, 0, 1, 1, 0, 1, 0, 0, 1])
y_pred = np.array([1, 1, 0, 1, 0, 0, 1, 0, 1, 1])

# 生成基础混淆矩阵
cm = confusion_matrix(y_true, y_pred)
print("基础混淆矩阵:")
print(cm)
"""
[[3 2]
 [1 4]]
"""

# 三分类问题示例
y_true_multi = [0, 1, 2, 0, 1, 2, 0, 1, 2]
y_pred_multi = [0, 1, 1, 0, 2, 2, 0, 1, 2]

cm_multi = confusion_matrix(y_true_multi, y_pred_multi)
print("多分类混淆矩阵:")
print(cm_multi)
"""
[[3 0 0]
 [0 2 1]
 [0 1 2]]
"""

# 归一化
cm_normalized = confusion_matrix(y_true_multi, y_pred_multi, normalize='true')
print("\n归一化混淆矩阵(按行):")
print(np.round(cm_normalized, 2))
"""
[[1.   0.   0.  ]
 [0.   0.67 0.33]
 [0.   0.33 0.67]]
"""

使用pandas的交叉表(数据探索版)

import pandas as pd

df = pd.DataFrame({'Actual': y_true, 'Predicted': y_pred})
cross_tab = pd.crosstab(df['Actual'], df['Predicted'], 
                       rownames=['Actual'], 
                       colnames=['Predicted'],
                       margins=True)

print(cross_tab)
"""
Predicted  0  1  All
Actual
0          3  2    5
1          1  4    5
All        4  6   10
"""

可视化

使用seaborn可视化混淆矩阵

# pip install seaborn matplotlib -i https://pypi.tuna.tsinghua.edu.cn/simple
import seaborn as sns
import matplotlib.pyplot as plt

plt.figure(figsize=(10,6))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', 
            xticklabels=['predict -', 'predict +'],
            yticklabels=['true -', 'true +'])
plt.title('cm')
plt.xlabel('predict')
plt.ylabel('true')
plt.show()
plt.savefig('aa.png')

使用yellowbrick的ConfusionMatrix(交互式版)

# pip install yellowbrick  -i https://pypi.tuna.tsinghua.edu.cn/simple
from yellowbrick.classifier import ConfusionMatrix
from sklearn.ensemble import RandomForestClassifier
from sklearn.datasets import load_iris

# 加载数据
data = load_iris()
X, y = data.data, data.target

# 创建可视化器
model = RandomForestClassifier()
cm = ConfusionMatrix(model, classes=data.target_names)
cm.fit(X, y)
cm.score(X, y)
cm.show()
cm.poof('cm.png')

结论

建议日常使用scikit-learn+seaborn可视化的方法

Logo

有“AI”的1024 = 2048,欢迎大家加入2048 AI社区

更多推荐