欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页

决策树算法

程序员文章站 2022-06-13 15:53:35
...

ID3算法是一种贪心算法,用来构造决策树。ID3算法起源于概念学习系统(CLS),以信息熵的下降速度为选取测试属性的标准,即在每个节点选取还尚未被用来划分的具有最高信息增益的属性作为划分标准,然后继续这个过程,直到生成的决策树能完美分类训练样例。本代码使用ID3算法来构造决策树:

from math import log
import operator

#创建数据集
def createDataSet():
    dataSet = [[1, 1, 'yes'],
               [1, 1, 'yes'],
               [1, 0, 'no'],
               [0, 1, 'no'],
               [0, 1, 'no']]
    labels = ['no surfacing', 'flippers']
    return dataSet, labels

def majorityCnt(classList):
    classCount = {} #创建标签字典
    for vote in classList:
        if vote not in classCount.keys(): classCount[vote] = 0
        classCount[vote] += 1
    sortedClassCount = sorted(classCount.items(), key=operator.itemgetter(1), reverse=True) #按标签个数降序排列
    return sortedClassCount[0][0] #输出标签数量最多的标签值

def createTree(dataSet, labels): #labels为特征的标签集
    classList = [example[-1] for example in dataSet] #数据的标签集
    if classList.count(classList[0]) == len(classList): #count函数统计classList[0]出现次数
        return classList[0]
    if len(dataSet[0]) == 1: #只剩下一列标签
        return majorityCnt(classList)
    bestFeat = chooseBestFeatureToSplit(dataSet) #最好的分类特征
    bestFeatLabel = labels[bestFeat] #最好的分类特征的标签
    myTree = {bestFeatLabel:{}} #建立当前树节点,即此标签下分类的数据
    del(labels[bestFeat])  #删除已选的特征标签值
    featValues = [example[bestFeat] for example in dataSet]  #具体特征值
    uniqueVals = set(featValues) #删除特征重复项
    for value in uniqueVals:
        subLabels = labels[:] #复制新的labels特征标签
        myTree[bestFeatLabel][value] = createTree(splitDataSet(dataSet, bestFeat, value), subLabels) #递归
    return myTree

myDat, labels = createDataSet()
myTree = createTree(myDat, labels)
print(myTree)

>>{'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}}}

1.如果数据集已经处理了所有属性,但是类标签依然不是唯一的,此时使用majorityCnt方法来定义该节点的分类。
2.在createTree中,递归函数的第一个停止条件是所有的类标签完全相同,则直接返回该类标签。递归函数的第二个停止条件是使用完了所有特征,仍然不能讲数据集划分成仅包含唯一类别的分组,所以使用majorityCnt方法。
3.{bestFeatLabel:{}}和myTree[bestFeatLabel][value]是关于字典的嵌套,可以简单测试,例如:

myTree = {"a":{}}
myTree["a"][1]=1313
myTree["a"][2]=3345
print(myTree)

>>{'a': {1: 1313, 2: 3345}}

4.输出结果解释:

{'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}}}

从左边开始,第一个关键字no surfacing是第一个划分数据集的特征名称,该关键字的值也是另一个数据字典。第二个关键字是no surfacing特征划分的数据集,这些关键字的值是no surfacing节点的子节点。这些值可能是类标签,也可能是另一个数据字典。如果值是类标签,则该子节点是叶子节点;如果值是另一个数据字典,则子节点是一个判断节点,不断重复就构成了整棵树。