欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页

机器学习笔记--机器学习实战CART算法错误

程序员文章站 2022-06-18 10:54:19
...

分类与回归树(Classification And Regression Tree,CART)生成过程中:
对回归树用平方误差最小化准则;
对分类树用基尼系数最小化准则.

使用<机器学习实战>第九章中介绍CART算法的代码,用平方误差最小化准则构造回归树,发现代码部分有问题:
问题处:
机器学习笔记--机器学习实战CART算法错误

机器学习笔记--机器学习实战CART算法错误

更改后:

# !/usr/bin/env python
# coding:utf-8

from numpy import *

# 加载数据
def loadDataSet(fileName):
    dataMat = []
    fr = open(fileName)
    for line in fr.readlines():
        curLine = line.strip().split('\t')
        fltLine = map(float,curLine)
        dataMat.append(fltLine)
    return dataMat

# 生成叶节点
def regLeaf(dataSet):
    return mean(dataSet[:,-1])

# 误差估计函数(平方误差)
def regErr(dataSet):
    return var(dataSet[:,-1]) * shape(dataSet)[0]

# (数据集合,待切分的特征,该特征的某个值)
# 将数据切分得到两个子集并返回
def binSplitDataSet(dataSet, feature, value):
    # nonzero返回非零元素的索引
    mat0 = dataSet[nonzero(dataSet[:,feature]>value)[0], :]
    mat1 = dataSet[nonzero(dataSet[:,feature]<=value)[0], :]
    return mat0, mat1

# 树构建函数
# (数据集,建立叶节点函数,误差计算函数)
def createTree(dataSet, leafType=regLeaf, errType=regErr, ops=(1,4)):
    # # 获得当前的(最佳切分特征,阈值)
    feat, val = chooseBestSplit(dataSet, leafType, errType, ops)
    if feat == None:
        return val
    retTree = {}
    # 待切分的特征
    retTree['spInd'] = feat
    # 待切分的特征值
    retTree['spVal'] = val
    # 将数据集分成两份,继续递归
    lSet, rSet = binSplitDataSet(dataSet, feat, val)
    # 左子树
    retTree['left'] = createTree(lSet, leafType, errType, ops)
    # 右子数
    retTree['right'] = createTree(rSet, leafType, errType, ops)
    return retTree

# 回归树的切分函数,找到数据集切分的最佳位置
# (最佳切分特征,阈值)
def chooseBestSplit(dataSet, leafType=regLeaf, errType=regErr, ops=(1,4)):
    # 容许的误差下降值
    tolS = ops[0]
    # 切分的最少样本数
    tolN = ops[1]
    # 若所有值相等则退出
    if len(set(dataSet[:,-1].T.tolist()[0])) == 1:
        return None, leafType(dataSet)
    m, n = shape(dataSet)
    S = errType(dataSet)
    # 当前误差bestS的初始值设为无穷大
    bestS = inf
    bestIndex = 0
    bestValue = 0
    for featIndex in range(n-1):
        # 通过集合set去重获得全部的可用特征值
        for splitVal in set(dataSet[:, featIndex].T.tolist()[0]):
            # 将数据切分得到两个子集
            mat0, mat1 = binSplitDataSet(dataSet, featIndex, splitVal)
            # 若切分后的任一个子集的样本数小于设置的最少样本数
            if (shape(mat0)[0] < tolN) or (shape(mat1)[0] < tolN):
                continue
            # 计算切分后的误差
            newS = errType(mat0) + errType(mat1)
            # 若切分后的误差小于当前误差
            if newS < bestS:
                # 则将当前特征设定为最佳切分特征
                bestIndex = featIndex
                # 则将当前阈值设定为最佳切分阈值
                bestValue = splitVal
                # 更新当前误差
                bestS = newS
    # 如果误差减小不大(小于容许的误差下降值)则退出
    if (S - bestS) < tolS:
        return None, leafType(dataSet)
    mat0, mat1 = binSplitDataSet(dataSet, bestIndex, bestValue)
    # 如果切分的数据很小(小于切分的最少样本数)则退出
    if (shape(mat0)[0] < tolN) or (shape(mat1)[0] < tolN):
        return None, leafType(dataSet)
    return bestIndex,bestValue



if __name__ == '__main__':
    myDat = loadDataSet("ex0.txt")
    myMat = mat(myDat)
    # 获得最佳切分特征和阈值
    print chooseBestSplit(myMat)
    # 获得回归树
    print createTree(myMat)

输出:

(1, 0.39435)
{'spInd': 1, 'spVal': 0.39435, 'right': {'spInd': 1, 'spVal': 0.197834, 'right': -0.023838155555555553, 'left': 1.0289583666666666}, 'left': {'spInd': 1, 'spVal': 0.582002, 'right': 1.980035071428571, 'left': {'spInd': 1, 'spVal': 0.797583, 'right': 2.9836209534883724, 'left': 3.9871631999999999}}}

文件ex00.txt中的简单数据集查看:

if __name__ == '__main__':
    myDat = loadDataSet("ex00.txt")
    myMat = mat(myDat)

    plt.figure()
    plt.scatter(myMat[:,0],myMat[:,1],s=15, c='b')
    plt.show()

可以看到数据比较明显的在两个区域聚集:

机器学习笔记--机器学习实战CART算法错误

将文件ex00.txt中的数据的y轴放大100倍,得到的数据放入ex2.txt中:
机器学习笔记--机器学习实战CART算法错误
此时数据仍保持原来的聚集状态.
但再次构建决策树进行分类时:

if __name__ == '__main__':
    myDat2 = loadDataSet("ex2.txt")
    myMat2 = mat(myDat2)
    print createTree(myMat2, ops=(10000,4))

输出:

{'spInd': 0, 'spVal': 0.499171, 'right': {'spInd': 0, 'spVal': 0.457563, 'right': {'spInd': 0, 'spVal': 0.126833, 'right': {'spInd': 0, 'spVal': 0.084661, 'right': {'spInd': 0, 'spVal': 0.044737, 'right': 4.0916259999999998, 'left': -2.5443927142857148}, 'left': 6.5098432857142843}, 'left': {'spInd': 0, 'spVal': 0.373501, 'right': {'spInd': 0, 'spVal': 0.335182, 'right': {'spInd': 0, 'spVal': 0.324274, 'right': {'spInd': 0, 'spVal': 0.297107, 'right': {'spInd': 0, 'spVal': 0.166765, 'right': {'spInd': 0, 'spVal': 0.156067, 'right': -6.2479000000000013, 'left': -12.107972500000001}, 'left': {'spInd': 0, 'spVal': 0.202161, 'right': 3.4496025000000001, 'left': {'spInd': 0, 'spVal': 0.217214, 'right': -11.822278500000001, 'left': {'spInd': 0, 'spVal': 0.228473, 'right': 6.770429, 'left': {'spInd': 0, 'spVal': 0.25807, 'right': -13.070501, 'left': 0.40377471428571476}}}}}, 'left': -19.994155200000002}, 'left': 15.059290750000001}, 'left': {'spInd': 0, 'spVal': 0.350725, 'right': -22.693879600000002, 'left': -15.085111749999999}}, 'left': {'spInd': 0, 'spVal': 0.437652, 'right': {'spInd': 0, 'spVal': 0.412516, 'right': {'spInd': 0, 'spVal': 0.385021, 'right': 3.6584772500000016, 'left': -0.89235549999999952}, 'left': 14.38417875}, 'left': -12.558604833333334}}}, 'left': {'spInd': 0, 'spVal': 0.467383, 'right': 3.4331330000000007, 'left': 12.50675925}}, 'left': {'spInd': 0, 'spVal': 0.729397, 'right': {'spInd': 0, 'spVal': 0.640515, 'right': {'spInd': 0, 'spVal': 0.613004, 'right': {'spInd': 0, 'spVal': 0.582311, 'right': {'spInd': 0, 'spVal': 0.553797, 'right': {'spInd': 0, 'spVal': 0.51915, 'right': 101.73699325000001, 'left': {'spInd': 0, 'spVal': 0.543843, 'right': 110.979946, 'left': 109.38961049999999}}, 'left': 97.200180249999988}, 'left': 123.2101316}, 'left': 93.673449714285724}, 'left': {'spInd': 0, 'spVal': 0.666452, 'right': 114.15162428571431, 'left': {'spInd': 0, 'spVal': 0.706961, 'right': {'spInd': 0, 'spVal': 0.698472, 'right': 108.92921799999999, 'left': 104.82495374999999}, 'left': 114.554706}}}, 'left': {'spInd': 0, 'spVal': 0.952833, 'right': {'spInd': 0, 'spVal': 0.759504, 'right': 78.085643250000004, 'left': {'spInd': 0, 'spVal': 0.790312, 'right': 102.35780185714285, 'left': {'spInd': 0, 'spVal': 0.833026, 'right': {'spInd': 0, 'spVal': 0.811602, 'right': 88.784498800000009, 'left': 81.110151999999999}, 'left': {'spInd': 0, 'spVal': 0.944221, 'right': {'spInd': 0, 'spVal': 0.85497, 'right': 95.275843166666661, 'left': {'spInd': 0, 'spVal': 0.910975, 'right': {'spInd': 0, 'spVal': 0.892999, 'right': {'spInd': 0, 'spVal': 0.872883, 'right': 102.25234449999999, 'left': 95.181792999999999}, 'left': 104.82540899999999}, 'left': 96.452866999999998}}, 'left': 87.310387500000004}}}}, 'left': {'spInd': 0, 'spVal': 0.958512, 'right': 112.42895575000001, 'left': 105.24862350000001}}}}

可以看到此时构造的树会变得非常庞大,拥有很多叶节点.原因是停止条件tolS对误差数量级过于敏感.
此时,可以通过不断调整停止条件得到仅有两个节点的树.

if __name__ == '__main__':
    myDat2 = loadDataSet("ex2.txt")
    myMat2 = mat(myDat2)
    print createTree(myMat2, ops=(10000,4))

输出:

{'spInd': 0, 'spVal': 0.499171, 'right': -2.6377193297872341, 'left': 101.35815937735848}

文件ex0.txt中的测试数据集:

if __name__ == '__main__':
    # 读取测试数据
    myDat = loadDataSet("ex0.txt")
    myMat = mat(myDat)
    print createTree(myMat)
    plt.figure()
    plt.scatter(myMat[:,1],myMat[:,2],s=15, c='b')
    plt.show()

机器学习笔记--机器学习实战CART算法错误

利用GUI对回归树调优

# !/usr/bin/env python
# coding:utf-8

from numpy import *

# 加载数据
def loadDataSet(fileName):
    dataMat = []
    fr = open(fileName)
    for line in fr.readlines():
        curLine = line.strip().split('\t')
        fltLine = map(float,curLine)
        dataMat.append(fltLine)
    return dataMat

# 生成叶节点
def regLeaf(dataSet):
    return mean(dataSet[:,-1])

# 误差估计函数(平方误差)
def regErr(dataSet):
    return var(dataSet[:,-1]) * shape(dataSet)[0]

# (数据集合,待切分的特征,该特征的某个值)
# 将数据切分得到两个子集并返回
def binSplitDataSet(dataSet, feature, value):
    # nonzero返回非零元素的索引
    mat0 = dataSet[nonzero(dataSet[:,feature]>value)[0], :]
    mat1 = dataSet[nonzero(dataSet[:,feature]<=value)[0], :]
    return mat0, mat1

# 树构建函数
# (数据集,建立叶节点函数,误差计算函数)
def createTree(dataSet, leafType=regLeaf, errType=regErr, ops=(1,4)):
    # # 获得当前的(最佳切分特征,阈值)
    feat, val = chooseBestSplit(dataSet, leafType, errType, ops)
    if feat == None:
        return val
    retTree = {}
    # 待切分的特征
    retTree['spInd'] = feat
    # 待切分的特征值
    retTree['spVal'] = val
    # 将数据集分成两份,继续递归
    lSet, rSet = binSplitDataSet(dataSet, feat, val)
    # 左子树
    retTree['left'] = createTree(lSet, leafType, errType, ops)
    # 右子数
    retTree['right'] = createTree(rSet, leafType, errType, ops)
    return retTree

# 回归树的切分函数,找到数据集切分的最佳位置
# (最佳切分特征,阈值)
def chooseBestSplit(dataSet, leafType=regLeaf, errType=regErr, ops=(1,4)):
    # 容许的误差下降值
    tolS = ops[0]
    # 切分的最少样本数
    tolN = ops[1]
    # 若所有值相等则退出
    if len(set(dataSet[:,-1].T.tolist()[0])) == 1:
        return None, leafType(dataSet)
    m, n = shape(dataSet)
    S = errType(dataSet)
    # 当前误差bestS的初始值设为无穷大
    bestS = inf
    bestIndex = 0
    bestValue = 0
    for featIndex in range(n-1):
        # 通过集合set去重获得全部的可用特征值
        for splitVal in set(dataSet[:, featIndex].T.tolist()[0]):
            # 将数据切分得到两个子集
            mat0, mat1 = binSplitDataSet(dataSet, featIndex, splitVal)
            # 若切分后的任一个子集的样本数小于设置的最少样本数
            if (shape(mat0)[0] < tolN) or (shape(mat1)[0] < tolN):
                continue
            # 计算切分后的误差
            newS = errType(mat0) + errType(mat1)
            # 若切分后的误差小于当前误差
            if newS < bestS:
                # 则将当前特征设定为最佳切分特征
                bestIndex = featIndex
                # 则将当前阈值设定为最佳切分阈值
                bestValue = splitVal
                # 更新当前误差
                bestS = newS
    # 如果误差减小不大(小于容许的误差下降值)则退出
    if (S - bestS) < tolS:
        return None, leafType(dataSet)
    mat0, mat1 = binSplitDataSet(dataSet, bestIndex, bestValue)
    # 如果切分的数据很小(小于切分的最少样本数)则退出
    if (shape(mat0)[0] < tolN) or (shape(mat1)[0] < tolN):
        return None, leafType(dataSet)
    return bestIndex,bestValue


def isTree(obj):
    return (type(obj).__name__ == 'dict')


def getMean(tree):
    if isTree(tree['right']): tree['right'] = getMean(tree['right'])
    if isTree(tree['left']): tree['left'] = getMean(tree['left'])
    return (tree['left'] + tree['right']) / 2.0


def prune(tree, testData):
    if shape(testData)[0] == 0: return getMean(tree)
    # 如果不是树,则修剪它们
    if (isTree(tree['right']) or isTree(tree['left'])):
        lSet, rSet = binSplitDataSet(testData, tree['spInd'], tree['spVal'])
    if isTree(tree['left']): tree['left'] = prune(tree['left'], lSet)
    if isTree(tree['right']): tree['right'] = prune(tree['right'], rSet)
    # 如果都是叶节点,判断是否可以合并
    if not isTree(tree['left']) and not isTree(tree['right']):
        lSet, rSet = binSplitDataSet(testData, tree['spInd'], tree['spVal'])
        errorNoMerge = sum(power(lSet[:, -1] - tree['left'], 2)) + \
                       sum(power(rSet[:, -1] - tree['right'], 2))
        treeMean = (tree['left'] + tree['right']) / 2.0
        errorMerge = sum(power(testData[:, -1] - treeMean, 2))
        if errorMerge < errorNoMerge:
            # print "merging"
            return treeMean
        else:
            return tree
    else:
        return tree


def linearSolve(dataSet):
    m,n = shape(dataSet)
    # 创建数据副本
    X = mat(ones((m,n)))
    Y = mat(ones((m,1)))
    X[:,1:n] = dataSet[:,0:n-1]
    # 去掉Y
    Y = dataSet[:,-1]
    xTx = X.T*X
    if linalg.det(xTx) == 0.0:
        raise NameError('This matrix is singular, cannot do inverse,\n\
        try increasing the second value of ops')
    ws = xTx.I * (X.T * Y)
    return ws,X,Y

# 创建线性模型并返回系数
def modelLeaf(dataSet):
    ws,X,Y = linearSolve(dataSet)
    return ws

def modelErr(dataSet):
    ws,X,Y = linearSolve(dataSet)
    yHat = X * ws
    return sum(power(Y - yHat,2))


def regTreeEval(model, inDat):
    return float(model)


def modelTreeEval(model, inDat):
    n = shape(inDat)[1]
    X = mat(ones((1, n + 1)))
    X[:, 1:n + 1] = inDat
    return float(X * model)


def treeForeCast(tree, inData, modelEval=regTreeEval):
    if not isTree(tree): return modelEval(tree, inData)
    if inData[tree['spInd']] > tree['spVal']:
        if isTree(tree['left']):
            return treeForeCast(tree['left'], inData, modelEval)
        else:
            return modelEval(tree['left'], inData)
    else:
        if isTree(tree['right']):
            return treeForeCast(tree['right'], inData, modelEval)
        else:
            return modelEval(tree['right'], inData)


def createForeCast(tree, testData, modelEval=regTreeEval):
    m = len(testData)
    yHat = mat(zeros((m, 1)))
    for i in range(m):
        yHat[i, 0] = treeForeCast(tree, mat(testData[i]), modelEval)
    return yHat





from Tkinter import *

import matplotlib

matplotlib.use('TkAgg')
from matplotlib.backends.backend_tkagg import FigureCanvasTkAgg
from matplotlib.figure import Figure



def reDraw(tolS,tolN):
    reDraw.f.clf()
    reDraw.a = reDraw.f.add_subplot(111)
    if chkBtnVar.get():
        if tolN < 2: tolN = 2
        myTree=createTree(reDraw.rawDat, modelLeaf,modelErr, (tolS,tolN))
        yHat = createForeCast(myTree, reDraw.testDat, modelTreeEval)
    else:
        myTree = createTree(reDraw.rawDat, ops=(tolS,tolN))
        yHat = createForeCast(myTree, reDraw.testDat)
    reDraw.a.scatter(reDraw.rawDat[:,0], reDraw.rawDat[:,1], s=5)
    reDraw.a.plot(reDraw.testDat, yHat, linewidth=2.0)
    reDraw.canvas.show()

def getInputs():
    try:
        tolN = int(tolNentry.get())
    except:
        tolN = 10
        print "enter Integer for tolN"
        tolNentry.delete(0, END)
        tolNentry.insert(0, '10')
    try:
        tolS = float(tolSentry.get())
    except:
        tolS = 1.0
        print "enter Float for tolS"
        tolSentry.delete(0, END)
        tolSentry.insert(0, '1.0')
    return tolN, tolS


def drawNewTree():
    tolN, tolS = getInputs()
    reDraw(tolS, tolN)









import matplotlib.pyplot as plt


if __name__ == '__main__':
    root = Tk()
    reDraw.f = Figure(figsize=(5, 4), dpi=100)  # create canvas
    reDraw.canvas = FigureCanvasTkAgg(reDraw.f, master=root)
    reDraw.canvas.show()
    reDraw.canvas.get_tk_widget().grid(row=0, columnspan=3)

    Label(root, text="tolN").grid(row=1, column=0)
    tolNentry = Entry(root)
    tolNentry.grid(row=1, column=1)
    tolNentry.insert(0, '10')
    Label(root, text="tolS").grid(row=2, column=0)
    tolSentry = Entry(root)
    tolSentry.grid(row=2, column=1)
    tolSentry.insert(0, '1.0')
    Button(root, text="ReDraw", command=drawNewTree).grid(row=1, column=2, rowspan=3)
    chkBtnVar = IntVar()
    chkBtn = Checkbutton(root, text="Model Tree", variable=chkBtnVar)
    chkBtn.grid(row=3, column=0, columnspan=2)

    reDraw.rawDat = mat(loadDataSet('sine.txt'))
    reDraw.testDat = arange(min(reDraw.rawDat[:, 0]), max(reDraw.rawDat[:, 0]), 0.01)
    reDraw(1.0, 10)
    root.mainloop()

可得如图效果:

机器学习笔记--机器学习实战CART算法错误

机器学习笔记--机器学习实战CART算法错误