基于西瓜数据集的KNN算法实现
程序员文章站
2022-06-26 17:41:43
因为本人在学习这块内容之后,发现网络上大部分现有代码的不简洁以及运行报错,再者想要的表达方法的不同,所以自己动手结合网络上已有的代码改写了一个,运行正常。代码及数据集以上传到GitHub:https://github.com/zhurui-king/aaa# -*- coding:utf-8 -*-# Author: 非鱼子焉# Creation_time: 2020.11.11# Content: 基于西瓜数据集的KNN算法实现# Blog: https://zhu-rui.blog.csdn...
因为本人在学习这块内容之后,发现网络上大部分现有代码的不简洁以及运行报错,再者想要的表达方法的不同,所以自己动手结合网络上已有的代码改写了一个,运行正常。
代码及数据集以上传到GitHub:https://github.com/zhurui-king/aaa
# -*- coding:utf-8 -*-
# Author: 非鱼子焉
# Creation_time: 2020.11.11
# Content: 基于西瓜数据集的KNN算法实现
# Blog: https://zhu-rui.blog.csdn.net/
# GitHub: https://github.com/zhurui-king/aaa
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
# KNN算法类
class KNN(object):
def __init__(self, x, y, K):#x:密度;y:含糖率;k:近邻数
self.x = x
self.y = y
self.K = K
self.n = len(x)
# 计算距离
def distance(self, p1, p2):
return np.linalg.norm(np.array(p1) - np.array(p2))
#算法实现
def knn(self, x):
distance = []
for i in range(self.n):
dist = self.distance(x, self.x[i])
distance.append([self.x[i], self.y[i], dist])
distance.sort(key=lambda x: x[2])
neighbors = []
neighbors_labels = []
for k in range(self.K):
neighbors.append(distance[k][0]) # 近邻具体数据
neighbors_labels.append(distance[k][1]) # 近邻标记
return neighbors, neighbors_labels
#选择多数投票数
def vote(self, x):
neighbors, neighbors_labels = self.knn(x)
vote = {} # 投票法
for label in neighbors_labels:
vote[label] = vote.get(label, 0) + 1
sort_vote = sorted(vote.items(), key=lambda x:x[1], reverse=True)
return sort_vote[0][0] # 返回投票数最多的标记
#对应标记
def fit(self):
labels = []
for sample in self.x:
label = self.vote(sample)
labels.append(label)
return labels # 返回所有样本的标记
# 计算正确率
def accuracy(self):
predict_labels = self.fit()
real_labels = self.y
correct = 0
for predict, real in zip(predict_labels, real_labels):
if int(predict) == int(real):
correct += 1
return correct / self.n
#读取数据
def getdata(path):
dataSet = pd.read_csv(path, delimiter=",")
X = dataSet[['density', 'sugar_rate']].values
Y = dataSet['label']
return X,Y
# 进行绘图
def drawpictures(x_positive, y_positive,x_negative, y_negative):
plt.scatter(x_positive, y_positive, marker='o', color='red', label='1')
plt.scatter(x_negative, y_negative, marker='o', color='blue', label='0')
plt.xlabel('密度')
plt.ylabel('含糖率')
plt.rcParams['font.sans-serif'] = ['SimHei'] # 用来正常显示中文标签
plt.legend(loc='upper left')
plt.show()
#训练数据
def train(X,Y):
for k in range(1, 9):
print("*****第%d次*****" %k)
print('本次knn的k值选取为{}'.format(k))
knn = KNN(X, Y, k)
predict = knn.fit()
print('本次knn的正确率为{}'.format(knn.accuracy()))
x_positive = []
y_positive = []
x_negative = []
y_negative = []
for i in range(len(X)):
if int(predict[i]) == 1:
x_positive.append(X[i][0])
y_positive.append(X[i][1])
else:
x_negative.append(X[i][0])
y_negative.append(X[i][1])
drawpictures(x_positive, y_positive,x_negative, y_negative)
if __name__ == '__main__':
X,Y = getdata('watermelon3_0a.csv')
train(X,Y)
print("************程序运行结束************")
最终结果输出为所选择近邻K的对应的正确率,并且进行plot可视化,其中循环遍历每一次的K值(K从1开始到所设定的值-1为止)
参考博文:https://blog.csdn.net/weixin_42152526/article/details/93528560
参考书籍:MACHINE LEARNING 机器学习(周志华)清华大学出版社
本文地址:https://blog.csdn.net/zhu_rui/article/details/109634016
推荐阅读
-
用JAVA语言实现的凝聚式层次聚类算法 ——基于数据结构中的线性结构和树形结构
-
TensorFlow系列(4)——基于MNIST数据集的CNN实现
-
基于python的BP神经网络算法对mnist数据集的识别--批量处理版
-
Python编程实现ID3算法,使用西瓜数据集产生结果
-
基于西瓜数据集的KNN算法实现
-
Python实现基于KNN算法的笔迹识别功能详解
-
Redis3.0集群crc16算法php客户端实现方法(php获得redis3.0集群中redis数据所在的redis分区插槽,并根据分区插槽取得分区所在redis服务器地址)
-
Redis3.0集群crc16算法php客户端实现方法(php获得redis3.0集群中redis数据所在的redis分区插槽,并根据分区插槽取得分区所在redis服务器地址)
-
算法与数据结构之基于数组实现的数组队列和循环队列Java版
-
基于python的BP神经网络算法对mnist数据集的识别--批量处理版