欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页

机器学习实战---ID3决策树

程序员文章站 2024-02-16 10:46:16
...

今天给大家带来一个很好用的分类和预测方法—决策树

决策树是一种监督算法,以树状图为基础,输出结果为一系列简单实用的规则,有点像if-then语句,这里偷一个图:
机器学习实战---ID3决策树
我们看到,决策树就是把数据按照其内在特征进行一层一层的分类,从而发现数据中蕴含的一些潜在信息从而帮助人们更好的进行事情的决策,这样说的话就很像数据挖掘了,的确在数据挖掘中经常会用到决策树

构建决策树一般流程

(1) 收集数据:可以使用任何方法。
(2) 准备数据:树构造算法只适用于标称型数据,因此数值型数据必须离散化。
(3) 分析数据:可以使用任何方法,构造树完成之后,我们应该检查图形是否符合预期。
(4) 训练算法:构造树的数据结构。
(5) 测试算法:使用经验树计算错误率。
(6) 使用 算法 :此步骤可以适用于任何监督学习算法,而使用决策树可以更好地理解数据的内在含义。

而一个具体的决策树的构建主要分为三个步骤:
1.特征选择: 按照最优特征进行划分数据
2.决策数生成 :根据特征划分生成决策树
3.决策树剪枝 : 由于按照训练数据进行生成决策树的时候每个数据都被考虑到了,因此会产生过拟合问题,也就是过分的满足这些数据而丧失对一般性的分类,因此要将一些不是很重要的分类去掉

特征选择

上面讲到,决策树非常重要的一点就是按照数据特征进行分类,而用哪一个特征进行分类的效果将会对决策树算法的结果产生很大的影响(选取有较强分类能力的特征)至于分类方法也有很多种:ID3,C4.5,CART,这里我们使用的是ID3算法
划分数据的很重要的原则就是,将无序的数据变得更加有序,在划分数据前后信息发生的变化称为信息增益,而如何评测信息增益量呢?有熵,基尼系数和方差,这里我们使用求熵的方法来计算我们的各种划分数据的信息增益

说到熵,我们就要向克劳德.香农致敬—-一个伟大的天才
那么熵是什么呢?

在信息论中,熵的本质是一个系统“内在的混乱程度,这里正好用来衡量我们的数据的有序性
机器学习实战---ID3决策树
其中p(xi)为选择分类的概率
熵的存在是个很难解释的高深存在,但是我们只需要直到我们为何要用到熵以及怎么用就够了

首先先构建一个简短的数据用来实验

def createDataSet():
    dataSet = [[1,1,'yes'],
               [1,1,'yes'],
               [1,0,'no'],
               [0,1,'no'],
               [0,1,'no']]
    labels = ['no surfacing' , 'flippers']
    return dataSet, labels

下面我们给出计算熵的程序:

def calcSHannonEnt(dataSet):   #计算数据的熵
    numEnteries = len(dataSet)
    labelCounts = {}
    for featVec in dataSet:
        currentLabel = featVec[-1]
        if currentLabel not in labelCounts.keys():
            labelCounts[currentLabel] = 0
        labelCounts[currentLabel] += 1
    shannonEnt = 0
    for key in labelCounts:
        prob = float(labelCounts[key])/numEnteries  #该分类的概率
        shannonEnt -= prob* log(prob, 2)  #香农求熵
    return shannonEnt

很容易理解我们首先计算出样本个数,然后为将所有可能的分类划分出来,然后按照求熵的方法进行计算就行

上面我们知道了如果衡量一个分类方法的好坏,下面我们就开始划分数据,
我们选择对样本的每个特征进行划分然后求出信息熵,最后判断哪个划分方式是最好的
由于要重复进行划分,我们将划分方式单独写一个函数


# 按照特征划分数据集
def splitDataSet(dataSet, axis, value): # 待划分数据集, 划分数据集的特征, 特征的返回值
    retDataSet = []
    for featVec in dataSet:
        if featVec[axis] == value:
            reducedFeatVec = featVec[:axis]    # 将符合特征的数据抽取出来,这里每个数据样本是有多个特征的
            reducedFeatVec.extend(featVec[axis+1:])
            retDataSet.append(reducedFeatVec)
    return retDataSet

划分方式就是找到进行划分的特征,然后将包含这个特征的样本抽取出来
下面来选择最好的划分方式

def chooseBestFeatureToSplit(dataSet):
    numFeatures = len(dataSet[0]) - 1  # 找到要进行划分的特征数
    baseEntropy = calcSHannonEnt(dataSet) # 计算数据熵
    bestInfoGain = 0.0  # 信息增益
    bestFeature = -1    # 特征
    for i in range(numFeatures):
        featList = [example[i] for example in dataSet]
        uniqueVals = set(featList)  # 转化为集合,得到列表中唯一元素
        newEntropy = 0.0  # 熵
        for value in uniqueVals:
            subDataSet = splitDataSet(dataSet, i, value)  # 计算每种划分方式结果
            prob = len(subDataSet)/float(len(dataSet))    # 求概率
            newEntropy += prob * calcSHannonEnt(subDataSet)  # 对所有唯一特征值进行熵求和
        infoGain = baseEntropy - newEntropy
        if (infoGain > bestInfoGain):
            bestInfoGain = infoGain
            bestFeature = i
    return bestFeature

这里我们的思想为: 首先我们先测一下原始数据熵以及要划分的特征数,然后分别进行分类,然后求出其划分后的熵与其他分类方法进行对比,找到熵最少的返回其划分方法

决策树生成

现在我们已经知道如何对数据进行划分能达到最高增益,然后开始按照这个方法去划分数据并生成决策树
先说下划分步骤:
首先得到最好的属性值划分数据集,如果特征还没划分完,就将数据传递到树分支的下一个节点,继续进行划分
递归结束的两个条件:
1.每个分支下的所有实例都具有相同的分类,直接返回类标签
2.程序遍历完所有划分数据集的属性,仍然不能将数据集护发费为仅包含唯一类别的分组,这时候我们找出现最多的类别作为返回值

这里我们先处理一下第二种情况


def majorityCat(classList):
    classCount= {}
    for vote in classList:
        if vote not in classCount.keys():
            classCount[vote] = 0
        classCount[vote] += 1
    sortedClassCount = sorted(classCount.iteritems(), key= operator.itemgetter(1), reversed= True)# 进行排序
    return sortedClassCount[0][0]

也就是当我们遍历完所有特征但仍未将数据集划分为仅包含唯一类别分组时,返回出现最多的类别
然后是开始创建树:

def createTree(dataSet, labels):
    classList = [example[-1] for example in dataSet]
    if classList.count(classList[0]) == len(classList):  # count统计字符出现的次数
        return classList[0]   # 类别完全相同则停止继续划分
    if len(dataSet[0]) == 1:    # 遍历完所有特征时还未完全相同返回出现最多的
        return majorityCat(classList)
    bestFeat = chooseBestFeatureToSplit(dataSet)  # 找到最好的划分方式
    bestFeatLabel = labels[bestFeat]
    myTree = {bestFeatLabel: {}}
    del(labels[bestFeat])
    featValues = [example[bestFeat] for example in dataSet]  # 获得列表包含所有属性值
    uniqueVals = set(featValues)
    for value in uniqueVals:
        subLabels = labels[:]
        myTree[bestFeatLabel][value] = createTree(splitDataSet(dataSet, bestFeat, value), subLabels)
        # 上面呢这里进行递归循环,在每次进入循环的时候,会首先将用来进行划分的特征去掉 
    return myTree
测试

如何认为我们构建的树是对的呢?
需要决策树以及用于构造树的标签向量。然 后 , 程序比较测试数据与决策树上的数值

def classify(inputTree, featLabels, testVec):  # 使用决策树的分类函数
    firstStr = list(inputTree.keys())[0]
    secondDict = inputTree[firstStr]
    featIndex = featLabels.index(firstStr)  # 将标签字符串转换为索引,即开始出现firstStr字符的索引
    for key in secondDict.keys():
        if testVec[featIndex] == key:
            if type(secondDict[key]).__name__ == 'dict':
                classLabel = classify(secondDict[key], featLabels, testVec)
            else:
                classLabel = secondDict[key]
    return classLabel
画图

我们已经完成了图创建,现在我们可以将这个图使用matplotlib库将其展现出来,由于对绘画方面天分不高,这里只提供代码,就不进行讲解

# encoding:utf-8
import matplotlib.pyplot as plt

decisionNode = dict(boxstyle= "sawtooth", fc = '0.8')
leafNode= dict(boxstyle= 'round4', fc= '0.8')
array_args = dict(arrowstyle= "<-")      # 定义文本框和箭头格式

def plotNode(nodeTxt, centerPt, parentPt, nodeType):  # 绘制带箭头的注解
    createPlot.axl.annotate(nodeTxt, xy= parentPt, xycoords= 'axes fraction', xytext= centerPt,
                            textcoords= 'axes fraction',va= 'center', ha= 'center', bbox= nodeType,
                            arrowprops= array_args)

def createPlot():
    fig= plt.figure(1, facecolor= 'white')
    fig.clf()
    createPlot.axl = plt.subplot(111, frameon= False)
    plotNode('a decision node', (0.5, 0.1), (0.1, 0.5), decisionNode)
    plotNode('a leaf node', (0.8, 0.1),(0.3, 0.8), leafNode)
    plt.show()

def getNumLeafs(myTree):  # 获取叶节点数
    numLeafs = 0
    firstStr = list(myTree.keys())[0]  # 获取第一个键
    secondDict = myTree[firstStr]
    for key in secondDict.keys():
        if type(secondDict[key]).__name__ == 'dict':  # 测试节点的数据类型是否为字典
            numLeafs += getNumLeafs(secondDict[key])
        else:
            numLeafs += 1
    return numLeafs

def getTreeDepth(myTree):  # 获得数层数
    maxDepth = 0
    firstStr = list(myTree.keys())[0]
    secondDict = myTree[firstStr]
    for key in secondDict.keys():
        if type(secondDict[key]).__name__ == 'dict':
            thisDepth = 1 + getTreeDepth(secondDict[key])
        else:
            thisDepth = 1
        if thisDepth > maxDepth:
            maxDepth = thisDepth
    return maxDepth

def retrieveTree(i):
    listOfTrees = [{'no surfacing' : {0: 'no', 1: {'flippers' : {0: 'no', 1: 'yes'}}}},
                   {'no surfacing' : {0: 'no', 1: {'flippers' : {0: {'head': {0: 'no', 1: 'yes'}}}}}}]
    return listOfTrees[i]

def plotMidText(cntrPt, parentPt, txtString):  # 在父子节点之间填充文本信息
    xMid = (parentPt[0]-cntrPt[0])/2.0 + cntrPt[0]
    yMid = (parentPt[1]- cntrPt[1])/2.0 + cntrPt[1]
    createPlot.axl.text(xMid, yMid, txtString)

def plotTree(myTree, parentPt, nodeText):  # 对这个画图问题真的头大
    numLeafs = getNumLeafs(myTree)
    depth = getTreeDepth(myTree)
    firstStr = list(myTree.keys())[0]   # plotTree.totalW 存储树宽度, totalD 存储树深度
    cntrPt = (plotTree.xOff + (1.0 + float(numLeafs))/2.0/plotTree.totalW, plotTree.yOff)
    plotMidText(cntrPt, parentPt, nodeText)   # 标记子节点属性值
    plotNode(firstStr, cntrPt, parentPt, decisionNode)
    secondDict = myTree[firstStr]
    plotTree.yOff= plotTree.yOff - 1.0/plotTree.totalD
    for key in secondDict.keys():                           #减少y偏移
        if type(secondDict[key]).__name__ == 'dict':
            plotTree(secondDict[key], cntrPt, str(key))
        else:
            plotTree.xOff = plotTree.xOff + 1.0/plotTree.totalW
            plotNode(secondDict[key], (plotTree.xOff, plotTree.yOff), cntrPt, leafNode)
            plotMidText((plotTree.xOff, plotTree.yOff), cntrPt, str(key))
    plotTree.yOff = plotTree.yOff + 1.0/plotTree.totalD

def createPlot(inTree):   # 绘图
    fig = plt.figure()
    fig.clf()
    axprops = dict(xticks= [], yticks = [])
    createPlot.axl = plt.subplot(111, frameon = False, **axprops)
    plotTree.totalW = float(getNumLeafs(inTree))
    plotTree.totalD = float(getTreeDepth(inTree))
    plotTree.xOff = -0.5/plotTree.totalW
    plotTree.yOff = 1.0
    plotTree(inTree, (0.5, 1.0), '')
    plt.show()

一般来说,我们构造出一个决策树要花费很多的时间,因此我们不能每次需要进行分类的时候都再构建一次决策树,这里我们在每次构建好一个决策树的时候,最好是把它存储下来,下次用到的时候直接用就好了,
我们用的是python的pickle模块

def storeTree(inputTree, filename):
    import pickle
    fw = open(filename, 'wb')
    pickle.dump(inputTree, fw)
    fw.close()

def grabTree(filename):
    import pickle
    fr = open(filename)
    return pickle.load(fr)