Machine Learning-kDtree
学会用 matplotlib 画树图
import matplotlib.pyplot as plt
decisionNode = dict(boxstyle="sawtooth", fc="0.8")
leafNode = dict(boxstyle="round4", fc="0.8")
arrow_args = dict(arrowstyle="<-")
def plotMidText(cntrPt, parentPt, txtString):
xMid = (parentPt[0]-cntrPt[0]) / 2.0 +cntrPt[0]
yMid = (parentPt[1]-cntrPt[1]) / 2.0 +cntrPt[1]
createPlot.ax1.text(xMid, yMid, txtString)
def plotNode(nodeTxt, centerPt, parentPt, nodeType):
createPlot.ax1.annotate(nodeTxt, xy=parentPt, xycoords='axes fraction',
xytext=centerPt, textcoords='axes fraction',
va="center", ha="center", bbox=nodeType, arrowprops=arrow_args)
def plotTree(myTree, parentPt, nodeTxt):
numLeafs = getNumLeafs(myTree)
getTreeDepth(myTree)
firstStr = list(myTree.keys())[0]
cntrPt = (plotTree.xOff + (1.0 + float(numLeafs)) / 2.0 / plotTree.totalW, plotTree.yOff)
plotMidText(cntrPt, parentPt, nodeTxt)
plotNode(firstStr, cntrPt, parentPt, decisionNode)
secondDict = myTree[firstStr]
plotTree.yOff = plotTree.yOff - 1.0 /plotTree.totalD
for key in secondDict.keys():
if type(secondDict[key]).__name__ == 'dict':
plotTree(secondDict[key], cntrPt, str(key))
else:
plotTree.xOff = plotTree.xOff + 1.0 / plotTree.totalW
plotNode(secondDict[key], (plotTree.xOff, plotTree.yOff), cntrPt, leafNode)
plotMidText((plotTree.xOff, plotTree.yOff), cntrPt, str(key))
plotTree.yOff = plotTree.yOff + 1.0 / plotTree.totalD
def createPlot(inTree):
fig = plt.figure(1, facecolor='white')
fig.clf()
axprops = dict(=[], yticks=[])
createPlot.ax1 = plt.subplot(111, frameon=False, **axprops)
plotTree.totalW = float(getNumLeafs(inTree))
plotTree.totalD = float(getTreeDepth(inTree))
plotTree.xOff = -0.5 / plotTree.totalW; plotTree.yOff = 1.0;
plotTree(inTree, (0.5, 1.0), '')
plt.show()
def getNumLeafs(myTree):
numLeafs = 0
firstStr = list(myTree.keys())[0]
secondDict = myTree[firstStr]
for key in secondDict.keys():
if type(secondDict[key]).__name__ == "dict":
numLeafs += getNumLeafs(secondDict[key])
else:numLeafs += 1
return numLeafs
def getTreeDepth(myTree):
maxDepth = 0
firstStr = list(myTree.keys())[0]
secondDict = myTree[firstStr]
for key in secondDict.keys():
if type(secondDict[key]).__name__ == "dict":
thisDepth = 1 + getTreeDepth(secondDict[key])
else: thisDepth = 1
if thisDepth > maxDepth: maxDepth = thisDepth
return maxDepth
def retrieveTree(i):
listOfTrees = [{'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}}}, {'no surfacing': {0: 'no', 1: {'flippers': {0: {'head': {0: 'no', 1: 'yes'}}, 1: 'no'}}}}]
return listOfTrees[i]
分析一下这几个功能函数的作用:
retrieveTree(i)
: 生成一个 dict,使用retrieveTree(0)
才能拿到这个 dict
getTreeDepth(myTree)
:
-
拿到一个 dict 的第一个 key
-
拿到前面的 key 对应的 value,是另一个 dict (secondDict)
-
对 secondDict 的 key 进行循环
-
检查 secondDict 的 key 是否是 dict (是 dict 意味着还可以进入)
-
如果可以进入,递归检测
-
如果不可以进入,则深度为 1(抛弃)
-
-
拿到最大深度
-
-
返回最大深度
核心步骤是检测 key 的 type 是否为 dict 和递归
getNumLeafs(myTree)
:
与getTreeDepth(myTree)
的不同之处在于:
-
key 的 type 不是 dict , numLeafs 就 + 1,而不是直接抛弃
-
最终返回的是累加的结果,而不是最大值
createTree(inTree)
这是确定了 tree 的各个参数,并且绘制了图,是主函数:
def createPlot(inTree):
fig = plt.figure(1, facecolor='white')
fig.clf()
axprops = dict(=[], yticks=[])
createPlot.ax1 = plt.subplot(111, frameon=False, **axprops)
plotTree.totalW = float(getNumLeafs(inTree))
plotTree.totalD = float(getTreeDepth(inTree))
plotTree.xOff = -0.5 / plotTree.totalW; plotTree.yOff = 1.0;
plotTree(inTree, (0.5, 1.0), '')
plt.show()
-
plt.figure()
, num 是当前图的编号
matplotlib.pyplot.figure(num=None, figsize=None, dpi=None, facecolor=None, edgecolor=None, frameon=True, FigureClass=<class 'matplotlib.figure.Figure'>, clear=False, **kwargs)[source]
-
fig.clf()
, clear the current figure. -
plt.subplot()
,第一个参数 111, 表示横轴的 start number 是 1, 纵轴的 start number 是 1,subplot 的序号是 1。**kwargs 是 key word arguments, 这里指定了 x 轴和 y 轴上的数据标签 list。参见https://devdocs.io/matplotlib~3.1/_as_gen/matplotlib.pyplot.subplot -
拿到宽度 Width( 总 leafs ),和 Depth ( 最大 Depth )
-
设置x 和 y 偏移量
-
调用
plotTree()
绘制 tree
在看 plotTree()
之前,我们先看一看它的两个子函数:plotMidText()
和 plotNode()
plotMidTxt(cntrPt, parentPt, txtString)
:
def plotMidText(cntrPt, parentPt, txtString):
xMid = (parentPt[0]-cntrPt[0]) / 2.0 +cntrPt[0]
yMid = (parentPt[1]-cntrPt[1]) / 2.0 +cntrPt[1]
createPlot.ax1.text(xMid, yMid, txtString)
xMid 算的是 x 轴的 parentPt 和 cntrPt 的中点坐标,同理算了 y 轴的中点坐标,然后调用 c把文字添加到相应坐标位置(父子 point 的连线中点)
plotNode(nodeTxt, centerPt, parentPt, nodeType)
:
def plotNode(nodeTxt, centerPt, parentPt, nodeType):
createPlot.ax1.annotate(nodeTxt, xy=parentPt, xycoords='axes fraction',xytext=centerPt, textcoords='axes fraction',va="center", ha="center", bbox=nodeType, arrowprops=arrow_args)
annotate 单词自身的意思是注释,本身是为 plot 添加注释,但是 built-in 的工具可以让你把文字画到 plot 里面
-
nodeTxt 就是将要显示的文字
-
xy = parentPt 表示将要注释的 point 的坐标
-
xycoords = ‘axes fraction’ 表示按照比例(而不是像素值)从轴(而不是整张图片)的左下角开始来绘点
-
xytext = centerPt 表示注释的文字的位置
-
textcoords=‘axes fraction’ 应当是文字的绘制方法,和坐标的类似
-
va=‘center’ 应当是 vertical align,类似地 ha 是 horizon align
-
bbox=nodeType,bbox 属性自身是方块(就是那个节点)的样式,是 dict 类型,而我们的两个 dict 分别是 decisionNode 和 lefaNode,这两个 dict 在最开始的时候便定义好了
-
arrow_args 同上,是箭头样式
这个 plotNode()
函数绘制的是一个箭头加上一个 Node,类似于这样:
到了这里可能大家对 centerPt 和 parentPt 的意义不太理解了,而且对于前面 xOff 以及 yOff 也不太理解,我也一样。
这个我们等会儿到调用它们的时候再看
plotTree(myTree, parentPt, nodeText)
def plotTree(myTree, parentPt, nodeTxt):
numLeafs = getNumLeafs(myTree)
getTreeDepth(myTree)
firstStr = list(myTree.keys())[0]
cntrPt = (plotTree.xOff + (1.0 + float(numLeafs)) / 2.0 / plotTree.totalW, plotTree.yOff)
plotMidText(cntrPt, parentPt, nodeTxt)
plotNode(firstStr, cntrPt, parentPt, decisionNode)
secondDict = myTree[firstStr]
plotTree.yOff = plotTree.yOff - 1.0 /plotTree.totalD
for key in secondDict.keys():
if type(secondDict[key]).__name__ == 'dict':
plotTree(secondDict[key], cntrPt, str(key))
else:
plotTree.xOff = plotTree.xOff + 1.0 / plotTree.totalW
plotNode(secondDict[key], (plotTree.xOff, plotTree.yOff), cntrPt, leafNode)
plotMidText((plotTree.xOff, plotTree.yOff), cntrPt, str(key))
plotTree.yOff = plotTree.yOff + 1.0 / plotTree.totalD
先看 cntrPt : 它是一个二维元组,它的两个值是当前 decisionNode 的位置:
先看 (plotTree.xOff + (1.0 + float(numLeafs)) / 2.0 / plotTree.totalW)
plotTree.xOff = -0.5 / plotTree.totalW
plotTree.totalW = float(getNumLeafs(inTree))
-
totalW 是总的叶节点个数,再上图里面, leafNode 的个数其实决定了整个图有多宽
-
xOff 是偏移量,向左偏移 0.5 / leafNodesNum
正常情况下我们会使用 1 / number 来均分 x 轴宽度,但是那样会使图像偏左(假如 3 个节点,那三个坐标分别是 1/3, 2/3, 3/3,起始点在 x 轴右边,因此需要加上一个向左的偏移量,移动多少呢? 不能直接又向左移动 1/3,因此移动一半,这样整个图像在 x 轴上才能位于图像中间)
-
plotTree.xOff + (1.0 + float(numLeafs)) / 2.0 / plotTree.totalW 是当前 decisionNode 的 x 坐标:它位于它的子节点的*位置
参考博客:https://www.cnblogs.com/fantasy01/p/4595902.html
首先由于整个画布根据叶子节点数和深度进行平均切分,并且x轴的总长度为1,即如同下图:
1、其中方形为非叶子节点的位置,@是叶子节点的位置,因此每份即上图的一个表格的长度应该为1/plotTree.totalW,但是叶子节点的位置应该为@所在位置,则在开始的时候plotTree.xOff的赋值为-0.5/plotTree.totalW,即意为开始x位置为第一个表格左边的半个表格距离位置,这样作的好处为:在以后确定@位置时候可以直接加整数倍的1/plotTree.totalW,
2、对于plotTree函数中的红色部分即如下:
cntrPt = (plotTree.xOff + (1.0 + float(numLeafs))/2.0/plotTree.totalW, plotTree.yOff)
plotTree.xOff即为最近绘制的一个叶子节点的x坐标,在确定当前节点位置时每次只需确定当前节点有几个叶子节点,因此其叶子节点所占的总距离就确定了即为float(numLeafs)/plotTree.totalW1(因为总长度为1),因此当前节点的位置即为其所有叶子节点所占距离的中间即一半为float(numLeafs)/2.0/plotTree.totalW1,但是由于开始plotTree.xOff赋值并非从0开始,而是左移了半个表格,因此还需加上半个表格距离即为1/2/plotTree.totalW1,则加起来便为(1.0 + float(numLeafs))/2.0/plotTree.totalW1,因此偏移量确定,则x位置变为plotTree.xOff + (1.0 + float(numLeafs))/2.0/plotTree.totalW
整体来说:
-
先拿到所有的 leafNode 数,把整个图像宽度均分为这么多份
-
先画出此 decisionNode(坐标由子 node 数和 depth 确定)
-
对 Tree dict 的 keys 进行遍历
-
如果这个 key 的 value 的 type name 是 dict,递归进去
-
如果这个 key 的 value 的 type name 不是 dict, 画出它的这个子 node(宽度总是用)
-
-
恢复 yOff 值
另外还需要解释一下初始时的 plotTree(inTree, (0.5, 1.0), ‘’), 这是因为我们虚拟了一个父 node (* Node 的父 node),它的父 node 和它自身(位置)重合,但是没有内容。
看了这么多绘制 tree 的内容,我们的核心仍然是 classify
回到 tree.py , 我们的主要任务就变成了怎样根据一堆数据生成一个 dict,然后供给 treePlotter 来绘图
def calcShannonEnt(dataSet):
numEntries = len(dataSet)
labelCounts = {}
for featVec in dataSet:
currentLabel = featVec[-1]
if currentLabel not in labelCounts.keys():
labelCounts[currentLabel] = 0
labelCounts[currentLabel] += 1
shannonEnt = 0.0
for key in labelCounts:
prob = float(labelCounts[key]) / numEntries
shannonEnt -= prob * log(prob, 2)
return shannonEnt
这个函数用来计算熵值,熵值会在后面被用于判断哪个属性用来做分类是最合适的。
-
先拿到整个 dataSet,类型为 list 的长度
-
对 dataSet 中的每一项:
- 先拿到最后一项(应该是一项属性值)如果 labelCounts 这个 dict 里面没有这个属性,加上
-
相应的属性值 +1
-
对 labelCounts 中的每一个元素:
-
算出具有这个属性的元素被选中的可能性: probablity(xi)
-
log2(p(xi)) 称为 information
-
熵值 Entropy 就是 information 的期望值
-
因此熵值的最终计算公式是
-
-
最后返回熵值
def splitDataSet(dataSet, axis, value):
retDataSet = []
for featVec in dataSet:
if featVec[axis] == value:
reducedFeatVec = featVec[:axis]
reducedFeatVec.extend(featVec[axis+1:])
retDataSet.append(reducedFeatVec)
return retDataSet
这只是简单的列表操作:
根据某个 axis 上的值,把所有的元素分为值是 value 的和值不是 value 的。
def chooseBestFeatureToSplit(dataSet):
numFeatures = len(dataSet[0]) - 1
baseEntropy = calcShannonEnt(dataSet)
bestInfoGain = 0.0
bestFeature = -1
for i in range (numFeatures):
featList = [example[i] for example in dataSet]
uniqueVals = set(featList)
newEntropy = 0.0
for value in uniqueVals:
subDataSet = splitDataSet(dataSet, i, value)
prob = len(subDataSet) / float(len(dataSet))
newEntropy += prob *calcShannonEnt(subDataSet)
infoGain = baseEntropy - newEntropy
if(infoGain > bestInfoGain):
bestInfoGain = infoGain
bestFeature = i
return bestFeature
(这里假定了数据的一些内容:最后一列是 label)
-
先拿到 features 的个数
-
计算基础 Entropy(整个 dataSet 的熵,没有被分类过的情况下)
-
循环进每一个 feature:
-
拿到此 feature 的所有不同值( set 特性)
-
循环进每一个 value:
-
根据这个 feature 的这个 value 进行分割
-
计算新的 entropy
-
算出这个 feature 的各个 value 的 entropy 的和
-
-
计算出这个 feature 的 information gain :所有 entropy 之差
-
总之只要 按照这个 feature 分割之后的 dataSet 的 entropy 之和最小,那么这个 feature 就是 bestFeature,最后返回的是 bestFeature 的 index
def majorityCnt(classList):
classCount = {}
for vote in classList:
if vote not in classCount: classCount[vote] = 0
classCount[vote] += 1
sortedClassCount = sorted(classCount.items(), key=operator.itemgetter(1), reverse=True)
return sortedClassCount
上面的 split 函数可能出现的一个问题是,当跑完所有的 value 之后发现还是有一些元素没法被分类出来(比如某个数据的某个 feature 的 value 缺失,那么它将无法被分类出来)
因此需要确定怎样算是分类结束,于是我们选择了只做二分,不做多分(每一个 feature 只判断一个 value)
这个 value 就是这个 feture 之下出现次数最多的那个
def createTree(dataSet, labels):
classList = [example[-1] for example in dataSet]
if classList.count(classList[0]) == len(classList):
return classList[0]
if len(dataSet[0]) == 1:
return majorityCnt(classList)
bestFeat = chooseBestFeatureToSplit(dataSet)
bestFeatLabel = labels[bestFeat]
myTree = {bestFeatLabel:{}}
del(labels[bestFeat])
featValues = [example[bestFeat] for example in dataSet]
uniqueVals = set(featValues)
for value in uniqueVals:
subLabels = labels[:]
myTree[bestFeatLabel][value] = createTree(splitDataSet(dataSet, bestFeat, value), subLabels)
return myTree
这是真正创建了 Tree 的函数:
-
拿到 dataSet 的最后一列( label )
-
如果所有的 labels 都相同:
- 直接返回这个 label
-
如果 dataSet 只有一个 feature:
- 返回这个 feature 出现次数最多的那个 value
-
拿到 best feature 的 index
-
拿到上面的 index 对应的 label(这个 label 是参数中 labels 的)
-
删掉参数 labels 中的 best feature 项
-
拿到 dataSet 里 best feature 对应的所有 value 并且去重
-
拿到 value (这个 vakue 可能还是一个 dict )之后递归
def classify(inputTree, featLabels, testVec):
firstStr = list(inputTree.keys())[0]
secondDict = inputTree[firstStr]
featIndex = featLabels.index(firstStr)
for key in secondDict.keys():
if testVec[featIndex] == key:
if type(secondDict[key]).__name__ == 'dict':
classLabel = classify(secondDict[key], featLabels, testVec)
else: classLabel = secondDict[key]
return classLabel
最终又回到了我们的 classify 函数,这里直接用的是 Tree(实际上是深层 dict ) 来做 classify
-
先拿到 tree 的第一个 key
-
判断 testVec 的各个 feature 是否能被分类进 tree
-
递归
我们的 k-Dtree 算法大概就到这儿了,核心是创建一个 tree 出来
condDict.keys():
if testVec[featIndex] == key:
if type(secondDict[key]).__name__ == 'dict':
classLabel = classify(secondDict[key], featLabels, testVec)
else: classLabel = secondDict[key]
return classLabel
最终又回到了我们的 classify 函数,这里直接用的是 Tree(实际上是深层 dict ) 来做 classify
-
先拿到 tree 的第一个 key
-
判断 testVec 的各个 feature 是否能被分类进 tree
-
递归
我们的 k-Dtree 算法大概就到这儿了,核心是创建一个 tree 出来
上一篇: JS将滚动条保持在最下方
下一篇: Python:计算两个向量的欧式距离
推荐阅读
-
用自定义的节来扩展web.config和machine.config配置文件的结构
-
eclipse编程时出现Fail to create the java Virtual Machine怎么办?
-
天云大数据CEO雷涛:AI建模平台演进趋势着力于Auto Machine Learning
-
could not create the java virtual machine解决办法
-
java.lang.UnsatisfiedLinkError:dlopen failed: “**/*/arm/*.so” has unexpected e_machine: 3
-
解决Eclipse启动出错:Failed to create the Java Virtual Machine
-
Docker Machine深入学习
-
Docker Machine是什么?
-
Implementing a virtual machine in C(虚拟机C语言实现)
-
苹果电脑mac系统备份 mac通过Time Machine实行系统备份与还原方法