机器学习:决策树
1 简介
决策树是一种树形结构,由决策树的根结点到叶结点的每一条路径构建一条规则;路径上的内部结点的特征对应着规则的条件,而叶结点对应着分类的结论。
2 算法
2.1 树的构建
在构造决策树时,第一个需要解决的问题就是,如何确定出哪个特征在划分数据分类是起决定性作用,或者说使用哪个特征分类能实现最好的分类效果。这样,为了找到决定性的特征,划分川最好的结果,我们就需要评估每个特征。当找到最优特征后,依此特征,数据集就被划分为几个数据子集,这些数据自己会分布在该决策点的所有分支中。此时,如果某个分支下的数据属于同一类型,则该分支下的数据分类已经完成,无需进行下一步的数据集分类;如果分支下的数据子集内数据不属于同一类型,那么就要重复划分该数据集的过程,按照划分原始数据集相同的原则,确定出该数据子集中的最优特征,继续对数据子集进行分类,直到所有的特征已经遍历完成,或者所有叶结点分支下的数据具有相同的分类。
创建分支的伪代码函数createBranch()如下:
检测数据集中的每一个子项是否属于同一分类:
if so return 类标签;
else
寻找划分数据集的最好特征
划分数据集
创建分支结点
for 每个分支结点
调用函数createBranch并增加返回结点到分支结点中
return 分支结点
下面我们给出使用决策树的一般流程:
(1)收集数据
(2)准备数据:构造树算法只适用于标称型数据,因此数值型数据必须离散化
(3)分析数据
(4)训练数据:上述的构造树过程构造决策树的数据结构
(5)测试算法:使用经验树计算错误率
(6)使用算法:在实际中更好地理解数据内在含义
2.2 划分数据集的原则:信息增益
划分数据集的大原则是:使得无序的数据变得更加有序。
我们可以使用多种方法划分数据集,每种方法都有各自的优缺点,这里我们使用信息增益来度量划分数据前后信息发生的变化,进而指导我们划分数据。
这里我们先引出信息熵的概念:
对于可能被划分在多个分类中的待分类的事务,符号的信息被定义为:
其中)是选择该分类的概率。
为了计算熵,我们需要计算所有类别所有可能值包含的信息期望值,计算公式如下:
2.3 代码实现
1:计算熵的代码:
from math import log
def calEnt(dataSet):
numEntries=len(dataSet)
labelCounts={}
for featVec in dataSet:
currentLabel=featVec[-1]
if currentLabel not in labelCounts.keys():
labelCounts[currentLabel]=0
labelCounts[currentLabel]+=1
Ent=0.0
for key in labelCounts:
prob=float(labelCounts[key])/numEntries
Ent-=prob*log(prob,2)
return Ent
2:划分数据集的代码:
#按照给定特征划分数据集
def splitDataSet(dataSet,axis,value):
retDataSet=[]
for featVec in dataSet:
if featVec[axis]==value:
reducedFeatVec=featVec[:axis]
reducedFeatVec.extend(featVec[axis+1:])
retDataSet.append(reducedFeatVec)
return retDataSet
#选择最好的数据集方式:
def chooseBestFeatureToSplit(dataSet):
numFeatures=len(dataSet[0])-1
baseEntropy=calEnt(dataSet)
bestInfoGain=0.0;bestFeature=-1
for i in range(numFeatures):
featList=[example[i] for example in dataSet]
uniqueVals=set(featList)
newEntropy=0.0
for value in uniqueVals:
subDataSet=splitDataSet(dataSet,i,value)
prob=len(subDataSet)/float(len(dataSet))
newEntropy+=prob*calEnt(subDataSet)
infoGain=baseEntropy-newEntropy
if (infoGain>bestInfoGain):
bestInfoGain=infoGain
bestFeature=i
return bestFeature
3:构建树的代码:
def majorityCnt(classList):
classCount={}
for vote in classList:
if vote not in classCount.keys():
classCount[vote]=0
classCount[vote]+=1
sortedClassCount=sorted(classCount.items,\
key=operator.itemgetter(1),reverse=true)
return sortedClassCount[0][0]
#创建树
def createTree(dataSet,labels):
classList=[example[-1] for example in dataSet]
if classList.count(classList[0])==len(classList):
return classList[0]
if len(dataSet[0])==1:
return majorityCnt(classList)
bestFeat=chooseBestFeatureToSplit(dataSet)
bestFeatLabel=labels[bestFeat]
myTree={bestFeatLabel:{}}
subLabels=labels[:]
del(subLabels[bestFeat])
featValues=[example[bestFeat] for example in dataSet]
uniqueVals=set(featValues)
for value in uniqueVals:
myTree[bestFeatLabel][value]=createTree(splitDataSet\
(dataSet,bestFeat,value),subLabels)
return myTree
4:绘制树的代码
import matplotlib.pyplot as plt
decisionNode = dict(boxstyle="sawtooth",fc="0.8")
leafNode = dict(boxstyle="round4",fc="0.8")
arrow_args = dict(arrowstyle="<-")
def plotNode(nodeTxt, centerPt, parentPt, nodeType):
createPlot.ax1.annotate(nodeTxt,xy=parentPt,xycoords='axes fraction'
,xytext=centerPt,textcoords='axes fraction',
va="center",ha="center",bbox=nodeType,arrowprops=arrow_args)
def getNumLeafs(myTree):
numLeafs = 0
firstStr = list(myTree.keys())[0]
secondDict = myTree[firstStr]
for key in secondDict.keys():
if type(secondDict[key]) == dict:
numLeafs= numLeafs+getNumLeafs(secondDict[key])
else:
numLeafs= numLeafs+1
return numLeafs
def getTreeDepth(myTree):
maxDepth = 0
firstStr = list(myTree.keys())[0]
secondDict = myTree[firstStr]
for key in secondDict.keys():
if type(secondDict[key]) == dict:
thisDepth = 1+ getTreeDepth(secondDict[key])
else:
thisDepth =1
if thisDepth>maxDepth:
maxDepth=thisDepth
return maxDepth
def plotMidText(cntrPt, parentPt, txtString):
xMid = (parentPt[0]-cntrPt[0])/2 +cntrPt[0]
yMid = (parentPt[1]-cntrPt[1])/2 +cntrPt[1]
createPlot.ax1.text(xMid,yMid, txtString)
def plotTree(myTree, parentPt, nodeTxt):
numLeafs = getNumLeafs(myTree)
depth = getTreeDepth(myTree)
firstStr = list(myTree.keys())[0]
cntrPt = (plotTree.xOff + (1.0 + float(numLeafs))/2.0/plotTree.totalW, plotTree.yOff)
plotMidText(cntrPt,parentPt,nodeTxt)
plotNode(firstStr,cntrPt,parentPt,decisionNode)
secondDict = myTree[firstStr]
plotTree.yOff = plotTree.yOff - 1.0/plotTree.totalD
for key in secondDict.keys():
if type(secondDict[key]) == dict:
plotTree(secondDict[key],cntrPt,str(key))
else:
plotTree.xOff = plotTree.xOff + 1.0 / plotTree.totalW
plotNode(secondDict[key], (plotTree.xOff,plotTree.yOff),cntrPt,leafNode)
plotMidText((plotTree.xOff, plotTree.yOff),cntrPt,str(key))
plotTree.yOff = plotTree.yOff + 1.0/plotTree.totalD
def createPlot(intree):
fig = plt.figure(1,facecolor='white')
fig.clf()
axprops = dict(xticks=[],yticks=[])
createPlot.ax1 = plt.subplot(111,frameon=False,**axprops)
plotTree.totalW = float(getNumLeafs(intree))
plotTree.totalD = float(getTreeDepth(intree))
plotTree.xOff = -0.5/plotTree.totalW;plotTree.yOff =1.0;
plotTree(intree,(0.5,1.0),' ')
plt.show()
3 实践
这次我把决策树应用到lenses数据集上,这是一个关于隐形眼镜推荐的分类数据集,训练出的决策树经过我们的代码,绘出的图如下:
4 总结
优点:
1) 可以生成可以理解的规则
2) 计算量相对来说不是很大
3) 对中间值的缺失不敏感
4) 可以清晰的显示哪些字段比较重要
缺点:
1) 对连续性的字段需要进行离散化处理
2) 对有时间顺序的数据,需要很多预处理的工作
3) 当类别太多时,错误可能就会增加的比较快
4) 有时会产生过度匹配的现象
worked by zzzzzr
下一篇: PHP中session使用方法详解