欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页

西瓜书《机器学习》课后答案——chapter9 _9.4

程序员文章站 2022-07-14 20:59:35
...

编程实现k均值算法,设置三组不同的k值、三组不同初始中心点,在西瓜数据集4.0上进行实验比较,并讨论什么样的初始中心有助于得到好结果。

下面所有图中的横坐标表示密度,纵坐标表示含糖率。

首先看看4.0数据集:
西瓜书《机器学习》课后答案——chapter9 _9.4

代码

#-*- coding:utf-8 -*-
"""
@Author: Victoria
@Date: 2017.10.24 12:00
"""
import random
import matplotlib.pyplot as plt
import xlrd
from copy import deepcopy

style = "*+o."
color = "kgybp"

class KMeans():
    def __init__(self, k):
        self.k = k

    def train(self, X):
        self.N = len(X)
        self.d = len(X[0])
        self.X = X

        self.init()
        self.init_centers = deepcopy(self.centers)
        #self.centers = [[0.403, 0.237], [0.343, 0.099], [0.478, 0.437]]
        #self.centers = [[0.403, 0.237], [0.343, 0.099], [0.532, 0.472]]
        print self.centers

        Js = []

        iter = 0
        while(1):
            iter += 1
            print "iteration: {}".format(iter)
            clusters_X  = {}
            clusters_y = {}
            for k in range(self.k):
                clusters_X[k] = []
                clusters_y[k] = []

            #cluster each sample to nearest cluster
            for i, x in enumerate(self.X):
                #print i, x                            
                label = self.cluster_sample(x)
                #print "label of x: ", label
                clusters_X[label].append(x)
                clusters_y[label].append(i+1)

            self.plot_clusters(iter, clusters_X)

            #computer centers for all clusters
            old_centers = deepcopy(self.centers)
            for k in range(self.k):
                self.centers[k] = self.compute_center(clusters_X[k])

            J_new = self.J_cost(clusters_X)
            Js.append(J_new)


            diff = 0
            for k in range(self.k):
                diff += self.dist(old_centers[k], self.centers[k])
            if diff < 1e-3:
                self.clusters_X = clusters_X
                break

        print "iter: ", iter        
        self.plot_J(Js)
        self.plot_clusters(iter, self.clusters_X)

    def plot_J(self, Js):
        plt.figure()            
        plt.plot(range(len(Js)), Js)
        plt.savefig("figures/k={}_cost.png".format(self.k))

    def plot_clusters(self, iter, clusters_X):

        plt.figure()
        for k in range(self.k):
            for x in clusters_X[k]:
                if x in self.init_centers:
                    plt.plot(x[0], x[1], style[k]+'r')
                else:
                    plt.plot(x[0], x[1], style[k]+color[k])
        plt.savefig("figures/k={}_cluster_iter{}.png".format(self.k, iter))

    def predict(self):
        pass

    def init(self):
        self.centers = []
        for k in range(self.k):
            index = random.randint(0, self.N-1)
            self.centers.append(self.X[index])


    def compute_center(self, X):
        center = []
        for i in range(self.d):
            sum = 0
            for x in X:
                sum += x[i]
            center.append(float(sum) / len(X))
        return center

    def cluster_sample(self, x):
        min_dist = float('inf')
        for k in range(self.k):
            dist_to_k = self.dist(x, self.centers[k])
            #print "dist_to_k: ",dist_to_k
            if min_dist > dist_to_k:
                 label = k
                 min_dist  = dist_to_k

        return label

    def dist(self, x, y):
        sum = 0
        for i in range(self.d):
            sum += (x[i] - y[i])**2
        return sum

    def J_cost(self, clusters_X):
        J = 0
        for k in range(self.k):
            for x in clusters_X[k]:
                J += self.dist(x, self.centers[k])
        return J

def main():
    workbook = xlrd.open_workbook("4.0.xlsx")
    sheet = workbook.sheet_by_name('Sheet1')
    X = []
    for i in range(30):
        X.append(sheet.col_values(i)[0:2])
    y = sheet.row_values(2)
    plt.figure()
    for i in range(30):
        plt.plot(X[i][0], X[i][1], 'k.')
    plt.savefig("figures/samples.png")

    k_means = KMeans(k=2)
    k_means.train(X)

if __name__ == '__main__':
    main()

k=2时:(图中红色表示初始中心点)
西瓜书《机器学习》课后答案——chapter9 _9.4

k=3时:用书中的初始化方法,但是得到的结果有一个点不一样(欢迎大家来找茬)
西瓜书《机器学习》课后答案——chapter9 _9.4

k=4时:
西瓜书《机器学习》课后答案——chapter9 _9.4

问题:
如果某次迭代时,某个类簇为空怎么办?