《机器学习实战》chapter03 决策树
程序员文章站
2022-07-14 21:03:49
...
分类生成决策树
import operator
from math import log
import pickle
# 计算香农熵
def calcShannonEnt(dataSet):
"""1、计算每个类别的频数"""
numEntries = len(dataSet)
# 类别字典,保存不同类别的频数
labelCounts = {}
for featVec in dataSet:
currentLabel = featVec[-1]
# 如果当前类别不在字典中,将其加入
if currentLabel not in labelCounts.keys():
labelCounts[currentLabel] = 0
# 当前类别数量+1
labelCounts[currentLabel] += 1
"""2、用香农熵公式计算香农熵"""
shannonEnt = 0.0
for key in labelCounts:
prob = float(labelCounts[key]) / numEntries
shannonEnt -= prob * log(prob, 2)
return shannonEnt
def createDataSet():
dataSet = [[1, 1, 'yes'],
[1, 1, 'yes'],
[1, 0, 'no'],
[0, 1, 'no'],
[0, 1, 'no']]
labels = ['no surfacing', 'flippers']
return dataSet, labels
# 划分数据集,以axis索引位的特征为根节点
def splitDataSet(dataSet, axis, value):
retDataSet = []
for featVec in dataSet:
if featVec[axis] == value:
reduceFeatVec = featVec[:axis]
reduceFeatVec.extend(featVec[axis + 1:])
retDataSet.append(reduceFeatVec)
return retDataSet
# 选择最好的数据集划分形式
def choseBestFeatureToSplit(dataSet):
# 特征个数, 有一个是类别(去掉)
numFeature = len(dataSet[0]) - 1
baseEntropy = calcShannonEnt(dataSet)
bestInfoGain = 0.0
bestFeature = -1
# 计算以第i个特征作为划分节点时的信息增益,选择信息增益最大的特征作为划分节点
for i in range(numFeature):
# 取当前数据集的第i个特征(第i列的所有值)
featList = [example[i] for example in dataSet]
# 当前特征的可能取值范围(去重复)
uniqueValues = set(featList)
newEntropy = 0.0
# 计算当前特征的信息增益
for value in uniqueValues:
subDataSet = splitDataSet(dataSet, i, value)
prob = len(subDataSet) / float(len(dataSet))
newEntropy += prob * calcShannonEnt(subDataSet)
infoGain = baseEntropy - newEntropy
# 修正最大信息增益,最优划分节点
if(infoGain > bestInfoGain):
bestInfoGain = infoGain
bestFeature = i
return bestFeature
# 多数表决确定叶子节点的分类
def majorityCnt(classList):
classCount = {}
for vote in classList:
if vote not in classCount.keys():
classCount[vote] = 0
classCount[vote] += 1
sortedClassCount = sorted(classCount.items(), key=operator.itemgetter(1), reverse=True)
return sortedClassCount[0][0]
# 递归构建决策树
def createTree(dataSet, labels):
classList = [example[-1] for example in dataSet]
# 类别完全相同则停止划分
if classList.count(classList[0]) == len(classList):
return classList[0]
# 遍历完所有特征时,返回出现次数最多的类别
if len(dataSet[0]) == 1:
return majorityCnt(classList)
bestFeat = choseBestFeatureToSplit(dataSet)
bestFeatLabel = labels[bestFeat]
myTree = {bestFeatLabel: {}}
del labels[bestFeat]
featVlues = [example[bestFeat] for example in dataSet]
uniqueValues = set(featVlues)
for value in uniqueValues:
# 注意分号,复制labels到subLabels,单独开辟了一块内存空间
# 如果没有分号的则是subLabels指向labels指向的内存
# 会因修改labels内容而出错
subLabels = labels[:]
myTree[bestFeatLabel][value] = createTree(splitDataSet(dataSet, bestFeat, value), subLabels)
return myTree
def classify(inputTree, featLabels, testVec):
firstStr = list(inputTree.keys())[0]
secondDict = inputTree[firstStr]
featIndex = featLabels.index(firstStr)
for key in secondDict.keys():
if testVec[featIndex] == key:
if type(secondDict[key]).__name__ == 'dict':
classLabel = classify(secondDict[key], featLabels, testVec)
else:
classLabel = secondDict[key]
return classLabel
def storeTree(inputTree, fileName):
try:
with open(fileName, 'wb') as fw:
pickle.dump(inputTree, fw)
except IOError as e:
print("File Error : " + str(e))
def grabTree(fileName):
fr = open(fileName, 'rb')
return pickle.load(fr)
使用Matplotlib注解绘制树形图
import matplotlib.pyplot as plt
# boxstyle文本框样式, fc(face color)背景透明度
decisionNode = dict(boxstyle="round4, pad=0.5", fc="0.8")
leafNode = dict(boxstyle="circle", fc="0.8")
# 箭头样式
arrow_args = dict(arrowstyle="<-")
# 绘制节点
def plotNode(nodeTxt, centerPt, parentPt, nodeType):
# 被注释的地方xy(x, y)和插入文本的地方xytext(x, y)
# xycoords和textcoords指定xy和xytext的坐标系。此处是左下角(0.0,0.0),右上角(1.0,1.0)
# 文本在文本框中的va(纵向),ha(横向)居中
createPlot.ax1.annotate(nodeTxt, xy=parentPt, xycoords="axes fraction",
xytext=centerPt, textcoords="axes fraction", va="center",
ha="center", bbox=nodeType, arrowprops=arrow_args)
# 获取叶节点数目
def getNumLeafs(myTree):
numLeafs = 0
# Python3与Python2的区别,先转换成list,再按索引取值
# firstStr = myTree.keys()[0]
firstStr = list(myTree.keys())[0]
# 子树
secondDict = myTree[firstStr]
for key in secondDict.keys():
if type(secondDict[key]).__name__ == 'dict':
# 如果是decisionNode,递归
numLeafs += getNumLeafs(secondDict[key])
else:
# leafNode
numLeafs += 1
return numLeafs
# 获取树的层数
def getTreeDepth(myTree):
maxDepth = 0
# 当前树的根节点
firstStr = list(myTree.keys())[0]
# 子树
secondDict = myTree[firstStr]
for key in secondDict.keys():
if type(secondDict[key]).__name__ == 'dict':
# 如果是decisionNode(有子节点),递归
thisDepth = 1 + getTreeDepth(secondDict[key])
else:
# leafNode,叶子节点
thisDepth = 1
# 修正maxDepth,保证maxDepth是最大值
if thisDepth > maxDepth:
maxDepth = thisDepth
return maxDepth
# 在父子节点之间填充文本信息
def plotMidText(cntrPt, parentPt, txtString):
xMid = (parentPt[0] - cntrPt[0]) / 2.0 + cntrPt[0]
yMid = (parentPt[1] - cntrPt[1]) / 2.0 + cntrPt[1]
createPlot.ax1.text(xMid, yMid, txtString)
# 绘制决策树
def plotTree(myTree, parentPt, nodeTxt):
# 当前树的叶子节点数和深度
numLeafs = getNumLeafs(myTree)
depth = getTreeDepth(myTree)
# 当前根节点
firstStr = list(myTree.keys())[0]
# 修正当前位置,xOff + 当前树的叶子节点数 / 2W + 1 / 2W
# 加1/2W 是因为初始位置是-1/2W,修正这个位置
cntrPt = (plotTree.xOff + (1.0 + float(numLeafs)) / (2.0 * plotTree.totalW), plotTree.yOff)
# 在父子节点间填充文本信息
plotMidText(cntrPt, parentPt, nodeTxt)
# decisionNode,绘制
plotNode(firstStr, cntrPt, parentPt, decisionNode)
# 当前树的子节点
secondDict = myTree[firstStr]
# 深度加1,修正plotTree.yOff - 1/D
plotTree.yOff = plotTree.yOff - 1.0 / plotTree.totalD
# 遍历绘制子节点
for key in secondDict.keys():
if type(secondDict[key]).__name__ == 'dict':
# decisionNode,调用plotTree绘制
plotTree(secondDict[key], cntrPt, str(key))
else:
# 遇到leafNode,修正xOff + 1/W,调用plotNode绘制
plotTree.xOff = plotTree.xOff + 1.0 / plotTree.totalW
plotNode(secondDict[key], (plotTree.xOff, plotTree.yOff), cntrPt, leafNode)
plotMidText((plotTree.xOff, plotTree.yOff), cntrPt, str(key))
plotTree.yOff = plotTree.yOff + 1.0 / plotTree.totalD
def createPlot(inTree):
fig = plt.figure(1, facecolor='white')
fig.clf()
axprops = dict(xticks=[], yticks=[])
createPlot.ax1 = plt.subplot(111, frameon=False, **axprops)
# 树的宽度
plotTree.totalW = float(getNumLeafs(inTree))
# 树的深度
plotTree.totalD = float(getTreeDepth(inTree))
# 初始偏移量-1/2W,每遇到一个叶节点加1/W,使画出来的树尽可能居中
# 如3个叶子(1/6, 1/2, 5/6),4个叶子(1/8, 3/8, 5/8, 7/8)
plotTree.xOff = -0.5 / plotTree.totalW
# 初始深度0,第一层
plotTree.yOff = 1.0
# 绘制图形
plotTree(inTree, (0.5, 1.0), '')
plt.show()
myTree = {'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}, 3: 'maybe'}}
createPlot(myTree)
测试
from chapter3 import treePlotter
from chapter3 import trees
fr = open('lenses.txt')
lenses = [inst.strip().split('\t') for inst in fr.readlines()]
lensesLabels = ['age', 'prescript', 'astigmatic', 'tearRate']
lensesTree = trees.createTree(lenses, lensesLabels)
print(lensesTree)
treePlotter.createPlot(lensesTree)