欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页

决策树实战

程序员文章站 2022-05-21 23:30:01
...

本文摘自“机器学习实战”中案例,在此对其进行了代码更新与简单注释。感兴趣者可回复资源需求!

问题描述

现有一份海洋生物数据表,如下图所示:

不浮出水面是否可以生存 是否有脚蹼 是否鱼类
1
2
3
4
5
要求:根据表中两个特征“不浮出水面是否可以生存”、“是否有脚蹼”以及标签“是否鱼类”,构造决策树,并预测不浮出水面不可以生存、没有脚蹼的海洋生物是否为鱼类。

一般流程

1、准备数据

因为没有大量样本存储于文档中,故在次没有将文档样本内的数据转成可以处理的数据形式,而是直接简单创造,如下所示:
将特征值“是”表示为1,“否”表示为0;标签中用“yes”、“no”表示。

def createDataSet():
    dataset=[[1,1,'yes'],
             [1,1,'yes'],
             [1,0,'no'],
             [0,1,'no'],
             [0,1,'no']]
    labels=['no surfing','flippers'] #特征名称
    return dataset,labels
2、划分数据集

划分数据集的最大原则:将无序的数据变得更加有序。使用信息论度化信息量是将数据变得更加有序的方法之一。划分数据集前后信息发生的变化称为信息增益,获得信息增益最高的特征作为每次划分的依据。在计算每种划分方式的信息增益之前,需要计算相应数据集的香农熵。
以上信息有想深入了解者,可自行查询。不甚了解不影响解题。
(1)计算给定数据集的香农熵

#计算给定数据集的香农熵
from math import log
def calcShannongEnt(dataset):
    numEntries=len(dataset)
    labelCount={} #统计所有类标签的发生频率
    for featVec in dataset:
        currentLabel=featVec[-1]
        if currentLabel not in labelCount.keys():
            labelCount[currentLabel]=0
        labelCount[currentLabel]+=1
    shannongEnt=0.0  #该数据集的香农熵
    for key in labelCount:
        prob=float(labelCount[key])/numEntries
        shannongEnt=-prob*log(prob,2)
    return shannongEnt

(2)按照给定的特征划分数据集

'''
例如将dataset=[[1,1,'yes'],
             [1,1,'yes'],
             [1,0,'no'],
             [0,1,'no'],
             [0,1,'no']]
 根据第0个特征(axis=0)“no surfing”,以及特征值为1(value=1)进行划分数据集,结果为
 [[1, 'yes'], 
 [1, 'yes'], 
 [0, 'no']]
'''
#按照给定的特征划分数据集
def spliteDataSet(dataset,axis,value):  #axis为划分数据集的特征,value为划分数据集的特征值
    retDataSet=[]
    for featVec in dataset:
        if featVec[axis]==value:
            reducedFeatVec=featVec[:axis]
            #[1,2].extend([3,4])结果为[1,2,3,4]
            #[1,2].append([3,4]结果为[1,2,[3,4]]
            reducedFeatVec.extend(featVec[axis+1:])
            retDataSet.append(reducedFeatVec)
    return retDataSet

(3)选择最好的数据集划分方式,即选取当前数据集中信息增益最高的特征

#选择最好的数据集划分方式
def chooseBestFeatureToSplit(dataset):
    # 条件:数据是由列表元素组成的列表,而且所有的列表元素具有相同的数据长度
    numFeatures=len(dataset[0])-1
    baseEntropy=calcShannongEnt(dataset)
    bestinfogain=0.0 ; bestFeature=-1
    for i in range(numFeatures):
        featList=[example[i] for example in dataset]
        #.set()将列表转化为每个值都不相同的集合
        uniqueVals=set(featList)
        newEntropy=0.0
        #计算每种划分方式的信息熵
        for value in uniqueVals:
            subdataset=spliteDataSet(dataset,i,value)
            prob=len(subdataset)/float(len(dataset))
            newEntropy+=prob*calcShannongEnt(subdataset)
        infogain=baseEntropy-newEntropy #信息增益是熵的减少
        if infogain>bestinfogain:
            bestinfogain=infogain
            bestFeature=i
    return bestFeature

3、构建决策树
#返回出现次数最多的分类名称
def majorityCnt(classList):
    classCount={}
    for vote in classList:
        if vote not in classCount.keys():
            classCount[vote]=0
        classCount[vote]+=1
    sortedClassCount=sorted(classCount.items(),key=lambda x:x[1],reverse=True)
    return sortedClassCount[0][0]
#创建决策树
def createTree(dataset,labels):
    classList=[example[-1] for example in dataset]  #包含了数据集中所有类别的标签
    if classList.count(classList[0])==len(classList):
        return classList[0]
    # 使用完了所有特征,仍然不能将数据集划分成仅包含唯一类别的分组,此时返回出现次数最多的类别
    if len(dataset[0])==1:
        return majorityCnt(classList)
    bestFeat=chooseBestFeatureToSplit(dataset)
    bestFeatLabel=labels[bestFeat]
    myTree={bestFeatLabel:{}}
    del(labels[bestFeat])
    featValue=[example[bestFeat] for example in dataset]
    uniqueValues=set(featValue)
    for value in uniqueValues:
        sublabels=labels[:]
        myTree[bestFeatLabel][value]=createTree(spliteDataSet(dataset,
                                                              bestFeat,
                                                              value),sublabels)
    return myTree

'''
例如将dataset=[[1,1,'yes'],
             [1,1,'yes'],
             [1,0,'no'],
             [0,1,'no'],
             [0,1,'no']]
构造成决策树,结果为
{'no surfing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}}}
'''
4、测试和存储分类器
#使用决策树的分类函数
def classify(inputTree,featLabels,testVec):
    firstStr=list(inputTree.keys())[0]  #决策树中的第一个特征名称
    featIndex=featLabels.index(firstStr)
    secondTree=inputTree[firstStr]  #决策树中的第一个特征在所有特征中的索引
    for key in secondTree:
        if testVec[featIndex]==key:
            if type(secondTree[key])==dict:
                classLabel=classify(secondTree,featLabels,testVec)
            else: classLabel=secondTree[key]
    return classLabel
#为了避免每次分类时都需要重新创建决策树,将决策树用pickle模块存储
def storeTree(inputTree,filename):
    import pickle
    fw=open(filename,'wb') #以二进制格式打开一个文件只用于写入
    pickle.dump(inputTree,fw)
    fw.close()
def grabTree(filename):
    import pickle
    fr=open(filename,'rb') #以二进制格式打开一个文件用于只读
    return pickle.load(fr)

预测功能,大家可以写成函数形式

myData,labels=createDataSet()
global_labels=labels[:] #将labels赋值给全局变量global_labels,若global_labels=labels则为引用传递
mytree=createTree(myData,labels) #labels在创造决策树时被修改
storeTree(mytree,'classifierStorage.txt')
#现有一种海洋生物:不浮出水面不可以生存(0),没有脚蹼(0)。预测其是否为鱼类
mytree=grabTree('classifierStorage.txt')
print(mytree)
if classify(mytree,global_labels,[0,0])=='no':
    print('不是鱼类')
else: print('是鱼类')
相关标签: 决策树实战