spark LDA聚类算法的例子
LDA是狄利克雷分布算法,一般用于海量文档主题聚类。一般每个文档都有一个隐藏的主题,LDA算法就是找出可能性最高的几个主题。在找出主题的同时,也就找到了文档中每个单词在主题中的分布概概率。可以参考http://blog.csdn.net/qq_34531825/article/details/52608003
下面是LDA算法的例子。该例子参考了官方网站中的例子。例子中的数据如下:
0 1:1 2:2 3:6 4:0 5:2 6:3 7:1 8:1 9:0 10:0 11:3
1 1:1 2:3 3:0 4:1 5:3 6:0 7:0 8:2 9:0 10:0 11:1
2 1:1 2:4 3:1 4:0 5:0 6:4 7:9 8:0 9:1 10:2 11:0
3 1:2 2:1 3:0 4:3 5:0 6:0 7:5 8:0 9:2 10:3 11:9
4 1:3 2:1 3:1 4:9 5:3 6:0 7:2 8:0 9:0 10:1 11:3
5 1:4 2:2 3:0 4:3 5:4 6:5 7:1 8:1 9:1 10:4 11:0
6 1:2 2:1 3:0 4:3 5:0 6:0 7:5 8:0 9:2 10:2 11:9
7 1:1 2:1 3:1 4:9 5:2 6:1 7:2 8:0 9:0 10:1 11:3
8 1:4 2:4 3:0 4:3 5:4 6:2 7:1 8:3 9:0 10:0 11:0
9 1:2 2:8 3:2 4:0 5:3 6:0 7:2 8:0 9:2 10:7 11:2
10 1:1 2:1 3:1 4:9 5:0 6:2 7:2 8:0 9:0 10:3 11:3
11 1:4 2:1 3:0 4:0 5:4 6:5 7:1 8:3 9:0 10:1 11:0
代码如下:
package spark;
import org.apache.log4j.Level;
import org.apache.log4j.Logger;
import org.apache.spark.ml.clustering.LDA;
import org.apache.spark.ml.clustering.LDAModel;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SparkSession;
import scala.collection.mutable.WrappedArray;
public class JavaLDAExample {
public static void main(String[] args) {
Logger logger = Logger.getLogger(JavaLDAExample.class);
// 设置日志的等级 并关闭jetty容器的日志
Logger.getLogger("org.apache.spark").setLevel(Level.WARN);
Logger.getLogger("org.apache.eclipse.jetty.server").setLevel(Level.OFF);
SparkSession spark = SparkSession.builder().master("local[2]").appName("JavaLDAExample").getOrCreate();
//加载数据,数据是标签向量。标签可以看作是文档序号。文档格式为: 文档序号 矩阵列序号:文档中的单词
Dataset<Row> dataset = spark.read().format("libsvm")
.load("F:/spark-2.1.0-bin-hadoop2.6/data/mllib/sample_lda_libsvm_data.txt");
dataset.foreach(func -> {
System.out.println(func);
});
// 训练lda模型
LDA lda = new LDA().setK(3).setMaxIter(10);
LDAModel model = lda.fit(dataset);
// log likelihood,越大越好。
double ll = model.logLikelihood(dataset);
// Perplexity评估,越小越好
double lp = model.logPerplexity(dataset);
System.out.println("The lower bound on the log likelihood of the entire corpus: " + ll);
System.out.println("The upper bound on perplexity: " + lp);
org.apache.spark.ml.linalg.Matrix matrix = model.topicsMatrix();
System.out.println("------------------------");
System.out.println("矩阵topics列为主题,总共有" + matrix.numCols() + "主题");
System.out.println("矩阵topics行为单词,总共有" + matrix.numRows() + "单词");
System.out.println("矩阵topics表示的是每个单词在每个主题中的权重");
for (int topic = 0; topic < 3; topic++) {
System.out.print("Topic " + topic + ":");
for (int word = 0; word < model.vocabSize(); word++) {
System.out.print(" " + matrix.apply(word, topic));
}
System.out.println();
}
System.out.println("------------------------");
Dataset<Row> topicss = model.describeTopics();
topicss.foreach(func -> {
int topic = func.getInt(0);
WrappedArray<Long> words = (WrappedArray<Long>)func.get(1);
WrappedArray<Double> distributes = (WrappedArray<Double>)func.get(2);
System.out.print("主题 " + topic + ",单词(按照概率从高到低排布)[");
for (int i = 0; i < words.length(); i++) {
System.out.print(words.apply(i) + " ");
}
System.out.print("],分布概率[");
for (int i = 0; i < distributes.length(); i++) {
System.out.print(distributes.apply(i) + " ");
}
System.out.print("]\n");
});
System.out.println("------------------------");
// 描述主题只展示概率前三的单词
Dataset<Row> topics = model.describeTopics(3);
System.out.println("The topics described by their top-weighted terms:");
topics.show(false);
// 对文档进行聚类,并展示主题分布结果。lable表示的是文档的序号
Dataset<Row> transformed = model.transform(dataset);
transformed.show(false);
double[] arr = model.getEffectiveDocConcentration();
for (double d : arr) {
System.out.println(d);
}
//System.out.println(model.getTopicConcentration());
spark.stop();
}
}
推荐阅读