欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页

spark LDA聚类算法的例子

程序员文章站 2022-05-19 13:13:07
...

LDA是狄利克雷分布算法,一般用于海量文档主题聚类。一般每个文档都有一个隐藏的主题,LDA算法就是找出可能性最高的几个主题。在找出主题的同时,也就找到了文档中每个单词在主题中的分布概概率。可以参考http://blog.csdn.net/qq_34531825/article/details/52608003

下面是LDA算法的例子。该例子参考了官方网站中的例子。例子中的数据如下:

0 1:1 2:2 3:6 4:0 5:2 6:3 7:1 8:1 9:0 10:0 11:3
1 1:1 2:3 3:0 4:1 5:3 6:0 7:0 8:2 9:0 10:0 11:1
2 1:1 2:4 3:1 4:0 5:0 6:4 7:9 8:0 9:1 10:2 11:0
3 1:2 2:1 3:0 4:3 5:0 6:0 7:5 8:0 9:2 10:3 11:9
4 1:3 2:1 3:1 4:9 5:3 6:0 7:2 8:0 9:0 10:1 11:3
5 1:4 2:2 3:0 4:3 5:4 6:5 7:1 8:1 9:1 10:4 11:0
6 1:2 2:1 3:0 4:3 5:0 6:0 7:5 8:0 9:2 10:2 11:9
7 1:1 2:1 3:1 4:9 5:2 6:1 7:2 8:0 9:0 10:1 11:3
8 1:4 2:4 3:0 4:3 5:4 6:2 7:1 8:3 9:0 10:0 11:0
9 1:2 2:8 3:2 4:0 5:3 6:0 7:2 8:0 9:2 10:7 11:2
10 1:1 2:1 3:1 4:9 5:0 6:2 7:2 8:0 9:0 10:3 11:3
11 1:4 2:1 3:0 4:0 5:4 6:5 7:1 8:3 9:0 10:1 11:0


代码如下:

package spark;

import org.apache.log4j.Level;
import org.apache.log4j.Logger;
import org.apache.spark.ml.clustering.LDA;
import org.apache.spark.ml.clustering.LDAModel;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SparkSession;

import scala.collection.mutable.WrappedArray;

public class JavaLDAExample {

	public static void main(String[] args) {

		Logger logger = Logger.getLogger(JavaLDAExample.class);
		// 设置日志的等级 并关闭jetty容器的日志
		Logger.getLogger("org.apache.spark").setLevel(Level.WARN);
		Logger.getLogger("org.apache.eclipse.jetty.server").setLevel(Level.OFF);
		SparkSession spark = SparkSession.builder().master("local[2]").appName("JavaLDAExample").getOrCreate();
		
		//加载数据,数据是标签向量。标签可以看作是文档序号。文档格式为:   文档序号  矩阵列序号:文档中的单词
		Dataset<Row> dataset = spark.read().format("libsvm")
				.load("F:/spark-2.1.0-bin-hadoop2.6/data/mllib/sample_lda_libsvm_data.txt");
		dataset.foreach(func -> {
			System.out.println(func);
		});

		// 训练lda模型
		LDA lda = new LDA().setK(3).setMaxIter(10);
		LDAModel model = lda.fit(dataset);
		// log likelihood,越大越好。
		double ll = model.logLikelihood(dataset);
		// Perplexity评估,越小越好
		double lp = model.logPerplexity(dataset);
		System.out.println("The lower bound on the log likelihood of the entire corpus: " + ll);
		System.out.println("The upper bound on perplexity: " + lp);

		org.apache.spark.ml.linalg.Matrix matrix = model.topicsMatrix();
		System.out.println("------------------------");
		System.out.println("矩阵topics列为主题,总共有" + matrix.numCols() + "主题");
		System.out.println("矩阵topics行为单词,总共有" + matrix.numRows() + "单词");

		System.out.println("矩阵topics表示的是每个单词在每个主题中的权重");
		for (int topic = 0; topic < 3; topic++) {
			System.out.print("Topic " + topic + ":");
			for (int word = 0; word < model.vocabSize(); word++) {
				System.out.print(" " + matrix.apply(word, topic));
			}
			System.out.println();
		}

		System.out.println("------------------------");

		Dataset<Row> topicss = model.describeTopics();
		topicss.foreach(func -> {
			int topic = func.getInt(0);
			WrappedArray<Long> words = (WrappedArray<Long>)func.get(1);
			WrappedArray<Double> distributes = (WrappedArray<Double>)func.get(2);
			System.out.print("主题 " + topic + ",单词(按照概率从高到低排布)[");
			for (int i = 0; i < words.length(); i++) {
				System.out.print(words.apply(i) + " ");
			}
			System.out.print("],分布概率[");
			for (int i = 0; i < distributes.length(); i++) {
				System.out.print(distributes.apply(i) + " ");
			}
			System.out.print("]\n");
		});

		System.out.println("------------------------");
		// 描述主题只展示概率前三的单词
		Dataset<Row> topics = model.describeTopics(3);
		System.out.println("The topics described by their top-weighted terms:");
		topics.show(false);

		// 对文档进行聚类,并展示主题分布结果。lable表示的是文档的序号
		Dataset<Row> transformed = model.transform(dataset);
		transformed.show(false);

		double[] arr = model.getEffectiveDocConcentration();
		for (double d : arr) {
			System.out.println(d);
		}

		//System.out.println(model.getTopicConcentration());
		spark.stop();
	}
}