ID3决策树程序实现
原文地址:https://blog.csdn.net/hongbin_xu/article/details/78516114
前言
之前的博客中介绍了决策树算法的原理并进行了数学推导(机器学习入门学习笔记:(3.1)决策树算法)。决策树的原理相对简单,决策树算法有:ID3,C4.5,CART等算法。接下来将对ID3决策树算法进行程序实现,参考了《机器学习实战》一书。这篇博客也作为自己个人的学习笔记,以便自己以后温习。
伪代码以及算法流程
伪代码:
创建分支的伪代码函数createBranch():
检测数据集中每一个子项是否属于统一分类:
If so return 类标签
Else
寻找划分数据集的最好特征
划分数据集
创建分支结点
for 每个划分的子集
调用函数createBranch()并增加返回结果到分支结点中
return 分支结点
算法流程:
决策树的一般流程:
(1)收集数据:可以使用任何方法。
(2)准备数据:树构造算法只适用于标称型数据,因此数值型数据必须离散化。
(3)分析数据:可以使用任何方法,构造树完成之后,我们应该检查图形是否符合预期。
(4)训练算法:构造树的数据结构。
(5)测试算法:使用经验树计算错误率。
(6)使用算法:此步骤可以适用于任何监督学习算法,而使用决策树可以更好地理解数据的内在含义。
程序实现
导入模块
import operator
import copy
from math import log
operator:是python中的一个标准库,包含了Python的各种内置操作符,诸如逻辑、比较、计算等。而我们后面要使用的是operator.itemgetter,后面碰到了再说。
copy:这个库就是字面意思,浅拷贝。(有关浅拷贝和深拷贝请参考这篇文章:python的复制,深拷贝和浅拷贝的区别)
math:后面要用到对数函数,所以导入log函数。
计算给定数据的香农熵
# 计算给定数据集的香农熵
def calcShannonEnt(dataSet):
numEntries = len(dataSet) # 数据总数
labelCounts = {} # 标签计数,字典类型
for featVec in dataSet: # 遍历每一组数据
currentLabel = featVec[-1] # 最后一列为标签
if currentLabel not in labelCounts.keys(): # 如果标签之前没有出现过,则新建一个字典的元素
labelCounts[currentLabel] = 0
labelCounts[currentLabel] += 1 # 该元素+1
shannonEnt = 0.0 # 香农熵
for key in labelCounts: # 遍历每一组标签
prob = float(labelCounts[key]) / numEntries # 计算概率
shannonEnt -= prob * log(prob, 2) # 套用香农公式
return shannonEnt
这段程序比较简单。
- 首先计算输入数据集的样本总数。
- 随后创建字典labelCounts,用以保存所有出现过得样本,如果新加入的这个样本的类别没有出现过,则将其新创建一个键(key),而它对应的键值(value)就是这个类别出现过的次数。每个键值都会记录当前这个类别出现的次数,所以每次都将键值(value)加一。
- 字典中的每一个键的键值都统计好了每种类别出现的次数,套用公式计算香农熵。
计算香农熵的公式:
Ent(D)=−∑k=1|γ|pklog2pkEnt(D)=−∑k=1|γ|pklog2pk
测试看看:
再在程序中加入一个函数:
# 创建数据集
def createDataSet():
dataSet = [ [1, 1, 'yes'], [1, 1, 'yes'], [1, 0, 'no'], [0, 1, 'no'], [0, 1, 'no'] ]
labels = ['no surfing', 'flippers']
return dataSet, labels
python命令行下进行测试:
划分数据集
# 按照给定特征划分数据集
def splitDataSet(dataSet, axis, value):
retDataSet = []
for featVec in dataSet:
if featVec[axis] == value: #抽取出这个元素
reducedFeatVec = featVec[:axis]
reducedFeatVec.extend(featVec[axis+1:])
retDataSet.append(reducedFeatVec)
return retDataSet
在数据集dataSet中,axis表示第几组特征,value则是那一组特征的某个取值。
这个函数的功能就是根据原数据集中第axis组特征,其中只要值为value,则将它抽取出来;retDataSet用来保存最后剩下的那些数据集,并返回。
看下测试示例就很好理解了:
选择最好的数据集划分方式
# 选择最好的数据集划分方式
def chooseBestFeatureToSplit(dataSet):
numFeatures = len(dataSet[0]) - 1 # 总特征数
baseEntropy = calcShannonEnt(dataSet) # 初始时计算一次香农熵
bestInfoGain = 0.0 # 最佳信息增益,越大越好;初始默认为0
bestFeature = -1 # 最佳划分属性
for i in range(numFeatures): # 遍历每一个特征
# 找到该特征所有可能的取值,存入uniqueVals列表中
# 这个是我常用的写法,功能一样,但是较为复杂
# uniqueVals = []
# for example in dataSet:
# if example[i] not in uniqueVals:
# uniqueVals.append(example[i])
featList = [example[i] for example in dataSet]
uniqueVals = set(featList)
newEntropy = 0.0
for value in uniqueVals: # 遍历所有可能取值
subDataSet = splitDataSet(dataSet, i, value) # 对数据集分类
prob = len(subDataSet) / float(len(dataSet)) # 计算概率
newEntropy += prob * calcShannonEnt(subDataSet) # 对熵加权求和
infoGain = baseEntropy - newEntropy # 求到最后对应这个特征的熵
if(infoGain > bestInfoGain): # 比较,取最大的熵
bestInfoGain = infoGain
bestFeature = i
return bestFeature
这里可以说是决策树算法的核心部分了。前面的博客(机器学习入门学习笔记:(3.1)决策树算法)中介绍了几种选择最佳划分方式的方法:信息增益(Information Gain)、信息率(Gain Ratio)、基尼指数(Gini Index)。我们这里要实现的是ID3决策树算法,其中使用的是信息增益(Information Gain),当然也是最简单的一个。
再次给出信息增益的公式:
之前我们已经写了计算香农熵的函数了,使用
calcShannonEnt()
函数计算出香农熵,套用上面公式就可以计算出信息增益了。这里不过多介绍概念,如果对概念不熟悉,请查看(机器学习入门学习笔记:(3.1)决策树算法)。 简述一下程序流程,结合注释不难读懂:
我们要遍历所有的特征,首先要遍历每个特征,i表示第几个特征,使用for循环实现;随后再来遍历这个特征下每个可能取值uniqueVals,又要遍历,使用for循环实现;接下来,要判断当前的特征和值是不是最佳的划分属性,怎么判断?使用前面的函数
splitDataSet()
将数据集划分一下,再计算划分数据集之后的子数据集的香农熵,好了,接下来就要使用信息增益了,计算信息增益,比较最大的那个就是我们要的最佳划分属性的情况。
有点绕,最好还是看程序吧。
测试看看:
程序运行结果告诉我们,0是当前最佳的用于划分数据集的特征。
统计出现次数最多的分类
def majorityCnt(classList):
# 字典对象标记了每个标签出现的概率
classCount = {}
for vote in classList:
if vote not in classList.keys():
classCount[vote] = 0
classCount[vote] += 1
# 使用operator操作键值排序字典
sortedClassCount = sorted(classCount.iteritems(), key=operator.itemgetter(1), reverse=True)
return sortedClassCount[0][0] # 返回出现次数最多的分类
输入样本集的标签,统计每种类别出现的次数,若没有,则新建一个键保存这个类,每次循环将当前的类的键值加1。最后使用sorted()
方法进行排序,key=operator.itemgetter(1)
指定了第1维,即字典classCount中所有键的键值,统计每种类别的次数;reverse=True
表示倒序,从大到小排序。最后返回第一个键的键值就是出现次数最多的那个。
通常在递归中,决策树无法再继续递归,即到达树节点时,就需要统计当前剩下的未分类的样本中出现最多的分类,将这个叶节点的结果取为那个出现次数最多的类,这时使用这个函数。
构建决策树
# 创建决策树的函数代码
def createTree(dataSet, labels):
# 这里是浅拷贝
# 注意:在创建决策树的过程中会删减labels中的成员,为了防止原始的labels也被更改,这里复制了一个新的labels出来
labelsTemp = copy.copy(labels) # 拷贝一个labels列表
classList = [example[-1] for example in dataSet] # 统计dataSet中所有的标签
# 递归的第一个停止条件
# 如果所有的类标签都是相同的,则可以停止递归
if classList.count(classList[0]) == len(classList):
return classList[0]
# 递归的第二个停止条件
# 如果先遍历完了所有特征,即无法简单地返回唯一的类标签,则使用出现次数最多的作为返回值
if len(dataSet[0]) == 1:
return majorityCnt(classList)
bestFeat = chooseBestFeatureToSplit(dataSet) # 选择最适合分类的特征
bestFeatLabel = labelsTemp[bestFeat] # 该特征对应的标签
myTree = {bestFeatLabel:{}} # 构建树,以字典来表示
del(labelsTemp[bestFeat]) # 删除这个最佳划分的特征的标签
# 统计该特征所含的所有属性值
featValues = [example[bestFeat] for example in dataSet]
uniqueVals = set(featValues)
# 只要还可以划分就继续递归调用
for value in uniqueVals:
subLabels = labelsTemp[:]
myTree[bestFeatLabel][value] = createTree(splitDataSet(dataSet, bestFeat, value), subLabels)
return myTree
这里是最后递归生成决策树的函数。有两个输入参数:dataSet,labels。dataSet是样本数据集,labels是标签列表。
最初使用labelsTemp = copy.copy(labels)
是为了对labels列表进行浅拷贝,改变labelsTemp不会影响labels。
使用递归创建决策树,那么最关心的当然是递归的结束条件。有三个递归的结束条件:
- 递归的第一个停止条件:如果所有的类标签都是相同的,则可以停止递归;
- 递归的第二个停止条件:如果先遍历完了所有特征,即无法简单地返回唯一的类标签,则使用出现次数最多的作为返回值;
- 递归的第三个停止条件:到达叶节点,无法继续划分,返回myTree。
测试结果看看:
结果是一个嵌套字典,no surfing
为划分属性,他的键值包含字典{0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}}
。如果它的键值是0
,则到达叶节点;结果是no
,如果是1
,则还需要继续划分。
使用决策树的分类函数
# 测试算法:使用决策树进行分类
def classify_ID3(inputTree, featLabels, testVec):
firstStr = inputTree.keys()[0]
# print(inputTree.keys())
secondDict = inputTree[firstStr]
featIndex = featLabels.index(firstStr) # 将标签字符串转换为索引
for key in secondDict.keys():
if(testVec[featIndex] == key):
if type(secondDict[key]).__name__ == 'dict': # 如果还有节点,则继续递归
classLabel = classify_ID3(secondDict[key], featLabels, testVec)
else: # 如果到达叶子节点,则返回当前节点的分类标签
classLabel = secondDict[key]
return classLabel
输入中有三个参数:inputTree是决策树,前面使用createTree()
函数生成;featLabels是样本集的标签列表;testVec是要预测的变量。
首先使用index()
方法,寻找标签集中匹配的firstStr
的那一项。目的是找到那一项在标签列表中是第几个,随后可以知道我们的测试数据集testVec
中对应特征的值testVec[featIndex]
。再往下就是循环遍历整个树,只要不到叶节点,就不断调用递归。这样遍历完整棵树,最后返回预测结果。
测试:
决策树的保存和读取
# 使用pickle模块存储决策树
def storeTree(inputTree, filename):
import pickle
fw = open(filename, 'w')
pickle.dump(inputTree, fw)
fw.close()
def grabTree(filename):
import pickle
fr = open(filename, 'r')
return pickle.load(fr)
这里都是调用python的pickle模块实现保存和读取。
测试如下:
数据可视化
这里直接贴上机器学习实战中的代码,复杂些,重点是决策树算法,这部分不做介绍。主要思想就是遍历整颗决策树,求出深度和叶节点个数,由深度计算出决策树的高度,由叶节点数计算出决策树的宽度。分配每个节点占的位置,画框显示结点对应的属性,递归过程中如果碰到叶节点则画出叶节点,若不是则递归调用。
# coding: utf-8
import matplotlib.pyplot as plt
# 使用文本注解绘制树结点
decisionNode = dict(boxstyle="sawtooth", fc="0.8")
leafNode = dict(boxstyle="round4", fc="0.8")
arrow_args = dict(arrowstyle="<-")
def plotNode(nodeText, centerPt, parentPt, nodeType):
createPlot.ax1.annotate(nodeText, xy=parentPt, xycoords='axes fraction',\
xytext=centerPt, textcoords='axes fraction', va='center', ha='center', \
bbox=nodeType, arrowprops=arrow_args)
# matplot绘图函数
# def createPlot():
# fig = plt.figure(1, facecolor='white')
# fig.clf()
# createPlot.ax1 = plt.subplot(111, frameon=False)
# plotNode('Decision Node', (0.5, 0.1), (0.1, 0.5), decisionNode)
# plotNode('Leaf Node', (0.8, 0.1), (0.3, 0.8), leafNode)
# plt.show()
# 获取叶子节点的数目和树的层数
def getNumLeaves(myTree):
numLeaves = 0
firstStr = myTree.keys()[0] # 决策树的第一个键值
secondDict = myTree[firstStr] # 字典中第一个key所对应的value,即下面的目录
for key in secondDict.keys():
if type(secondDict[key]).__name__=='dict': # 如果是字典类型,则继续递归
numLeaves += getNumLeaves(secondDict[key])
else: # 如果不是字典类型,则停止递归
numLeaves += 1
return numLeaves
def getTreeDepth(myTree):
maxDepth = 0
thisDepth = 0
firstStr = myTree.keys()[0]
secondDict = myTree[firstStr]
for key in secondDict.keys():
if type(secondDict[key]).__name__=='dict':
thisDepth += getTreeDepth(secondDict[key])
else:
thisDepth = 1
if(thisDepth > maxDepth):
maxDepth = thisDepth
return maxDepth
# 预先存储树的信息,方便读取测试
def retrieveTree(i):
listOfTrees =[{'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}}},
{'no surfacing': {0: 'no', 1: {'flippers': {0: {'head': {0: 'no', 1: 'yes'}}, 1: 'no'}}}}
]
return listOfTrees[i]
def plotMidText(centerPt, parentPt, txtStr):
xMid = (parentPt[0] + centerPt[0]) / 2.0
yMid = (parentPt[1] + centerPt[1]) / 2.0
createPlot.ax1.text(xMid, yMid, txtStr)
def plotTree(myTree, parentPt, nodeTxt):
firstStr = myTree.keys()[0]
numLeaves = getNumLeaves(myTree)
depth = getTreeDepth(myTree)
centerPt = (plotTree.xOff + (1.0 + float(numLeaves))/2.0/plotTree.totalW, plotTree.yOff)
# centerPt = (0.5, plotTree.yOff)
plotMidText(centerPt, parentPt, nodeTxt)
plotNode(firstStr, centerPt, parentPt, decisionNode)
secondDict = myTree[firstStr]
plotTree.yOff -= 1.0 / plotTree.totalD # 由于绘图是自顶向下的,所以y的偏移要递减
for key in secondDict.keys():
if type(secondDict[key]).__name__ == 'dict':
plotTree(secondDict[key], centerPt, str(key))
else:
plotTree.xOff += 1.0 / plotTree.totalW
plotNode(secondDict[key], (plotTree.xOff, plotTree.yOff), centerPt, leafNode)
plotMidText((plotTree.xOff, plotTree.yOff), centerPt, str(key))
plotTree.yOff += 1.0 / plotTree.totalD # 在绘制了所有子节点后,增加y的偏移
def createPlot(inTree):
fig = plt.figure(1, facecolor='white')
fig.clf()
axprops = dict(xticks=[], yticks=[])
createPlot.ax1 = plt.subplot(111, frameon=False, **axprops)
plotTree.totalW = float(getNumLeaves(inTree))
plotTree.totalD = float(getTreeDepth(inTree))
plotTree.xOff = -0.5/plotTree.totalW
plotTree.yOff = 1.0
plotTree(inTree, (0.5, 1.0), '')
plt.show()
使用决策树预测实例
lenses.txt中是隐形眼睛数据集。隐形眼镜数据集是非常著名的数据集 , 它包含很多患者眼部状况的观察条件以及医生推荐的隐形眼镜类型 。隐形眼镜类型包括硬材质 、软材质以及不适合佩戴 隐形眼镜 。数据来源于UCI数据库 ,为了更容易显示数据 , 将数据存储在源代码下载路径的文本文件中。
python命令行下输入:
import trees
import treePlotter
fr = open('lenses.txt')
lenses = [inst.strip().split('\t') for inst in fr.readlines()]
lensesLabels = ['age', 'prescript', 'astigmatic', 'tearRate']
lensesTree = trees.createTree(lenses, lensesLabels)
lensesTree
treePlotter.createPlot(lensesTree)
运行结果:
以上就是使用ID3决策树算法的程序实现。虽然最后很好地拟合了样本集,但是这些结点不免让人觉得太多了,即有可能过拟合(overfiting)了。为了减小过拟合的风险,我们还可以对决策树进行剪枝,而剪枝又有:预剪枝和后剪枝。剪枝后能在保证同等性能的前提下去掉一些不必要的结点。
后记
这篇博客是我个人学习机器学习实战中的笔记,写下来也有不少字数了。
http://www.cnblogs.com/fydeblog/p/7159775.html
工程下载链接:
百度云盘:链接: https://pan.baidu.com/s/1eSeRQIQ 密码: 3zwm
参考资料:
- 《机器学习》周志华
- 《机器学习实战》
上一篇: 决策树+Python3实现ID3