欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页

Kmeans 图像分割 by python

程序员文章站 2024-03-25 08:23:10
...

依然是神奇的numpy boardcast!!!
于是代码只有28行!!!

>
输入:k,data[n];
(1) 选择 k 个初始中心点,例如 c[0] = data[0], ……,c[k-1]=data[k-1];
(2) 对于 data[0],……,data[n],分别与 c[0],……,c[k-1]比较,假设 c[i]差值最少,就标记为 i;
(3) 对于所有标记为 i 点,重新计算 c[i]={所有标记为 i 的 data[i]z 之和}/标记为 i 的个数;
(4) 重复(2)(3),直至所有 c[j]值的变化小于给定阈值。

from scipy.misc import imread,imshow,imsave
import numpy as np
from functools import partial

def kmeans(img,K,epsilon):

    img = img.astype(np.float64)

    randpos = partial(np.random.randint,0,min(img.shape[0],img.shape[1]))
    cx,cy = [randpos(K) for i in range(2)]
    center = img[cx,cy]

    img = img.reshape(1, img.shape[0], img.shape[1], -1)
    center = center.reshape(K, 1, 1, 3)

    # ite = 0
    diff = np.inf
    pre_center = np.sum(center)
    while(diff>epsilon):

        dis = (img - center) ** 2
        pos_label = np.sum(dis, axis=3).argmin(axis=0)

        for i in range(K): center[i] = np.mean(img[0,pos_label == i],axis=0)

        diff = np.abs(np.sum(center)-pre_center)
        pre_center = np.sum(center)
        # ite+=1
        # print(ite,diff)

    for i in range(K): img[0,pos_label == i] = center[i]

    return np.squeeze(img).astype(np.float16)

if __name__ == '__main__':

    img = np.floor(imread("/home/ryan/Desktop/cat.jpg"))
    img = kmeans(img,5,0.05)
    imshow(img)

相关标签: python numpy