用java实现分类回归树(CART)
程序员文章站
2022-06-18 11:19:55
...
首先我们来了解一下分类回归树(CART),如下:
算法描述:其中T代表当前样本集,当前候选属性集用T_attributelist表示。
(1)创建根节点N
(2)为N分配类别
(3)if T都属于同一类别or T中只剩下 一个样本则返回N为叶节点,否则为其分配属性
(4)for each T_attributelist中属性执行该属性上的一个划分,计算此划分的GINI系数
(5)N的测试属性test_attribute=T_attributelist中最小GINI系数的属性
(6)划分T得到T1 T2子集
(7)对于T1重复(1)-(6)
(8)对于T2重复(1)-(6)
CART算法考虑到每个节点都有成为叶子节点的可能,对每个节点都分配类别。分配类别的方法可以用当前节点中出现最多的类别,也可以参考当前节点的分类错误或者其他更复杂的方法。
CART算法仍然使用后剪枝。在树的生成过程中,多展开一层就会有多一些的信息被发现,CART算法运行到不能再长出分支为止,从而得到一棵最大的决策树。然后对这棵大树进行剪枝。
上面描述比较详细我就不做赘述,我们看看分类回归树(CART)的python实现代码,如下:
#数据集划分
def binSplitDataSet(dataSet, feature, value):
mat0 = dataSet[nonzero(dataSet[:,feature] > value)[0],:][0]
mat1 = dataSet[nonzero(dataSet[:,feature] <= value)[0],:][0]
return mat0,mat1
#求平均数
def regLeaf(dataSet):#returns the value used for each leaf
return mean(dataSet[:,-1])
#求总方差
def regErr(dataSet):
return var(dataSet[:,-1]) * shape(dataSet)[0]
#推算最优切割
def chooseBestSplit(dataSet, leafType=regLeaf, errType=regErr, ops=(1,4)):
tolS = ops[0]; tolN = ops[1]
#if all the target variables are the same value: quit and return value
print(dataSet)
if len(set(dataSet[:,-1].T.tolist()[0])) == 1: #exit cond 1
return None, leafType(dataSet)
m,n = shape(dataSet)
#the choice of the best feature is driven by Reduction in RSS error from mean
S = errType(dataSet)
bestS = inf; bestIndex = 0; bestValue = 0
for featIndex in range(n-1):
for splitVal in set(dataSet[:,featIndex]):
mat0, mat1 = binSplitDataSet(dataSet, featIndex, splitVal)
if (shape(mat0)[0] < tolN) or (shape(mat1)[0] < tolN): continue
newS = errType(mat0) + errType(mat1)
if newS < bestS:
bestIndex = featIndex
bestValue = splitVal
bestS = newS
#if the decrease (S-bestS) is less than a threshold don't do the split
if (S - bestS) < tolS:
return None, leafType(dataSet) #exit cond 2
mat0, mat1 = binSplitDataSet(dataSet, bestIndex, bestValue)
if (shape(mat0)[0] < tolN) or (shape(mat1)[0] < tolN): #exit cond 3
return None, leafType(dataSet)
return bestIndex,bestValue#returns the best feature to split on
#and the value used for that split
#构建回归树
def createTree(dataSet, leafType=regLeaf, errType=regErr, ops=(1,4)):#assume dataSet is NumPy Mat so we can array filtering
feat, val = chooseBestSplit(dataSet, leafType, errType, ops)#choose the best split
if feat == None: return val #if the splitting hit a stop condition return val
retTree = {}
retTree['spInd'] = feat
retTree['spVal'] = val
lSet, rSet = binSplitDataSet(dataSet, feat, val)
retTree['left'] = createTree(lSet, leafType, errType, ops)
retTree['right'] = createTree(rSet, leafType, errType, ops)
return retTree
我们开始用java实现,首先是java实现数据集切割函数
public static DenseMatrix64F[] binSplitDataSet(DenseMatrix64F dataSet,int feature,double value) {
DenseMatrix64F mat0 = new DenseMatrix64F(0,dataSet.numCols);
DenseMatrix64F mat1 = new DenseMatrix64F(0,dataSet.numCols);
for(int i=0;i<dataSet.numRows;i++) {
if(dataSet.get(i,feature) > value) {
mat0.reshape(mat0.numRows+1,dataSet.numCols,true);
for(int j=0;j<dataSet.numCols;j++) {
mat0.set(mat0.numRows-1,j,dataSet.get(i,j));
}
}else{
mat1.reshape(mat1.numRows+1,dataSet.numCols,true);
for(int j=0;j<dataSet.numCols;j++) {
mat1.set(mat1.numRows-1,j,dataSet.get(i,j));
}
}
}
return new DenseMatrix64F[] {mat0,mat1};
}
然后是求矩阵平均数
public static double regLeaf(DenseMatrix64F dataSet) {
double sum=0;
for(int i=0;i<dataSet.numRows;i++) {
sum+=dataSet.get(i, dataSet.numCols-1);
}
return sum/dataSet.numRows;
}
然后是求矩阵总方差
public static double regErr(DenseMatrix64F dataSet) {
double avg=regLeaf(dataSet);
double sum=0;
for(int i=0;i<dataSet.numRows;i++) {
sum+=(dataSet.get(i, dataSet.numCols-1)-avg)*(dataSet.get(i, dataSet.numCols-1)-avg);
}
return sum;
}
定义最优切割参数类
package com.algorithm;
public class BestInfo {
private int bestIndex;
private double bestValue;
public int getBestIndex() {
return bestIndex;
}
public void setBestIndex(int bestIndex) {
this.bestIndex = bestIndex;
}
public double getBestValue() {
return bestValue;
}
public void setBestValue(double bestValue) {
this.bestValue = bestValue;
}
}
然后是最优切割函数
public static BestInfo chooseBestSplit(DenseMatrix64F dataSet,double[] ops) {
List<Double> tmp = new ArrayList<Double>();
for(int i=0;i<dataSet.numRows;i++) {
if(!tmp.contains(dataSet.get(i, dataSet.numCols-1))) {
tmp.add(dataSet.get(i, dataSet.numCols-1));
}
}
if(tmp.size() == 1) {
BestInfo bi = new BestInfo();
bi.setBestIndex(-1);
bi.setBestValue(regLeaf(dataSet));
return bi;
}
double S = regErr(dataSet);
double bestS = Double.MAX_VALUE;
int bestIndex = 0;
double bestValue = 0;
for(int featIndex=0;featIndex<dataSet.numCols-1;featIndex++) {
List<Double> splitVal = new ArrayList<Double>();
for(int i=0;i<dataSet.numRows;i++) {
if(!splitVal.contains(dataSet.get(i,featIndex))) {
splitVal.add(dataSet.get(i,featIndex));
}
}
for(int i=0;i<splitVal.size();i++) {
DenseMatrix64F[] mat = binSplitDataSet(dataSet, featIndex, splitVal.get(i));
if ((mat[0].numRows < ops[1]) || (mat[1].numRows < ops[1])){
continue;
}
double newS = regErr(mat[0]) + regErr(mat[1]);
if(newS < bestS) {
bestIndex = featIndex;
bestValue = splitVal.get(i);
bestS = newS;
}
}
}
if((S - bestS) < ops[0]) {
BestInfo bi = new BestInfo();
bi.setBestIndex(-1);
bi.setBestValue(regLeaf(dataSet));
return bi;
}
DenseMatrix64F[] mat = binSplitDataSet(dataSet, bestIndex, bestValue);
if ((mat[0].numRows < ops[1]) || (mat[1].numRows < ops[1])){
BestInfo bi = new BestInfo();
bi.setBestIndex(-1);
bi.setBestValue(regLeaf(dataSet));
return bi;
}
BestInfo bi = new BestInfo();
bi.setBestIndex(bestIndex);
bi.setBestValue(bestValue);
return bi;
}
然后是回归树的节点类
package com.algorithm;
public class cnode {
private int spInd;
private double spVal;
private cnode left;
private cnode right;
cnode(){
}
cnode(int spInd,double spVal){
this.spInd = spInd;
this.spVal = spVal;
}
public int getSpInd() {
return spInd;
}
public void setSpInd(int spInd) {
this.spInd = spInd;
}
public double getSpVal() {
return spVal;
}
public void setSpVal(double spVal) {
this.spVal = spVal;
}
public cnode getLeft() {
return left;
}
public void setLeft(cnode left) {
this.left = left;
}
public cnode getRight() {
return right;
}
public void setRight(cnode right) {
this.right = right;
}
}
然后是构建回归树
public static cnode createTree(DenseMatrix64F dataSet,double[] ops) {
BestInfo bi = chooseBestSplit(dataSet,ops);
if(bi.getBestIndex() == -1) {
return new cnode(bi.getBestIndex(),bi.getBestValue());
}
cnode cn = new cnode(bi.getBestIndex(),bi.getBestValue());
DenseMatrix64F[] mat = binSplitDataSet(dataSet,bi.getBestIndex(),bi.getBestValue());
cn.setLeft(createTree(mat[0],ops));
cn.setRight(createTree(mat[1],ops));
return cn;
}
ok到这里可以开始测试
List<String> list = new ArrayList<String>();
try{
BufferedReader br = new BufferedReader(new FileReader("D:\\machinelearninginaction-master\\Ch09\\ex0.txt"));
String s = null;
while((s = br.readLine())!=null){
list.add(s);
}
br.close();
}catch(Exception e){
e.printStackTrace();
}
DenseMatrix64F dataMatIn = new DenseMatrix64F(list.size(),3);
for(int i=0;i<list.size();i++) {
String[] items = list.get(i).split(" ");
dataMatIn.set(i, 0, Double.parseDouble(items[0]));
dataMatIn.set(i,1, Double.parseDouble(items[1]));
dataMatIn.set(i,2, Double.parseDouble(items[2]));
}
cnode tree = createTree(dataMatIn,new double[] {1,4});
System.out.println(tree.getSpInd());
构建回归树成功