欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页  >  IT编程

应用SVM对MNIST数据集进行分类

程序员文章站 2022-08-06 08:03:32
MNIST是机器学习领域十分经典的一个手写数字数据集,共60000张训练图像,10000张测试图像,图像大小为28*28.MNIST百度网盘下载地址:https://pan.baidu.com/s/1k1Ji6amaUhDG6jfdcl_kNg 提取码:nykv将下载下来的压缩包解压后放到源代码所在的文件夹下即可。如运行缺少相关python库,可往https://www.lfd.uci.edu/~gohlke/pythonlibs/下载SVM分类MNIST的源代码如下:from sk....

MNIST是机器学习领域十分经典的一个手写数字数据集,共60000张训练图像,10000张测试图像,图像大小为28*28.

MNIST百度网盘下载地址:https://pan.baidu.com/s/1k1Ji6amaUhDG6jfdcl_kNg  提取码:nykv

将下载下来的压缩包解压后放到源代码所在的文件夹下即可。

如运行缺少相关python库,可往https://www.lfd.uci.edu/~gohlke/pythonlibs/下载

SVM分类MNIST的源代码如下:

from sklearn import svm
import numpy as np
from time import time
from sklearn.metrics import accuracy_score
from struct import unpack
from sklearn.model_selection import GridSearchCV

def readimage(path):
    with open(path, 'rb') as f:
        magic, num, rows, cols = unpack('>4I', f.read(16))
        img = np.fromfile(f, dtype=np.uint8).reshape(num, 784)
    return img

def readlabel(path):
    with open(path, 'rb') as f:
        magic, num = unpack('>2I', f.read(8))
        lab = np.fromfile(f, dtype=np.uint8)
    return lab

def main():
    train_data  = readimage("train-images.idx3-ubyte")
    train_label = readlabel("train-labels.idx1-ubyte")
    test_data   = readimage("t10k-images.idx3-ubyte")
    test_label  = readlabel("t10k-labels.idx1-ubyte")
    svc=svm.SVC()
    parameters = {'kernel':['rbf'], 'C':[1]}
    print("Train...")
    clf=GridSearchCV(svc,parameters,n_jobs=-1)
    start = time()
    clf.fit(train_data, train_label)
    end = time()
    t = end - start
    print('Train:%dmin%.3fsec' % (t//60, t - 60 * (t//60)))
    prediction = clf.predict(test_data)
    print("accuracy: ", accuracy_score(prediction, test_label))
    accurate=[0]*10
    sumall=[0]*10
    i=0
    while i<len(test_label):
        sumall[test_label[i]]+=1
        if prediction[i]==test_label[i]:
            accurate[test_label[i]]+=1
        i+=1
    print("分类正确的:",accurate)
    print("总的测试标签:",sumall)

if __name__ == '__main__':
    main()

程序通过readimage和readlabel函数读入数据后创建svm分类器,并用parameter添加相应的参数,这里使用GridSearchCV将参数作为输入优化网络,这里输入的parameter对应分类器唯一,可进行添加以达到优化参数的目的,代码中使用GridSearchCV的主要目的是引入n_jobs让cpu进行多线程处理,n_jobs=-1时程序的并行数将和cpu的核数一致,从而极大的加速程序的运行。在i5-8300H的四核CPU中训练时间为26min。

源代码训练时的正确率如下:

应用SVM对MNIST数据集进行分类

欢迎评论区交流。

友情链接:svm.SVC参数详解:https://blog.csdn.net/weixin_41990278/article/details/93137009

                  GridSearchCV参数详解:https://blog.csdn.net/foneone/article/details/89985045

本文地址:https://blog.csdn.net/qq_43160985/article/details/107675241

相关标签: 机器学习