欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页

图像分类之KNN算法

程序员文章站 2022-04-07 22:01:10
...

简介

这两天我用了KNN方法对Cifar10数据进行分类,结果却是差强人意,只有30%左右的正确率。

KNN算法的训练只是将训练数据集存储起来,所以训练不需要花费很多时间,但是测试就需要花费大量时间。
图像分类之KNN算法
对于MNIST数据集,该分类器效果很好,原因我觉得主要是MNIST数据集都是黑白照片,KNN本质上是通过图象的像素差来进行计算的,所以MNIST数据集图像像素差包含的信息比较多。
图像分类之KNN算法

代码

my_utils.py

#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @File  : my_utils.py
# @Author: Fly_dragon
# @Date  : 2019/11/29
# @Desc  : 

import numpy as np


def getXmean(x_train):
    x_train = np.reshape(x_train, (x_train.shape[0], -1))  # Turn the image to 1-D
    mean_image = np.mean(x_train, axis=0)  # 求每一列均值。即求所有图片每一个像素上的平均值
    return mean_image


def centralized(x_test, mean_image):
    x_test = np.reshape(x_test, (x_test.shape[0], -1))
    x_test = x_test.astype(np.float)
    x_test -= mean_image  # Subtract the mean from the graph, and you get zero mean graph
    return x_test

#%% KNN class
class Knn:

    def __init__(self):
        pass

    def fit(self, X_train, y_train):
        self.Xtr = X_train
        self.ytr = y_train

    def predict(self, k, dis, X_test):
        """

        """
        assert dis == 'E' or dis == 'M'
        num_test = X_test.shape[0]
        label_list = []
        # 使用欧拉公式作为距离测量
        if dis == 'E':
            for i in range(num_test):
                distances = np.sqrt(np.sum(((self.Xtr - np.tile(X_test[i],
                                                                (self.Xtr.shape[0], 1)))) ** 2, axis=1))
                nearest_k = np.argsort(distances)
                topK = nearest_k[:k]
                class_count = {}
                for i in topK:
                    class_count[self.ytr[i]] = class_count.get(self.ytr[i], 0) + 1
                sorted_class_count = sorted(class_count.items(), key=lambda elem: elem[1], reverse=True)
                label_list.append(sorted_class_count[0][0])

            return np.array(label_list)

        # 使用Manhattan distance进行度量

cifar10主函数,在sci mode下运行

import numpy as np
import matplotlib.pyplot as plt
import torch
from torch.utils.data import DataLoader
import torchvision.datasets as datasets

from my_utils import *

batch_size = 100


#%% define and load the data
train_datasets = datasets.CIFAR10(root='D:\python\深度学习与图像识别\pycifar',
                                  train=True,
                                  download=False,)
test_datasets = datasets.CIFAR10(root='D:\python\深度学习与图像识别\pycifar',
                                 train=False,
                                 download=False)


# load the data
train_loader = DataLoader(train_datasets, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_datasets, batch_size=batch_size, shuffle=True)

#%% Look at the pictures
pic_num = 7
digit = train_loader.dataset[pic_num]
plt.imshow(digit[0], cmap=plt.cm.binary)
plt.title(train_loader.dataset.classes[digit[1]])
plt.show()

pic_num = 7
digit = test_loader.dataset[pic_num]
plt.imshow(digit[0], cmap=plt.cm.binary)
plt.title(test_loader.dataset.classes[digit[1]])
plt.show()

#%% prepare the data
# 对训练数据处理
x_train = train_loader.dataset.data
mean_image = getXmean(x_train)
x_train = centralized(x_train, mean_image)
y_train = train_loader.dataset.targets
# 对测试数据处理,取前num_test个测试数据
num_test = 10
x_test = test_loader.dataset.data[:num_test]
mean_image = getXmean(x_test)
x_test = centralized(x_test, mean_image)
y_test = test_loader.dataset.targets[:num_test]

print(x_train.shape)
print(len(y_train))
print(x_test.shape)
print(len(y_test))

#%% show the results using KNN
for k in range(1, 8, 2):
    classifier = Knn()
    classifier.fit(x_train, y_train)
    y_pred = classifier.predict(k, 'E', x_test)
    num_correct = np.sum(y_pred == y_test)
    accuracy = float(num_correct) / num_test
    print(k, ':', accuracy)

#%% show the false picture
result = y_pred == y_test
for i in range(num_test):
    if result[i] == False:
        digit = test_loader.dataset[i]
        plt.imshow(digit[0], cmap=plt.cm.binary)
        plt.title(test_loader.dataset.classes[digit[1]])
        plt.show()

        print(test_loader.dataset.classes[y_pred[i]])



MNIST主函数

#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @File  : KNN_MNIST.py
# @Author: Fly_dragon
# @Date  : 2019/11/30
# @Desc  :

import numpy as np
import matplotlib.pyplot as plt
import torch
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

from my_utils import *

batch_size = 100
path = 'D:\python\datasets\mnist_data'

#%% define and load the data
train_datasets = datasets.MNIST(root=path,
                                train=True,
                                download=True)
test_datasets = datasets.MNIST(root=path,
                                train=False,
                                download=True)


# load the data
train_loader = DataLoader(train_datasets, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_datasets, batch_size=batch_size, shuffle=True)

#%% Look at the pictures
pic_num = 7
digit = train_loader.dataset[pic_num]
plt.imshow(digit[0], cmap=plt.cm.binary)
plt.title(train_loader.dataset.classes[digit[1]])
plt.show()

pic_num = 7
digit = test_loader.dataset[pic_num]
plt.imshow(digit[0], cmap=plt.cm.binary)
plt.title(test_loader.dataset.classes[digit[1]])
plt.show()

#%% prepare the data
# 对训练数据处理
x_train = train_loader.dataset.data.numpy()
mean_image = getXmean(x_train)
x_train = centralized(x_train, mean_image)
y_train = train_loader.dataset.targets.numpy()
# 对测试数据处理,取前num_test个测试数据
num_test = 200
x_test = test_loader.dataset.data[:num_test].numpy()
mean_image = getXmean(x_test)
x_test = centralized(x_test, mean_image)
y_test = test_loader.dataset.targets[:num_test].numpy()

print(x_train.shape)
print(len(y_train))
print(x_test.shape)
print(len(y_test))

#%% show the results using KNN
for k in range(1, 4, 2):
    classifier = Knn()
    classifier.fit(x_train, y_train)
    y_pred = classifier.predict(k, 'E', x_test)
    num_correct = np.sum(y_pred == y_test)
    accuracy = float(num_correct) / num_test
    print(k, ':', accuracy)

#%% show the false picture
result = y_pred == y_test
for i in range(num_test):
    if result[i] == False:
        digit = test_loader.dataset[i]
        plt.imshow(digit[0], cmap=plt.cm.binary)
        plt.title(test_loader.dataset.classes[digit[1]])
        plt.show()

        print(test_loader.dataset.classes[y_pred[i]])