欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页  >  IT编程

基于西瓜数据集的KNN算法实现

程序员文章站 2022-06-26 17:41:43
因为本人在学习这块内容之后,发现网络上大部分现有代码的不简洁以及运行报错,再者想要的表达方法的不同,所以自己动手结合网络上已有的代码改写了一个,运行正常。代码及数据集以上传到GitHub:https://github.com/zhurui-king/aaa# -*- coding:utf-8 -*-# Author: 非鱼子焉# Creation_time: 2020.11.11# Content: 基于西瓜数据集的KNN算法实现# Blog: https://zhu-rui.blog.csdn...

因为本人在学习这块内容之后,发现网络上大部分现有代码的不简洁以及运行报错,再者想要的表达方法的不同,所以自己动手结合网络上已有的代码改写了一个,运行正常。
代码及数据集以上传到GitHubhttps://github.com/zhurui-king/aaa

# -*- coding:utf-8 -*-
# Author: 非鱼子焉
# Creation_time: 2020.11.11
# Content: 基于西瓜数据集的KNN算法实现
# Blog: https://zhu-rui.blog.csdn.net/
# GitHub: https://github.com/zhurui-king/aaa

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

# KNN算法类
class KNN(object):
    def __init__(self, x, y, K):#x:密度;y:含糖率;k:近邻数
        self.x = x
        self.y = y
        self.K = K
        self.n = len(x)
        
	# 计算距离
    def distance(self, p1, p2):
        return np.linalg.norm(np.array(p1) - np.array(p2))
        
	#算法实现
    def knn(self, x):
        distance = []
        for i in range(self.n):
            dist = self.distance(x, self.x[i])
            distance.append([self.x[i], self.y[i], dist])
        distance.sort(key=lambda x: x[2])
        neighbors = []
        neighbors_labels = []
        for k in range(self.K):
            neighbors.append(distance[k][0])  # 近邻具体数据
            neighbors_labels.append(distance[k][1])  # 近邻标记
        return neighbors, neighbors_labels
        
	#选择多数投票数
    def vote(self, x):
        neighbors, neighbors_labels = self.knn(x)
        vote = {}  # 投票法
        for label in neighbors_labels:
            vote[label] = vote.get(label, 0) + 1
        sort_vote = sorted(vote.items(), key=lambda x:x[1], reverse=True)
        return sort_vote[0][0]  # 返回投票数最多的标记
        
	#对应标记
    def fit(self):
        labels = []
        for sample in self.x:
            label = self.vote(sample)
            labels.append(label)
        return labels  # 返回所有样本的标记
        
	# 计算正确率
    def accuracy(self):
        predict_labels = self.fit()
        real_labels = self.y
        correct = 0
        for predict, real in zip(predict_labels, real_labels):
            if int(predict) == int(real):
                correct += 1
        return correct / self.n
        
#读取数据
def getdata(path):
    dataSet = pd.read_csv(path, delimiter=",")
    X = dataSet[['density', 'sugar_rate']].values
    Y = dataSet['label']
    return X,Y
    
# 进行绘图
def drawpictures(x_positive, y_positive,x_negative, y_negative): 
    plt.scatter(x_positive, y_positive, marker='o', color='red', label='1')
    plt.scatter(x_negative, y_negative, marker='o', color='blue', label='0')
    plt.xlabel('密度')
    plt.ylabel('含糖率')
    plt.rcParams['font.sans-serif'] = ['SimHei']  # 用来正常显示中文标签
    plt.legend(loc='upper left')
    plt.show()
    
#训练数据
def train(X,Y):
    for k in range(1, 9):
        print("*****第%d次*****" %k)
        print('本次knn的k值选取为{}'.format(k))
        knn = KNN(X, Y, k)
        predict = knn.fit()
        print('本次knn的正确率为{}'.format(knn.accuracy()))
        x_positive = []
        y_positive = []
        x_negative = []
        y_negative = []
        for i in range(len(X)):
            if int(predict[i]) == 1:
                x_positive.append(X[i][0])
                y_positive.append(X[i][1])
            else:
                x_negative.append(X[i][0])
                y_negative.append(X[i][1])
        drawpictures(x_positive, y_positive,x_negative, y_negative)

if __name__ == '__main__':
    X,Y = getdata('watermelon3_0a.csv')
    train(X,Y)
    print("************程序运行结束************")

最终结果输出为所选择近邻K的对应的正确率,并且进行plot可视化,其中循环遍历每一次的K值(K从1开始到所设定的值-1为止)

参考博文:https://blog.csdn.net/weixin_42152526/article/details/93528560
参考书籍:MACHINE LEARNING 机器学习(周志华)清华大学出版社

本文地址:https://blog.csdn.net/zhu_rui/article/details/109634016