欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页

用java实现分类回归树(CART)

程序员文章站 2022-06-18 11:19:55
...

首先我们来了解一下分类回归树(CART),如下:

算法描述:其中T代表当前样本集,当前候选属性集用T_attributelist表示。

(1)创建根节点N

(2)为N分配类别

(3)if T都属于同一类别or T中只剩下 一个样本则返回N为叶节点,否则为其分配属性

(4)for each T_attributelist中属性执行该属性上的一个划分,计算此划分的GINI系数

(5)N的测试属性test_attribute=T_attributelist中最小GINI系数的属性

(6)划分T得到T1 T2子集

(7)对于T1重复(1)-(6)

(8)对于T2重复(1)-(6)

CART算法考虑到每个节点都有成为叶子节点的可能,对每个节点都分配类别。分配类别的方法可以用当前节点中出现最多的类别,也可以参考当前节点的分类错误或者其他更复杂的方法。

CART算法仍然使用后剪枝。在树的生成过程中,多展开一层就会有多一些的信息被发现,CART算法运行到不能再长出分支为止,从而得到一棵最大的决策树。然后对这棵大树进行剪枝。

上面描述比较详细我就不做赘述,我们看看分类回归树(CART)的python实现代码,如下:

#数据集划分
def binSplitDataSet(dataSet, feature, value):
    mat0 = dataSet[nonzero(dataSet[:,feature] > value)[0],:][0]
    mat1 = dataSet[nonzero(dataSet[:,feature] <= value)[0],:][0]
    return mat0,mat1


#求平均数
def regLeaf(dataSet):#returns the value used for each leaf
    return mean(dataSet[:,-1])

#求总方差
def regErr(dataSet):
    return var(dataSet[:,-1]) * shape(dataSet)[0]

#推算最优切割
def chooseBestSplit(dataSet, leafType=regLeaf, errType=regErr, ops=(1,4)):
    tolS = ops[0]; tolN = ops[1]
    #if all the target variables are the same value: quit and return value
    print(dataSet)
    if len(set(dataSet[:,-1].T.tolist()[0])) == 1: #exit cond 1
        return None, leafType(dataSet)
    m,n = shape(dataSet)
    #the choice of the best feature is driven by Reduction in RSS error from mean
    S = errType(dataSet)
    bestS = inf; bestIndex = 0; bestValue = 0
    for featIndex in range(n-1):
        for splitVal in set(dataSet[:,featIndex]):
            mat0, mat1 = binSplitDataSet(dataSet, featIndex, splitVal)
            if (shape(mat0)[0] < tolN) or (shape(mat1)[0] < tolN): continue
            newS = errType(mat0) + errType(mat1)
            if newS < bestS:
                bestIndex = featIndex
                bestValue = splitVal
                bestS = newS
    #if the decrease (S-bestS) is less than a threshold don't do the split
    if (S - bestS) < tolS:
        return None, leafType(dataSet) #exit cond 2
    mat0, mat1 = binSplitDataSet(dataSet, bestIndex, bestValue)
    if (shape(mat0)[0] < tolN) or (shape(mat1)[0] < tolN):  #exit cond 3
        return None, leafType(dataSet)
    return bestIndex,bestValue#returns the best feature to split on
                              #and the value used for that split


#构建回归树
def createTree(dataSet, leafType=regLeaf, errType=regErr, ops=(1,4)):#assume dataSet is NumPy Mat so we can array filtering
    feat, val = chooseBestSplit(dataSet, leafType, errType, ops)#choose the best split
    if feat == None: return val #if the splitting hit a stop condition return val
    retTree = {}
    retTree['spInd'] = feat
    retTree['spVal'] = val
    lSet, rSet = binSplitDataSet(dataSet, feat, val)
    retTree['left'] = createTree(lSet, leafType, errType, ops)
    retTree['right'] = createTree(rSet, leafType, errType, ops)
    return retTree

我们开始用java实现,首先是java实现数据集切割函数

	public static DenseMatrix64F[] binSplitDataSet(DenseMatrix64F dataSet,int feature,double value) {
		
		DenseMatrix64F mat0 = new DenseMatrix64F(0,dataSet.numCols);
		DenseMatrix64F mat1 = new DenseMatrix64F(0,dataSet.numCols);
		
		for(int i=0;i<dataSet.numRows;i++) {
			if(dataSet.get(i,feature) > value) {
				mat0.reshape(mat0.numRows+1,dataSet.numCols,true);
				for(int j=0;j<dataSet.numCols;j++) {
					mat0.set(mat0.numRows-1,j,dataSet.get(i,j));
				}	
			}else{
				mat1.reshape(mat1.numRows+1,dataSet.numCols,true);
				for(int j=0;j<dataSet.numCols;j++) {
					mat1.set(mat1.numRows-1,j,dataSet.get(i,j));
				}	
			}
		}
		
		return new DenseMatrix64F[] {mat0,mat1};
	}

然后是求矩阵平均数

public static double regLeaf(DenseMatrix64F dataSet) {
		double sum=0;
		
		for(int i=0;i<dataSet.numRows;i++) {
			sum+=dataSet.get(i, dataSet.numCols-1);
		}
		
		return sum/dataSet.numRows;
	}

然后是求矩阵总方差

public static double regErr(DenseMatrix64F dataSet) {
		
		double avg=regLeaf(dataSet);
		
		double sum=0;
		
		for(int i=0;i<dataSet.numRows;i++) {
			sum+=(dataSet.get(i, dataSet.numCols-1)-avg)*(dataSet.get(i, dataSet.numCols-1)-avg);
		}
		
		return sum;
	}

定义最优切割参数类

package com.algorithm;

public class BestInfo {
	
	private int bestIndex;
	
	private double bestValue;

	public int getBestIndex() {
		return bestIndex;
	}

	public void setBestIndex(int bestIndex) {
		this.bestIndex = bestIndex;
	}

	public double getBestValue() {
		return bestValue;
	}

	public void setBestValue(double bestValue) {
		this.bestValue = bestValue;
	}
	
	
}

然后是最优切割函数

public static  BestInfo chooseBestSplit(DenseMatrix64F dataSet,double[] ops) {

		List<Double> tmp = new ArrayList<Double>();
		for(int i=0;i<dataSet.numRows;i++) {
			if(!tmp.contains(dataSet.get(i, dataSet.numCols-1))) {
				tmp.add(dataSet.get(i, dataSet.numCols-1));
			}
		}
		
		if(tmp.size() == 1) {
			
			BestInfo bi = new BestInfo();
				bi.setBestIndex(-1);
				bi.setBestValue(regLeaf(dataSet));
				
			return bi;
		}

	    double S = regErr(dataSet);
	    
	    double bestS = Double.MAX_VALUE; 
	    int bestIndex = 0; 
	    double bestValue = 0;
	    
	    for(int featIndex=0;featIndex<dataSet.numCols-1;featIndex++) {
	    	
	    	List<Double> splitVal = new ArrayList<Double>();
			for(int i=0;i<dataSet.numRows;i++) {
				if(!splitVal.contains(dataSet.get(i,featIndex))) {
					splitVal.add(dataSet.get(i,featIndex));
				}
			}
			
			for(int i=0;i<splitVal.size();i++) {
				DenseMatrix64F[] mat = binSplitDataSet(dataSet, featIndex, splitVal.get(i));
	            if ((mat[0].numRows < ops[1]) || (mat[1].numRows < ops[1])){
	            	continue;
	            } 
	            double newS = regErr(mat[0]) + regErr(mat[1]);
	            if(newS < bestS) {
	                bestIndex = featIndex;
	                bestValue = splitVal.get(i);
	                bestS = newS;
	            }
			}
	    	
	    }

	    if((S - bestS) < ops[0]) {
	    	BestInfo bi = new BestInfo();
				bi.setBestIndex(-1);
				bi.setBestValue(regLeaf(dataSet));
				
			return bi;
	    }

	    DenseMatrix64F[] mat = binSplitDataSet(dataSet, bestIndex, bestValue);
	    if ((mat[0].numRows < ops[1]) || (mat[1].numRows < ops[1])){
	    	
	    	BestInfo bi = new BestInfo();
				bi.setBestIndex(-1);
				bi.setBestValue(regLeaf(dataSet));
				
			return bi;
	    }
	    
	    BestInfo bi = new BestInfo();
			bi.setBestIndex(bestIndex);
			bi.setBestValue(bestValue);
			
		return bi;
	
	}

然后是回归树的节点类

package com.algorithm;

public class cnode {
	
	private int spInd;
	
	private double spVal;
	
	private cnode left;
	
	private cnode right;
	
	cnode(){
		
	}
	
	
	cnode(int spInd,double spVal){
		this.spInd = spInd;
		this.spVal = spVal;
	}

	public int getSpInd() {
		return spInd;
	}

	public void setSpInd(int spInd) {
		this.spInd = spInd;
	}

	public double getSpVal() {
		return spVal;
	}

	public void setSpVal(double spVal) {
		this.spVal = spVal;
	}

	public cnode getLeft() {
		return left;
	}

	public void setLeft(cnode left) {
		this.left = left;
	}

	public cnode getRight() {
		return right;
	}

	public void setRight(cnode right) {
		this.right = right;
	} 
	
	
	

}

然后是构建回归树

public static cnode createTree(DenseMatrix64F dataSet,double[] ops) {
		BestInfo bi = chooseBestSplit(dataSet,ops);
		
		if(bi.getBestIndex() == -1) {
			return new cnode(bi.getBestIndex(),bi.getBestValue());
		}
		
		cnode cn = new cnode(bi.getBestIndex(),bi.getBestValue());

		DenseMatrix64F[] mat = binSplitDataSet(dataSet,bi.getBestIndex(),bi.getBestValue());
		cn.setLeft(createTree(mat[0],ops));
		cn.setRight(createTree(mat[1],ops));

	    return cn;  
	}

ok到这里可以开始测试

List<String> list = new ArrayList<String>();
        try{
            BufferedReader br = new BufferedReader(new FileReader("D:\\machinelearninginaction-master\\Ch09\\ex0.txt"));
            String s = null;
            while((s = br.readLine())!=null){
            	list.add(s);
            }
            br.close();    
        }catch(Exception e){
            e.printStackTrace();
        }
        
        DenseMatrix64F dataMatIn = new DenseMatrix64F(list.size(),3);

        for(int i=0;i<list.size();i++) {
        	
        	String[] items = list.get(i).split("	");
        	dataMatIn.set(i, 0, Double.parseDouble(items[0]));
        	dataMatIn.set(i,1, Double.parseDouble(items[1]));
        	dataMatIn.set(i,2, Double.parseDouble(items[2]));
        }
        
        
        cnode tree = createTree(dataMatIn,new double[] {1,4});
        
        
        System.out.println(tree.getSpInd());

构建回归树成功

用java实现分类回归树(CART)