欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页

决策树算法及python实现

程序员文章站 2022-05-21 23:47:26
...

参考书:周志华-西瓜书

# 数据:
#       不浮出水面是否可以生存      是否有脚蹼     是不是鱼
# 1              是                   是            是
# 2              是                   是            是
# 3              是                   否            否
# 4              否                   是            否
# 5              否                   是            否
from math import log
import operator


# 获取数据集
def createDateSet():
    dataSet = [[1, 1, 'yes'],
               [1, 1, 'yes'],
               [1, 0, 'no'],
               [0, 1, 'no'],
               [0, 1, 'no']]

    labels = ['no surfacing', 'flippers']

    return dataSet, labels


# 求熵
def calcShannonEnt(dataSet):
    numEntries = len(dataSet)
    labelCounts = {}

    for featVec in dataSet:  # the number of unique , featVec:特征向量
        currentLabel = featVec[-1]  # 即向量featVec的第三个维度:yes,no。相当于两类样本。
        if currentLabel not in labelCounts.keys():  # currentLabel是样本类别,计算数据集dataSet中,该类样本的数量。
            labelCounts[currentLabel] = 1
        else:
            labelCounts[currentLabel] += 1

    shannonEnt = 0.0

    for key in labelCounts:
        prob = float(labelCounts[key]) / numEntries  # 计算第key类样本所占的比例
        shannonEnt -= prob * log(prob, 2)  # 计算当前结点的信息熵

    return shannonEnt


# 划分数据
#  function:找出数据集中,所有在属性axis上取值为value的样本,组成一个子集retDataSet。
#  parameter:待划分的数据集,划分数据集的特征,需要返回的特征的值
#  return:返回值即公式中的 D^v
def splitDataSet(dataSet, axis, value):
    retDataSet = []  # 即公式中的D^v
    for featVec in dataSet:
        if featVec[axis] == value:
            reducedFeatVec = featVec[:axis]  # 取不到axis这一列
            reducedFeatVec.extend(featVec[axis + 1:])
            retDataSet.append(reducedFeatVec)
    return retDataSet
#
# #  测试:
# myDat,labels = createDateSet()
# print(splitDataSet(myDat, 0, 1))
# print(splitDataSet(myDat, 0, 0))
# #  输出:
# [[1, 'yes'], [1, 'yes'], [0, 'no']]
# [[1, 'no'], [1, 'no']]


# function:求出当前属性集合中信息增益最大的特征
# parameter:样本数据集
# return ;信息增益最大的属性,对应的列号
def chooseBestFeatureToSplit(dataSet):
    numFeatures = len(dataSet[0]) - 1  # 统计数据集的特征个数,减1是减去序号所在列。—— 对应公式4.2中|D|,即当前属性集合中属性的个数。
    baseEntropy = calcShannonEnt(dataSet)  # 计算香农熵 —— 对应公式4.2中的Ent(D),即待划分结点的信息熵

    bestInfoGain = 0.0  # 最大的信息增益
    bestFeature = -1  # 最大信息增益 对应的 属性

    for i in range(numFeatures):
        featList = [example[i] for example in dataSet]  # 第i个特征的所有值
        uniqueVals = set(featList)  # 去除feastList中的重复元素
        newEntropy = 0.0

        for value in uniqueVals:
            subDataSet = splitDataSet(dataSet, i, value)  # 即公式中4.2中的 D^v,在属性i上,取值为value的样本组成的集合
            prob = len(subDataSet) / float(len(dataSet))  # 即公式中4.2中的 |D^v|/|D|
            newEntropy += prob * calcShannonEnt(subDataSet)  # 即公式中所有|D^v|/|D|*Ent(D^v)相加的和

        infoGain = baseEntropy - newEntropy  # infoGain即信息增益,即公式中的Gain(D,a)

        if (infoGain > bestInfoGain):  # 选出当前属性集合中所有特征,
            bestInfoGain = infoGain  # if better than current
            bestFeature = i

    return bestFeature  # 信息增益最大的属性,对应的列号

# 迭代到最后时,若所有属性都分裂了,剩下的数据项还是同一个类别,那该怎么划分类别?

# 思想:
# 经过N轮判断后到最后一列了,这时所有特征值已经全部分裂后了,然而剩余的数据项依然不是同一类别。
# 由于该算法要求最终明确划分,所以采用投票机制,即按照少数服从多数的思想进行分类。
# 具体操作:在余下的数据项中选择分类中个数最最多的分类作为该结点的分类。

# 返回值:样本数最多的类别
def majorityCnt(classList):
    classCount = {}
    for vote in classList:
        if vote not in classCount.keys():
            classCount[vote] = 0

        classCount[vote] += 1

    sortedClassCount = sorted(classCount.items(), key=operator.itemgetter(1), reverse=False)  ## 注意:python3.5中,iteritems()变为items()
    return sortedClassCount[0][0]  # 注意,sortedClassCount是一个list,元素是tuple


# 递归 求出决策树
def createTree(dataSet, labels):
    classList = [example[-1] for example in dataSet]  # 当前样本的类别集合,即“好瓜”or“坏瓜”

    # 若当前结点包含的样本全属于同一个类别,无需划分
    if classList.count(classList[0]) == len(classList):  # 统计类别为classList[0]的数据有多少,若与classList的长度相等,说明,classList中的数据属于同一类
        return classList[0]

    if len(dataSet[0]) == 1:
        return majorityCnt(classList)  # 返回样本数量最多的类别

    bestFeat = chooseBestFeatureToSplit(dataSet)  # 求出当前属性集合中,信息增益最高的属性 所对应的列号
    bestFeatLabel = labels[bestFeat]  # 当前属性集合中,信息增益最高的属性 的名字
    myTree = {bestFeatLabel: {}}  # 决策树 由属性的名字 组成,是个嵌套的字典
    del (labels[bestFeat])
    featValues = [example[bestFeat] for example in dataSet]
    uniqueVals = set(featValues)
    for value in uniqueVals:
        subLabels = labels[:]  # 除去信息增益最高的属性后,剩余的属性
        myTree[bestFeatLabel][value] = createTree(splitDataSet(dataSet, bestFeat, value),subLabels)  # 继续划分子树

    return myTree


myDat, labels = createDateSet()
myTree = createTree(myDat, labels)

print(myTree)
相关标签: 决策树 python