欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页

一文搞懂决策树的ID3算法

程序员文章站 2024-02-15 15:11:41
...

决策树是机器学习中一个比较重要的算法,和其他机器学习算法不一样的是,你不懂过多的数学理论知识,也能理解这个算法的原理。前几天我看到一篇文章,大概内容是利用决策树来预测世界杯最终的冠军是谁,最终预测结果是巴西(巴西好像已经凉凉了),不过这并不能影响你学习今天的算法。好了,废话不多说,进入今天的主题。

一听决策树,就知道这个算法和树形数据结构有关。确实如此,决策树本质上是一个树形结构,但和我们熟知的二叉树又不太一样,它是一个多叉树。这是我在百度上找的一张样图

一文搞懂决策树的ID3算法

首先要说明一下决策树的两大用途:分类和回归。其实很多的机器学习算法都可以做分类与回归,比如支持向量机(SVM),它有支持向量分类机和支持向量回归机,关于支持向量机这里不做深究,以后会专门写一篇关于支持向量机的文章。本文只用决策树来解决分类问题,而不探讨回归(其实是我还没学)。

为了接下来更好的描述决策树的整个算法流程,我找了一个具体的例子,一步一步分析。
一文搞懂决策树的ID3算法

这张表是拥有15个训练样本的贷款申请数据集。四个特征分别是年龄、是否有工作、是否有自己的房子、信贷情况。最后一列表示是否同意此人贷款申请。接下来我们要做的就是根据这个样本数据集,做一个合理的决策,对于未知的申请人,是否同意申请贷款。

就目前来说,我们已知的只有样本数据集,然后利用某种算法,希望通过样本数据集来学习出一组分类规则,并且这个分类规则要与样本数据集的矛盾尽可能小,这样才算是一组合理的分类规则,决策树就是一个不错的选择。

决策树中除了叶节点以外的节点,都是代表的特征,通过特征来不断划分数据集,使得最终叶节点上的数据集是“纯的”(数据集标签要么全为是,要么全为否),如果能选出这样的特征,那么我们就能轻易的判断未知样本到底该属于哪一类。比如下面这种特征选择。

一文搞懂决策树的ID3算法

不过难就难在我们应该如何选取最佳特征来划分数据集。这个就要从信息论的知识说起了。

信息论中,有一个很重要的概念就是香农熵,下面直接给出随机变量的香农熵计算公式:

一文搞懂决策树的ID3算法

当香农熵越大时,不确定性就越大,我们所能获取的信息量就越少。为了更加直观的理解这一句话,我针对随机变量只有两种取值的时候进行讨论。

当n = 2时,表达式如下,画出函数图像:

一文搞懂决策树的ID3算法

可以看出,当概率为0.5的时候,熵值最大,也就代表这个时候所能获得的信息量就越少。那这是为什么呢?直观上可以这样理解,当有人告诉你X = 0和X = 1的概率都是0.5时,你觉得他说的是不是废话,你既然说他们都是等可能的,那我根本不能判断谁的取值可能性更大,也就是说这个取值在概率上讲,是没有偏向性的,那我得到的信息量自然就是最少的,也就是香农熵最大。本身熵这个概念就有混乱的含义,熵越大,就越混乱,我们就很难提取出有用的信息。

下面要介绍另一个概念,条件熵。顾名思义,跟条件概率有关,定义为:X给定条件下Y的条件概率分布的熵对X的数学期望,公式如下:
一文搞懂决策树的ID3算法
然后今天的主角就登场了——信息增益

特征A对训练数据集D的信息增益g(D,A),定义为集合D的经验熵H(D)与特征A给定条件下D的条件熵H(D,A)之差,即

一文搞懂决策树的ID3算法
这个定义简直不是人话,翻译成人话就是说:我们先计算出训练集D的熵,然后计算出在给定特征A的条件下,计算出条件熵H(D,A),由于熵能表示信息的混乱度,两个熵值之差那不就代表信息混乱减少的程度么,换句话说,也就是在给定特征A的条件下,我们得到的信息增加的程度,那就是信息增益。决策树中选定哪个特征划分数据集,也就是看哪个特征的信息增益最大,信息增益最大对应的特征就是最佳特征

现在知道了怎么选取最佳特征,剩下的就是递归创建决策树了,这一部分需要较强的数据结构知识,我不做过多的探究,直接贴上代码:

#计算香农熵
def calcShannonEnt(dataSet):

    numEntries = len(dataSet)
    labelCounts = {}
    for featVec in dataSet:
        currentLabel = featVec[-1]
        if currentLabel not in labelCounts.keys():
            labelCounts[currentLabel] = 0
        labelCounts[currentLabel] += 1

    shannonEnt = 0.0
    for key in labelCounts:
        prob = float(labelCounts[key])/numEntries
        shannonEnt -= prob*math.log(prob, 2)

    return shannonEnt

#划分数据集
def splitDataSet(dataSet, axis, value):

    retDataSet = []
    for featVec in dataSet:
        if featVec[axis] == value:
            reducedFeatVec = featVec[:axis]
            reducedFeatVec.extend(featVec[axis+1:])
            retDataSet.append(reducedFeatVec)
    return retDataSet

#计算最佳特征
def chooseBestFeatureToSplit(dataSet):


    # 每个样例的特征数目
    numFeatures = len(dataSet[0]) - 1
    # 计算香农熵
    baseEntropy = calcShannonEnt(dataSet)
    # 初始化最佳信息增益
    bestInfoGain = 0.0
    # 初始化用来划分特征空间的最佳特征
    bestFeatures = -1


    for i in range(numFeatures):   # 遍历所有特征
        featList = [example[i] for example in dataSet]  # 把数据集中第i个特征存入列表
        uniqueVals = set(featList)  # 去除重复的特征取值
        newEntropy = 0.0  # 初始化条件熵
        for value in uniqueVals:  # 遍历第i个特征的所有可能取值
            subDataSet = splitDataSet(dataSet, i, value)  # 划分数据集
            prob = len(subDataSet)/float(len(dataSet))  # 计算该特征取值的频率
            newEntropy += prob*calcShannonEnt(subDataSet)  # 计算条件熵
        infoGain = baseEntropy - newEntropy  # 计算信息增益
        if infoGain > bestInfoGain:
            bestInfoGain = infoGain  # 找出最佳信息增益
            bestFeatures = i  # 找出划分特征空间的最佳特征
    return bestFeatures

#返回递归终止条件
def majorityCnt(classList):

    classCount = {}
    for vote in classList:
        if vote not in classCount.keys():
            classCount[vote] = 0
        classCount[vote] += 1
    sortedClassCount = sorted(classCount.items(),key=operator.itemgetter(1),reverse=True)
    return sortedClassCount[0][0]

#递归创建树
def createTree(dataSet, labels):

    classList = [example[-1] for example in dataSet]
    if classList.count(classList[0]) == len(classList):
        return classList[0]
    if len(dataSet[0]) == 1:
        return majorityCnt(classList)
    bestFeat = chooseBestFeatureToSplit(dataSet)
    bestFeatLabel = labels[bestFeat]
    myTree = {bestFeat:{}}
    del (labels[bestFeat])
    featValues = [example[bestFeat] for example in dataSet]
    uniqueVals = set(featValues)
    for value in uniqueVals:
        subLabels = labels[:]
        myTree[bestFeatLabel][value] = createTree(splitDataSet(dataSet, bestFeat, value),subLabels)
    return myTree

对于决策树的ID3算法还有几点补充说明:

(1)递归的终止条件还有一种情况,当所有特征都已经用完的时候,也是递归的终止条件,这个时候的叶节点标签可能“不纯”,需要通过多数表决的方法决定返回值,这个类似于knn算法的思想。

(2)ID3算法有一个致命的缺点,它过分的要求决策树去匹配数据集,使得模型在很多时候会过于复杂,就产生了机器学习中一个致命的点——overffiting(过拟合),这个时候就需要降低模型复杂度,采用的技巧是——剪枝,这就涉及到后面的C4.5和CART算法,就更加复杂了,还涉及到动态规划的思想。

这篇文章主要是为了梳理最近学习决策树算法整个流程,其中肯定会有一些不恰当的描述,欢迎大家指正,一起交流,一起学习。