七、代码实现(python)
以下代码来自Peter Harrington《Machine Learing in Action》
本例代码实现算法5,生成最小二乘回归树。
代码如下(保存为CART.py):
# -- coding: utf-8 --
from numpy import *
def loadDataSet(fileName):
# 获取训练集
dataMat = []
fr = open(fileName)
for line in fr.readlines():
curLine = line.strip().split('\t')
fltLine = map(float,curLine)
dataMat.append(fltLine)
return dataMat
def binSplitDataSet(dataSet, feature, value):
# 该函数接收3个参数,数据集、第几个特征(切分变量)、划分条件(切分点),根据选择的特征和划分条件将数据分成两个区域
mat0 = dataSet[nonzero(dataSet[:,feature] > value)[0],:][0]
mat1 = dataSet[nonzero(dataSet[:,feature] <= value)[0],:][0]
return mat0,mat1
def regLeaf(dataSet):
# 获取数据集dataSet最后一列的平均值
return mean(dataSet[:,-1])
def regErr(dataSet):
# 根据式(4)计算数据集dataSet的平方误差
# var用于计算方差
return var(dataSet[:,-1]) * shape(dataSet)[0]
def chooseBestSplit(dataSet, leafType=regLeaf, errType=regErr, ops=(1,4)):
# 该函数用于寻找对于数据集dataSet的最好切分变量及切分点(即使得平方误差最小),ops用于控制函数停止机制
tolS = ops[0] # 容许的误差下降值
tolN = ops[1] # 切分的最小样本数
if len(set(dataSet[:,-1].T.tolist()[0])) == 1:
return None, leafType(dataSet) # 若所有类别值相等,退出,此时无最好切分量
m,n = shape(dataSet)
S = errType(dataSet) # 存储数据集的平方误差
bestS = inf
bestIndex = 0 # 初始化切分变量
bestValue = 0 # 初始化切分点
for featIndex in range(n-1):
# 循环特征数目,featIndex此时为切分变量
for splitVal in set(dataSet[:,featIndex]):
# 循环数据集行数,splitVal此时为切分点
mat0, mat1 = binSplitDataSet(dataSet, featIndex, splitVal) # 根据循环到的切分变量与切分点将数据分成两个区域
if (shape(mat0)[0] < tolN) or (shape(mat1)[0] < tolN): continue # 若切分后的样本点小于最小样本数,退出此次循环,继续下一个循环
newS = errType(mat0) + errType(mat1)# 计算划分后数据集的平方误差
if newS < bestS: # 若新的平方误差更小,更新各个数据
bestIndex = featIndex
bestValue = splitVal
bestS = newS
if (S - bestS) < tolS: # 若误差下降值小于容许的误差下降值,退出,此时无最好切分量
return None, leafType(dataSet)
mat0, mat1 = binSplitDataSet(dataSet, bestIndex, bestValue) # 用新的切分变量与切分点将数据分成两个区域
if (shape(mat0)[0] < tolN) or (shape(mat1)[0] < tolN): # 若切分后的样本点小于最小样本数,退出,此时无最好切分量
return None, leafType(dataSet)
return bestIndex,bestValue
def createTree(dataSet, leafType=regLeaf, errType=regErr, ops=(1,4)):
# 该函数根据接收的数据集创建决策树(子树)
feat, val = chooseBestSplit(dataSet, leafType, errType, ops) # 寻找对于数据集dataSet的最好切分变量及切分点
if feat == None: return val # 若无最好的切分点,则返回数据集均值作为叶节点
retTree = {}
retTree['spInd'] = feat
retTree['spVal'] = val
lSet, rSet = binSplitDataSet(dataSet, feat, val) # 用最好的切分变量与切分点将数据分成两个区域,作为左右子树
retTree['left'] = createTree(lSet, leafType, errType, ops)
retTree['right'] = createTree(rSet, leafType, errType, ops)
return retTree
以上全部内容参考书籍如下:
李航《统计学习方法》