![现代决策树模型及其编程实践:从传统决策树到深度决策树](https://wfqqreader-1252317822.image.myqcloud.com/cover/409/44888409/b_44888409.jpg)
2.2.3 CART分类决策树的编程实践
针对2.2.2节的天气与是否打网球的数据集(PlayTennis数据集),我们利用Python和PyTorch编码展示CART分类树模型的细节。
2.2.3.1 整体流程
首先介绍整体流程,如代码段2.1所示。程序主要由四部分组成:数据集加载、模型训练、模型预测和决策树可视化。
代码段2.1 CART分类树测试主程序(源码位于Chapter02/test_CartClassifier.py)
![](https://epubservercos.yuewen.com/01202D/24002389209517606/epubprivate/OEBPS/Images/d2-1.jpg?sign=1739291735-rUoaEXrewZ3hxD1bNv2gfBmxMC8sifhC-0-f50b1ca5b827fc641f69001a6d0939e7)
![](https://epubservercos.yuewen.com/01202D/24002389209517606/epubprivate/OEBPS/Images/056-i.jpg?sign=1739291735-hmmi4bbP2LeNl4acuVM9xf4YQDXoQrgD-0-c1dd71d84689740949401061c7443b0b)
1. 数据集加载(第14~23行)
首先,第17行使用“with open as”语法打开指定数据文件,并获取文件句柄f。该语法会在执行完毕“with open”作用域内的代码后自动关闭数据文件。open函数的第一个参数为数据文件的路径;第二个参数为文件打开方式,“r”代表以只读方式打开;encoding参数指定读取文件的编码格式,“gbk”代表使用GBK编码。
然后,第18行使用csv库读取数据文件内容并将其转换成Python内置的list类型。使用csv库需要在代码片段开头添加“import csv”语句。读取数据时使用csv库的reader函数,参数传入“with open”获取的文件句柄f。
最后,第19~23行利用Python切片和列表生成式将原始数据集分割成属性名列表feature_names、目标变量名y_name、属性集X、目标变量集y。由于本例是解决数据集比较小的分类问题,因此令训练集(X_train和y_train)和测试集(X_test和y_test)使用相同的数据集(X和y),而通常的做法是将原始数据集按8:1:1或6:2:2的比例划分成训练集、测试集和验证集。另外,为了便于使用PyTorch进行GPU加速,我们将数据集从list类型进一步转换为numpy类型(PyTorch内部集成了numpy与Tensor的快速转换方法)。同样,使用numpy库也需要导入相应的包,语句为“import numpy as np”。
2. 决策树模型的训练和生成(第25~32行)
CART分类树的创建和训练过程被封装成CartClassifier类,通过使用“from cart import CartClassifier”将其导入当前环境。
在第27行创建决策树的过程中,触发CartClassifier类的构造函数。在这里,设置use_gpu参数为True,代表启用GPU加速。此外,该构造函数还可以传入其他参数,后面的内容将对其进行展开介绍。
在第31行CART分类树的训练过程中,调用CartClassifier类的train成员函数,传入训练集数据X_train、y_train和feature_names。训练完成后,返回模型数据如下:
![](https://epubservercos.yuewen.com/01202D/24002389209517606/epubprivate/OEBPS/Images/057-i.jpg?sign=1739291735-gVPkB5cmOKZdgm4le1Q32sfK3f4eEjzJ-0-92dfdd67e510eaf27c812c6b8b51c1ab)
该model变量实际上是由一组规则表示的。以上输出结果为决策树的字典(树形结构)数据结构形式,在这棵model树中,从根节点到每个叶子节点的每条路径都代表一条规则。为了更清晰地表示规则,我们可以将以上数据结构转换成“if-then”的格式,如下所示:
![](https://epubservercos.yuewen.com/01202D/24002389209517606/epubprivate/OEBPS/Images/057-2-i.jpg?sign=1739291735-l49BqGR4RqyXCSC9E05tsWPp8Azp4uzq-0-4beb471794cdb8d6eb357ad0c995a9c2)
事实上,CART分类决策树与一组“if-then”规则是等价的。
3. 决策树模型的使用(第34~41行)
在第36~38行的模型预测阶段,调用CartClassifier类的成员函数predict,传入测试集数据X_test,返回numpy.array类型的预测结果y_pred,并且打印输出测试集的真实值y_test和预测值y_pred。
在第39~41行的模型评估阶段,首先使用Python的列表生成式生成测试集中预测值与真实值相等的元素,每种相等的情况用int型变量1表示。然后使用numpy的sum函数对上述列表求和,统计出预测正确的计数。最后打印预测正确的样本计数、总样本计数以及预测准确率。实际执行结果如下:
![](https://epubservercos.yuewen.com/01202D/24002389209517606/epubprivate/OEBPS/Images/057-3-i.jpg?sign=1739291735-VqRK8pkI3Ep2tFR5WXLrt0vqmoYxX3LB-0-1cb577e01ead85940c7e2139899a4a61)
4. 决策树的可视化(第43~45行)
决策树可视化阶段使用了tree_plotter包的tree_plot函数,传入前面训练好的模型model。tree_plotter包是我们使用Matplotlib自定义的决策树绘图包,在随后的内容中我们将详细介绍,在此先展示一下可视化的效果,如图2.9所示。
2.2.3.2 训练和创建过程
首先介绍用到的构造函数,见代码段2.2。从第8~9行可以看到,CartClassifier类的实现依赖于torch和numpy。
在CartClassifier类的构造函数__init__中,需要提供use_gpu和min_samples_split两个参数。其中,use_gpu是一个布尔值,代表该类是否启用GPU加速,默认为False,代表使用CPU;min_samples_split是一个整型数,代表决策树分裂完成后叶子节点的最少样本数,默认值为1,代表树完全分裂。
![](https://epubservercos.yuewen.com/01202D/24002389209517606/epubprivate/OEBPS/Images/2d9.jpg?sign=1739291735-KSwuFejSAZ1DecvwgVKhLdpB2xYNUfcb-0-86cf02077a0a81b9879ff65d0b104674)
图2.9 PlayTennis数据集生成的CART分类树
代码段2.2 CART分类树的构造函数(源码位于Chapter02/CartClassifier.py)
![](https://epubservercos.yuewen.com/01202D/24002389209517606/epubprivate/OEBPS/Images/d2-2.jpg?sign=1739291735-y6cnoYdFsaYRW8m9oEQlWSOikTZb54SP-0-c28277fab0e06d444dd29afacb2cdceb)
另外,构造函数中还维护一系列类的核心变量。其中,self.tree为存储树模型的核心结构,初始情况下为空dict;self.feature_names为存储数据集属性名的numpy数组;self.str_map、self.num_map、self.x_use_map、self.y_use_map为字符串与数值之间的映射器和开关,用于将numpy数组中的字符串类型映射成数值类型,以兼容PyTorch,与之相关的函数接口为__deal_value_map和__get_value,在后文中将逐一介绍。
接下来介绍CART分类树的训练函数train,如代码段2.3所示。由于Python函数中传递的numpy变量是引用,为了避免后续对数据集进行分割时破坏原始数据集,首先在第36~37行执行numpy.array的copy函数制作X和y的副本。然后在第41行进行数据预处理,将X_copy和y_copy中的字符串通过__deal_value_map函数映射成数值。之后在第44~48行将numpy数组X_copy和y_copy转换成Tensor数组,并根据self.use_gpu的值决定是否启用GPU加速。其中,torch.from_numpy函数是PyTorch提供的内置函数,负责将numpy.array数组转化成torch.Tensor格式,torch.Tensor.cuda函数也是PyTorch提供的内置函数,负责对当前的tensor数组启用GPU加速。最后,第51行进入创建CART分类树的核心函数__create_tree。
代码段2.3 CART分类树训练过程(源码位于Chapter02/CartClassifier.py)
![](https://epubservercos.yuewen.com/01202D/24002389209517606/epubprivate/OEBPS/Images/d2-3.jpg?sign=1739291735-9ipDaIIj4XRGzFRGR17OeyIaHL5jpLS0-0-68294adad3799031b89db8836f872bcf)
第41行提到了一个关键的字符串映射函数__deal_value_map,在这里我们对它进行详细介绍。之所以将X_copy和y_copy中的字符串通过该函数映射成数值,是因为在执行计算时,为了实现GPU加速,numpy.array需要转换成PyTorch的torch.Tensor数据结构,而torch.Tensor仅支持数值类型。具体实现过程见代码段2.4。
代码段2.4 建立字符串与数值的映射的函数__deal_value_map(源码位于Chapter02/CartClassifier.py)
![](https://epubservercos.yuewen.com/01202D/24002389209517606/epubprivate/OEBPS/Images/d2-4.jpg?sign=1739291735-lDaWyGtyX5q8h7Ur8B9WO70fwy91cxEt-0-a0b155c39d4948b8ac98a3f601ed5638)
![](https://epubservercos.yuewen.com/01202D/24002389209517606/epubprivate/OEBPS/Images/d2-4-1.jpg?sign=1739291735-wlrDMRwedeZdv2bgQ5TyyIXcFPD96oaj-0-57397f425594d0ee551d5acd1ab18e8a)
代码段2.4展示了从字符串到数值建立映射的过程。首先,在第91~103行处理X,改变self.x_use_map的标记,遍历X中的每个元素,以元素值为key,以当前self.str_map的长度为value,在self.str_map中建立映射,同时,在self.num_map中建立反向的映射。然后,在第106~117行处理y,同理,在y不为None的情况下,在当前self.str_map和self.num_map的基础上继续建立字符串映射。最后,在第120~123行返回映射好的X和y。
代码段2.5则实现将数值映射回字符串的功能。其中分为两种情况:一种是使用了字符串到数值的映射的key(通过to_X、self.x_use_map和self.y_use_map的值可以判别,如第132~133行所示),此时使用self.num_map映射回字符串;另一种是没有使用映射的key(如第134~137行所示),这种情况直接返回key或者key.item(key为Tensor中的数值类型时)的值。
代码段2.5 从数值到字符串的映射中还原值(源码位于Chapter02/CartClassifier.py)
![](https://epubservercos.yuewen.com/01202D/24002389209517606/epubprivate/OEBPS/Images/d2-5.jpg?sign=1739291735-lPgvUVez9oDWHIpcxkNBPxFROljtCu6h-0-ed686327ac21ad385d9e73d263b241ad)
![](https://epubservercos.yuewen.com/01202D/24002389209517606/epubprivate/OEBPS/Images/d2-5-1.jpg?sign=1739291735-dQrxhXhf5LPSwNHtD8bpfH2uvmU8Oxw4-0-747cd1fa97f2d4029257ef9da0cb9822)
接下来,我们回到代码段2.3。在代码段2.3的第51行调用了__create_tree函数,它完成了决策树的创建过程,其具体代码见代码段2.6。
代码段2.6 创建CART分类树的核心代码(源码位于Chapter02/CartClassifier.py)
![](https://epubservercos.yuewen.com/01202D/24002389209517606/epubprivate/OEBPS/Images/d2-6.jpg?sign=1739291735-HRgTL1YgX3QWDjmfE6cg8OdHpgIplczW-0-4466de2445ab80c00790556d425675e1)
在上述代码段中,__create_tree函数是一个递归创建决策树的过程。首先,在第147~153行判断三种递归终止条件:X中样本全部属于同一类别、当前节点样本数小于self.min_samples_split、属性集上的取值均相同。若满足终止条件,则调用__get_value函数返回从数值到字符串的映射值,若未满足终止条件,则继续往下计算。然后,在第155~157行根据基尼增益从属性值中选择最优分裂属性的最优切分点,具体过程如__choose_best_point_to_split函数所示。最后,在第159~169行根据最优切分点对子树进行划分,对于其子树再继续执行__create_tree函数完成划分过程。
在代码段2.6的第153行调用了__majority_y_id函数,它用于计算节点中出现次数最多的类别,具体见代码段2.7。它首先在第191行进行合法性检查,确保输入参数y_tensor的元素个数大于0。然后在第193~197行初始化一个空dict,遍历y_tensor并对其元素进行计数。最后在第199~203行从字典y_count中查找出现次数最多的类别ID(或映射值)。
在代码段2.6的第156行,调用了__choose_best_point_to_split函数,它用于选择最优切分点,具体见代码段2.8。
代码段2.7 计算节点中出现次数最多的类别(源码位于Chapter02/CartClassifier.py)
![](https://epubservercos.yuewen.com/01202D/24002389209517606/epubprivate/OEBPS/Images/d2-7.jpg?sign=1739291735-u880GCsyNYKG6XAzCLwM8iCoIX3zRXH9-0-a7ad73e87a87ef4b3295bebb113b6f86)
代码段2.8 选择最优切分点(源码位于Chapter02/CartClassifier.py)
![](https://epubservercos.yuewen.com/01202D/24002389209517606/epubprivate/OEBPS/Images/d2-8.jpg?sign=1739291735-KHQ3IDEJJ91WglSZhDY4jeOEWLxFvwyQ-0-e380491e4ac3dddac98a43ea5c58f9ee)
![](https://epubservercos.yuewen.com/01202D/24002389209517606/epubprivate/OEBPS/Images/d2-8-1.jpg?sign=1739291735-fNZ4G51xE2qWImYyeXIzIbIGK3UAFMRl-0-f2abb20fa1d3f3a1afd945e8ab1098d4)
代码段2.8是CART分类树中最核心的函数,该函数负责选择最优切分点。根据前面的理论推导,该函数的目的是计算取得最大基尼增益的属性值。首先在第219行调用__cal_gini_impurity函数计算总数据集的基尼不纯度G(root)。然后在第220~237行遍历每个属性的每个属性值,根据是否等于属性值(二分类问题)将数据集分割到左右子树,依次计算左右子树的基尼不纯度G(left)和G(right),以及左右子树中数据样本在总样本中占的比例P(left)和P(right),并且将G(root)、G(left)、G(right)、P(left)和P(right)代入__cal_gini_gain函数中计算基尼增益。最后在第239~243行选出具有最大基尼增益的属性值,作为当前节点的最优切分点,并返回最优切分点和最优分裂属性索引。
在代码段2.8的第236行,调用__cal_gini_gain函数来计算基尼增益。在计算基尼增益之前,我们需要知道如何计算一个数据集的基尼不纯度。如代码段2.8的第229行和第232行所示,通过调用__cal_gini_impurity函数来计算基尼不纯度,它的具体实现见代码段2.9。
代码段2.9 计算基尼不纯度(源码位于Chapter02/CartClassifier.py)
![](https://epubservercos.yuewen.com/01202D/24002389209517606/epubprivate/OEBPS/Images/d2-9.jpg?sign=1739291735-bxSh1hmA62XAptb5iM1joMAm90fXMLBR-0-ccb1e5e29ba1023f6de78fc175ebbbcf)
在代码段2.9的第253~258行,我们分析导入的数据集的最后一列(一般默认为数据类别),根据不同类别按出现次数统计到分类字典中。在第260~265行遍历该字典,根据公式用1减去不同的类分布概率的平方和,得到最终的基尼不纯度。接下来在计算基尼不纯度的基础上进一步实现基尼增益的计算,即__cal_gini_gain函数。它的具体代码见代码段2.10。
代码段2.10 计算基尼增益(源码位于Chapter02/CartClassifier.py)
![](https://epubservercos.yuewen.com/01202D/24002389209517606/epubprivate/OEBPS/Images/d2-10.jpg?sign=1739291735-6GHPwGThWnN0NG4bn0aAmVw5ovzz66Lz-0-c248d4ce80982e2a63d2dfabb37350cb)
求解基尼指数的过程与求解基尼增益的过程有着相似之处,它们都需要划分数据求出基尼不纯度,以及左右子树中类的比例,只不过基尼指数不需要求总数据集的基尼不纯度,而是将pro_left*gini_impurity_left与pro_right*gini_impurity_right累加求和。因此,在选择最优切分点时,我们选择具有最大基尼增益的属性值,或者具有最小基尼指数的属性值。
2.2.3.3 预测过程
代码段2.11a和代码段2.11b演示了CART分类树进行预测时的整体过程。在预测过程中,依然首先在代码段2.11a的第61~69行拷贝数据集、处理字符串映射和判断是否启用GPU加速,然后在代码段2.11a的第72行和代码段2.11b的第178~182行遍历测试集X_tensor的每个样本,使用__classify函数分别对其进行预测,最终返回拼接好的预测结果。从代码段2.11b的结构中可以看出非常好的并行性,因此在多核CPU机器上处理大数据预测时,使用多线程将其并行化可以大大提升预测效率,对此不做赘述。
代码段2.11a CART分类树预测过程(源码位于Chapter02/CartClassifier.py)
![](https://epubservercos.yuewen.com/01202D/24002389209517606/epubprivate/OEBPS/Images/d2-11a.jpg?sign=1739291735-RQRq3ZvQmF4a1OhXZSmS3iI7xHi3h2vf-0-36fc49c1c67a420b128110c86a84de9c)
代码段2.11b CART分类树预测过程(源码位于Chapter02/CartClassifier.py)
![](https://epubservercos.yuewen.com/01202D/24002389209517606/epubprivate/OEBPS/Images/d2-11b.jpg?sign=1739291735-RGHzqhJa62Il002AMUYvHdOCHsFwkqSi-0-78cc17a0babf191b7b8dd5eda06fe7e3)
在代码段2.11b的第180行,通过调用__classify进行预测分类,其具体代码见代码段2.12。在函数__classify的参数中,树模型tree是字典结构,它的每两层代表了实际意义上的一层决策树。因此在第291~304行的递归遍历过程中,每次取出tree的前两层(根节点和根节点的左右孩子节点),其中根节点代表属性,根节点的左右孩子节点代表属性的取值及路由方向。根据以上特点,从根节点开始,递归遍历CART分类树,最终路由到某个叶子节点,叶子节点上的值即为该决策树的预测结果。
代码段2.12 CART分类树预测的核心代码(源码位于Chapter02/CartClassifier.py)
![](https://epubservercos.yuewen.com/01202D/24002389209517606/epubprivate/OEBPS/Images/d2-12.jpg?sign=1739291735-b2Ae5oFNxKmAOXNGxUjo24IfY2UfWItP-0-a54aa7817970e0ef2bc31e6dcd139ad0)
2.2.3.4 可视化过程
对于可视化过程,可以借助Matplotlib库来实现,为此,我们结合树的遍历特点,封装了一套适用于上述决策树的tree_plotter可视化包。
在代码段2.13a和2.13b中,tree_plot函数为该包对外提供的决策树绘制接口,其整体算法的思路可分为两个步骤:首先绘制自身节点,然后判断自身节点类型,若为非叶子节点则继续递归创建子树,若为叶子节点则直接绘制。关于更详细的实现细节,请读者自行阅读源码,由于篇幅原因在此不再展开。
代码段2.13a tree_plotter可视化包(源码位于Chapter02/tree_plotter.py)
![](https://epubservercos.yuewen.com/01202D/24002389209517606/epubprivate/OEBPS/Images/d2-13a.jpg?sign=1739291735-mDfcQNUCfEDSSqizw721Tl6IvRQLaTjI-0-30c0d2e341c284f3b6425d739dd6c4c7)
代码段2.13b tree_plotter可视化包(源码位于Chapter02/tree_plotter.py)
![](https://epubservercos.yuewen.com/01202D/24002389209517606/epubprivate/OEBPS/Images/d2-13b.jpg?sign=1739291735-E6JN4UmnW5k7kZM2GIcSGsQVRya5dFX4-0-9b977b81c41202cde42780f24af39f25)
![](https://epubservercos.yuewen.com/01202D/24002389209517606/epubprivate/OEBPS/Images/d2-13b-1.jpg?sign=1739291735-yWSZIoMMxiihXinNCu6cyQVWUWmCmKRb-0-4d8d0b687e23cd7d5da0d9d367fc6a6f)