欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页

机器学习(四):决策树

程序员文章站 2022-03-30 23:05:20
...

机器学习(四):决策树

七、代码实现(python)

以下代码来自Peter Harrington《Machine Learing in Action》
本例代码实现算法5,生成最小二乘回归树。
代码如下(保存为CART.py):

 

# -- coding: utf-8 --
from numpy import *

def loadDataSet(fileName):
    # 获取训练集
    dataMat = []
    fr = open(fileName)
    for line in fr.readlines():
        curLine = line.strip().split('\t')
        fltLine = map(float,curLine)
        dataMat.append(fltLine)
    return dataMat

def binSplitDataSet(dataSet, feature, value):
    # 该函数接收3个参数,数据集、第几个特征(切分变量)、划分条件(切分点),根据选择的特征和划分条件将数据分成两个区域
    mat0 = dataSet[nonzero(dataSet[:,feature] > value)[0],:][0]
    mat1 = dataSet[nonzero(dataSet[:,feature] <= value)[0],:][0]
    return mat0,mat1

def regLeaf(dataSet):
    # 获取数据集dataSet最后一列的平均值
    return mean(dataSet[:,-1])

def regErr(dataSet):
    # 根据式(4)计算数据集dataSet的平方误差
    # var用于计算方差
    return var(dataSet[:,-1]) * shape(dataSet)[0]

def chooseBestSplit(dataSet, leafType=regLeaf, errType=regErr, ops=(1,4)):
    # 该函数用于寻找对于数据集dataSet的最好切分变量及切分点(即使得平方误差最小),ops用于控制函数停止机制
    tolS = ops[0]                              # 容许的误差下降值
    tolN = ops[1]                              # 切分的最小样本数
    if len(set(dataSet[:,-1].T.tolist()[0])) == 1:
        return None, leafType(dataSet)         # 若所有类别值相等,退出,此时无最好切分量
    m,n = shape(dataSet)
    S = errType(dataSet)                       # 存储数据集的平方误差
    bestS = inf
    bestIndex = 0                              # 初始化切分变量
    bestValue = 0                              # 初始化切分点
    for featIndex in range(n-1):
        # 循环特征数目,featIndex此时为切分变量
        for splitVal in set(dataSet[:,featIndex]):
            # 循环数据集行数,splitVal此时为切分点
            mat0, mat1 = binSplitDataSet(dataSet, featIndex, splitVal)      # 根据循环到的切分变量与切分点将数据分成两个区域
            if (shape(mat0)[0] < tolN) or (shape(mat1)[0] < tolN): continue # 若切分后的样本点小于最小样本数,退出此次循环,继续下一个循环
            newS = errType(mat0) + errType(mat1)# 计算划分后数据集的平方误差
            if newS < bestS:                    # 若新的平方误差更小,更新各个数据
                bestIndex = featIndex
                bestValue = splitVal
                bestS = newS
    if (S - bestS) < tolS:                      # 若误差下降值小于容许的误差下降值,退出,此时无最好切分量
        return None, leafType(dataSet)
    mat0, mat1 = binSplitDataSet(dataSet, bestIndex, bestValue)              # 用新的切分变量与切分点将数据分成两个区域
    if (shape(mat0)[0] < tolN) or (shape(mat1)[0] < tolN):                   # 若切分后的样本点小于最小样本数,退出,此时无最好切分量
        return None, leafType(dataSet)
    return bestIndex,bestValue

def createTree(dataSet, leafType=regLeaf, errType=regErr, ops=(1,4)):
    # 该函数根据接收的数据集创建决策树(子树)
    feat, val = chooseBestSplit(dataSet, leafType, errType, ops)             # 寻找对于数据集dataSet的最好切分变量及切分点
    if feat == None: return val                 # 若无最好的切分点,则返回数据集均值作为叶节点
    retTree = {}
    retTree['spInd'] = feat
    retTree['spVal'] = val
    lSet, rSet = binSplitDataSet(dataSet, feat, val)                         # 用最好的切分变量与切分点将数据分成两个区域,作为左右子树
    retTree['left'] = createTree(lSet, leafType, errType, ops)
    retTree['right'] = createTree(rSet, leafType, errType, ops)
    return retTree

 

以上全部内容参考书籍如下:
李航《统计学习方法》