欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页

tensorflow学习笔记二---k近邻分类器

程序员文章站 2024-03-23 13:20:40
...

使用Tensorflow实现k近邻分类器模型
1.k近邻模型的基本原理
  1. 距离度量
tensorflow学习笔记二---k近邻分类器

      2.k值的选择

tensorflow学习笔记二---k近邻分类器


3 .分类决策规则

tensorflow学习笔记二---k近邻分类器

2.Tensorflow实现k近邻分类代码
  1. inference()-构建学习器模型前向预测过程(从输入到输出的计算图路径)
  2. evaluate()-在测试集数据上对模型的预测性能进行评估
  3. 此模型没有添加loss也没有train
    

3.计算步骤
  •    算距离:给定测试样本的特征向量,计算他与训练集中每个样本特征向量的距离,得到一个一维张量
tensorflow学习笔记二---k近邻分类器tensorflow学习笔记二---k近邻分类器


  • 找近邻:圈定最近的k个训练样本作为测试样本近邻
  • 作分类:根据k个近邻的归属主要类别,来对测试做主要分类
tensorflow学习笔记二---k近邻分类器

4.总结
  • tensorflow实现k近邻算法主要有以下几个步骤
    • 算距离:计算测试样本与每一个训练样本的距离,缩减求和后得到一个一维数组存储
    • 找近邻:划定k值的大小,选取k个训练样本做为测试样本的近邻
    • 做分类:根据k个近邻,对测试样本做分类(距离最小的索引)
    • 做评估:与真实的标签进行比较,计算准确率
  • 核心代码
tensorflow学习笔记二---k近邻分类器
import numpy as np
import  os
import  tensorflow as tf
# 防止意外报错
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
# 导入mnist数据集
from tensorflow.examples.tutorials.mnist import input_data
# onehot标签标识一个长度为n的数组,只有一个元素是1,其他的都是0,用来表示mnist中标签数据
mnist = input_data.read_data_sets("mnist_data/", one_hot=True)
# 对mnist数据集做一个数量限制
Xtrain,Ytrain=mnist.train.next_batch(5000)#使用5000个训练数据
Xtest,Ytest=mnist.train.next_batch(200) # 使用200个测试数据
print('Xtrain.shape: ', Xtrain.shape, ', Xtest.shape: ',Xtest.shape)
print('Ytrain.shape: ', Ytrain.shape, ', Ytest.shape: ',Ytest.shape)
# 计算图输入占位符
#train 使用全部样本,test 逐个样本进行测试
xtrain=tf.placeholder("float",[None,784])#图片训练集
xtest=tf.placeholder("float",[784])#测试集
#使用L1距离进行最近邻计算
#计算L1距离
distance=tf.reduce_sum(tf.abs(tf.add(xtrain,tf.negative(xtest))),axis=1)
# 预测: 获得最小距离的索引 (根据最近邻的类标签进行判断)
pred = tf.arg_min(distance, 0)
#评估:判断给定的一条测试样本是否预测正确
#评估正确率
accuracy=0
# 初始化节点
init = tf.global_variables_initializer()
#启动会话
with tf.Session() as sess:
    sess.run(init)
    Ntest=len(Xtest)#测试样本的数量
    for i in range(Ntest):
      # 获取当前测试样本的最近邻
        nn_index = sess.run(pred, feed_dict={xtrain: Xtrain, xtest: Xtest[i, :]})#一个样本一个样本的输入
      # 获得最近邻预测标签,然后与真实的类标签比较,由于是 one_hot 编码,所以要用 argmax 将类标取出
        pred_class_label = np.argmax(Ytrain[nn_index])
        true_class_label = np.argmax(Ytest[i])
        print("Test", i, "Predicted Class Label:", pred_class_label,
          "True Class Label:", true_class_label)
          # 计算准确率
        if pred_class_label == true_class_label:
            accuracy += 1

    print("Done!")
    accuracy /= Ntest
    print("Accuracy:", accuracy)

训练结果准确率是0.925,使用的数据是mnist5000的训练数据和200的测试数据

tensorflow学习笔记二---k近邻分类器

关于这部分也就介绍到这,我只是代码的搬运工,嘿嘿嘿,把学到的东西分享出来,温故而知新。

总觉得很快乐。

持续更新,大家可以一起讨论哦。

相关标签: tensorflow入门