应用SVM对MNIST数据集进行分类
MNIST是机器学习领域十分经典的一个手写数字数据集,共60000张训练图像,10000张测试图像,图像大小为28*28.
MNIST百度网盘下载地址:https://pan.baidu.com/s/1k1Ji6amaUhDG6jfdcl_kNg 提取码:nykv
将下载下来的压缩包解压后放到源代码所在的文件夹下即可。
如运行缺少相关python库,可往https://www.lfd.uci.edu/~gohlke/pythonlibs/下载
SVM分类MNIST的源代码如下:
from sklearn import svm
import numpy as np
from time import time
from sklearn.metrics import accuracy_score
from struct import unpack
from sklearn.model_selection import GridSearchCV
def readimage(path):
with open(path, 'rb') as f:
magic, num, rows, cols = unpack('>4I', f.read(16))
img = np.fromfile(f, dtype=np.uint8).reshape(num, 784)
return img
def readlabel(path):
with open(path, 'rb') as f:
magic, num = unpack('>2I', f.read(8))
lab = np.fromfile(f, dtype=np.uint8)
return lab
def main():
train_data = readimage("train-images.idx3-ubyte")
train_label = readlabel("train-labels.idx1-ubyte")
test_data = readimage("t10k-images.idx3-ubyte")
test_label = readlabel("t10k-labels.idx1-ubyte")
svc=svm.SVC()
parameters = {'kernel':['rbf'], 'C':[1]}
print("Train...")
clf=GridSearchCV(svc,parameters,n_jobs=-1)
start = time()
clf.fit(train_data, train_label)
end = time()
t = end - start
print('Train:%dmin%.3fsec' % (t//60, t - 60 * (t//60)))
prediction = clf.predict(test_data)
print("accuracy: ", accuracy_score(prediction, test_label))
accurate=[0]*10
sumall=[0]*10
i=0
while i<len(test_label):
sumall[test_label[i]]+=1
if prediction[i]==test_label[i]:
accurate[test_label[i]]+=1
i+=1
print("分类正确的:",accurate)
print("总的测试标签:",sumall)
if __name__ == '__main__':
main()
程序通过readimage和readlabel函数读入数据后创建svm分类器,并用parameter添加相应的参数,这里使用GridSearchCV将参数作为输入优化网络,这里输入的parameter对应分类器唯一,可进行添加以达到优化参数的目的,代码中使用GridSearchCV的主要目的是引入n_jobs让cpu进行多线程处理,n_jobs=-1时程序的并行数将和cpu的核数一致,从而极大的加速程序的运行。在i5-8300H的四核CPU中训练时间为26min。
源代码训练时的正确率如下:
欢迎评论区交流。
友情链接:svm.SVC参数详解:https://blog.csdn.net/weixin_41990278/article/details/93137009
GridSearchCV参数详解:https://blog.csdn.net/foneone/article/details/89985045
本文地址:https://blog.csdn.net/qq_43160985/article/details/107675241
上一篇: iOS CGPath提升阴影性能
推荐阅读
-
应用SVM对MNIST数据集进行分类
-
keras对猫、狗数据集进行分类(二)
-
python3 24.kera使用DropOut进行MNIST数据集简单分类 学习笔记
-
分别采用线性LDA、k-means和SVM算法对鸢尾花数据集和月亮数据集进行二分类可视化分析
-
Keras : 利用卷积神经网络CNN对图像进行分类,以mnist数据集为例建立模型并预测
-
ML之多分类预测之PLiR:使用PLiR实现对六类label数据集进行多分类
-
ML:基于自定义数据集利用Logistic、梯度下降算法GD、LoR逻辑回归、Perceptron感知器、SVM支持向量机、LDA线性判别分析算法进行二分类预测(决策边界可视化)
-
应用SVM对MNIST数据集进行分类
-
ML之多分类预测之PLiR:使用PLiR实现对六类label数据集进行多分类
-
ML:基于自定义数据集利用Logistic、梯度下降算法GD、LoR逻辑回归、Perceptron感知器、SVM支持向量机、LDA线性判别分析算法进行二分类预测(决策边界可视化)