手写数字识别:Python+BP神经网络
程序员文章站
2024-01-22 11:15:40
...
1.简述
编程语言:Python3
界面设计:PyQt5
识别方法:BP神经网络
训练数据:Mnist训练集
手写数字获取:画图板
2.环境配置
Python版本:3.6.3
相关库及版本:
Package | Version |
---|---|
numpy | 1.12.0 |
Pillow | 4.2.0 |
PyQt5 | 5.11.3 |
PyQt5-sip | 4.19.19 |
PyQt5-stubs | 5.11.3.0 |
sip | 4.19.8 |
3.文件目录
项目文件列表:
文件/文件夹 | 用途 |
---|---|
Mnist文件夹 | 手写数字训练集 |
Resource文件夹 | 界面设计相关图像 |
tmp文件夹 | 图像处理临时文件 |
bp_train.py | 神经网络训练 |
Main.py | 主程序 |
MainWidget.py | 交互界面 |
NN.py | 神经网络定义 |
PaintBoard.py | 画板 |
recognize.py | 识别算法 |
weights.npy | 权重 |
4.主程序
main.py
from MainWidget import MainWidget
from PyQt5.QtWidgets import QApplication
from PyQt5.Qt import *
import sys
def main():
app = QApplication(sys.argv)
app.setWindowIcon(QIcon('./Resource/icon/xmf.ico'))
mainWidget = MainWidget() # 新建一个主界面
mainWidget.show() # 显示主界面
exit(app.exec_()) # 进入消息循环
if __name__ == '__main__':
main()
5.神经网络定义
NN.py
import numpy as np
def LabelBinarizer(label):
relabel=np.zeros([len(label),10],dtype=np.int32)
for i in range(len(label)):
relabel[i][label[i]]=1
return relabel
def tanh(x):
return np.tanh(x)
def tanh_deriv(x):
return 1.0 - np.tanh(x)*np.tanh(x)
def logistic(x):
return 1/(1 + np.exp(-x))
def logistic_derivative(x):
return logistic(x)*(1-logistic(x))
class NeuralNetwork:
def __init__(self, layers, activation='tanh'):
if activation == 'logistic':
self.activation = logistic
self.activation_deriv = logistic_derivative
elif activation == 'tanh':
self.activation = tanh
self.activation_deriv = tanh_deriv
self.weights = []
for i in range(1, len(layers) - 1):
self.weights.append((2*np.random.random((layers[i - 1] + 1, layers[i] + 1))-1)*0.25)
self.weights.append((2*np.random.random((layers[i] + 1, layers[i + 1]))-1)*0.25)
def fit(self, X, y, learning_rate=0.15, epochs=60000):
X = np.atleast_2d(X)
temp = np.ones([X.shape[0], X.shape[1]+1])
temp[:, 0:-1] = X
X = temp
y = np.array(y)
for k in range(epochs):
print(k+1)
i = np.random.randint(X.shape[0])
a = [X[i]]
for l in range(len(self.weights)):
a.append(self.activation(np.dot(a[l], self.weights[l])))
error = y[i] - a[-1]
deltas = [error * self.activation_deriv(a[-1])]
for l in range(len(a) - 2, 0, -1):
deltas.append(deltas[-1].dot(self.weights[l].T)*self.activation_deriv(a[l]))
deltas.reverse()
for i in range(len(self.weights)):
layer = np.atleast_2d(a[i])
delta = np.atleast_2d(deltas[i])
self.weights[i] += learning_rate * layer.T.dot(delta)
np.save("weights",self.weights)
def predict(self, x):
self.weights=np.load("weights.npy")
x = np.array(x)
temp = np.ones(x.shape[0]+1)
temp[0:-1] = x
a = temp
for l in range(0, len(self.weights)):
a = self.activation(np.dot(a, self.weights[l]))
return a
fit:定义训练函数
predict:定义识别函数
5.神经网络训练
bp_train.py
import struct
from NN import *
# 训练集读取函数
def load_minist(labels_path,images_path):
with open(labels_path,'rb') as lbpath:
magic,n =struct.unpack('>II',lbpath.read(8))
labels=np.fromfile(lbpath,dtype=np.uint8)
with open(images_path,"rb")as imgpath:
magic,num,rows,cols=struct.unpack('>IIII',imgpath.read(16))
images=np.fromfile(imgpath,dtype=np.uint8).reshape(len(labels),784)
return images,labels
# 读取训练集图像与标签
images,labels=load_minist('./Mnist/train-labels.idx1-ubyte','./Mnist/train-images.idx3-ubyte')
# 读取测试集图像与标签
test_images,test_labels=load_minist('./Mnist/t10k-labels.idx1-ubyte','./Mnist/t10k-images.idx3-ubyte')
labels = LabelBinarizer(labels)
# 实例化BP神经网络
nn = NeuralNetwork([784,250,10], 'logistic')
# 开始训练
nn.fit(images,labels)
6.识别算法
recognize.py
from PIL import Image
from NN import *
def Normalization(dataset):
temp=dataset-np.tile(dataset.min(),dataset.shape)
maxmatrix=np.tile(temp.max(),dataset.shape)
return temp/maxmatrix
def rgb2gray(rgb):
return np.dot(rgb[...,:3],[0.299,0.587,0.114])
def getimage_array(filename):
img = Image.open(filename)
img_array = np.array(img)
if img_array.ndim==3:
img_array = rgb2gray(img_array)
img_array = img_array.flatten()
img_array = 1 - Normalization(img_array)
return img_array
def JudgeEdge(img_array):
height = len(img_array)
width = len(img_array[0])
size = [-1, -1, -1, -1]
for i in range(height):
high = img_array[i]
low = img_array[height - 1 - i]
if len(high[high > 0]) > 0 and size[0]==-1:
size[0] = i
if len(low[low > 0]) > 0 and size[1]==-1:
size[1] = height - 1 - i
if size[1] != -1 and size[0] != -1:
break
for i in range(width):
left = img_array[:, i]
right = img_array[:, width - 1 - i]
if len(left[left > 0]) > 0 and size[2]==-1:
size[2] = i
if len(right[right > 0]) > 0 and size[3]==-1:
size[3] = width - i - 1
if size[2] != -1 and size[3] != -1:
break
return size
def JudgeOneNumber(img_array):
edge=[-1,-1]
width=len(img_array[0])
for i in range(width):
left = img_array[:, i]
right = img_array[:, width - 1 - i]
if len(left[left > 0]) > 0 and edge[0]==-1:
edge[0] = i
if len(right[right > 0]) > 0 and edge[1]==-1:
edge[1] = width - i - 1
if edge[0] != -1 and edge[1] != -1:
break
for j in range(edge[0],edge[1]+1):
border=img_array[:,j]
if len(border[border>0])==0:
return False
return True
def SplitPicture(img_array,img_list):
if JudgeOneNumber(img_array):
img_list.append(img_array)
return img_list
width=len(img_array[0])
for i in range(width):
left_border=img_array[:,i]
right_border=img_array[:,i+1]
if len(left_border[left_border>0])>0 and len(right_border[right_border>0])==0:
break
return_array=img_array[:,0:i+1]
img_list.append(return_array)
new_array=img_array[:,i+1:]
return SplitPicture(new_array,img_list)
#读取图片,包括图片灰度化、剪裁、压缩
def GetCutZip(imagename):
img = Image.open(imagename)
img_array = np.array(img)
#RGB图灰度化
if img_array.ndim == 3:
img_array = rgb2gray(img_array)
#提高数字与背景的对比度
img_array=Normalization(img_array)
#白底黑字图像转化为黑底白字
arr1=(img_array>=0.9)
arr0=(img_array<=0.1)
if arr1.sum()> arr0.sum():
img_array = 1 - img_array
img = Image.fromarray(np.uint8(img_array))
#消除部分噪音,便于提取数字
img_array[img_array>0.7]=1
img_array[img_array<0.4]=0
img_list = SplitPicture(img_array, [])
final_list=[]
for img_array in img_list:
edge = JudgeEdge(img_array)
cut_array = img_array[edge[0]:edge[1] + 1, edge[2]:edge[3] + 1]
cut_img = Image.fromarray(np.uint8(cut_array * 255))
if cut_img.size[0]<=cut_img.size[1]:
zip_img = cut_img.resize((20 * cut_img.size[0] // cut_img.size[1], 20), Image.ANTIALIAS)
else:
zip_img =cut_img.resize((20,20*cut_img.size[1]//cut_img.size[0]),Image.ANTIALIAS)
zip_img_array = np.array(zip_img)
final_array = np.zeros((28, 28))
height = len(zip_img_array)
width = len(zip_img_array[0])
high = (28 - height) // 2
left = (28 - width) // 2
final_array[high:high + height, left:left + width] = zip_img_array
final_array=Normalization(final_array)
final_list.append(final_array)
return final_list
def recognize(src):
nn = NeuralNetwork([784,250,10], 'logistic')
img_list=GetCutZip(src)
final_result=''
for img_array in img_list:
img_array=img_array.flatten()
result_list=nn.predict(img_array)
result=np.argmax(result_list)
final_result=final_result+str(result)
return final_result
if __name__=="__main__":
img_path = input("请输入图片路径:\n")
final_result = recognize(img_path)
print("识别的最终结果是:"+final_result)
7.画板
PaintBoard.py
from PyQt5.QtWidgets import QWidget
from PyQt5.Qt import QPixmap, QPainter, QPoint, QPaintEvent, QMouseEvent, QPen, \
QColor, QSize
from PyQt5.QtCore import Qt
class PaintBoard(QWidget):
def __init__(self, Parent=None):
'''
Constructor
'''
super().__init__(Parent)
self.__InitData() # 先初始化数据,再初始化界面
self.__InitView()
def __InitData(self):
self.__size = QSize(480, 460)
# 新建QPixmap作为画板,尺寸为__size
self.__board = QPixmap(self.__size)
self.__board.fill(Qt.white) # 用白色填充画板
self.__IsEmpty = True # 默认为空画板
self.EraserMode = False # 默认为禁用橡皮擦模式
self.__lastPos = QPoint(0, 0) # 上一次鼠标位置
self.__currentPos = QPoint(0, 0) # 当前的鼠标位置
self.__painter = QPainter() # 新建绘图工具
self.__thickness = 10 # 默认画笔粗细为10px
self.__penColor = QColor("black") # 设置默认画笔颜色为黑色
self.__colorList = QColor.colorNames() # 获取颜色列表
def __InitView(self):
# 设置界面的尺寸为__size
self.setFixedSize(self.__size)
def Clear(self):
# 清空画板
self.__board.fill(Qt.white)
self.update()
self.__IsEmpty = True
def ChangePenColor(self, color="black"):
# 改变画笔颜色
self.__penColor = QColor(color)
def ChangePenThickness(self, thickness=10):
# 改变画笔粗细
self.__thickness = thickness
def IsEmpty(self):
# 返回画板是否为空
return self.__IsEmpty
def GetContentAsQImage(self):
# 获取画板内容(返回QImage)
image = self.__board.toImage()
return image
def paintEvent(self, paintEvent):
# 绘图事件
# 绘图时必须使用QPainter的实例,此处为__painter
# 绘图在begin()函数与end()函数间进行
# begin(param)的参数要指定绘图设备,即把图画在哪里
# drawPixmap用于绘制QPixmap类型的对象
self.__painter.begin(self)
# 0,0为绘图的左上角起点的坐标,__board即要绘制的图
self.__painter.drawPixmap(0, 0, self.__board)
self.__painter.end()
def mousePressEvent(self, mouseEvent):
# 鼠标按下时,获取鼠标的当前位置保存为上一次位置
self.__currentPos = mouseEvent.pos()
self.__lastPos = self.__currentPos
def mouseMoveEvent(self, mouseEvent):
# 鼠标移动时,更新当前位置,并在上一个位置和当前位置间画线
self.__currentPos = mouseEvent.pos()
self.__painter.begin(self.__board)
if self.EraserMode == False:
# 非橡皮擦模式
self.__painter.setPen(QPen(self.__penColor, self.__thickness)) # 设置画笔颜色,粗细
else:
# 橡皮擦模式下画笔为纯白色,粗细为10
self.__painter.setPen(QPen(Qt.white, 10))
# 画线
self.__painter.drawLine(self.__lastPos, self.__currentPos)
self.__painter.end()
self.__lastPos = self.__currentPos
self.update() # 更新显示
def mouseReleaseEvent(self, mouseEvent):
self.__IsEmpty = False # 画板不再为空
8.交互界面
MainWidget.py
from PyQt5.Qt import *
from PyQt5.QtWidgets import QHBoxLayout, QVBoxLayout, QPushButton, QSplitter, \
QComboBox, QLabel, QSpinBox, QFileDialog
from PaintBoard import PaintBoard
from recognize import recognize
class MainWidget(QWidget):
def __init__(self, Parent=None):
super().__init__(Parent)
self.__InitData() # 先初始化数据,再初始化界面
self.__InitView()
def __InitData(self):
'''
初始化成员变量
'''
self.__paintBoard = PaintBoard(self)
# 获取颜色列表(字符串类型)
self.__colorList = QColor.colorNames()
def __InitView(self):
'''
初始化界面
'''
self.setFixedSize(640, 480)
self.setWindowTitle("手写数字识别系统")
# 新建一个水平布局作为本窗体的主布局
main_layout = QHBoxLayout(self)
# 设置主布局内边距以及控件间距为10px
main_layout.setSpacing(10)
# 在主界面左侧放置画板
main_layout.addWidget(self.__paintBoard)
# 新建垂直子布局用于放置按键
sub_layout = QVBoxLayout()
# 设置此子布局和内部控件的间距为10px
sub_layout.setContentsMargins(10, 10, 10, 10)
self.__label_School = QLabel(self)
self.__label_School.setText("踏雪寻梅")
self.__label_School.setAlignment(Qt.AlignCenter)
self.__label_School.setFont(QFont("楷体", 12, QFont.Bold))
sub_layout.addWidget(self.__label_School)
self.__label_faculty = QLabel(self)
self.__label_faculty.setText("张熤")
self.__label_faculty.setAlignment(Qt.AlignCenter)
self.__label_faculty.setFont(QFont("楷体", 12, QFont.Bold))
sub_layout.addWidget(self.__label_faculty)
self.__label_Name = QLabel(self)
self.__label_Name.setText("(〃'▽'〃)")
self.__label_Name.setAlignment(Qt.AlignCenter)
self.__label_Name.setFont(QFont("楷体", 12, QFont.Bold))
sub_layout.addWidget(self.__label_Name)
splitter = QSplitter(self) # 占位符
sub_layout.addWidget(splitter)
self.__btn_Clear = QPushButton("清空画板")
self.__btn_Clear.setParent(self) # 设置父对象为本界面
# 将按键按下信号与画板清空函数相关联
self.__btn_Clear.clicked.connect(self.__paintBoard.Clear)
sub_layout.addWidget(self.__btn_Clear)
self.__btn_Quit = QPushButton("退出")
self.__btn_Quit.setParent(self) # 设置父对象为本界面
self.__btn_Quit.clicked.connect(self.Quit)
sub_layout.addWidget(self.__btn_Quit)
self.__btn_Save = QPushButton("保存作品")
self.__btn_Save.setParent(self)
self.__btn_Save.clicked.connect(self.on_btn_Save_Clicked)
sub_layout.addWidget(self.__btn_Save)
self.__cbtn_Eraser = QCheckBox(" 使用橡皮擦")
self.__cbtn_Eraser.setParent(self)
self.__cbtn_Eraser.clicked.connect(self.on_cbtn_Eraser_clicked)
sub_layout.addWidget(self.__cbtn_Eraser)
self.__btn_Recognize = QPushButton("识别")
self.__btn_Recognize.setParent(self) # 设置父对象为本界面
self.__btn_Recognize.clicked.connect(self.on_recognize_clicked)
sub_layout.addWidget(self.__btn_Recognize)
self.__label_result = QLabel(self)
self.__label_result.setText("识别结果:")
sub_layout.addWidget(self.__label_result)
self.__label_rec_result = QLabel('',self)
self.__label_rec_result.setAlignment(Qt.AlignCenter)
self.__label_rec_result.setFont(QFont("Roman times", 18, QFont.Bold))
self.__label_rec_result.setStyleSheet("color:red")
sub_layout.addWidget(self.__label_rec_result)
splitter = QSplitter(self) # 占位符
sub_layout.addWidget(splitter)
self.__label_penThickness = QLabel(self)
self.__label_penThickness.setText("画笔粗细")
self.__label_penThickness.setFixedHeight(20)
sub_layout.addWidget(self.__label_penThickness)
self.__spinBox_penThickness = QSpinBox(self)
self.__spinBox_penThickness.setMaximum(20)
self.__spinBox_penThickness.setMinimum(2)
self.__spinBox_penThickness.setValue(10) # 默认粗细为10
self.__spinBox_penThickness.setSingleStep(2) # 最小变化值为2
self.__spinBox_penThickness.valueChanged.connect(
self.on_PenThicknessChange) # 关联spinBox值变化信号和函数on_PenThicknessChange
sub_layout.addWidget(self.__spinBox_penThickness)
self.__label_penColor = QLabel(self)
self.__label_penColor.setText("画笔颜色")
self.__label_penColor.setFixedHeight(20)
sub_layout.addWidget(self.__label_penColor)
self.__comboBox_penColor = QComboBox(self)
self.__fillColorList(self.__comboBox_penColor) # 用各种颜色填充下拉列表
self.__comboBox_penColor.currentIndexChanged.connect(self.on_PenColorChange) # 关联下拉列表的当前索引变更信号与函数on_PenColorChange
sub_layout.addWidget(self.__comboBox_penColor)
main_layout.addLayout(sub_layout) # 将子布局加入主布局
def __fillColorList(self, comboBox):
index_black = 0
index = 0
for color in self.__colorList:
if color == "black":
index_black = index
index += 1
pix = QPixmap(70, 20)
pix.fill(QColor(color))
comboBox.addItem(QIcon(pix), None)
comboBox.setIconSize(QSize(70, 20))
comboBox.setSizeAdjustPolicy(QComboBox.AdjustToContents)
comboBox.setCurrentIndex(index_black)
def on_PenColorChange(self):
color_index = self.__comboBox_penColor.currentIndex()
color_str = self.__colorList[color_index]
self.__paintBoard.ChangePenColor(color_str)
def on_PenThicknessChange(self):
penThickness = self.__spinBox_penThickness.value()
self.__paintBoard.ChangePenThickness(penThickness)
def on_btn_Save_Clicked(self):
savePath = QFileDialog.getSaveFileName(self, 'Save Your Paint', '.\\', '*.png')
print(savePath)
if savePath[0] == "":
print("Save cancel")
return
image = self.__paintBoard.GetContentAsQImage()
image.save(savePath[0])
def on_cbtn_Eraser_clicked(self):
if self.__cbtn_Eraser.isChecked():
self.__paintBoard.EraserMode = True # 进入橡皮擦模式
else:
self.__paintBoard.EraserMode = False # 退出橡皮擦模式
def on_recognize_clicked(self):
savePath = './tmp/image.png'
image = self.__paintBoard.GetContentAsQImage()
image.save(savePath)
print(savePath)
predict = recognize(savePath)
print(predict)
self.__label_rec_result.setText(predict)
def Quit(self):
self.close()
8.效果展示
(`・ω・´) →项目文件已上传…
上一篇: TensorFlow手写数字识别(一)
下一篇: kaggel_手写数字识别