(sklearn:Logistic回归)和(keras:全连接神经网络)完成mnist手写数字分类
程序员文章站
2023-04-04 11:20:04
文章目录读入数据Logistic回归全连接神经网络使用Logsitic回归进行手写数字(8×8×18\times8\times18×8×1)分类,样本量1797使用神经网络(无隐藏层、softmax激活函数、交叉熵损失函数、批量梯度下降)进行分类上述两模型在形式上等价,但由于优化求解的算法不一样,两者最终的模型参数以及分类准确率有所差异import numpy as npimport pandas as pdimport matplotlib.pyplot as pltimport wa...
文章目录
- 使用Logsitic回归进行手写数字()分类,样本量1797
- 使用神经网络(无隐藏层、softmax激活函数、交叉熵损失函数、批量梯度下降)进行分类
- 上述两模型在形式上等价,但由于优化求解的算法不一样,两者最终的模型参数以及分类准确率有所差异
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import warnings
warnings.filterwarnings("ignore")
from IPython.core.interactiveshell import InteractiveShell
InteractiveShell.ast_node_interactivity = "all"
from sklearn.datasets import load_digits
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import accuracy_score
from keras.models import Sequential
from keras import optimizers
from keras import layers
from keras.utils import to_categorical
Using TensorFlow backend.
读入数据
手写数字分类数据(样本量1797,0-9分类;每张图片)
mnist = load_digits()
X = mnist["data"]
Y = mnist["target"]
X.shape, Y.shape
X[:5, :5]
Y[:5]
np.unique(Y)
((1797, 64), (1797,))
array([[ 0., 0., 5., 13., 9.],
[ 0., 0., 0., 12., 13.],
[ 0., 0., 0., 4., 15.],
[ 0., 0., 7., 15., 13.],
[ 0., 0., 0., 1., 11.]])
array([0, 1, 2, 3, 4])
array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])
画出前100张手写数字
fig, axes = plt.subplots(10, 10, figsize = (10, 10))
fig.subplots_adjust(hspace = 0.1, wspace = 0.1)
for i, ax in enumerate(axes.flat):
ax.imshow(X[i].reshape(8, 8), cmap = "binary")
ax.text(0.5, 1, str(Y[i]), color = "red")
ax.set_xticks([])
ax.set_yticks([])
plt.show();
划分训练集(1000)与测试集(797)
np.random.seed(1)
# 将数据顺序打乱
index = np.random.permutation(X.shape[0])
X = X[index]
Y = Y[index]
X_train = X[:1000, ]
X_test = X[1000:, ]
Y_train = Y[:1000, ]
Y_test = Y[1000:, ]
X_train.shape, X_test.shape, Y_train.shape, Y_test.shape
((1000, 64), (797, 64), (1000,), (797,))
Logistic回归
Y_train、Y_test应用于sklearn时,为一维的稀疏数据格式
Logit = LogisticRegression().fit(X_train, Y_train)
Y_pred = Logit.predict(X_test)
print("模型在测试集上的准确率:%.2f" % accuracy_score(Y_pred, Y_test))
模型在测试集上的准确率:0.97
全连接神经网络
Y_train、Y_test应用于keras时,需要转化为One-Hot编码格式
Y_train = to_categorical(Y_train)
Y_test = to_categorical(Y_test)
Y_train.shape, Y_test.shape
((1000, 10), (797, 10))
模型结构:无隐藏层、softmax激活函数
model = Sequential()
model.add(layers.Dense(units = 10, input_shape = (64, ), activation = "softmax"))
model.summary()
Model: "sequential_1"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
dense_1 (Dense) (None, 10) 650
=================================================================
Total params: 650
Trainable params: 650
Non-trainable params: 0
_________________________________________________________________
模型训练:交叉熵损失函数、批量梯度下降;20轮后测试集上的准确率为0.93
model.compile(loss = "categorical_crossentropy", optimizer = optimizers.SGD(lr = 1e-1), metrics = ["accuracy"])
history = model.fit(X_train, Y_train, epochs = 20, batch_size = 1000, validation_data = (X_test, Y_test))
Train on 1000 samples, validate on 797 samples
Epoch 1/20
1000/1000 [==============================] - 0s 55us/step - loss: 17.8020 - accuracy: 0.0810 - val_loss: 26.5950 - val_accuracy: 0.0903
Epoch 2/20
1000/1000 [==============================] - 0s 5us/step - loss: 23.8534 - accuracy: 0.0920 - val_loss: 36.3915 - val_accuracy: 0.3513
Epoch 3/20
1000/1000 [==============================] - 0s 7us/step - loss: 37.6768 - accuracy: 0.3650 - val_loss: 40.8177 - val_accuracy: 0.2572
Epoch 4/20
1000/1000 [==============================] - 0s 6us/step - loss: 43.0619 - accuracy: 0.2180 - val_loss: 44.6566 - val_accuracy: 0.3325
Epoch 5/20
1000/1000 [==============================] - 0s 5us/step - loss: 41.9507 - accuracy: 0.2950 - val_loss: 43.6071 - val_accuracy: 0.3764
Epoch 6/20
1000/1000 [==============================] - 0s 5us/step - loss: 39.2078 - accuracy: 0.4090 - val_loss: 35.6318 - val_accuracy: 0.3237
Epoch 7/20
1000/1000 [==============================] - 0s 6us/step - loss: 33.3949 - accuracy: 0.3600 - val_loss: 30.8554 - val_accuracy: 0.4529
Epoch 8/20
1000/1000 [==============================] - 0s 4us/step - loss: 31.2371 - accuracy: 0.4380 - val_loss: 20.0968 - val_accuracy: 0.4793
Epoch 9/20
1000/1000 [==============================] - 0s 7us/step - loss: 19.6809 - accuracy: 0.4550 - val_loss: 11.8599 - val_accuracy: 0.4241
Epoch 10/20
1000/1000 [==============================] - 0s 6us/step - loss: 11.7020 - accuracy: 0.4310 - val_loss: 10.9439 - val_accuracy: 0.5922
Epoch 11/20
1000/1000 [==============================] - 0s 5us/step - loss: 11.6388 - accuracy: 0.5880 - val_loss: 10.7407 - val_accuracy: 0.6010
Epoch 12/20
1000/1000 [==============================] - 0s 5us/step - loss: 12.3840 - accuracy: 0.5710 - val_loss: 11.2764 - val_accuracy: 0.7967
Epoch 13/20
1000/1000 [==============================] - 0s 5us/step - loss: 9.9372 - accuracy: 0.7860 - val_loss: 6.8957 - val_accuracy: 0.8243
Epoch 14/20
1000/1000 [==============================] - 0s 6us/step - loss: 6.0211 - accuracy: 0.8340 - val_loss: 3.2951 - val_accuracy: 0.8206
Epoch 15/20
1000/1000 [==============================] - 0s 4us/step - loss: 2.8981 - accuracy: 0.8310 - val_loss: 0.7951 - val_accuracy: 0.9059
Epoch 16/20
1000/1000 [==============================] - 0s 5us/step - loss: 0.8095 - accuracy: 0.8980 - val_loss: 0.6546 - val_accuracy: 0.9147
Epoch 17/20
1000/1000 [==============================] - 0s 5us/step - loss: 0.5772 - accuracy: 0.9100 - val_loss: 0.5997 - val_accuracy: 0.9222
Epoch 18/20
1000/1000 [==============================] - 0s 5us/step - loss: 0.5163 - accuracy: 0.9150 - val_loss: 0.5893 - val_accuracy: 0.9247
Epoch 19/20
1000/1000 [==============================] - 0s 6us/step - loss: 0.4728 - accuracy: 0.9190 - val_loss: 0.5540 - val_accuracy: 0.9285
Epoch 20/20
1000/1000 [==============================] - 0s 4us/step - loss: 0.4451 - accuracy: 0.9220 - val_loss: 0.5476 - val_accuracy: 0.9260
本文地址:https://blog.csdn.net/weixin_40575651/article/details/107357352
下一篇: 一般处理程序Session