欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页

【推荐系统】聚类算法-K-Means算法

程序员文章站 2022-07-14 20:56:39
...

K-means算法思想

​ K-means算法是最为经典的基于划分的聚类方法。是一种比较简单的算法。其基本思想和核心内容就是在算法开始时随机给定若干(K)个中心,按照最近距离原则将样本点分配到各个中心点,之后按平均法计算聚类集的中心点位置,从而重新确定新的中心点位置。这样不断地迭代下去直至聚类集内的样本满足阈值为止

​ 单单概念无法理解k-means算法,接着看图a,如果普通人,很容易就可以区分出来两类数据。但是机器却无法区分,毕竟机器是死脑筋。如果机器来区分的话有以下步骤

  1. 机器则随机定义了两个点红点和蓝点,如图b。
  2. 所有的点计算到红点和蓝点的距离,距离哪个近即属于哪一个集合,计算完成之后分类结束得到图c。
  3. 显然图c不是最好的结果,每个集合计算该集合的质心(中心点)。然后重复步骤2。质心的位置变化越来越小,直到原来的质心到计算出来的质心的距离小于提前设置好的阈值的时候,即可认为分类结束,如图d,e,f。

【推荐系统】聚类算法-K-Means算法

K-means算法Scala实现

kmeans.txt ,这是libsvm数据格式。

0 1:0.0 2:0.0 3:0.0
1 1:0.1 2:0.1 3:0.1
2 1:0.2 2:0.2 3:0.2
3 1:9.0 2:9.0 3:9.0
4 1:9.1 2:9.1 3:9.1
5 1:9.2 2:9.2 3:9.2

  def main(args: Array[String]): Unit = {
    val spark = SparkSession.builder().appName("kmeans")
      .master("local[2]")
      .getOrCreate();
    //libsvm是一种数据格式
    val dataset = spark.read.format("libsvm").load("C:\\Users\\archermind\\Desktop\\ml-1m\\kmeans.txt")
    // Trains a k-means model.
    // k 表示分为两类
    val kmeans = new KMeans().setK(2).setSeed(1l)
    val model = kmeans.fit(dataset)
    // Make predictions
    val predictions = model.transform(dataset)
    predictions.show(false)
      /* predictions的内容如下
     	 +-----+-------------------------+
        |label|features                 |
        +-----+-------------------------+
        |0.0  |(3,[],[])                |
        |1.0  |(3,[0,1,2],[0.1,0.1,0.1])|
        |2.0  |(3,[0,1,2],[0.2,0.2,0.2])|
        |3.0  |(3,[0,1,2],[9.0,9.0,9.0])|
        |4.0  |(3,[0,1,2],[9.1,9.1,9.1])|
        |5.0  |(3,[0,1,2],[9.2,9.2,9.2])|
        |6.0  |(3,[0,1,2],[3.1,3.2,3.3])|
        +-----+-------------------------+
     	*/
    //用于验证群集内一致性的一种度量。取值范围是1到-1,其中接近1的值表示一个群集中的点靠近同一群集中的其他点
    val evaluator = new ClusteringEvaluator()
    val silhouette = evaluator.evaluate(predictions)
    println(s"Silhouette with squared euclidean distance = $silhouette")
    // 输入分类结果
    println("Cluster Centers: ")
    model.clusterCenters.foreach(println)
    /** 输出结果
     * Silhouette with squared euclidean distance = 0.9997530305375207
     * 分类结果
     * Cluster Centers:
     * [0.1,0.1,0.1]
     * [9.1,9.1,9.1]
     */
  }

​ 如上所示,可以认为在三维坐标系中有6个点,分成了两类数据。
添加第6个数据,分成3类,有如下输出

0 1:0.0 2:0.0 3:0.0
1 1:0.1 2:0.1 3:0.1
2 1:0.2 2:0.2 3:0.2
3 1:9.0 2:9.0 3:9.0
4 1:9.1 2:9.1 3:9.1
5 1:9.2 2:9.2 3:9.2
6 1:3.1 2:3.2 3:3.3

输入结果如下:

Silhouette with squared euclidean distance = 0.9997530305375207
Cluster Centers:
[0.1,0.1,0.1]
[9.1,9.1,9.1]

注:libsvm是一种数据格式,格式如下

<label> <index1>:<value1> <index2>:<value2> ...

​ 其中

参考地址:

svm算法参考

libsvm格式说明