决策树算法详解:从入门到精通

一、什么是决策树?——灵魂三问帮你理解

想象你是一个月老,你的工作是判断一对男女是否适合结婚。你会怎么判断?

你肯定不会用一个简单的“是”或“否”来概括,而是会问一连串的问题:

  1. 第一问:他们的年龄差大吗?
    • 如果年龄差 > 10岁:可能会有代沟,直接判断为 “不太合适”
    • 如果年龄差 ≤ 10岁:年龄不是主要问题,进入下一个问题
  2. 第二问:他们的兴趣爱好相似吗?
    • 如果完全不相似:可能玩不到一块儿,判断为 “不太合适”
    • 如果有共同爱好:有戏,进入下一个问题
  3. 第三问:他们的消费观念一致吗?
    • 如果一个节俭,一个奢侈:日后容易为钱吵架,判断为 “不太合适”
    • 如果双方都理性消费:天作之合,判断为 “非常合适”

你看,我们通过一连串的是非问题,就像一棵倒着长的树一样,最终把不同的人分到了不同的类别里。这就是决策树的核心思想:通过一系列的是/否问题,将数据不断划分,最终到达叶子节点,得到预测结果。

专业说明

决策树是一种树形结构的监督学习算法,用于分类和回归任务。它通过对特征进行提问,将特征空间划分为若干个矩形区域,每个区域对应一个预测值。决策树由根节点(初始问题)、内部节点(中间问题)、分支(答案)和叶子节点(最终决策)组成。


二、解剖一棵树——组成部分

把上面的过程画成图,就是一棵倒置的树:

                    [根节点: 年龄差大吗?]
                      /              \
                   是/                \否
                    /                  \
            [不合适]           [内部节点: 兴趣相似吗?]
                                  /              \
                               是/                \否
                                /                  \
              [内部节点: 消费观念一致吗?]           [不合适]
                      /              \
                   是/                \否
                    /                  \
              [非常合适]               [不合适]
  • 根节点:最顶端的问题,包含所有数据
  • 内部节点:中间的问题,代表对特征的测试
  • 分支:问题的答案
  • 叶子节点:最终的决策结果

三、树是怎么“学习”的?——选择最佳问题

决策树的核心在于:如何自动选择出“最好”的问题作为根节点和内部节点?

举个例子:相亲数据集

假设我们有8对情侣的数据,以及他们最终是否“幸福”:

情侣 年龄差 兴趣相似 消费观念一致 幸福?
1
2
3
4
5
6
7
8

初始状态(根节点):8个样本,2个“是”(幸福),6个“否”(不幸福)。这是一个相当混乱的状态。

专业说明:信息熵的计算

信息熵衡量系统的混乱程度,公式为:
H(X)=−∑i=1nP(xi)log⁡2P(xi)H(X) = -\sum_{i=1}^{n} P(x_i) \log_2 P(x_i)H(X)=i=1nP(xi)log2P(xi)

根节点的熵计算:
H(根)=−[28log⁡2(28)+68log⁡2(68)]≈0.811 比特H(根) = - [\frac{2}{8} \log_2(\frac{2}{8}) + \frac{6}{8} \log_2(\frac{6}{8})] \approx 0.811 \text{ 比特}H()=[82log2(82)+86log2(86)]0.811 比特

尝试问题A:先问“年龄差大吗?”

  • 左分支(年龄差=大):情侣1、2、8,3个人,幸福人数0
    • 熵 = 0(完全纯净)
  • 右分支(年龄差=小):情侣3、4、5、6、7,5个人,幸福2人,不幸福3人
    • 熵 = -[0.4×log₂(0.4) + 0.6×log₂(0.6)] ≈ 0.971

加权平均熵 = (3/8)×0 + (5/8)×0.971 = 0.607

信息增益 = 0.811 - 0.607 = 0.204

专业说明

决策树的学习过程就是最大化信息增益的过程。信息增益 = 父节点熵 - 子节点加权平均熵。增益越大,说明该特征对分类的贡献越大。算法会选择信息增益最大的特征作为当前节点的分裂特征。


四、主要决策树算法对比

1. ID3 (Iterative Dichotomiser 3)

通俗理解:ID3是决策树的“祖师爷”,它用信息增益来选择特征。但它比较“死板”,只能处理离散特征,而且会把树长得特别深,容易过拟合。

专业说明

  • 提出时间:1986年,Ross Quinlan
  • 划分标准:信息增益
  • 特征类型:只能处理离散型特征
  • 缺点:不能处理连续值、缺失值,无剪枝,容易过拟合

2. C4.5 (ID3的改进版)

通俗理解:C4.5是ID3的“超级升级版”。它更聪明,不仅能处理数字(连续值),还能处理数据不全的情况,而且学会了“修剪枝叶”,不让树长得太疯。

专业说明

  • 提出时间:1993年,Ross Quinlan
  • 划分标准:信息增益率(克服ID3偏向多值特征的缺点)
  • 改进:可处理连续特征、缺失值,引入剪枝策略

3. CART (Classification and Regression Tree)

通俗理解:CART是一个“全能型选手”。它不仅能做分类题,还能做回归题(预测具体数值)。它习惯把问题拆成“是”和“否”的二分法,逻辑非常清晰。

专业说明

  • 提出时间:1984年,Breiman等
  • 划分标准:分类用基尼系数,回归用均方误差
  • 树结构:严格的二叉树
  • 特点:当前最主流的决策树实现(如sklearn中的DecisionTreeClassifier和DecisionTreeRegressor)

4. CHAID (Chi-squared Automatic Interaction Detector)

通俗理解:CHAID像一个严谨的统计学家。它不相信直觉,只相信统计检验的“显著性”,会先做卡方检验,只有显著相关的特征才会被使用。

专业说明

  • 划分标准:卡方检验(分类)或F检验(回归)
  • 树结构:多叉树
  • 特点:基于统计显著性,善于合并类别,主要用于市场调研、社会科学等领域

五、基尼系数 vs 信息熵

基尼系数

通俗理解:基尼系数可以理解为“从数据集中随机抽取两个样本,其类别标签不一致的概率”。

计算公式
Gini=1−∑i=1npi2Gini = 1 - \sum_{i=1}^{n} p_i^2Gini=1i=1npi2

例子

  • 节点全是“是”:( p_{是}=1, p_{否}=0 ),Gini = 1 - (1² + 0²) = 0(最纯净)
  • 节点一半“是”一半“否”:( p_{是}=0.5, p_{否}=0.5 ),Gini = 1 - (0.5² + 0.5²) = 0.5(最混乱)

信息熵

通俗理解:信息熵衡量系统的混乱程度,熵越高,系统越混乱。

计算公式
H=−∑i=1npilog⁡2(pi)H = -\sum_{i=1}^{n} p_i \log_2(p_i)H=i=1npilog2(pi)

例子

  • 全“是”:H = - (1×log₂1 + 0×log₂0) = 0
  • 一半“是”一半“否”:H = - (0.5×log₂0.5 + 0.5×log₂0.5) = 1

专业说明:两者的区别

特征 基尼系数 信息熵
计算速度 更快(无对数运算) 较慢(有对数运算)
取值范围 0 到 1-1/n 0 到 log₂(n)
敏感性 对纯度变化较平滑 对纯度变化更敏感
实际效果 绝大多数情况与熵接近 与基尼系数差异很小

在实际应用中,基尼系数是更主流的选择(如sklearn默认设置),因为计算更快且效果相当。


六、剪枝——防止树“长疯”的艺术

为什么要剪枝?

通俗理解:决策树太“勤奋”了。如果不加限制,它会一直生长,直到把所有训练数据都正确分类。这会导致过拟合——树不仅学到了规律,还把每个训练样本的噪声和异常值都记住了。

打个比方:一个学生背下了整本习题册的答案,但没学会解题方法。考试时题目一换,他就懵了。

预剪枝

通俗理解:“边生长边剪”。在树构建过程中,提前设置停止条件。

常见停止条件

  • 最大深度:树最多问几个问题
  • 节点最小样本数:样本太少就不再分裂
  • 叶子节点最小样本数:分裂后叶子太小就不分裂

专业说明
预剪枝效率高,适合大规模数据,但有“视野局限”——可能当前分裂增益不大,但能为后续更好分裂创造条件,预剪枝可能错过这些机会。

后剪枝

通俗理解:“先长成,再修剪”。让树充分生长,然后自底向上,把那些不可靠的枝叶剪掉。

CART的后剪枝:最小代价复杂性剪枝法(MCCP)

核心公式
代价复杂度=误差+α×叶子节点数\text{代价复杂度} = \text{误差} + \alpha \times \text{叶子节点数}代价复杂度=误差+α×叶子节点数

  • 误差:树在训练集上的误分类率(分类)或均方误差(回归)
  • 叶子节点数:代表树的复杂度
  • α (alpha):惩罚系数,调节对复杂度的惩罚力度

专业说明
当α=0时,只关心误差最小,选最深的树;当α增大时,开始为每个叶子节点“付费”,倾向于选更简单的树。通过交叉验证选择使交叉验证误差最小的α,得到最终树。

Python中的剪枝实现

from sklearn.tree import DecisionTreeClassifier, DecisionTreeRegressor

# 分类树的预剪枝
clf = DecisionTreeClassifier(
    max_depth=5,              # 树的最大深度
    min_samples_split=10,      # 内部节点再划分所需最小样本数
    min_samples_leaf=5,        # 叶子节点最少样本数
    max_leaf_nodes=20,         # 最大叶子节点数
    min_impurity_decrease=0.01 # 分裂使不纯度下降的最小值
)

# 回归树的预剪枝
reg = DecisionTreeRegressor(
    max_depth=5,
    min_samples_split=10,
    min_samples_leaf=5
)

# 后剪枝(代价复杂度剪枝)
clf = DecisionTreeClassifier()
path = clf.cost_complexity_pruning_path(X_train, y_train)
ccp_alphas = path.ccp_alphas  # 获取一系列alpha值
# 对每个alpha训练树,用验证集选出最优

七、分类变量的处理——一个棘手的问题

名义变量 vs 等级变量

通俗理解:想象你去电影院看电影。

  • 名义变量:就像座位号(3排5座)。这个数字只是标签,告诉你位置在哪儿,但你不能说5座比6座好。
  • 等级变量:就像座位等级(普通厅 < VIP厅 < 总统厅)。有明确顺序,但不知道具体好多少。

专业说明

特征 名义变量 等级变量
核心意义 是什么 怎么样
有无顺序 (只是标签) (明确顺序)
例子 民族、性别、血型 学历、满意度、尺码

高基数分类变量的挑战——以56个民族为例

通俗理解
如果你把56个民族直接变成56个0/1特征(独热编码),会出现几个问题:

  1. 维度爆炸:特征一下子多了56个
  2. 数据稀疏:每个样本只有1个是1,其余55个都是0
  3. 失去语义:汉族和藏族在模型眼里,距离和汉族与“外星民族”一样远

专业说明
SAS EM等工具会尝试所有可能的组合(如{汉族,满族} vs. {其他})来找到最优分裂,但计算量巨大(2^55种组合)。sklearn的决策树采用独热编码+逐个特征尝试的方式,虽然计算快,但可能错过某些组合。

解决方案

1. 目标编码

# 将民族替换为该民族的目标变量均值
# 例如:汉族的好客户比例=0.8,藏族=0.6,回族=0.7
# 民族 → [0.8, 0.6, 0.7, ...]

2. 使用LightGBM/CatBoost
这两个框架原生支持分类特征,内部实现了高效的最优组合查找。

3. 业务合并
按语系、地域等合并为高层次的类别(如“汉族”、“西南少数民族”、“西北少数民族”)。


八、决策树回归——当你要预测数值而不是类别

从分类到回归:问题变了

通俗理解
之前我们问“幸福还是不幸福?”——这是分类问题,答案是非此即彼的类别。

但现在问题变了:

  • 预测一套房子的价格是多少万
  • 预测一个人明天的体温是多少度
  • 预测某支股票的涨跌幅是多少

这些都是回归问题——我们要预测的是一个具体的数值,而不是一个类别。

回归树的核心思想:还是那棵树,但叶子不同

分类树的叶子节点上,存的是类别(如“幸福”或“不幸福”)。

回归树的叶子节点上,存的是数值——具体来说是落在这个叶子节点里的所有训练样本的平均值

一个具体的例子:预测房价

假设我们有这样的数据:

房子 面积(㎡) 卧室数 位置 实际价格(万)
1 80 2 郊区 200
2 85 2 郊区 210
3 90 3 郊区 250
4 100 3 市区 400
5 110 3 市区 420
6 120 4 市区 500

回归树的学习过程

  1. 第一问:位置是市区还是郊区?
    • 左枝(郊区):房子1、2、3,价格 = [200, 210, 250]
    • 右枝(市区):房子4、5、6,价格 = [400, 420, 500]
  2. 第二问(郊区枝):面积 > 87㎡?
    • 左枝(面积≤87):房子1、2,价格 = [200, 210] → 预测值 = 205万
    • 右枝(面积>87):房子3,价格 = [250] → 预测值 = 250万
  3. 第二问(市区枝):面积 > 105㎡?
    • 左枝(面积≤105):房子4、5,价格 = [400, 420] → 预测值 = 410万
    • 右枝(面积>105):房子6,价格 = [500] → 预测值 = 500万

生成的回归树:

                    [根节点: 位置?]
                      /          \
                    /            \
              [郊区]              [市区]
                /                    \
           [面积>87?]              [面积>105?]
            /      \                /      \
       205万    250万           410万    500万

回归树的划分标准:均方误差

核心思想:找到一个特征和切分点,使得切分后左右子节点的加权均方误差最小

均方误差(MSE)公式
MSE=1n∑i=1n(yi−yˉ)2MSE = \frac{1}{n} \sum_{i=1}^{n} (y_i - \bar{y})^2MSE=n1i=1n(yiyˉ)2

其中 (\bar{y}) 是该节点所有样本的平均值。

通俗理解

  • 一个节点的MSE越小,说明里面的数值越集中(比如[200, 201, 199])
  • 一个节点的MSE越大,说明里面的数值越分散(比如[200, 300, 100])

计算例子:用“位置”划分

根节点(所有6个房子):

  • 平均价格 = (200+210+250+400+420+500)/6 ≈ 330万
  • MSE = [(200-330)² + (210-330)² + (250-330)² + (400-330)² + (420-330)² + (500-330)²]/6 = 13267

用“位置”划分后

  • 左节点(郊区)MSE = 467
  • 右节点(市区)MSE = 1867
  • 加权平均MSE = (3/6)×467 + (3/6)×1867 = 1167

划分带来的MSE下降 = 13267 - 1167 = 12100,下降非常大,说明这个划分很好。

回归树的预测过程

当你要预测一个新房子时:

  1. 从根节点开始,根据特征值一路向下
  2. 最终到达一个叶子节点
  3. 预测值 = 这个叶子节点里所有训练样本的平均值

为什么用平均值?
因为均方误差的最小化,就是在每个节点上用平均值作为预测值。数学上可以证明,对于给定的节点,使MSE最小的预测值就是该节点所有样本的平均值。

回归树的局限性

  1. 预测是阶梯函数:预测结果不是平滑曲线,而是一段一段的常数
  2. 无法外推:不能预测超出训练数据范围的值
  3. 对线性关系拟合不佳:需要用很多阶梯去近似直线

九、分类树 vs 回归树:完整对比

方面 分类树 回归树
预测目标 类别(如幸福/不幸福) 数值(如房价)
划分标准 基尼系数、信息熵 均方误差、平均绝对误差
叶子节点内容 多数类别 平均值
预测输出 类别标签 连续数值
评估指标 准确率、ROC-AUC、PR曲线 MSE、R²、MAE
是否支持CART ✅ 是 ✅ 是(CART中的R)
过拟合表现 在训练集上准确率极高 在训练集上MSE极低

十、模型评估(一)——ROC曲线和AUC

为什么不用准确率?

通俗理解
假设你开发了一个诊断罕见癌症的模型(只有5%的人得病),模型准确率达到95%。你很高兴,但真相是:模型可能什么都没学会,只是把所有人都预测为“健康”。因为95%的人本来就没病,所以猜“健康”的正确率就是95%。

这个模型不会发现任何真正的病人,但在准确率指标上却表现很好。这就是数据不平衡时准确率的局限性

混淆矩阵

首先定义:我们关心的是发现“敌机”(正例)。

预测为敌机 预测为友军
真实是敌机 TP (真正例) FN (假负例) ❌ 漏报
真实是友军 FP (假正例) ❌ 误报 TN (真负例)

两个核心指标

  • TPR(真正例率) = 召回率 = TP/(TP+FN)
    • 通俗理解:“在所有真正的敌机中,你识别出了多少?”
    • 希望越高越好
  • FPR(假正例率) = FP/(FP+TN)
    • 通俗理解:“在所有真正的友军中,你有多少误认为是敌机?”
    • 希望越低越好

ROC曲线

通俗理解
你心里有一个“开火阈值”。阈值设得高,你很少误伤,但容易漏掉敌机;阈值设得低,你很少漏掉敌机,但容易误伤。

ROC曲线就是研究阈值变化时,TPR和FPR如何变化的工具。取遍所有可能的阈值,把(FPR, TPR)点连起来,就得到ROC曲线。

完美模型:存在一个阈值,能抓到所有敌机(TPR=1),同时不误伤任何友军(FPR=0)。曲线从(0,0)冲到(0,1),再到(1,1)。

随机猜测:TPR和FPR总是同步上升,曲线是从(0,0)到(1,1)的对角线。

AUC——一把尺子量到底

通俗理解
AUC是ROC曲线下方的面积。它有一个非常直观的统计意义:

随机从正例(敌机)和负例(友军)中各抽一个样本,AUC就是模型将正例排到负例前面的概率。

  • AUC = 0.5:等于随机猜测,没用
  • AUC = 1:完美区分
  • AUC越接近1,模型越好

专业说明
AUC的优点:①不受数据不平衡影响(数学公式上);②与阈值无关,衡量模型整体排序能力;③直观易懂。


十一、模型评估(二)——PR曲线:当正例稀缺时的利器

PR曲线是什么?

通俗理解
假设你是警察,要在1000个人中找出10个通缉犯。用ROC曲线时,FPR的分母是990个普通人,只要误报几个,FPR还是很小,ROC曲线看起来可能还不错。

但你真正关心的是:在你抓回来的人里,有多少是真的通缉犯? 这就是PR曲线要回答的问题。

PR曲线的横纵轴是:

  • 横轴:召回率(Recall) = TPR = TP/(TP+FN)
    • 问:所有通缉犯里,我抓到了多少?
  • 纵轴:精确率(Precision) = TP/(TP+FP)
    • 问:在我抓回来的人里,有多少是真的通缉犯?

一个具体的例子

还是那个雷达兵场景,但这次敌机很少(5架敌机,95架友军):

阈值 TP FP FN TN 召回率 精确率
高(0.9) 3 1 2 94 3/5=0.6 3/4=0.75
中(0.5) 4 10 1 85 4/5=0.8 4/14≈0.29
低(0.1) 5 50 0 45 5/5=1.0 5/55≈0.09

解读

  • 阈值高时:你很谨慎,抓回4个人,3个是真敌机(精确率75%),但漏掉了2个(召回率60%)
  • 阈值低时:你宁可错杀一千,抓回55个人,虽然抓到了所有敌机(召回率100%),但只有5个是真的(精确率9%)

PR曲线就是把这些点连起来,展示召回率和精确率之间的权衡


十二、ROC曲线 vs PR曲线:核心区别

一个关键追问:ROC曲线到底受不受数据不平衡影响?

这是一个容易混淆的点,需要讲清楚。

从数学公式上看

  • TPR = TP/(TP+FN) —— 分母只涉及真正的正例,与负例数量无关
  • FPR = FP/(FP+TN) —— 分母只涉及真正的负例,与正例数量无关

当负例从1,000增加到1,000,000时,如果模型性能不变,TPR和FPR都会保持不变。所以从数学定义上,ROC曲线确实与正负例比例无关

但为什么实际中要说“在极度不平衡时要用PR曲线”?

因为虽然FPR的数学定义不变,但FPR的“实际含义”变了

场景 FPR=0.01的含义 实际误报人数
平衡数据(500正,500负) 误报1%的负例 5人
不平衡(100正,999,900负) 误报1%的负例 9,999人

同样的FPR,在实际应用中可能意味着完全不同的后果。ROC曲线看不到这个“绝对值”的问题,它只关心“比例”。而PR曲线直接面对这个绝对值问题,因为它的分母是TP+FP(你预测为正的总人数),这个数字直接反映了误报的规模。

用一个极端例子说明

场景:100万人做癌症筛查,只有100个真病人(0.01%)。

模型A:随机猜,把1%的人判为癌症

  • ROC点:TPR=0.01, FPR=0.01(在对角线上,AUC≈0.5,正确反映是随机)
  • PR点:Precision=0.01(预测为正的人里只有1%是真的)

模型B:稍好的模型,抓到50个病人,误报5,000个健康人

  • ROC点:TPR=0.5, FPR=0.005(看起来不错,远高于对角线)
  • PR点:Precision=50/(50+5,000)≈0.01(还是只有1%!)

问题:ROC曲线说模型B很好(TPR=0.5, FPR很低),但PR曲线揭示真相:为了找出50个病人,你要让5,000个健康人承受不必要的恐慌和检查。这个模型在实际中可能根本不可用。

区别总结

方面 ROC曲线 PR曲线
数学上是否受不平衡影响 不受影响(公式决定) 受影响(依赖正例比例)
实际应用中是否“好用” 不平衡时可能“过于乐观” 不平衡时更“真实”
能看出误报的绝对规模吗 不能(只看比例) (通过精确率反映)
关注点 模型区分正负例的整体能力 模型在正例上的预测质量
横轴 FPR = FP/(FP+TN) Recall = TP/(TP+FN)
纵轴 TPR = TP/(TP+FN) Precision = TP/(TP+FP)

什么时候用哪个?

问题:你的数据平衡吗?
├─ 是,正负例差不多 → 用ROC曲线(足够了)
└─ 否,极度不平衡
   ├─ 你想了解模型的整体排序能力 → 可以用ROC曲线,但要记住:
   │    ROC曲线上的“好”可能只是“比例好”,实际误报可能很多
   └─ 你想了解模型在实际应用中的价值 → 必须用PR曲线
       它能告诉你:当你让模型去抓正例时,抓对的比例有多高

专业建议:两个一起看

在实际工作中,两个曲线一起看是最好的:

  1. ROC曲线 + AUC:快速了解模型的整体排序能力,便于不同模型间比较
  2. PR曲线 + AP(平均精确率):深入了解模型在正例上的表现,评估业务可行性

一句话总结两个曲线的核心差异

  • ROC曲线问:敌机和友军,模型能分开吗?(数学上公平)
  • PR曲线问:当模型说“这是敌机”时,到底对不对?(业务上真实)

两者不是矛盾,而是互补——ROC告诉你模型有没有区分能力,PR告诉你这种能力在实际中值不值得用。


十三、Python实战:完整的决策树代码示例

分类树示例

# -*- coding: utf-8 -*-
"""
决策树分类实战:鸢尾花数据集
"""
from sklearn.datasets import load_iris
from sklearn.tree import DecisionTreeClassifier, plot_tree
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, confusion_matrix, roc_auc_score
import matplotlib.pyplot as plt
import seaborn as sns

# 加载数据
iris = load_iris()
X, y = iris.data, iris.target

# 划分训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(
    X, y, test_size=0.3, random_state=42, stratify=y
)

# 创建分类树(带预剪枝)
clf = DecisionTreeClassifier(
    max_depth=3,
    min_samples_split=5,
    min_samples_leaf=2,
    random_state=42
)

# 训练
clf.fit(X_train, y_train)

# 预测
y_pred = clf.predict(X_test)
y_pred_proba = clf.predict_proba(X_test)

# 评估
print("分类报告:")
print(classification_report(y_test, y_pred, target_names=iris.target_names))

# 可视化决策树
plt.figure(figsize=(20, 10))
plot_tree(clf, feature_names=iris.feature_names, 
          class_names=iris.target_names, filled=True, rounded=True)
plt.title("鸢尾花分类决策树")
plt.show()

回归树示例

# -*- coding: utf-8 -*-
"""
决策树回归实战:波士顿房价预测(用替代数据集)
"""
from sklearn.datasets import make_regression
from sklearn.tree import DecisionTreeRegressor
from sklearn.model_selection import train_test_split
from sklearn.metrics import mean_squared_error, r2_score
import matplotlib.pyplot as plt
import numpy as np

# 生成模拟数据(实际可用boston数据集,但sklearn已移除)
X, y = make_regression(n_samples=500, n_features=1, noise=20, random_state=42)

# 划分数据集
X_train, X_test, y_train, y_test = train_test_split(
    X, y, test_size=0.2, random_state=42
)

# 创建回归树(不同深度对比)
reg_shallow = DecisionTreeRegressor(max_depth=2, random_state=42)
reg_deep = DecisionTreeRegressor(max_depth=10, random_state=42)

# 训练
reg_shallow.fit(X_train, y_train)
reg_deep.fit(X_train, y_train)

# 预测
y_pred_shallow = reg_shallow.predict(X_test)
y_pred_deep = reg_deep.predict(X_test)

# 评估
print("浅树(max_depth=2)表现:")
print(f"MSE: {mean_squared_error(y_test, y_pred_shallow):.2f}")
print(f"R²: {r2_score(y_test, y_pred_shallow):.2f}")

print("\n深树(max_depth=10)表现:")
print(f"MSE: {mean_squared_error(y_test, y_pred_deep):.2f}")
print(f"R²: {r2_score(y_test, y_pred_deep):.2f}")

# 可视化对比
X_plot = np.linspace(X.min(), X.max(), 300).reshape(-1, 1)
plt.figure(figsize=(12, 5))

plt.subplot(1, 2, 1)
plt.scatter(X_test, y_test, s=20, alpha=0.6, label='真实值')
plt.plot(X_plot, reg_shallow.predict(X_plot), color='red', linewidth=2, label='预测值')
plt.title('浅树回归 (max_depth=2)')
plt.xlabel('X')
plt.ylabel('y')
plt.legend()

plt.subplot(1, 2, 2)
plt.scatter(X_test, y_test, s=20, alpha=0.6, label='真实值')
plt.plot(X_plot, reg_deep.predict(X_plot), color='red', linewidth=2, label='预测值')
plt.title('深树回归 (max_depth=10)')
plt.xlabel('X')
plt.ylabel('y')
plt.legend()

plt.tight_layout()
plt.show()

ROC和PR曲线绘制

# -*- coding: utf-8 -*-
"""
ROC曲线和PR曲线绘制示例
"""
from sklearn.datasets import make_classification
from sklearn.model_selection import train_test_split
from sklearn.tree import DecisionTreeClassifier
from sklearn.metrics import roc_curve, auc, precision_recall_curve, average_precision_score
import matplotlib.pyplot as plt
import numpy as np

# 生成不平衡的二分类数据
X, y = make_classification(
    n_samples=10000, 
    n_features=20,
    n_classes=2,
    weights=[0.95, 0.05],  # 5%的正例
    random_state=42
)

# 划分数据集
X_train, X_test, y_train, y_test = train_test_split(
    X, y, test_size=0.3, random_state=42, stratify=y
)

# 训练决策树
clf = DecisionTreeClassifier(max_depth=5, random_state=42)
clf.fit(X_train, y_train)

# 获取预测概率
y_score = clf.predict_proba(X_test)[:, 1]

# 计算ROC曲线
fpr, tpr, _ = roc_curve(y_test, y_score)
roc_auc = auc(fpr, tpr)

# 计算PR曲线
precision, recall, _ = precision_recall_curve(y_test, y_score)
pr_auc = average_precision_score(y_test, y_score)

# 绘制对比图
plt.figure(figsize=(14, 6))

# ROC曲线
plt.subplot(1, 2, 1)
plt.plot(fpr, tpr, color='darkorange', lw=2, 
         label=f'ROC曲线 (AUC = {roc_auc:.2f})')
plt.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--', label='随机猜测')
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.05])
plt.xlabel('假正例率 (FPR)')
plt.ylabel('真正例率 (TPR)')
plt.title('ROC曲线')
plt.legend(loc="lower right")

# PR曲线
plt.subplot(1, 2, 2)
plt.plot(recall, precision, color='green', lw=2,
         label=f'PR曲线 (AP = {pr_auc:.2f})')
plt.xlabel('召回率 (Recall)')
plt.ylabel('精确率 (Precision)')
plt.title('PR曲线 (数据极度不平衡)')
plt.legend(loc="lower left")
plt.grid(True)

plt.tight_layout()
plt.show()

print(f"ROC-AUC: {roc_auc:.3f}")
print(f"PR-AUC (Average Precision): {pr_auc:.3f}")
print("\n注意:在不平衡数据中,PR曲线比ROC曲线更能反映真实性能。")

十四、完整知识体系:决策树全景图

决策树
├── 分类树
│   ├── 划分标准:基尼系数 / 信息熵
│   ├── 叶子节点:多数类别
│   ├── 评估:准确率、ROC-AUC、PR曲线
│   └── 典型算法:ID3, C4.5, CART分类
│
├── 回归树
│   ├── 划分标准:均方误差 / 平均绝对误差
│   ├── 叶子节点:平均值
│   ├── 评估:MSE、R²、MAE
│   └── 典型算法:CART回归
│
├── 剪枝策略
│   ├── 预剪枝:提前停止生长
│   └── 后剪枝:先长后剪(MCCP)
│
├── 变量处理
│   ├── 名义变量:无顺序,需要编码
│   ├── 等级变量:有顺序,可直接用
│   └── 高基数问题:目标编码、集成模型
│
└── 模型评估
    ├── ROC曲线:衡量整体区分能力
    ├── PR曲线:衡量正例预测质量
    └── 两者互补:数学公平 vs 业务真实

十五、决策树优缺点总结

优点

  1. 可解释性强:可以清晰地看到整个决策过程,像流程图一样
  2. 数据预处理少:不用归一化、标准化
  3. 能处理混合数据:既能处理数值型,也能处理类别型
  4. 白盒模型:易于理解和解释

缺点

  1. 容易过拟合:需要通过剪枝来控制
  2. 不稳定:数据微小变化可能导致完全不同树→集成学习解决
  3. 对高基数分类变量处理不佳:需要特征工程
  4. 回归时预测是阶梯函数:不光滑,无法外推

十六、核心要点回顾

模块 关键点
划分标准 分类:基尼系数、信息熵;回归:均方误差
剪枝 预剪枝(提前停止)、后剪枝(先长后剪)
分类评估 ROC曲线+AUC(整体能力)、PR曲线(正例质量)
回归评估 MSE、R²、MAE
变量类型 名义变量(无顺序)、等级变量(有顺序)
高基数处理 目标编码、LightGBM/CatBoost、业务合并
Logo

有“AI”的1024 = 2048,欢迎大家加入2048 AI社区

更多推荐