欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页

【机器学习入门一】决策树及ID3决策树的python实现

程序员文章站 2024-02-16 12:47:46
...

本文实现的是ID3决策树。一开始是想实现一下adaboost算法,但是弱分类器选择的是决策树桩,因此干脆先实现决策树。本文基于周志华老师的《机器学习》第四章


目录

  1. 基本流程
  2. 划分选择
    • 信息增益
    • 增益率
    • 基尼指数
  3. 剪枝处理
    • 预剪枝
    • 后剪枝
  4. 连续与缺失值
    • 连续值处理
    • 缺失值处理
  5. 代码
    • 代码说明
    • 完整代码

正文

1. 基本流程

决策树对某个样本进行分类实际上是模拟人们思维的一种决策过程。以买西瓜为例,人们在挑选西瓜时的决策过程大致为如下过程,首先看一看西瓜的色泽,如果是青绿的,那么这个瓜可能是好瓜,如果这个瓜外表已经泛黄甚至是黑的,就几乎可以判定这个瓜不好吃。是不是颜色是青绿的一定是好瓜呢?显然也未必,人们还要继续看看其他的一些属性,比如根蒂,根据经验,如果根蒂是硬挺的,说明瓜可能刚采摘不久,还比较新鲜,如果是蜷缩柔软的,说明已经摘下很长时间了。那么根据我们的经验,可能根蒂硬挺的瓜是好瓜。这样继续判断别的属性,直到我们可以断定这个瓜是好瓜还是坏瓜为止。
这个决策过程可以用如下的决策流程图来表示:
【机器学习入门一】决策树及ID3决策树的python实现
这是西瓜问题的一种可能的决策过程。这个树形的结构就是所说的决策树。特殊一点,我在买瓜的时候没有什么经验,看着别人买瓜的时候听一听瓜的敲声,我也会装作很懂的敲一敲,最后挑个长的好看的。那么显然此时我的决策树只有一层,就叫做决策树桩。
决策树的生成过程大致如下:首先选择一个“好的属性”,根据该属性的所有取值为这个节点生成孩子节点,如上图中的色泽属性有“青绿”和“乌黑”2种可能的取值,于是色泽节点就有两个孩子节点,将所有色泽为“青绿”的样本划入第一个节点,所有色泽为“乌黑”的样本划入第二个节点。这样递归生成决策树。
这里有两个问题:(1)什么时候递归返回?(2)什么是“好的属性”?
首先看第一个问题。递归返回一个有3种情况,(1)当前节点包含的样本全部属于一个类别。(2)当前已经使用了样本的所有属性,或者样本在所有属性上的取值相同,无法区分。(3)当前节点包含的样本集合为空,不能划分。
第一种情况下,叶节点的类别就是这些样本的类别。第二种情况下,叶节点的类别是样本集中所属样本数量最多的类别。第三种情况下,取父节点中所含样本最多的类别。
这是第一个问题,那么什么是“好的属性”呢?

2. 划分选择

2.1 信息增益

直观上来看,我们在做出决策时,希望当前选择的属性能够把尽可能多的样本分对,即按照每个取值,子样本集的“纯度”尽可能高。在信息学上,经常用信息熵来度量集合的纯度。
假定当前样本集合 D中第κ类样本中所占的比例为pκ(κ=1,2,3,,γ),则D的信息熵定义为:

Ent(D)=k=1γpklog2pk(1)

Ent(D)的值越小,则D的纯度越高。
假定离散属性a有V可能的取值{a1,a2,,aV},若使用a来对样本集合D进行划分,则会产生V个分枝节点,其中第ν个分枝节点包含了D中所有在属性a上取值为aν的样本,记为Dν。根据式(1)计算出Dν的信息熵,再考虑到不同的分枝节点所包含的样本数不同,给分枝节点赋予权重Dν/D,即样本数越多的分支节点的影响越大,于是可以计算出属性a对样本集D记性划分所获得的“信息增益”:
Gain(D,a)=Ent(D)v=1VDνDEnt(Dν)(2)

一般而言,信息增益越大,则意味着使用属性a来进行划分所获得的“纯度提升”越大,也可以理解为在划分的过程中获取的信息越多。因此,我们就优先选择使得Gain(D,a)最大的属性a。
用机器学习中的数据集为例:

编号 色泽 根蒂 敲声 纹理 脐部 触感 好瓜?
1 0 0 0 0 0 0 Y
2 1 0 1 0 0 0 Y
3 1 0 0 0 0 0 Y
4 0 0 1 0 0 0 Y
5 2 0 0 0 0 0 Y
6 0 1 0 0 1 1 Y
7 1 1 0 1 1 1 Y
8 1 1 0 0 1 0 Y
9 1 1 1 1 1 0 N
10 0 2 2 0 2 1 N
11 2 2 2 2 2 0 N
12 2 0 0 2 2 1 N
13 0 1 0 1 0 0 N
14 2 1 1 1 0 0 N
15 1 1 0 0 1 1 N
16 2 0 0 2 2 0 N
17 0 0 1 1 1 0 N

对于上表中数字和字母的说明:

数字 色泽 根蒂 敲声 纹理 脐部 触感 好瓜?
0 青绿 蜷缩 浊响 清晰 凹陷 硬滑
1 乌黑 稍蜷 沉闷 稍糊 稍凹 软粘
2 浅白 硬挺 清脆 模糊 平坦
Y
N

数据集正例(好瓜)的比例为p1=817,反例(坏瓜)的比例为p2=917,所以数据集D的信息熵为:

Ent(D)=k=12pklog2pk=(817log2817+917log2917)=0.998

然后计算每个属性的信息增益。以色泽为例:
D0=1,4,6,10,13,17D1=2,3,7,8,9,15D2=5,11,12,14,16由此可以计算出根据色泽划分之后所获得的三个分支节点的信息熵为:
Ent(D1)=(36log236+36log236)=1.000

Ent(D2)=(46log246+26log226)=0.918

Ent(D3)=(15log215+45log245)=0.722

由此计算出此时的信息增益为
(1)Gain(D,)=Ent(D)ν=13DνDEnt(Dν)(2)=0.998(617×1.0000+617×0.918+517×0.722)(3)=0.109

类似的可以计算出其他属性的信息增益,可以发现属性“纹理”信息增益最大,因此选择纹理进行划分。如果有多个的属性的信息增益相同,则随机选择一个一个属性划分。
ID3决策树就是基于信息增益来生成决策树。最终计算出的决策树在文章的最后。

2.2 增益率

实际上,信息增益准则对可取值数目较多的属性有所偏好。C4.5算法使用“增益率”来选择最优的属性。“增益率”的定义式如下:

Gain_ratio(D,a)=Gain(D,a)v=1VDvDlog2DvD(3)

当属性a的取值越多,通常分母越小。增益率准则对可取值数目较少的属性有所偏好,因此,C4.5 算法并不是直接选择增益率最大的候选属性划分,而是使用了一个启发式的方法:先从候选属性中找出信息增益高于平均水平的属性,再从中选择增益率最高的。

2.3 基尼指数

CART决策树使用“基尼系数”来选择划分属性。基尼系数的定义式如下:

(4)Gini(D)=k=1ykkpkpk(5)=1k=1ypk2(4)

直观的说,Gini(D)反映了从数据集D中随机抽取两个样本,其类别标记不一致的概率。因此,Gini(D)越小,则数据集D的初度越高。
具体到数据集D,属性a的基尼指数定义为:
Gini_index(D,a)=v=1VDvDGini(Dv)(5)

于是,选取使得基尼指数最小的属性作为划分的依据。

3.剪枝处理

剪枝处理是决策树学习过程中解决过拟合的主要手段。决策树的过拟合主要体现在考虑的属性过多。决策树剪枝的方法主要有“预剪枝”和“后剪枝”。

3.1 预剪枝

预剪枝是指在决策树的生成过程中,对每个节点在划分阶段前先进行估计,若当前节点的划分不能带来决策树泛化性能的提升,则停止划分并将当前节点标记为叶节点。
预剪枝的好处在于不仅降低了过拟合的危险,还显著减少了决策树的训练时间开销和测试时间开销。但另一方面,有些分支的当前划分虽不能提升泛化性能,甚至可能导致泛化性能暂时下降,但在其基础上的后续划分却有可能导致性能显著提升。即预剪枝有可能会带来欠拟合的风险。

3.2 后剪枝

后剪枝则是先从训练集生成一棵决策树,然后自底向上地对非叶节点进行考察,若将该节点对应的子树替换为叶节点能带来决策树泛化性能的提升,则将该子树替换为叶节点。
一般情形下,后剪枝决策树的欠拟合风险很小,泛化性能往往优于预剪枝决策树,但是坏处是训练时间要比未剪枝和预剪枝高的多。
预剪枝和后剪枝的详细计算请参见《机器学习》80至83页。不再展开。

4. 连续与缺失值

4.1 连续值处理

通常是将连续的属性离散化。最简单策略是采用二分法,这也是C4.5决策树算法中采用的机制。
对于属性a,假设有n个不同的取值,将这些值按升序排列后为{a1,a2,,an},基于划分点t可以将D分为子集DtDt+,前者包含属性a的取值不大于t的样本,后者为前者的补集。划分点通常取相邻两个取值的中点。这样划分点的集合为:

Ta={ai+ai+12|1in1}(6)

然后就可以像离散属性一样来考察这些划分点。
需要注意的是,与离散属性不同,若当前节点划分属性为连续属性,该属性还可作为其后代节点的划分属性。

4.2 缺失值处理

在数据集中的缺失值较多时,需要解决两个问题:(1)如何在属性值缺失的情况下进行划分属性的选择。(2)给定划分属性,若样本在该属性上的值缺失,如何对样本进行划分?
对于第一个问题,我们只能选择属性a的值没有缺失的样本进行计算。
对于第二个问题,如果样本x在划分属性a上的值缺失,则将其划入a的所有子节点中,且样本权值进行相应的调整,直观的看,这就是让同一个样本以不同的概率划分到不同的子节点中去。


5. 代码

5.1 代码说明

我实现的是ID3决策树,C4.5决策树不仅仅是将划分标准变为增益率,还有属性选择的启发式方法,连续值的处理剪枝等等,但是网上很多代码都是只使用了增益率,其他的与ID3没有差别
这里的数据采用了上面提到的数据集
手动输入训练集,输出的两个array对象,dataSet是包含标记在内的数据集,labels实际上是属性的名字的array,萌新表示英语很渣,勿喷。

import numpy as np
import math
import treePlotter
import matplotlib.pyplot as plt
%matplotlib inline
##创建训练集
def createDataset():
    dataSet=np.array([[0,0,0,0,0,0,'Y'],
            [1,0,1,0,0,0,'Y'],
            [1,0,0,0,0,0,'Y'],
            [0,0,1,0,0,0,'Y'],
            [2,0,0,0,0,0,'Y'],
            [0,1,0,0,1,1,'Y'],
            [1,1,0,1,1,1,'Y'],
            [1,1,0,0,1,0,'Y'],
            [1,1,1,1,1,0,'N'],
            [0,2,2,0,2,1,'N'],
            [2,0,0,2,2,1,'N'],
            [2,2,2,2,2,0,'N'],
            [0,1,0,1,0,0,'N'],
            [2,1,1,1,0,0,'N'],
            [1,1,0,0,1,1,'N'],
            [2,0,0,2,2,0,'N'],
            [0,0,1,1,1,0,'N']])
    labels=np.array(['color','root','lisen','look','qibu','feel'])
    return dataSet,labels

下面函数的作用是计算信息熵
输入:数据集dataSet,格式是m×nm是样本的数量,n是当前样本集的属性数+1,即最后一列存储类别。注意,若当前节点选择属性a作为划分属性,则传递给下一层的自数据集的dataSet中不再含有a这一列。
输出:实数shannonEnt,计算出的信息熵。
描述:

##计算信息熵 
def calcShannonEnt(dataSet):
##求解数据列表的维度
    rows=np.shape(dataSet)[1]
    cols=np.shape(dataSet)[0]
##用转化为集合的方式获取数据的标记,并且保存
    labels=set(dataSet[:,-1])
    shannonEnt=0.0
##以此求解每类样本所占比例,以求出信息熵
    for i in np.arange(len(labels)):
        label=labels.pop()
        prob=float(np.sum(dataSet==label)/cols)
        shannonEnt-=prob*math.log(prob,2)
    return shannonEnt

输入:数据集dataSet,格式和上面一样
输出:整数feature,选择信息增益最大的属性所在的列坐标
描述:实际上的第一个属性对应的列坐标为0,首先计算出现在的数据集的熵Ent0,然后循环便利数据集的每个特征i,在每次循环中,用集合提取出该属性的所有取值values,循环遍历values的每个可能取值value,index是所有该属性取值为value的样本的索引。现将array数组转化为list(用array.tolist())方法,然后使用list.count()方法可以统计出这些样本的数量。这样计算出按当前属性划分的熵Ent1,Ent0-Ent1就是信息增益。

##选择出最大的信息增益所对应的特征的索引
def chooseBestFeatureToSplit(dataSet):
    rows=np.shape(dataSet)[1]
    cols=np.shape(dataSet)[0]
    ##计算出分类前数据集的熵
    Ent0=calcShannonEnt(dataSet)
    feature=-1
    maxGain=0.0
    ##依次考察每个特征
    for i in np.arange(rows-1):
        values=set(dataSet[:,i])
        Ent1=0.0
        ##计算该特征的每个值所对应的分类后的熵值
        for j in np.arange(len(values)):
            value=values.pop()
            index=np.where(dataSet[:,i]==value)
            subEnt=calcShannonEnt(dataSet[index,:].reshape(-1,rows))
            Ent1+=(dataSet[:,i].tolist().count(value)/cols)*subEnt
        ##如果该特征的信息增益大于之前特征的信息增益,则保留该特征的索引
        if(Ent0-Ent1)>maxGain:
            feature=i
            maxGain=Ent0-Ent1
    return feature

输入:数据集dataSet,作为划分依据的属性的列坐标axis,该属性的取值value
输出:array对象retDataSet,存储着dataSet上被选中属性的取值为value的所有样本。注意,此时的retDataSet中不再含有该属性所在的列。
描述:首先找出取值为value的样本dataSet[index,:],然后将这个array分为两部分,axis列左边的和右边的,将这两部分横向拼接就是retDataSet。

##划分数据,为下一层计算做准备
def splitDataSet(dataSet,axis,value):
    subDataSet=[] ##subDataSet是要传递给下一层的数据集,即应该划分给第axis个特征值为value的样本
    rows=np.shape(dataSet)[1]
    index=np.where(dataSet[:,axis]==value)##找到axis特征的取值为value的所有样本的索引index
    ##构建新的数据集,注意要去掉第axis列,使该特征之后不再参加比较
    retDataSet=np.hstack((dataSet[index,:].reshape(-1,rows)[:,:axis],dataSet[index,:].reshape(-1,rows)[:,axis+1:]))
    return retDataSet

输入:一个由类别标记组成的列向量matrix
输出:标记中出现次数最多的标记label
描述:先用构建集合的方法找出matrix中所有出现过的标记,然后循环查找每个标记出现的次数num。

##查找出现次数最多的标记并返回
def majorityCnt(matrix):
    labelSet=set(matrix)
    maxSum=0
    for i in np.arange(len(labelSet)):
        t=labelSet.pop()
        num=matrix.tolist().count(t)
        if(maxSum<num):
            maxSum=num
            label=t
    return label

输入:数据集dataSet(array对象),所有属性的名字字符串featurenames(array对象)。
输出:基于字典的决策树
描述:见注释

##递归构建决策树
def createTree(dataSet,featureNames):
    labels=dataSet[:,-1]##labels是数据集的所有标记
    ##如果数据集中只有一种标记,则停止递归,并返回该标记作为叶节点
    if(len(set(labels))==1):
        return dataSet[0,-1]
    ##如果数据集中的样本已经没有特征可选,则停止递归,并返回现有样本中出现次数最多的标记作为该叶节点的标记
    if(np.shape(dataSet)[0]==1):
        return majorityCnt(dataSet)
    ##选择出使得信息增益最大的特征,bestFeatIndex是该特征在当前数据集dataSet中的列坐标
    bestFeatIndex=chooseBestFeatureToSplit(dataSet)
    myTree={bestFeatIndex:{}}##构建字典树,用bestFeatIndex作为根节点
    subFeatureNames=np.delete(featureNames,bestFeatIndex)##去除掉本次划分使用的特征的名字,featureNames是特征名字的数组
    ##求出该特征的所有取值的集合featValues
    featValues=set(dataSet[:,bestFeatIndex])
    ##用featvalues中的每个取值做为根节点递归建树
    for value in featValues:
        myTree[bestFeatIndex][value]=createTree(splitDataSet(dataSet,bestFeatIndex,value),subFeatureNames)
    return myTree

输入:见注释
输出:依据决策树判定的测试样本所属的类别
描述:for循环中提取出的k实际就是找出整个决策树的根节点是哪个属性,没有想到其他的方法,迫不得已用了个for循环,很low,大家如果有其他方法欢迎告知。首先判断测试样本在该属性上的取值是否超出了该节点的子节点的范围,subTree是该样本所对应的子决策树,由于生成决策树的每一层时都丢掉了相应的属性列,因此测试样本在向下层传递时也要丢掉相应的属性列。

##分类函数,decisionTree是生成的决策树,testData是用于测试的数据,labels是所有样本可能取值的标价组成的array
def classify(decisionTree,testData,labels):
    ##这里提取decisionTree的键值,由于字典树的长度是1,因此这个循环注定只能运行一次,暂时没有想到更好的提取键值的方法,以后再更新
    for k in decisionTree:
        if k in labels:
            return k
        else:
            values=decisionTree[k]
            ##首先判断是否存在不合法的情况,即某个特征的取值不在决策树内
            if(not testData[k] in values):
                return "wrong"
            subTree=decisionTree[k][testData[k]]
            newTestData=np.hstack((testData[:k],testData[k+1:]))
            return classify(subTree,newTestData,labels)

下面是训练,测试的过程。

dataSet,labels=createDataset()
myTree=createTree(dataSet,labels)
testData=np.array(['3','3','3','3','3','3'])
labels=np.array(['N','Y'])
label=classify(myTree,testData,labels)
treePlotter.createPlot(myTree)##调用找到的一个可以画出树的形状的文件treePlotter.py,该文件见文末。

【机器学习入门一】决策树及ID3决策树的python实现
这个是生成的决策树的效果图,与《机器学习》78页图4.4是完全相同的。

5.2 完整代码
import numpy as np
import math
import treePlotter
import matplotlib.pyplot as plt
%matplotlib inline
##创建训练集
def createDataset():
    dataSet=np.array([[0,0,0,0,0,0,'Y'],
            [1,0,1,0,0,0,'Y'],
            [1,0,0,0,0,0,'Y'],
            [0,0,1,0,0,0,'Y'],
            [2,0,0,0,0,0,'Y'],
            [0,1,0,0,1,1,'Y'],
            [1,1,0,1,1,1,'Y'],
            [1,1,0,0,1,0,'Y'],
            [1,1,1,1,1,0,'N'],
            [0,2,2,0,2,1,'N'],
            [2,0,0,2,2,1,'N'],
            [2,2,2,2,2,0,'N'],
            [0,1,0,1,0,0,'N'],
            [2,1,1,1,0,0,'N'],
            [1,1,0,0,1,1,'N'],
            [2,0,0,2,2,0,'N'],
            [0,0,1,1,1,0,'N']])
    labels=np.array(['color','root','lisen','look','qibu','feel'])
    return dataSet,labels
##计算信息熵
def calcShannonEnt(dataSet):
##求解数据列表的维度
    rows=np.shape(dataSet)[1]
    cols=np.shape(dataSet)[0]
##用转化为集合的方式获取数据的标记,并且保存
    labels=set(dataSet[:,-1])
    shannonEnt=0.0
##以此求解每类样本所占比例,以求出信息熵
    for i in np.arange(len(labels)):
        label=labels.pop()
        prob=float(np.sum(dataSet==label)/cols)
        shannonEnt-=prob*math.log(prob,2)
    return shannonEnt
##选择出最大的信息增益所对应的特征的索引
def chooseBestFeatureToSplit(dataSet):
    rows=np.shape(dataSet)[1]
    cols=np.shape(dataSet)[0]
    ##计算出分类前数据集的熵
    Ent0=calcShannonEnt(dataSet)
    feature=-1
    maxGain=0.0
    ##依次考察每个特征
    for i in np.arange(rows-1):
        values=set(dataSet[:,i])
        Ent1=0.0
        ##计算该特征的每个值所对应的分类后的熵值
        for j in np.arange(len(values)):
            value=values.pop()
            index=np.where(dataSet[:,i]==value)
            subEnt=calcShannonEnt(dataSet[index,:].reshape(-1,rows))
            Ent1+=(dataSet[:,i].tolist().count(value)/cols)*subEnt
        ##如果该特征的信息增益大于之前特征的信息增益,则保留该特征的索引
        if(Ent0-Ent1)>maxGain:
            feature=i
            maxGain=Ent0-Ent1
    return feature
##划分数据,为下一层计算做准备
def splitDataSet(dataSet,axis,value):
    subDataSet=[] ##subDataSet是要传递给下一层的数据集,即应该划分给第axis个特征值为value的样本
    rows=np.shape(dataSet)[1]
    index=np.where(dataSet[:,axis]==value)##找到axis特征的取值为value的所有样本的索引index
    ##构建新的数据集,注意要去掉第axis列,使该特征之后不再参加比较
    retDataSet=np.hstack((dataSet[index,:].reshape(-1,rows)[:,:axis],dataSet[index,:].reshape(-1,rows)[:,axis+1:]))
    return retDataSet
##查找出现次数最多的标记并返回
def majorityCnt(matrix):
    labelSet=set(matrix)
    maxSum=0
    for i in np.arange(len(labelSet)):
        t=labelSet.pop()
        num=matrix.tolist().count(t)
        if(maxSum<num):
            maxSum=num
            label=t
    return label
##递归构建决策树
def createTree(dataSet,featureNames):
    labels=dataSet[:,-1]##labels是数据集的所有标记
    ##如果数据集中只有一种标记,则停止递归,并返回该标记作为叶节点
    if(len(set(labels))==1):
        return dataSet[0,-1]
    ##如果数据集中的样本已经没有特征可选,则停止递归,并返回现有样本中出现次数最多的标记作为该叶节点的标记
    if(np.shape(dataSet)[0]==1):
        return majorityCnt(dataSet)
    ##选择出使得信息增益最大的特征,bestFeatIndex是该特征在当前数据集dataSet中的列坐标
    bestFeatIndex=chooseBestFeatureToSplit(dataSet)
    myTree={bestFeatIndex:{}}##构建字典树,用bestFeatIndex作为根节点
    subFeatureNames=np.delete(featureNames,bestFeatIndex)##去除掉本次划分使用的特征的名字,featureNames是特征名字的数组
    ##求出该特征的所有取值的集合featValues
    featValues=set(dataSet[:,bestFeatIndex])
    ##用featvalues中的每个取值做为根节点递归建树
    for value in featValues:
        myTree[bestFeatIndex][value]=createTree(splitDataSet(dataSet,bestFeatIndex,value),subFeatureNames)
    return myTree
##分类函数,decisionTree是生成的决策树,testData是用于测试的数据,labels是所有样本可能取值的标价组成的array
def classify(decisionTree,testData,labels):
    ##这里提取decisionTree的键值,由于字典树的长度是1,因此这个循环注定只能运行一次,暂时没有想到更好的提取键值的方法,以后再更新
    for k in decisionTree:
        if k in labels:
            return k
        else:
            values=decisionTree[k]
            ##首先判断是否存在不合法的情况,即某个特征的取值不在决策树内
            if(not testData[k] in values):
                return "wrong"
            subTree=decisionTree[k][testData[k]]
            newTestData=np.hstack((testData[:k],testData[k+1:]))
            return classify(subTree,newTestData,labels)


dataSet,labels=createDataset()
myTree=createTree(dataSet,labels) 
testData=np.array(['3','3','3','3','3','3'])
labels=np.array(['N','Y'])
classify(myTree,testData,labels)  
treePlotter.createPlot(myTree)         

绘制决策树的treePlotter.py文件在这里:

import matplotlib.pyplot as plt

decisionNode = dict(boxstyle="sawtooth", fc="0.8")
leafNode = dict(boxstyle="round4", fc="0.8")
arrow_args = dict(arrowstyle="<-")

def plotNode(nodeTxt, centerPt, parentPt, nodeType):
    createPlot.ax1.annotate(nodeTxt, xy=parentPt, xycoords='axes fraction', \
                            xytext=centerPt, textcoords='axes fraction', \
                            va="center", ha="center", bbox=nodeType, arrowprops=arrow_args)

def getNumLeafs(myTree):
    numLeafs = 0
    firstStr = list(myTree.keys())[0]
    secondDict = myTree[firstStr]
    for key in secondDict.keys():
        if type(secondDict[key]).__name__ == 'dict':
            numLeafs += getNumLeafs(secondDict[key])
        else:
            numLeafs += 1
    return numLeafs

def getTreeDepth(myTree):
    maxDepth = 0
    firstStr = list(myTree.keys())[0]
    secondDict = myTree[firstStr]
    for key in secondDict.keys():
        if type(secondDict[key]).__name__ == 'dict':
            thisDepth = getTreeDepth(secondDict[key]) + 1
        else:
            thisDepth = 1
        if thisDepth > maxDepth:
            maxDepth = thisDepth
    return maxDepth

def plotMidText(cntrPt, parentPt, txtString):
    xMid = (parentPt[0] - cntrPt[0]) / 2.0 + cntrPt[0]
    yMid = (parentPt[1] - cntrPt[1]) / 2.0 + cntrPt[1]
    createPlot.ax1.text(xMid, yMid, txtString)

def plotTree(myTree, parentPt, nodeTxt):
    numLeafs = getNumLeafs(myTree)
    depth = getTreeDepth(myTree)
    firstStr = list(myTree.keys())[0]
    cntrPt = (plotTree.xOff + (1.0 + float(numLeafs)) / 2.0 / plotTree.totalw, plotTree.yOff)
    plotMidText(cntrPt, parentPt, nodeTxt)
    plotNode(firstStr, cntrPt, parentPt, decisionNode)
    secondDict = myTree[firstStr]
    plotTree.yOff = plotTree.yOff - 1.0 / plotTree.totalD
    for key in secondDict.keys():
        if type(secondDict[key]).__name__ == 'dict':
            plotTree(secondDict[key], cntrPt, str(key))
        else:
            plotTree.xOff = plotTree.xOff + 1.0 / plotTree.totalw
            plotNode(secondDict[key], (plotTree.xOff, plotTree.yOff), cntrPt, leafNode)
            plotMidText((plotTree.xOff, plotTree.yOff), cntrPt, str(key))
    plotTree.yOff = plotTree.yOff + 1.0 / plotTree.totalD

def createPlot(inTree):
    fig = plt.figure(1, facecolor='white')
    fig.clf()
    axprops = dict(xticks=[], yticks=[])
    createPlot.ax1 = plt.subplot(111, frameon=False, **axprops)
    plotTree.totalw = float(getNumLeafs(inTree))
    plotTree.totalD = float(getTreeDepth(inTree))
    plotTree.xOff = -0.5 / plotTree.totalw
    plotTree.yOff = 1.0
    plotTree(inTree, (0.5, 1.0), '')
    plt.show()

初入计算机,请大家多多指教嘛,共同学习~~~持续更新中……
参考一位大佬的文章在这里
参考书为周志华老师的《机器学习》