欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页

数据挖掘十大算法(十):CART(分类回归树)

程序员文章站 2022-06-18 11:09:52
...

本文记录一下关于CART的相关知识其中包括(回归树、树的后剪枝、模型树、树回归模型的预测(树回归模型的评估))。在之前学习完ID3算法有记录一篇相关的学习笔记,所以后面学习CART算法能有一个比较和熟悉的理解。

    贪心算法的决策树,构建算法是ID3,即通过香农熵计算数据的混乱程度,然后求出信息增益,每次选择最大信息增益的划分方式,作为当前的划分方式,直到数据集完成划分,被划分过的特征在之后不会再有任何作用。所以这种划分方式被认为过于迅速,并且处理连续型数据时需要先离散化,这样可能会破坏连续型数据的内在性质。

    另一种切分方式是二元切分法即每次把数据切成两份。如果数据的某特征值等于切分所要求的值,那么这些数据就进入左子树,反之则进入右子树,这就是CART算法的思想。

    CART(分类回归树)算法,该算法既可以用来分类还可以用来回归,所以很值得学习。下面首先使用CART算法构建回归树,并介绍如何为复杂的回归树剪枝(防止过拟合问题)。然后引入一种更高级的方法——模型树。最后对回归树、模型树、线性回归做一个预测(评估)。

    模型树与回归树(在叶子节点使用各自的均值做预测)不同,该算法需要在每个叶子节点构建出一个线性模型

一个核心递归伪代码:

找到最佳的待切分特征: 
    如果该节点不能再分,将该节点存为叶节点 
    执行二元切分 
    在右子树继续调用该函数
    在左子树继续调用该函数

回归树:

说明:创建树函数creatTree()的两个参数默认值为 回归树的叶子节点创建函数、误差计算函数,所以这决定了如果使用默认值,则创建的是回归树。后面我们需要构建模型树,只需要改为传入模型树的两个函数参数即可。参数ops为预剪枝方法,该参数的设置决定了树构建的大小。

样例数据(来自第九章):

数据挖掘十大算法(十):CART(分类回归树)

from numpy import *
import matplotlib.pyplot as plt

# 读取本地文件,python3 list(map)
def loadDataSet(fileName):
    dataMat = []
    fr = open(fileName)
    for line in fr.readlines():
        curLine = line.strip().split('\t')
        fltLine = list(map(float,curLine))  # python3问题修改
        dataMat.append(fltLine)
    return dataMat

# 根据特征值划分数据集,得到两个数据集
def binSplitDataSet(dataSet, feature, value):   # nonzero返回真(True)值的下标
    mat0 = dataSet[nonzero(dataSet[:,feature] > value)[0],:] # 取该列某值大于特征值的行
    mat1 = dataSet[nonzero(dataSet[:,feature] <= value)[0],:]   # python3问题修改
    return mat0,mat1

# 返回一个值,生成叶子节点(目标变量均值)
def regLeaf(dataSet):
    return mean(dataSet[:,-1])

# 误差计算函数 返回方差总和
def regErr(dataSet):
    return var(dataSet[:,-1]) * shape(dataSet)[0]   # var方差计算函数

# 选择最佳的特征、特征值   (一旦不满足划分的条件便返回叶子节点)
def chooseBestSplit(dataSet, leafType=regLeaf, errType=regErr, ops=(1,4)):
    tolS = ops[0]; tolN = ops[1]
    if len(set(dataSet[:,-1].T.tolist()[0])) == 1:  # 特征值唯一,返回None和叶子节点
        return None, leafType(dataSet)
    m,n = shape(dataSet)
    S = errType(dataSet)    # 获得数据集的混乱程度误差,后面求混乱程度减少了多少
    bestS = inf; bestIndex = 0; bestValue = 0
    for featIndex in range(n-1):
        for splitVal in set((dataSet[:, featIndex].T.A.tolist())[0]):   # python3问题修改
            mat0, mat1 = binSplitDataSet(dataSet, featIndex, splitVal)      # 二分 划分数据集
            if (shape(mat0)[0] < tolN) or (shape(mat1)[0] < tolN): continue # 若划分效果不好(数据集太小),继续划分
            newS = errType(mat0) + errType(mat1)    # 两个数据集的混乱程度求和,与bestS相比较
            if newS < bestS:
                bestIndex = featIndex
                bestValue = splitVal
                bestS = newS
    # 如果混乱程度减少不大,则返回叶子节点
    if (S - bestS) < tolS:
        return None, leafType(dataSet)
    mat0, mat1 = binSplitDataSet(dataSet, bestIndex, bestValue) # 根据最佳的特征、特征值来二分划分数据
    if (shape(mat0)[0] < tolN) or (shape(mat1)[0] < tolN):  # 若划分效果不好(数据集太小),返回叶子节点
        return None, leafType(dataSet)
    return bestIndex,bestValue # 返回最佳特征、特征值

# 数据集、创建叶子节点、误差计算函数、(1:最小的误差下降阈值 4:切分的最少样本数要求)
def createTree(dataSet, leafType=regLeaf, errType=regErr, ops=(1,4)):
    feat, val = chooseBestSplit(dataSet, leafType, errType, ops) # 选择最佳特征、特征值
    if feat == None: return val  # 若特征为None,返回叶子节点值
    retTree = {}    # 创建字典,用于保存树节点的信息
    retTree['spInd'] = feat
    retTree['spVal'] = val
    lSet, rSet = binSplitDataSet(dataSet, feat, val)    # 根据已经划分返回的特征、特征值继续划分数据集
    retTree['left'] = createTree(lSet, leafType, errType, ops)  # 这两个函数为递归,直到叶子节点
    retTree['right'] = createTree(rSet, leafType, errType, ops)
    return retTree

myDat2 = loadDataSet('ex00.txt')
myMat2 = mat(myDat2)
result = createTree(myMat2)
print(result)

数据挖掘十大算法(十):CART(分类回归树)

本段代码在书中有几处错误,我找到了两处,另一处参考了一篇博客:

1、TypeError: unsupported operand type(s) for /: ‘map‘ and ‘int‘ 
    修改loadDataSet函数某行为fltLine = list(map(float,curLine)),因为python3中map的返回值变了,所以要加list() 
2、TypeError: unhashable type: ‘matrix’ 
    修改chooseBestSplit函数某行为:for splitVal in set((dataSet[:,featIndex].T.A.tolist())[0]): matrix类型不能被hash。 
3、TypeError: index 0 is out of bounds 
    函数修改两行binSplitDataSet 
    mat0 = dataSet[nonzero(dataSet[:, feature] > value)[0], :] 
    mat1 = dataSet[nonzero(dataSet[:, feature] <= value)[0], :]

上面的代码不是特别的复杂,核心思想就是通过二元切分法,用目前最佳的方式对数据进行切分。

看一下数据集的图型:

import matplotlib.pyplot as plt
myDat=loadDataSet('ex00.txt')
myMat=mat(myDat)
plt.plot(myMat[:,0],myMat[:,1],'ro')
plt.show()

数据挖掘十大算法(十):CART(分类回归树)

树的剪枝:

    前面我们提过ops参数的设置,可以决定我们树的构建大小,可能过拟合也可能欠拟合。该参数的设置会对我们树在构建过程中就进行剪枝操作,所以这是一种预剪枝操作。下面介绍一下另一种后剪枝操作,一般需要将两种剪枝操作同时使用,能达到更好的剪枝效果。

    后剪枝操作:后剪枝需要将数据分为训练集、测试集,首先给定参数,构建足够复杂的树,然后从上而下找到叶子节点,用测试集来判断将这些叶子节点合并能否降低测试误差,如果可以则合并。

# 后剪枝操作
# 判断该节点是否为子节点(字典 True)
def isTree(obj):
    return (type(obj).__name__ == 'dict')

# 递归 从上到下遍历直到两个叶子节点计算它们的平均值(塌陷处理)
def getMean(tree):
    if isTree(tree['right']):
        tree['right'] = getMean(tree['right'])
    if isTree(tree['left']):                        # 塌陷处理 简单描述就是从最下面的叶子节点(通过某种计算方式)开始两两合并
        tree['left'] = getMean(tree['left'])
    return (tree['left'] + tree['right']) / 2.0

# 修剪过程主函数
def prune(tree, testData):
    if shape(testData)[0] == 0: return getMean(tree)  # 如果没有测试集,塌陷处理(及getMean函数)
    if (isTree(tree['right']) or isTree(tree['left'])):  # 如果有树,则根据树的信息划分测试集
        lSet, rSet = binSplitDataSet(testData, tree['spInd'], tree['spVal'])
    if isTree(tree['left']): tree['left'] = prune(tree['left'], lSet)   # 测试集非空,有树,则继续prune递归对测试集进行切分
    if isTree(tree['right']): tree['right'] = prune(tree['right'], rSet)
    # 如果它们现在都是叶子,看看是否可以合并它们
    if not isTree(tree['left']) and not isTree(tree['right']):
        # 划分测试集,计算划分后的误差与划分前的误差,两者比较,若划分更好则合并操作,否则不合并直接返回
        lSet, rSet = binSplitDataSet(testData, tree['spInd'], tree['spVal'])
        errorNoMerge = sum(power(lSet[:, -1] - tree['left'], 2)) + sum(power(rSet[:, -1] - tree['right'], 2))   # 未合并误差
        treeMean = (tree['left'] + tree['right']) / 2.0   # 合并及将两个叶子节点的值求均值
        errorMerge = sum(power(testData[:, -1] - treeMean, 2))  # 合并后误差
        if errorMerge < errorNoMerge:
            print("merging")
            return treeMean # 合并便返回两者的均值
        else:
            return tree
    else:
        return tree

# 获得数据
myDat2 = loadDataSet('ex2.txt')
myMat2 = mat(myDat2)
# 创建尽可能大的树(0,1)
myTree = createTree(myMat2,ops=(0,1))
myDatTest = loadDataSet('ex2test.txt')
myMat2Test = mat(myDatTest)
# 剪枝过程
result = prune(myTree,myMat2Test)
print(result)

数据挖掘十大算法(十):CART(分类回归树)

这里是剪枝函数,需要调用上面的树模型构建函数。

剪枝过程的判断条件、递归有点多,需要仔细的理解,当然一步一步来都是比较容易理解的。下面介绍更高级的模型树。

模型树:

    上面提到过,该方法与回归树不同的地方是:该算法需要在每个叶子节点构建出一个线性模型,取代回归树的均值表示法。

树模型的叶子节点可以是一个常数,当然也可以是分段的线性函数,下面来看一个图就明白:

数据挖掘十大算法(十):CART(分类回归树)

模型树的可解释性优于回归树,同时具有更高的预测准确度。如图中如果我们使用分段的线性函数肯定比一组常数拟合的效果好,而分段点大概在0.3左右,等下我们构建出模型树后,便可以得到该分段点和分段函数了。

由于模型树的构建与就回归树大致相同,只是叶子节点的创建函数leafType()、误差计算函数errType()需要重新定义,以及传参改变原来的默认参数。下面为模型树的主要函数:

# 模型树构建
# 这里的模型树使用到了上面回归树的函数createTree(),该函数只需改变两个固定参数(子节点生成函数 误差计算函数)
# 便可以在回归树与模型树之间切换

# 获得线性回归系数  与线性回归那里一样
def linearSolve(dataSet):
    m,n = shape(dataSet)
    X = mat(ones((m,n)))
    Y = mat(ones((m,1)))
    X[:,1:n] = dataSet[:,0:n-1]; Y = dataSet[:,-1]
    xTx = X.T*X        # 线性回归公式代入
    if linalg.det(xTx) == 0.0:
        raise NameError('This matrix is singular, cannot do inverse,\n try increasing the second value of ops')
    ws = xTx.I * (X.T * Y)
    return ws,X,Y # 返回回归系数 数据 目标值

# 创建叶子节点(即回归系数)
def modelLeaf(dataSet):#create linear model and return coeficients
    ws,X,Y = linearSolve(dataSet)
    return ws

# 预测目标值 用于与Y求平方误差和
def modelErr(dataSet):
    ws,X,Y = linearSolve(dataSet)
    yHat = X * ws
    return sum(power(Y - yHat,2))

myMat = mat(loadDataSet('exp2.txt'))
myTree = createTree(myMat,modelLeaf,modelErr,(1,10)) # 传入中间两个用于构建模型树的函数参数
print(myTree)

数据挖掘十大算法(十):CART(分类回归树)

同过模型树返回划分的信息我们来看看它线性回归的拟合线如何:

    y = kx + b           左值为b,右值为k,带入横坐标x得到y   y = 3.46+1.185x  y = 12x

数据挖掘十大算法(十):CART(分类回归树)

可以看到,分段的拟合线,很不错,也更加的直观。

树回归于标准回归的评估(预测):

# (树回归模型与标准回归的预测)树回归模型与标准回归的评估

# 回归树的值计算函数
def regTreeEval(model, inDat):
    return float(model)     # 不是树,则为叶节点(值)

# 模型树的值计算函数
def modelTreeEval(model, inDat):
    n = shape(inDat)[1]     # 模型树通过线性回归系数来计算预测值
    X = mat(ones((1, n + 1)))   # 第一个值为1
    X[:, 1:n + 1] = inDat
    return float(X * model) # 测试数据向量*回归系数向量 得到预测值

# 预测 (通过树模型预测当前值)
def treeForeCast(tree, inData, modelEval=regTreeEval):
    if not isTree(tree):    # 如果不是一颗树,则为叶节点(数值)
        return modelEval(tree, inData)
    # 根据树来查询当前测试数据位置
    if inData[tree['spInd']] > tree['spVal']:   #tree['spInd'] 本次树的划分特征点   inData[tree['spInd']] 该特征值
        if isTree(tree['left']):
            return treeForeCast(tree['left'], inData, modelEval)
        else:
            return modelEval(tree['left'], inData)
    else:
        if isTree(tree['right']):
            return treeForeCast(tree['right'], inData, modelEval)
        else:
            return modelEval(tree['right'], inData)

# 预测 (循环所有测试集)
def createForeCast(tree, testData, modelEval=regTreeEval):
    m = len(testData)
    yHat = mat(zeros((m, 1)))
    for i in range(m):
        yHat[i, 0] = treeForeCast(tree, mat(testData[i]), modelEval)
    return yHat

trainMat = mat(loadDataSet('bikeSpeedVsIq_train.txt'))
testMat = mat(loadDataSet('bikeSpeedVsIq_test.txt'))

myTree1 = createTree(trainMat,ops=(1,20))       # 创建回归树
yHat = createForeCast(myTree1,testMat[:,0])     # 预测
result1 = corrcoef(yHat,testMat[:,1],rowvar=0)[0,1] # 该函数计算预测值与真实值的相关系数
print(result1)

myTree2 = createTree(trainMat,modelLeaf,modelErr,ops=(1,20))    # 创建模型树
yHat = createForeCast(myTree2,testMat[:,0],modelTreeEval)
result2 = corrcoef(yHat,testMat[:,1],rowvar=0)[0,1]
# print(myTree2)
print(result2)

ws,X,y = linearSolve(trainMat)    # 标准回归
for i in range(shape(testMat)[0]):
    yHat[i] = testMat[i,0]*ws[1,0]+ws[0,0]
result3 = corrcoef(yHat,testMat[:,1],rowvar=0)[0,1]
print(result3)

数据挖掘十大算法(十):CART(分类回归树)

从结果可以看到模型树的效果最好,回归树其次,标准回归的效果最差。

以上是所有内容,通过实践可以看到CART算法,相对于ID3算法确实有很大的优势。尤其是对分类和回归通吃更是让人欲罢不能,CART可以用于构建二元树并处理离散型或连续型数据的切分,使用不同的误差准则、叶节点创建,我们可以构建回归树和模型树。

 

参考书籍:《机器学习实战》

参考博客:https://blog.csdn.net/sinat_17196995/article/details/69621687    某条代码错误参考