欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页

SparkMLlib Java 决策树分类算法(DecisionTree)

程序员文章站 2024-02-09 13:24:22
...

决策树基本理解:

决策树利用树形结构,根据特征一层一层做出判断,会在某一层得到结果。我在其他博客中看到了一副非常好的诠释图:
SparkMLlib Java 决策树分类算法(DecisionTree)

SparkMLlib Java程序所用数据:

   训练数据:C:\hello\trainData.txt

SparkMLlib Java 决策树分类算法(DecisionTree)
该数据,逗号前为目标向量,逗号后为特征向量(空格隔开)。
   测试数据:C:\hello\testData.txt
SparkMLlib Java 决策树分类算法(DecisionTree)
该数据为特征向量,空格隔开。

SparkMLlib DecisionTreeJava程序:

package MLlibTest;

import java.util.HashMap;
import java.util.Map;

import org.apache.spark.SparkConf;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.api.java.function.PairFunction;
import org.apache.spark.api.java.function.VoidFunction;
import org.apache.spark.mllib.linalg.Vector;
import org.apache.spark.mllib.linalg.Vectors;
import org.apache.spark.mllib.regression.LabeledPoint;
import org.apache.spark.mllib.tree.DecisionTree;
import org.apache.spark.mllib.tree.model.DecisionTreeModel;

import scala.Tuple2;

public class DecisionTreeTest{

	public static void main(String[] args) {
		 SparkConf conf = new SparkConf().setAppName("DecisionTreeTest").setMaster("local[*]");
 	     JavaSparkContext jsc = new JavaSparkContext(conf);
 	     JavaRDD<String> lines = jsc.textFile("C://hello//trainData.txt");
 	     JavaRDD<LabeledPoint> transdata = lines.map(new Function<String,LabeledPoint>(){
 	    	 private static final long serialVersionUID = 1L;
			 @Override	 
 	    	public LabeledPoint call(String str) throws Exception{
				 String[] t1 = str.split(",");
					String[] t2 = t1[1].split(" ");
					LabeledPoint lab = new LabeledPoint(Double.parseDouble(t1[0]),
							Vectors.dense(Double.parseDouble(t2[0]), Double.parseDouble(t2[1]), Double.parseDouble(t2[2])));
				return lab;
			}
 	     });
 	     //设置决策树参数,训练模型
 	    Integer numClasses = 3;
        Map<Integer, Integer> categoricalFeaturesInfo = new HashMap<Integer, Integer>();
        String impurity = "gini";
        Integer maxDepth = 5;
        Integer maxBins = 32;
        final DecisionTreeModel tree_model = DecisionTree.trainClassifier(transdata, numClasses,categoricalFeaturesInfo, impurity, maxDepth, maxBins);
        System.out.println("决策树模型:");  
        System.out.println(tree_model.toDebugString());
        //保存模型
        tree_model.save(jsc.sc(), "C://hello//DecisionTreeModel");
        
        
        //未处理数据,带入模型处理
        JavaRDD<String> testLines = jsc.textFile("C://hello//testData.txt");
        JavaPairRDD<String,String> res = testLines.mapToPair(new PairFunction<String, String, String>() {
        	private static final long serialVersionUID = 1L;
        	@Override
        	public Tuple2<String,String> call(String line) throws Exception{
        		String[] t2 = line.split(" ");
				Vector v = Vectors.dense(Double.parseDouble(t2[0]), Double.parseDouble(t2[1]),
						Double.parseDouble(t2[2]));
				double res = tree_model.predict(v);
        		return new Tuple2<String,String>(line,Double.toString(res));
        	}
		}).cache();
        //打印结果
 	    res.foreach(new VoidFunction<Tuple2<String,String>>() {
 	    	private static final long serialVersionUID = 1L;
			 @Override	 
	    	public void call(Tuple2<String,String> a) throws Exception{
				 System.out.println(a._1+" : "+a._2);
			}
		});
 	    //将结果保存在本地
 	    res.saveAsTextFile("C://hello/res");
	}

}

结语:

     做的时间匆忙,错误之处,请大家指出批评,相互学习。