SparkMLlib Java 决策树分类算法(DecisionTree)
程序员文章站
2024-02-09 13:24:22
...
决策树基本理解:
决策树利用树形结构,根据特征一层一层做出判断,会在某一层得到结果。我在其他博客中看到了一副非常好的诠释图:
SparkMLlib Java程序所用数据:
训练数据:C:\hello\trainData.txt
该数据,逗号前为目标向量,逗号后为特征向量(空格隔开)。
测试数据:C:\hello\testData.txt
该数据为特征向量,空格隔开。
SparkMLlib DecisionTreeJava程序:
package MLlibTest;
import java.util.HashMap;
import java.util.Map;
import org.apache.spark.SparkConf;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.api.java.function.PairFunction;
import org.apache.spark.api.java.function.VoidFunction;
import org.apache.spark.mllib.linalg.Vector;
import org.apache.spark.mllib.linalg.Vectors;
import org.apache.spark.mllib.regression.LabeledPoint;
import org.apache.spark.mllib.tree.DecisionTree;
import org.apache.spark.mllib.tree.model.DecisionTreeModel;
import scala.Tuple2;
public class DecisionTreeTest{
public static void main(String[] args) {
SparkConf conf = new SparkConf().setAppName("DecisionTreeTest").setMaster("local[*]");
JavaSparkContext jsc = new JavaSparkContext(conf);
JavaRDD<String> lines = jsc.textFile("C://hello//trainData.txt");
JavaRDD<LabeledPoint> transdata = lines.map(new Function<String,LabeledPoint>(){
private static final long serialVersionUID = 1L;
@Override
public LabeledPoint call(String str) throws Exception{
String[] t1 = str.split(",");
String[] t2 = t1[1].split(" ");
LabeledPoint lab = new LabeledPoint(Double.parseDouble(t1[0]),
Vectors.dense(Double.parseDouble(t2[0]), Double.parseDouble(t2[1]), Double.parseDouble(t2[2])));
return lab;
}
});
//设置决策树参数,训练模型
Integer numClasses = 3;
Map<Integer, Integer> categoricalFeaturesInfo = new HashMap<Integer, Integer>();
String impurity = "gini";
Integer maxDepth = 5;
Integer maxBins = 32;
final DecisionTreeModel tree_model = DecisionTree.trainClassifier(transdata, numClasses,categoricalFeaturesInfo, impurity, maxDepth, maxBins);
System.out.println("决策树模型:");
System.out.println(tree_model.toDebugString());
//保存模型
tree_model.save(jsc.sc(), "C://hello//DecisionTreeModel");
//未处理数据,带入模型处理
JavaRDD<String> testLines = jsc.textFile("C://hello//testData.txt");
JavaPairRDD<String,String> res = testLines.mapToPair(new PairFunction<String, String, String>() {
private static final long serialVersionUID = 1L;
@Override
public Tuple2<String,String> call(String line) throws Exception{
String[] t2 = line.split(" ");
Vector v = Vectors.dense(Double.parseDouble(t2[0]), Double.parseDouble(t2[1]),
Double.parseDouble(t2[2]));
double res = tree_model.predict(v);
return new Tuple2<String,String>(line,Double.toString(res));
}
}).cache();
//打印结果
res.foreach(new VoidFunction<Tuple2<String,String>>() {
private static final long serialVersionUID = 1L;
@Override
public void call(Tuple2<String,String> a) throws Exception{
System.out.println(a._1+" : "+a._2);
}
});
//将结果保存在本地
res.saveAsTextFile("C://hello/res");
}
}
结语:
做的时间匆忙,错误之处,请大家指出批评,相互学习。
推荐阅读
-
SparkMLlib Java 决策树分类算法(DecisionTree)
-
分类算法初探—决策树
-
关于有效的性能调优的一些建议 博客分类: java 算法性能
-
使用Java实现CA(不考虑证书链) 博客分类: Java JavaIEWindows算法EXT
-
使用Java实现CA(不考虑证书链) 博客分类: Java JavaIEWindows算法EXT
-
Java注释规范 博客分类: java Java敏捷开发算法软件测试设计模式
-
Java注释规范 博客分类: java Java敏捷开发算法软件测试设计模式
-
Java实现的决策树算法完整实例
-
python实现决策树分类算法
-
Python决策树分类算法学习