欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页

机器学习:决策树

程序员文章站 2024-02-03 18:42:46
...

1 简介

  决策树是一种树形结构,由决策树的根结点到叶结点的每一条路径构建一条规则;路径上的内部结点的特征对应着规则的条件,而叶结点对应着分类的结论。


2 算法

2.1 树的构建

  在构造决策树时,第一个需要解决的问题就是,如何确定出哪个特征在划分数据分类是起决定性作用,或者说使用哪个特征分类能实现最好的分类效果。这样,为了找到决定性的特征,划分川最好的结果,我们就需要评估每个特征。当找到最优特征后,依此特征,数据集就被划分为几个数据子集,这些数据自己会分布在该决策点的所有分支中。此时,如果某个分支下的数据属于同一类型,则该分支下的数据分类已经完成,无需进行下一步的数据集分类;如果分支下的数据子集内数据不属于同一类型,那么就要重复划分该数据集的过程,按照划分原始数据集相同的原则,确定出该数据子集中的最优特征,继续对数据子集进行分类,直到所有的特征已经遍历完成,或者所有叶结点分支下的数据具有相同的分类。

创建分支的伪代码函数createBranch()如下:

 检测数据集中的每一个子项是否属于同一分类:
if so return 类标签;
else
    寻找划分数据集的最好特征
    划分数据集
    创建分支结点
        for 每个分支结点
            调用函数createBranch并增加返回结点到分支结点中
    return 分支结点

下面我们给出使用决策树的一般流程:

(1)收集数据
(2)准备数据:构造树算法只适用于标称型数据,因此数值型数据必须离散化
(3)分析数据
(4)训练数据:上述的构造树过程构造决策树的数据结构
(5)测试算法:使用经验树计算错误率
(6)使用算法:在实际中更好地理解数据内在含义

2.2 划分数据集的原则:信息增益

  划分数据集的大原则是:使得无序的数据变得更加有序。
  我们可以使用多种方法划分数据集,每种方法都有各自的优缺点,这里我们使用信息增益来度量划分数据前后信息发生的变化,进而指导我们划分数据。
  这里我们先引出信息熵的概念:
对于可能被划分在多个分类中的待分类的事务,符号机器学习:决策树的信息被定义为:

机器学习:决策树)机器学习:决策树)

其中机器学习:决策树)是选择该分类的概率。
  为了计算熵,我们需要计算所有类别所有可能值包含的信息期望值,计算公式如下:
机器学习:决策树log_2p(x_i))

2.3 代码实现

  1:计算熵的代码:

from math import log
def calEnt(dataSet):
    numEntries=len(dataSet)
    labelCounts={}
    for featVec in dataSet:
        currentLabel=featVec[-1]
        if currentLabel not in labelCounts.keys():
            labelCounts[currentLabel]=0
        labelCounts[currentLabel]+=1
    Ent=0.0
    for key in labelCounts:
        prob=float(labelCounts[key])/numEntries
        Ent-=prob*log(prob,2)
    return Ent

  2:划分数据集的代码:

#按照给定特征划分数据集
def splitDataSet(dataSet,axis,value):
    retDataSet=[]
    for featVec in dataSet:
        if featVec[axis]==value:
            reducedFeatVec=featVec[:axis]
            reducedFeatVec.extend(featVec[axis+1:])
            retDataSet.append(reducedFeatVec)
    return retDataSet

#选择最好的数据集方式:
def chooseBestFeatureToSplit(dataSet):
    numFeatures=len(dataSet[0])-1
    baseEntropy=calEnt(dataSet)
    bestInfoGain=0.0;bestFeature=-1
    for i in range(numFeatures):
        featList=[example[i] for example in dataSet]
        uniqueVals=set(featList)
        newEntropy=0.0
        for value in uniqueVals:
            subDataSet=splitDataSet(dataSet,i,value)
            prob=len(subDataSet)/float(len(dataSet))
            newEntropy+=prob*calEnt(subDataSet)
        infoGain=baseEntropy-newEntropy
        if (infoGain>bestInfoGain):
           bestInfoGain=infoGain
           bestFeature=i
        return bestFeature

  3:构建树的代码:

def majorityCnt(classList):
    classCount={}
    for vote in classList:
        if vote not in classCount.keys():
            classCount[vote]=0
        classCount[vote]+=1
    sortedClassCount=sorted(classCount.items,\
    key=operator.itemgetter(1),reverse=true)
    return sortedClassCount[0][0]

#创建树
def createTree(dataSet,labels):
    classList=[example[-1] for example in dataSet]
    if classList.count(classList[0])==len(classList):
        return classList[0]
    if len(dataSet[0])==1:
        return majorityCnt(classList)
    bestFeat=chooseBestFeatureToSplit(dataSet)
    bestFeatLabel=labels[bestFeat]
    myTree={bestFeatLabel:{}}       
    subLabels=labels[:]
    del(subLabels[bestFeat])
    featValues=[example[bestFeat] for example in dataSet]
    uniqueVals=set(featValues)
    for value in uniqueVals:
    myTree[bestFeatLabel][value]=createTree(splitDataSet\
            (dataSet,bestFeat,value),subLabels)
    return myTree

  4:绘制树的代码

import matplotlib.pyplot as plt


decisionNode = dict(boxstyle="sawtooth",fc="0.8")
leafNode = dict(boxstyle="round4",fc="0.8")
arrow_args = dict(arrowstyle="<-")

def plotNode(nodeTxt, centerPt, parentPt, nodeType):
    createPlot.ax1.annotate(nodeTxt,xy=parentPt,xycoords='axes fraction'
                            ,xytext=centerPt,textcoords='axes fraction',
                                va="center",ha="center",bbox=nodeType,arrowprops=arrow_args)


def getNumLeafs(myTree):
    numLeafs = 0
    firstStr = list(myTree.keys())[0]
    secondDict = myTree[firstStr]
    for key in secondDict.keys():
        if type(secondDict[key]) == dict:
            numLeafs= numLeafs+getNumLeafs(secondDict[key])
        else:
            numLeafs= numLeafs+1
    return numLeafs

def getTreeDepth(myTree):
    maxDepth = 0
    firstStr = list(myTree.keys())[0]
    secondDict = myTree[firstStr]
    for key in secondDict.keys():
        if type(secondDict[key]) == dict:
            thisDepth = 1+ getTreeDepth(secondDict[key])
        else:
            thisDepth =1
        if thisDepth>maxDepth:
            maxDepth=thisDepth
    return maxDepth

def plotMidText(cntrPt, parentPt, txtString):
    xMid = (parentPt[0]-cntrPt[0])/2 +cntrPt[0]
    yMid = (parentPt[1]-cntrPt[1])/2 +cntrPt[1]
    createPlot.ax1.text(xMid,yMid, txtString)

def plotTree(myTree, parentPt, nodeTxt):
    numLeafs = getNumLeafs(myTree)
    depth = getTreeDepth(myTree)
    firstStr = list(myTree.keys())[0]
    cntrPt = (plotTree.xOff + (1.0 + float(numLeafs))/2.0/plotTree.totalW,  plotTree.yOff)
    plotMidText(cntrPt,parentPt,nodeTxt)
    plotNode(firstStr,cntrPt,parentPt,decisionNode)
    secondDict = myTree[firstStr]
    plotTree.yOff = plotTree.yOff - 1.0/plotTree.totalD
    for key in secondDict.keys():
        if type(secondDict[key]) == dict:
            plotTree(secondDict[key],cntrPt,str(key))
        else:
            plotTree.xOff = plotTree.xOff + 1.0 / plotTree.totalW
            plotNode(secondDict[key],   (plotTree.xOff,plotTree.yOff),cntrPt,leafNode)
            plotMidText((plotTree.xOff, plotTree.yOff),cntrPt,str(key))
    plotTree.yOff = plotTree.yOff + 1.0/plotTree.totalD


def createPlot(intree):
    fig = plt.figure(1,facecolor='white')
    fig.clf()
    axprops = dict(xticks=[],yticks=[])
    createPlot.ax1 = plt.subplot(111,frameon=False,**axprops)
    plotTree.totalW = float(getNumLeafs(intree))
    plotTree.totalD = float(getTreeDepth(intree))
    plotTree.xOff = -0.5/plotTree.totalW;plotTree.yOff =1.0;
    plotTree(intree,(0.5,1.0),' ')
    plt.show()

3 实践

  这次我把决策树应用到lenses数据集上,这是一个关于隐形眼镜推荐的分类数据集,训练出的决策树经过我们的代码,绘出的图如下:

机器学习:决策树


4 总结

  优点:
  1) 可以生成可以理解的规则
  2) 计算量相对来说不是很大
  3) 对中间值的缺失不敏感
  4) 可以清晰的显示哪些字段比较重要
  缺点:
  1) 对连续性的字段需要进行离散化处理
  2) 对有时间顺序的数据,需要很多预处理的工作
  3) 当类别太多时,错误可能就会增加的比较快
  4) 有时会产生过度匹配的现象


worked by zzzzzr

  
机器学习:决策树