欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页

Machine Learning-kDtree

程序员文章站 2022-03-31 23:25:39
...

学会用 matplotlib 画树图

import matplotlib.pyplot as plt

decisionNode = dict(boxstyle="sawtooth", fc="0.8")
leafNode = dict(boxstyle="round4", fc="0.8")
arrow_args = dict(arrowstyle="<-")

def plotMidText(cntrPt, parentPt, txtString):
    xMid = (parentPt[0]-cntrPt[0]) / 2.0 +cntrPt[0]
    yMid = (parentPt[1]-cntrPt[1]) / 2.0 +cntrPt[1]
    createPlot.ax1.text(xMid, yMid, txtString)

def plotNode(nodeTxt, centerPt, parentPt, nodeType):
    createPlot.ax1.annotate(nodeTxt, xy=parentPt,  xycoords='axes fraction',
             xytext=centerPt, textcoords='axes fraction',
             va="center", ha="center", bbox=nodeType, arrowprops=arrow_args)

def plotTree(myTree, parentPt, nodeTxt):
    numLeafs = getNumLeafs(myTree)
    getTreeDepth(myTree)
    firstStr = list(myTree.keys())[0]
    cntrPt = (plotTree.xOff + (1.0 + float(numLeafs)) / 2.0 / plotTree.totalW, plotTree.yOff)
    plotMidText(cntrPt, parentPt, nodeTxt)
    plotNode(firstStr, cntrPt, parentPt, decisionNode)
    secondDict = myTree[firstStr]
    plotTree.yOff = plotTree.yOff - 1.0 /plotTree.totalD
    for key in secondDict.keys():
        if type(secondDict[key]).__name__ == 'dict':
            plotTree(secondDict[key], cntrPt, str(key))
        else:
            plotTree.xOff = plotTree.xOff + 1.0 / plotTree.totalW
            plotNode(secondDict[key], (plotTree.xOff, plotTree.yOff), cntrPt, leafNode)
            plotMidText((plotTree.xOff, plotTree.yOff), cntrPt, str(key))
    plotTree.yOff = plotTree.yOff + 1.0 / plotTree.totalD

def createPlot(inTree):
    fig = plt.figure(1, facecolor='white')
    fig.clf()
    axprops = dict(=[], yticks=[])
    createPlot.ax1 = plt.subplot(111, frameon=False, **axprops)
    plotTree.totalW = float(getNumLeafs(inTree))
    plotTree.totalD = float(getTreeDepth(inTree))
    plotTree.xOff = -0.5 / plotTree.totalW; plotTree.yOff = 1.0;
    plotTree(inTree, (0.5, 1.0), '')
    plt.show()

def getNumLeafs(myTree):
    numLeafs = 0
    firstStr = list(myTree.keys())[0]
    secondDict = myTree[firstStr]
    for key in secondDict.keys():
        if type(secondDict[key]).__name__ == "dict":
            numLeafs += getNumLeafs(secondDict[key])
        else:numLeafs += 1
    return numLeafs

def getTreeDepth(myTree):
    maxDepth = 0
    firstStr = list(myTree.keys())[0]
    secondDict = myTree[firstStr]
    for key in secondDict.keys():
        if type(secondDict[key]).__name__ == "dict":
            thisDepth = 1 + getTreeDepth(secondDict[key])

        else: thisDepth = 1
        if thisDepth > maxDepth: maxDepth = thisDepth
    return maxDepth

def retrieveTree(i):
    listOfTrees = [{'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}}}, {'no surfacing': {0: 'no', 1: {'flippers': {0: {'head': {0: 'no', 1: 'yes'}}, 1: 'no'}}}}]
    return listOfTrees[i]

分析一下这几个功能函数的作用:

retrieveTree(i) : 生成一个 dict,使用retrieveTree(0)才能拿到这个 dict


getTreeDepth(myTree):

  1. 拿到一个 dict 的第一个 key

  2. 拿到前面的 key 对应的 value,是另一个 dict (secondDict)

  3. 对 secondDict 的 key 进行循环

    1. 检查 secondDict 的 key 是否是 dict (是 dict 意味着还可以进入)

      • 如果可以进入,递归检测

      • 如果不可以进入,则深度为 1(抛弃)

    2. 拿到最大深度

  4. 返回最大深度

核心步骤是检测 key 的 type 是否为 dict 和递归


getNumLeafs(myTree):

getTreeDepth(myTree) 的不同之处在于:

  1. key 的 type 不是 dict , numLeafs 就 + 1,而不是直接抛弃

  2. 最终返回的是累加的结果,而不是最大值


createTree(inTree)这是确定了 tree 的各个参数,并且绘制了图,是主函数:

def createPlot(inTree):
    fig = plt.figure(1, facecolor='white')
    fig.clf()
    axprops = dict(=[], yticks=[])
    createPlot.ax1 = plt.subplot(111, frameon=False, **axprops)
    plotTree.totalW = float(getNumLeafs(inTree))
    plotTree.totalD = float(getTreeDepth(inTree))
    plotTree.xOff = -0.5 / plotTree.totalW; plotTree.yOff = 1.0;
    plotTree(inTree, (0.5, 1.0), '')
    plt.show()
  1. plt.figure(), num 是当前图的编号
matplotlib.pyplot.figure(num=None, figsize=None, dpi=None, facecolor=None, edgecolor=None, frameon=True, FigureClass=<class 'matplotlib.figure.Figure'>, clear=False, **kwargs)[source]
  1. fig.clf() , clear the current figure.

  2. plt.subplot(),第一个参数 111, 表示横轴的 start number 是 1, 纵轴的 start number 是 1,subplot 的序号是 1。**kwargs 是 key word arguments, 这里指定了 x 轴和 y 轴上的数据标签 list。参见https://devdocs.io/matplotlib~3.1/_as_gen/matplotlib.pyplot.subplot

  3. 拿到宽度 Width( 总 leafs ),和 Depth ( 最大 Depth )

  4. 设置x 和 y 偏移量

  5. 调用 plotTree() 绘制 tree


在看 plotTree() 之前,我们先看一看它的两个子函数:plotMidText()plotNode()

plotMidTxt(cntrPt, parentPt, txtString):

def plotMidText(cntrPt, parentPt, txtString):
    xMid = (parentPt[0]-cntrPt[0]) / 2.0 +cntrPt[0]
    yMid = (parentPt[1]-cntrPt[1]) / 2.0 +cntrPt[1]
    createPlot.ax1.text(xMid, yMid, txtString)

xMid 算的是 x 轴的 parentPt 和 cntrPt 的中点坐标,同理算了 y 轴的中点坐标,然后调用 c把文字添加到相应坐标位置(父子 point 的连线中点)

plotNode(nodeTxt, centerPt, parentPt, nodeType):

def plotNode(nodeTxt, centerPt, parentPt, nodeType):
    createPlot.ax1.annotate(nodeTxt, xy=parentPt,  xycoords='axes fraction',xytext=centerPt, textcoords='axes fraction',va="center", ha="center", bbox=nodeType, arrowprops=arrow_args)

annotate 单词自身的意思是注释,本身是为 plot 添加注释,但是 built-in 的工具可以让你把文字画到 plot 里面

  • nodeTxt 就是将要显示的文字

  • xy = parentPt 表示将要注释的 point 的坐标

  • xycoords = ‘axes fraction’ 表示按照比例(而不是像素值)从轴(而不是整张图片)的左下角开始来绘点

  • xytext = centerPt 表示注释的文字的位置

  • textcoords=‘axes fraction’ 应当是文字的绘制方法,和坐标的类似

  • va=‘center’ 应当是 vertical align,类似地 ha 是 horizon align

  • bbox=nodeType,bbox 属性自身是方块(就是那个节点)的样式,是 dict 类型,而我们的两个 dict 分别是 decisionNode 和 lefaNode,这两个 dict 在最开始的时候便定义好了

  • arrow_args 同上,是箭头样式

这个 plotNode() 函数绘制的是一个箭头加上一个 Node,类似于这样:
Machine Learning-kDtree
到了这里可能大家对 centerPt 和 parentPt 的意义不太理解了,而且对于前面 xOff 以及 yOff 也不太理解,我也一样。

这个我们等会儿到调用它们的时候再看


plotTree(myTree, parentPt, nodeText)

def plotTree(myTree, parentPt, nodeTxt):
    numLeafs = getNumLeafs(myTree)
    getTreeDepth(myTree)
    firstStr = list(myTree.keys())[0]
    cntrPt = (plotTree.xOff + (1.0 + float(numLeafs)) / 2.0 / plotTree.totalW, plotTree.yOff)
    plotMidText(cntrPt, parentPt, nodeTxt)
    plotNode(firstStr, cntrPt, parentPt, decisionNode)
    secondDict = myTree[firstStr]
    plotTree.yOff = plotTree.yOff - 1.0 /plotTree.totalD
    for key in secondDict.keys():
        if type(secondDict[key]).__name__ == 'dict':
            plotTree(secondDict[key], cntrPt, str(key))
        else:
            plotTree.xOff = plotTree.xOff + 1.0 / plotTree.totalW
            plotNode(secondDict[key], (plotTree.xOff, plotTree.yOff), cntrPt, leafNode)
            plotMidText((plotTree.xOff, plotTree.yOff), cntrPt, str(key))
    plotTree.yOff = plotTree.yOff + 1.0 / plotTree.totalD

先看 cntrPt : 它是一个二维元组,它的两个值是当前 decisionNode 的位置:

先看 (plotTree.xOff + (1.0 + float(numLeafs)) / 2.0 / plotTree.totalW)

plotTree.xOff = -0.5 / plotTree.totalW

plotTree.totalW = float(getNumLeafs(inTree))
Machine Learning-kDtree

  1. totalW 是总的叶节点个数,再上图里面, leafNode 的个数其实决定了整个图有多宽

  2. xOff 是偏移量,向左偏移 0.5 / leafNodesNum

    正常情况下我们会使用 1 / number 来均分 x 轴宽度,但是那样会使图像偏左(假如 3 个节点,那三个坐标分别是 1/3, 2/3, 3/3,起始点在 x 轴右边,因此需要加上一个向左的偏移量,移动多少呢? 不能直接又向左移动 1/3,因此移动一半,这样整个图像在 x 轴上才能位于图像中间)

  3. plotTree.xOff + (1.0 + float(numLeafs)) / 2.0 / plotTree.totalW 是当前 decisionNode 的 x 坐标:它位于它的子节点的*位置

参考博客:https://www.cnblogs.com/fantasy01/p/4595902.html

首先由于整个画布根据叶子节点数和深度进行平均切分,并且x轴的总长度为1,即如同下图:

Machine Learning-kDtree

1、其中方形为非叶子节点的位置,@是叶子节点的位置,因此每份即上图的一个表格的长度应该为1/plotTree.totalW,但是叶子节点的位置应该为@所在位置,则在开始的时候plotTree.xOff的赋值为-0.5/plotTree.totalW,即意为开始x位置为第一个表格左边的半个表格距离位置,这样作的好处为:在以后确定@位置时候可以直接加整数倍的1/plotTree.totalW,

2、对于plotTree函数中的红色部分即如下:

cntrPt = (plotTree.xOff + (1.0 + float(numLeafs))/2.0/plotTree.totalW, plotTree.yOff)

plotTree.xOff即为最近绘制的一个叶子节点的x坐标,在确定当前节点位置时每次只需确定当前节点有几个叶子节点,因此其叶子节点所占的总距离就确定了即为float(numLeafs)/plotTree.totalW1(因为总长度为1),因此当前节点的位置即为其所有叶子节点所占距离的中间即一半为float(numLeafs)/2.0/plotTree.totalW1,但是由于开始plotTree.xOff赋值并非从0开始,而是左移了半个表格,因此还需加上半个表格距离即为1/2/plotTree.totalW1,则加起来便为(1.0 + float(numLeafs))/2.0/plotTree.totalW1,因此偏移量确定,则x位置变为plotTree.xOff + (1.0 + float(numLeafs))/2.0/plotTree.totalW

整体来说:

  1. 先拿到所有的 leafNode 数,把整个图像宽度均分为这么多份

  2. 先画出此 decisionNode(坐标由子 node 数和 depth 确定)

  3. 对 Tree dict 的 keys 进行遍历

    1. 如果这个 key 的 value 的 type name 是 dict,递归进去

    2. 如果这个 key 的 value 的 type name 不是 dict, 画出它的这个子 node(宽度总是用)

  4. 恢复 yOff 值

另外还需要解释一下初始时的 plotTree(inTree, (0.5, 1.0), ‘’), 这是因为我们虚拟了一个父 node (* Node 的父 node),它的父 node 和它自身(位置)重合,但是没有内容。

看了这么多绘制 tree 的内容,我们的核心仍然是 classify


回到 tree.py , 我们的主要任务就变成了怎样根据一堆数据生成一个 dict,然后供给 treePlotter 来绘图

def calcShannonEnt(dataSet):
    numEntries = len(dataSet)
    labelCounts = {}
    for featVec in dataSet:
        currentLabel = featVec[-1]
        if currentLabel not in labelCounts.keys():
            labelCounts[currentLabel] = 0
        labelCounts[currentLabel] += 1
    shannonEnt = 0.0
    for key in labelCounts:
        prob = float(labelCounts[key]) / numEntries
        shannonEnt -= prob * log(prob, 2)
    return shannonEnt

这个函数用来计算熵值,熵值会在后面被用于判断哪个属性用来做分类是最合适的。

  1. 先拿到整个 dataSet,类型为 list 的长度

  2. 对 dataSet 中的每一项:

    1. 先拿到最后一项(应该是一项属性值)如果 labelCounts 这个 dict 里面没有这个属性,加上
  3. 相应的属性值 +1

  4. 对 labelCounts 中的每一个元素:

    1. 算出具有这个属性的元素被选中的可能性: probablity(xi)

    2. log2(p(xi)) 称为 information

    3. 熵值 Entropy 就是 information 的期望值

    4. 因此熵值的最终计算公式是
      Machine Learning-kDtree

  5. 最后返回熵值


def splitDataSet(dataSet, axis, value):
    retDataSet = []
    for featVec in dataSet:
        if featVec[axis] == value:
            reducedFeatVec = featVec[:axis]
            reducedFeatVec.extend(featVec[axis+1:])
            retDataSet.append(reducedFeatVec)
    return retDataSet

这只是简单的列表操作:

根据某个 axis 上的值,把所有的元素分为值是 value 的和值不是 value 的。


def chooseBestFeatureToSplit(dataSet):
    numFeatures = len(dataSet[0]) - 1
    baseEntropy = calcShannonEnt(dataSet)
    bestInfoGain = 0.0
    bestFeature = -1
    for i in range (numFeatures):
        featList = [example[i] for example in dataSet]
        uniqueVals = set(featList)
        newEntropy = 0.0
        for value in uniqueVals:
            subDataSet = splitDataSet(dataSet, i, value)
            prob = len(subDataSet) / float(len(dataSet))
            newEntropy += prob *calcShannonEnt(subDataSet)
        infoGain = baseEntropy - newEntropy
        if(infoGain > bestInfoGain):
            bestInfoGain = infoGain
            bestFeature = i
    return bestFeature

(这里假定了数据的一些内容:最后一列是 label)

  1. 先拿到 features 的个数

  2. 计算基础 Entropy(整个 dataSet 的熵,没有被分类过的情况下)

  3. 循环进每一个 feature:

    1. 拿到此 feature 的所有不同值( set 特性)

    2. 循环进每一个 value:

      1. 根据这个 feature 的这个 value 进行分割

      2. 计算新的 entropy

      3. 算出这个 feature 的各个 value 的 entropy 的和

    3. 计算出这个 feature 的 information gain :所有 entropy 之差

总之只要 按照这个 feature 分割之后的 dataSet 的 entropy 之和最小,那么这个 feature 就是 bestFeature,最后返回的是 bestFeature 的 index


def majorityCnt(classList):
    classCount = {}
    for vote in classList:
        if vote not in classCount: classCount[vote] = 0
        classCount[vote] += 1
    sortedClassCount = sorted(classCount.items(), key=operator.itemgetter(1), reverse=True)
    return sortedClassCount

上面的 split 函数可能出现的一个问题是,当跑完所有的 value 之后发现还是有一些元素没法被分类出来(比如某个数据的某个 feature 的 value 缺失,那么它将无法被分类出来)

因此需要确定怎样算是分类结束,于是我们选择了只做二分,不做多分(每一个 feature 只判断一个 value)

这个 value 就是这个 feture 之下出现次数最多的那个


def createTree(dataSet, labels):
    classList = [example[-1] for example in dataSet]
    if classList.count(classList[0]) == len(classList):
        return classList[0]
    if len(dataSet[0]) == 1:
        return majorityCnt(classList)
    bestFeat = chooseBestFeatureToSplit(dataSet)
    bestFeatLabel = labels[bestFeat]
    myTree = {bestFeatLabel:{}}
    del(labels[bestFeat])
    featValues = [example[bestFeat] for example in dataSet]
    uniqueVals = set(featValues)
    for value in uniqueVals:
        subLabels = labels[:]
        myTree[bestFeatLabel][value] = createTree(splitDataSet(dataSet, bestFeat, value), subLabels)
    return myTree

这是真正创建了 Tree 的函数:

  1. 拿到 dataSet 的最后一列( label )

  2. 如果所有的 labels 都相同:

    • 直接返回这个 label
  3. 如果 dataSet 只有一个 feature:

    • 返回这个 feature 出现次数最多的那个 value
  4. 拿到 best feature 的 index

  5. 拿到上面的 index 对应的 label(这个 label 是参数中 labels 的)

  6. 删掉参数 labels 中的 best feature 项

  7. 拿到 dataSet 里 best feature 对应的所有 value 并且去重

  8. 拿到 value (这个 vakue 可能还是一个 dict )之后递归


def classify(inputTree, featLabels, testVec):
    firstStr = list(inputTree.keys())[0]
    secondDict = inputTree[firstStr]
    featIndex = featLabels.index(firstStr)
    for key in secondDict.keys():
        if testVec[featIndex] == key:
            if type(secondDict[key]).__name__ == 'dict':
                classLabel = classify(secondDict[key], featLabels, testVec)
            else: classLabel = secondDict[key]
    return classLabel

最终又回到了我们的 classify 函数,这里直接用的是 Tree(实际上是深层 dict ) 来做 classify

  1. 先拿到 tree 的第一个 key

  2. 判断 testVec 的各个 feature 是否能被分类进 tree

  3. 递归

我们的 k-Dtree 算法大概就到这儿了,核心是创建一个 tree 出来

condDict.keys():
        if testVec[featIndex] == key:
            if type(secondDict[key]).__name__ == 'dict':
                classLabel = classify(secondDict[key], featLabels, testVec)
            else: classLabel = secondDict[key]
    return classLabel

最终又回到了我们的 classify 函数,这里直接用的是 Tree(实际上是深层 dict ) 来做 classify

  1. 先拿到 tree 的第一个 key

  2. 判断 testVec 的各个 feature 是否能被分类进 tree

  3. 递归

我们的 k-Dtree 算法大概就到这儿了,核心是创建一个 tree 出来