欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页  >  IT编程

Python实现C-均值聚类

程序员文章站 2022-06-25 19:07:38
*不保证正确性,仅作记录# -*- coding: utf-8 -*-############################## Author : 水娃 ## Date : 2020-12-04-01.21.51###############################导入库import numpy as npimport matplotlib.pyplot as plt#计算两个样本之间的距离def calculateDistance(posa,posb...

*不保证正确性,仅作记录

# -*- coding: utf-8 -*-
#############################
# Author : 水娃             #
# Date : 2020-12-04-01.21.51#
#############################
#导入库
import numpy as np
import matplotlib.pyplot as plt

#计算两个样本之间的距离
def calculateDistance(posa,posb):
    return np.sqrt(np.sum(np.power(posa-posb,2)))

#随机生成P个二维坐标
def getRandomData(P):
    return np.array([[np.random.rand()*100 for j in range(0,2)] for i in range(0,P)])

#获取初始类心
def getInitCenters(data,mp,P,C):
    #先把刚开始距离最远的两个点加入类心
    maxdis=0.0
    st1=-1
    st2=-1
    for i in range(0,P):
        for j in range(0,P):
            if mp[i][j]>maxdis:
                st1=i
                st2=j
                maxdis=mp[i][j]
    currentCenters=[st1,st2]
    
    #算出其余类心
    for i in range(0,C-2):
        currentCandidate=-1
        currentDistance=0;
        for j in range (0,P):
            if j in currentCenters:
                continue
            flag=True
            maxDistance=200
            for k in currentCenters:
                if mp[j][k]<currentDistance:
                    flag=False
                    break
                maxDistance=min(maxDistance,mp[j][k])
            if flag==False:
                continue
            if maxDistance>currentDistance:
                currentCandidate=j
                currentDistance=maxDistance
        currentCenters.append(currentCandidate)
    return currentCenters

#获取初始聚类结果
def getOriginalCluster(data,centers,mp,P,C):
    #按照最近距离原则将所有点聚类
    result=[[] for i in range(0,C)]
    for i in range(0,P):
        distance=mp[i][centers[0]]
        index=0
        for j in range(0,C):
            currentCenter=centers[j]
            if mp[i][currentCenter]<distance:
                distance=mp[i][currentCenter]
                index=j
        result[index].append(i)
    return result

#C均值聚类
def Cmeans(data,P,C):
    #获取距离列表,初始类心和初始聚类结果
    mp=[[calculateDistance(data[i],data[j])for j in range(0,P)]for i in range(0,P)]
    centers=getInitCenters(data,mp,P,C)
    cluster=getOriginalCluster(data,centers,mp,P,C)
    
    #将类心从索引转换成坐标
    positionCenters=[[]for i in range(0,C)]
    for i in range(0,C):
        positionCenters[i].append(data[centers[i]][0])
        positionCenters[i].append(data[centers[i]][1])
    centers=positionCenters
    
    #开始迭代聚类
    haschanged=True
    times=0
    while haschanged:
        times+=1
        print(times)
        tmpCluster=[[]for i in range(0,C)]
        haschanged=False
        
        #更新类心
        for i in range(0,C):
            totalx=0.0
            totaly=0.0
            for j in cluster[i]:
                totalx+=data[j][0]
                totaly+=data[j][1]
            centers[i][0]=totalx/(1.0*len(cluster[i]))
            centers[i][1]=totaly/(1.0*len(cluster[i]))
            
        #更新聚类结果
        for i in range(0,C):
            for j in cluster[i]:
                currentDistance=calculateDistance(centers[i],data[j])
                currentIndex=i;
                for k in range(0,C):
                    if calculateDistance(centers[k],data[j])<currentDistance:
                        currentDistance=calculateDistance(centers[k],data[j])
                        currentIndex=k
                if currentIndex!=i:
                    haschanged=True
                tmpCluster[currentIndex].append(j)
        cluster=tmpCluster
    return centers,cluster
    
#设定类的数目C和点数P
C=10
P=200

#获取随机数据
data=getRandomData(P)

#进行聚类获得结果
result,cluster=Cmeans(data,P,C)

#根据测试数据和结果点集绘制图形
for i in range(0,C):
    pointx=[]
    pointy=[]
    for k in cluster[i]:
        pointx.append(data[k][0])
        pointy.append(data[k][1])
    plt.scatter(pointx,pointy)
    
resultx=[result[i][0]for i in range(0,C)]
resulty=[result[i][1]for i in range(0,C)]
plt.scatter(resultx,resulty,marker='s',color='k')
plt.show()

本文地址:https://blog.csdn.net/qq_43807662/article/details/110633215