欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页

决策树——ID3算法实现

程序员文章站 2024-02-11 12:50:16
...

决策树:构建一个基于属性的树形分类器。
1.每个非叶节点表示一个特征属性上的测试(分割),
2.每个分支代表这个特征属性在某个值域上的输出,
3.每个叶节点存放一个类别。
使用决策树进行决策的过程就是从根节点开始,测试待分类项中相应的特征属性,并按照其值选择输出分支,直到到达叶子节点,将叶子节点存放的类别作为决策结果。
采用递归的方法进行建树

递归的结束条件

1.当前结点样本均属于同一类别,无需划分。
2.当前属性集为空。
3.所有样本在当前属性集上取值相同,无法划分。
4.当前结点包含的样本集合为空,不能划分。

决策树的核心

经过属性划分后,不同类样本被更好的分离
理想情况:划分后样本被完美分类。即每个分支的样本都属性同一类。
实际情况:不可能完美划分!尽量使得每个分支某一类样本比例尽量高!即尽量提高划分后子集的纯度。

划分的目标:提升划分后子集的纯度,降低划分后子集的不纯度

决策树算法分类

决策树算法的区别主要在于所采用的纯度判别标准

ID3算法:

使用信息增益作为判别标准
信息熵计算公式:
决策树——ID3算法实现
假设属性��有��可能取值{��^1,��^2,⋯⋯,��^��}, ��^��对应划分后的数据子集为��^��.
决策树——ID3算法实现
信息增益:
决策树——ID3算法实现
信息增益越大,说明当前的划分效果越好
决策树——ID3算法实现

C4.5算法

使用信息增益率作为判别准则
决策树——ID3算法实现
����(��)称为属性��的“固有值”(Intrinsic Value)
决策树——ID3算法实现
信息增益率越大,说明当前划分效果越好

CART算法

使用基尼系数作为判别准则
决策树——ID3算法实现

实验环境

python3.6
macOS 10.12

代码思路

BuildTree函数:在该函数中完成递归建树,递归返回条件的判断,建立存储树所用的字典,打印各类信息
ChooseAttr函数:在该函数中完成选出最佳特征的功能,根据Ent函数计算出的所有样本的信息熵和加权的信息熵计算信息增益,信息增越大的意味着该属性的纯度越高,选取信息增益最大的属性为最佳属性。
Ent函数:计算输入样本的信息熵,通过输入Sample的最后一列统计出该正例与反例出现的概率,根据信息熵公式计算信息熵
SpiltData函数:该函数用于对数据进行拆分,去掉已经判断过的属性对应的样本
CreatePlot函数:用于决策树的可视化

数据集

使用西瓜书上的西瓜数据集2.0
为了方便计算,将西瓜数据集的内容转换为数字
色泽: 0:青绿 1:乌黑 2:浅白
根底: 0:蜷缩 1:少蜷 2:硬挺
敲声: 0:浊响 1:沉闷 2:清脆
纹理: 0:清晰 1:稍糊 2:模糊
脐部: 0:凹陷 1:稍凹 2:平坦
触感: 0:硬滑 1:软黏
好瓜: 0:不是 1:是

上代码

import math
import numpy
import DrawTree

数据集,属性列表

#初始化一个属性列表
AttrArr=["色泽","根蒂","敲声","纹理","脐部","触感","好瓜"]
#此处使用西瓜数据集2.0
data = numpy.array(
[[0,0,0,0,0,0,1],
[1,0,1,0,0,0,1],
[1,0,0,0,0,0,1],
[0,0,1,0,0,0,1],
[2,0,0,0,0,0,1],
[0,1,0,0,1,1,1],
[1,1,0,1,1,1,1],
[1,1,0,0,1,0,1],
[1,1,1,1,1,0,0],
[0,2,2,0,2,1,0],
[2,2,2,2,2,0,0],
[2,0,0,2,2,1,0],
[0,1,0,1,0,0,0],
[2,1,1,1,0,0,0],
[1,1,0,0,1,1,0],
[2,0,0,2,2,0,0],
[0,0,1,1,1,0,0]]
)

BuildTree函数:在该函数中完成递归建树,递归返回条件的判断,建立存储树所用的字典,打印各类信息

#建树的函数
def BuildTree(Sample,Label):
    #Sample 为输入的数据
    #Label 为对应的标签

    #获取输入数据的的大小
    [Count, Attr] = Sample.shape;
    n = Attr - 1;
    m = Count - 1;
    print("Sample:")
    print(Sample)
    #使用classlist存储表示正例与反例所在的列
    classList = Sample[:, n];
    # 记录第一个类中的个数
    classOne = 1;
    for i in range(1, Count):
        if (classList[i] == classList[0]):
            classOne = classOne + 1;
    #如果当前结点包含的样本全属于同一个样本,则停止划分
    if (classOne == Count):
        print("Final")
        print(Sample)
        if(classList[0]==0):return "no" #通过 classlist 的 0 1 判断 最终的结果
        if (classList[0] == 1): return "yes" #通过 classlist 的 0 1 判断 最终的结果
    #如果当前属性集为空,无法划分
    if (Attr == 0):
        print("Final")
        print(Sample)
        return classList[0]
    #使用ChooseAttr函数获取最佳的特征对应编号
    bestAttr = ChooseAttr(Sample)
    #通过最佳特征的编号获得标签名
    name=Label[bestAttr]
    #新建一个字典用于存储树
    Tree = {name: {}}
    #打印出最佳特征
    print("最佳特征:", name);
    #取出对最佳属性对应的一列 并去掉重复值 用于得出一个属性下所包含的取值
    featValue = numpy.unique(Sample[:, bestAttr])
    #计算出一个属性下包含的取值
    numOfFeatValue = len(featValue);

    #最佳属性的每一个评级都打印出来
    for i in range(0, numOfFeatValue):
        print(name, "评级:", featValue[i])
        subLabels = Label[:]
        #对现有的树执行 SpiltData 去掉计算过的属性所对应的样本 递归调用buildTree
        Tree[name][i] = BuildTree(SpiltData(Sample, bestAttr, featValue[i]),subLabels)
        print('-------------------------');
    return Tree

ChooseAttr函数:在该函数中完成选出最佳特征的功能,根据Ent函数计算出的所有样本的信息熵和加权的信息熵计算信息增益,信息增越大的意味着该属性的纯度越高,选取信息增益最大的属性为最佳属性。

#Choose函数用于 选出最佳的属性
def ChooseAttr(Sample):
    #Sample 为输入的数据


    #获取输入数据的大小
    [Count, Attr] = Sample.shape
    numOfFeature = Attr - 1
    #计算整个数据的信息熵
    baseEnt = Ent(Sample)
    #初始信息增益
    bestInfoGain = 0.0
    #初始的最佳属性为 -1
    bestFeature = -1
    #遍历当前所有属性
    for j in range(0, numOfFeature):
        #记录出每一个属性中的取值 并去掉重复值
        featureTemp = numpy.unique(Sample[:, j])
        #记录下属性取值的个数
        numF = len(featureTemp)
        newEnt = 0.0;
        #遍历所有的取值
        for i in range(0, numF):
            #去除掉当前已经判断的样本
            subSet = SpiltData(Sample, j, featureTemp[i])
            #得到每一个取值的个数
            [newCount, newAttr] = subSet.shape
            #计算每一个取值出现的概率
            prob = newCount / Count
            #计算新的信息熵
            newEnt = newEnt + prob * Ent(subSet)
        #计算信息增益
        infoGain = baseEnt - newEnt
        #找到信息增益最大的属性 作为当前最佳属性
        if (infoGain > bestInfoGain):
            bestInfoGain = infoGain
            bestFeature = j
    return bestFeature

Ent函数:计算输入样本的信息熵,通过输入Sample的最后一列统计出该正例与反例出现的概率,根据信息熵公式计算信息熵

#Ent函数用于计算信息熵
def Ent(Sample):
    #Sample为输入的数据

    #得到输入数据的大小
    [Count, Attr] = Sample.shape
    n = Attr - 1
    m = Count - 1
    #获取正例与反例所在的列
    label = Sample[:, n]
    #去掉重复的数据
    deal = numpy.unique(label)
    #得到最后判别情况的个数
    numOfLabel = len(deal)
    #新建一个概率list 用于存储概率
    prob = numpy.zeros([numOfLabel, 2])
    for i in range(0, numOfLabel):
        #获取正例与反例
        prob[i, 0] = deal[i]
        for j in range(0, Count):
            #对正例 与 反例 进行计数
            if (label[j] == deal[i]):
                prob[i, 1] = prob[i, 1] + 1
    #计算出概率
    prob[:, 1] = prob[:, 1] / Count
    ent = 0
    #根据信息熵公示计算出信息熵
    for i in range(0, numOfLabel):
        ent = ent - prob[i, 1] * math.log2(prob[i, 1])
    return ent

SpiltData函数:该函数用于对数据进行拆分,去掉已经判断过的属性对应的样本

#对数据进行拆分 去掉已经判断过的属性所对应的样本
def SpiltData(Sample, axis, value):
    #Sample 代表输入的数据
    #axis 表示要删除值所在的行
    #value表示要删除的值
    [Count, Attr] = Sample.shape
    subSet = Sample
    k = 0
    #对每一个样本都做判断 把已经做过判断的样本删掉
    for i in range(0, Count):
        if (Sample[i, axis]) != value:
            subSet=numpy.delete(subSet,i-k,0)
            k = k + 1
    return subSet

主函数部分

TreeDict=BuildTree(data,AttrArr)

实验结果

存储树的字典:
{‘纹理’: {0: {‘根蒂’: {0: ‘yes’, 1: {‘色泽’: {0: ‘yes’, 1: {‘触感’: {0: ‘yes’, 1: ‘no’}}}}, 2: ‘no’}}, 1: {‘触感’: {0: ‘no’, 1: ‘yes’}}, 2: ‘no’}}

递归的过程:
Sample:
[[0 0 0 0 0 0 1]
[1 0 1 0 0 0 1]
[1 0 0 0 0 0 1]
[0 0 1 0 0 0 1]
[2 0 0 0 0 0 1]
[0 1 0 0 1 1 1]
[1 1 0 1 1 1 1]
[1 1 0 0 1 0 1]
[1 1 1 1 1 0 0]
[0 2 2 0 2 1 0]
[2 2 2 2 2 0 0]
[2 0 0 2 2 1 0]
[0 1 0 1 0 0 0]
[2 1 1 1 0 0 0]
[1 1 0 0 1 1 0]
[2 0 0 2 2 0 0]
[0 0 1 1 1 0 0]]
最佳特征: 纹理
纹理 评级: 0
Sample:
[[0 0 0 0 0 0 1]
[1 0 1 0 0 0 1]
[1 0 0 0 0 0 1]
[0 0 1 0 0 0 1]
[2 0 0 0 0 0 1]
[0 1 0 0 1 1 1]
[1 1 0 0 1 0 1]
[0 2 2 0 2 1 0]
[1 1 0 0 1 1 0]]
最佳特征: 根蒂
根蒂 评级: 0
Sample:
[[0 0 0 0 0 0 1]
[1 0 1 0 0 0 1]
[1 0 0 0 0 0 1]
[0 0 1 0 0 0 1]
[2 0 0 0 0 0 1]]
Final
[[0 0 0 0 0 0 1]
[1 0 1 0 0 0 1]
[1 0 0 0 0 0 1]
[0 0 1 0 0 0 1]
[2 0 0 0 0 0 1]]

根蒂 评级: 1
Sample:
[[0 1 0 0 1 1 1]
[1 1 0 0 1 0 1]
[1 1 0 0 1 1 0]]
最佳特征: 色泽
色泽 评级: 0
Sample:
[[0 1 0 0 1 1 1]]
Final
[[0 1 0 0 1 1 1]]

色泽 评级: 1
Sample:
[[1 1 0 0 1 0 1]
[1 1 0 0 1 1 0]]
最佳特征: 触感
触感 评级: 0
Sample:
[[1 1 0 0 1 0 1]]
Final
[[1 1 0 0 1 0 1]]

触感 评级: 1
Sample:
[[1 1 0 0 1 1 0]]
Final
[[1 1 0 0 1 1 0]]

根蒂 评级: 2
Sample:
[[0 2 2 0 2 1 0]]
Final
[[0 2 2 0 2 1 0]]

纹理 评级: 1
Sample:
[[1 1 0 1 1 1 1]
[1 1 1 1 1 0 0]
[0 1 0 1 0 0 0]
[2 1 1 1 0 0 0]
[0 0 1 1 1 0 0]]
最佳特征: 触感
触感 评级: 0
Sample:
[[1 1 1 1 1 0 0]
[0 1 0 1 0 0 0]
[2 1 1 1 0 0 0]
[0 0 1 1 1 0 0]]
Final
[[1 1 1 1 1 0 0]
[0 1 0 1 0 0 0]
[2 1 1 1 0 0 0]
[0 0 1 1 1 0 0]]

触感 评级: 1
Sample:
[[1 1 0 1 1 1 1]]
Final
[[1 1 0 1 1 1 1]]

纹理 评级: 2
Sample:
[[2 2 2 2 2 0 0]
[2 0 0 2 2 1 0]
[2 0 0 2 2 0 0]]
Final
[[2 2 2 2 2 0 0]
[2 0 0 2 2 1 0]
[2 0 0 2 2 0 0]]

树的可视化

此处使用的是网上一个常见的可视化代码

可视化函数

#绘制树形图
import matplotlib
# matplotlib.use('qt4agg')
from matplotlib.font_manager import *
import matplotlib.pyplot as plt
myfont = FontProperties(fname='/Users/zhangxuancheng/Library/Fonts/simhei.ttf')
decision_node = dict(boxstyle="sawtooth",fc="0.8")
leaf_node = dict(boxstyle="round4",fc="0.8")
arrow_args = dict(arrowstyle="<-")
plt.rcParams['font.sans-serif'] = ['SimHei']
#获取树的叶子结点个数(确定图的宽度)
def get_leaf_num(tree):
    leaf_num = 0
    first_key = list(tree.keys())[0]
    next_dict = tree[first_key]
    for key in next_dict.keys():
        if type(next_dict[key]).__name__=="dict":
            leaf_num +=get_leaf_num(next_dict[key])
        else:
            leaf_num +=1
    return leaf_num
#获取数的深度(确定图的高度)
def get_tree_depth(tree):
    depth = 0
    first_key = list(tree.keys())[0]
    next_dict = tree[first_key]
    for key in next_dict.keys():
        if type(next_dict[key]).__name__ == "dict":
            thisdepth = 1+ get_tree_depth(next_dict[key])
        else:
            thisdepth = 1
        if thisdepth>depth: depth = thisdepth
    return depth

def plotNode(nodeTxt, centerPt, parentPt, nodeType):
    createPlot.ax1.annotate(nodeTxt, xy=parentPt, xycoords='axes fraction',
                            xytext=centerPt, textcoords='axes fraction',
                            va="center", ha="center", bbox=nodeType, arrowprops=arrow_args)

#在父子节点间填充文本信息
def plotMidText(cntrPt, parentPt, txtString):
    xMid = (parentPt[0] - cntrPt[0]) / 2.0 + cntrPt[0]
    yMid = (parentPt[1] - cntrPt[1]) / 2.0 + cntrPt[1]
    createPlot.ax1.text(xMid, yMid, txtString, va="center", ha="center", rotation=30)

def plotTree(myTree, parentPt, nodeTxt):
    numLeafs = get_leaf_num(myTree)
    depth = get_tree_depth(myTree)
    firstStr = list(myTree.keys())[0]
    cntrPt = (plotTree.xOff + (1.0 + float(numLeafs)) / 2.0 / plotTree.totalW, plotTree.yOff)
    plotMidText(cntrPt, parentPt, nodeTxt)
    plotNode(firstStr, cntrPt, parentPt, decision_node)
    secondDict = myTree[firstStr]
    plotTree.yOff = plotTree.yOff - 1.0 / plotTree.totalD
    for key in secondDict.keys():
        if type(secondDict[
                    key]).__name__ == 'dict':
            plotTree(secondDict[key], cntrPt, str(key))
        else:
            plotTree.xOff = plotTree.xOff + 1.0 / plotTree.totalW
            plotNode(secondDict[key], (plotTree.xOff, plotTree.yOff), cntrPt, leaf_node)
            plotMidText((plotTree.xOff, plotTree.yOff), cntrPt, str(key))
    plotTree.yOff = plotTree.yOff + 1.0 / plotTree.totalD



def createPlot(inTree):
    fig = plt.figure(1, facecolor='white')
    fig.clf()
    axprops = dict(xticks=[], yticks=[])
    createPlot.ax1 = plt.subplot(111, frameon=False, **axprops)
    plotTree.totalW = float(get_leaf_num(inTree))
    plotTree.totalD = float(get_tree_depth(inTree))
    plotTree.xOff = -0.5 / plotTree.totalW
    plotTree.yOff = 1.0
    plotTree(inTree, (0.5, 1.0), '')
    plt.show()

调用该函数

DrawTree.createPlot(TreeDict)

可视化结果

决策树——ID3算法实现