欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页

手写数字识别:Python+BP神经网络

程序员文章站 2024-01-22 11:15:40
...

1.简述

编程语言Python3
界面设计PyQt5
识别方法BP神经网络
训练数据Mnist训练集
手写数字获取画图板

2.环境配置

Python版本3.6.3
相关库及版本

Package Version
numpy 1.12.0
Pillow 4.2.0
PyQt5 5.11.3
PyQt5-sip 4.19.19
PyQt5-stubs 5.11.3.0
sip 4.19.8

3.文件目录

项目文件列表:

文件/文件夹 用途
Mnist文件夹 手写数字训练集
Resource文件夹 界面设计相关图像
tmp文件夹 图像处理临时文件
bp_train.py 神经网络训练
Main.py 主程序
MainWidget.py 交互界面
NN.py 神经网络定义
PaintBoard.py 画板
recognize.py 识别算法
weights.npy 权重

4.主程序

main.py

from MainWidget import MainWidget
from PyQt5.QtWidgets import QApplication
from PyQt5.Qt import *
import sys

def main():
    app = QApplication(sys.argv)
    app.setWindowIcon(QIcon('./Resource/icon/xmf.ico'))
    mainWidget = MainWidget()  # 新建一个主界面
    mainWidget.show()  # 显示主界面

    exit(app.exec_())  # 进入消息循环

if __name__ == '__main__':
    main()

5.神经网络定义

NN.py

import numpy as np

def LabelBinarizer(label):
    relabel=np.zeros([len(label),10],dtype=np.int32)
    for i in range(len(label)):
        relabel[i][label[i]]=1
    return relabel

def tanh(x):
    return np.tanh(x)

def tanh_deriv(x):
    return 1.0 - np.tanh(x)*np.tanh(x)

def logistic(x):
    return 1/(1 + np.exp(-x))

def logistic_derivative(x):
    return logistic(x)*(1-logistic(x))

class NeuralNetwork:
    def __init__(self, layers, activation='tanh'):
 
        if activation == 'logistic':
            self.activation = logistic
            self.activation_deriv = logistic_derivative
        elif activation == 'tanh':
            self.activation = tanh
            self.activation_deriv = tanh_deriv

        self.weights = []
        for i in range(1, len(layers) - 1):
            self.weights.append((2*np.random.random((layers[i - 1] + 1, layers[i] + 1))-1)*0.25)
            self.weights.append((2*np.random.random((layers[i] + 1, layers[i + 1]))-1)*0.25)

    def fit(self, X, y, learning_rate=0.15, epochs=60000):
        X = np.atleast_2d(X)
        temp = np.ones([X.shape[0], X.shape[1]+1])
        temp[:, 0:-1] = X  
        X = temp
        y = np.array(y)

        for k in range(epochs):
            print(k+1)
            i = np.random.randint(X.shape[0])
            a = [X[i]]

            for l in range(len(self.weights)):  
                a.append(self.activation(np.dot(a[l], self.weights[l])))  
            error = y[i] - a[-1]  
            deltas = [error * self.activation_deriv(a[-1])] 
            
            for l in range(len(a) - 2, 0, -1): 
                deltas.append(deltas[-1].dot(self.weights[l].T)*self.activation_deriv(a[l]))
            deltas.reverse()
            for i in range(len(self.weights)):
                layer = np.atleast_2d(a[i])
                delta = np.atleast_2d(deltas[i])
                self.weights[i] += learning_rate * layer.T.dot(delta)
        np.save("weights",self.weights)

    def predict(self, x):
        self.weights=np.load("weights.npy")
        x = np.array(x)
        temp = np.ones(x.shape[0]+1)
        temp[0:-1] = x
        a = temp
        for l in range(0, len(self.weights)):
            a = self.activation(np.dot(a, self.weights[l]))
        return a

fit:定义训练函数
predict:定义识别函数

5.神经网络训练

bp_train.py

import struct
from NN import *

# 训练集读取函数
def load_minist(labels_path,images_path):
    with open(labels_path,'rb') as lbpath:
        magic,n =struct.unpack('>II',lbpath.read(8))
        labels=np.fromfile(lbpath,dtype=np.uint8)

    with open(images_path,"rb")as imgpath:
        magic,num,rows,cols=struct.unpack('>IIII',imgpath.read(16))
        images=np.fromfile(imgpath,dtype=np.uint8).reshape(len(labels),784)
    return images,labels
# 读取训练集图像与标签
images,labels=load_minist('./Mnist/train-labels.idx1-ubyte','./Mnist/train-images.idx3-ubyte')
# 读取测试集图像与标签
test_images,test_labels=load_minist('./Mnist/t10k-labels.idx1-ubyte','./Mnist/t10k-images.idx3-ubyte')

labels = LabelBinarizer(labels)
# 实例化BP神经网络
nn = NeuralNetwork([784,250,10], 'logistic')
# 开始训练
nn.fit(images,labels)

6.识别算法

recognize.py

from PIL import Image
from NN import *

def Normalization(dataset):
    temp=dataset-np.tile(dataset.min(),dataset.shape)
    maxmatrix=np.tile(temp.max(),dataset.shape)
    return temp/maxmatrix

def rgb2gray(rgb):
    return np.dot(rgb[...,:3],[0.299,0.587,0.114])

def getimage_array(filename):
    img = Image.open(filename)
    img_array = np.array(img)
    if img_array.ndim==3:
        img_array = rgb2gray(img_array)
    img_array = img_array.flatten()
    img_array = 1 - Normalization(img_array)
    return img_array

def JudgeEdge(img_array):
    height = len(img_array)
    width = len(img_array[0])
    size = [-1, -1, -1, -1]
    for i in range(height):
        high = img_array[i]
        low = img_array[height - 1 - i]
        if len(high[high > 0]) > 0 and size[0]==-1:
            size[0] = i
        if len(low[low > 0]) > 0 and size[1]==-1:
            size[1] = height - 1 - i
        if size[1] != -1 and size[0] != -1:
            break
    for i in range(width):
        left = img_array[:, i]
        right = img_array[:, width - 1 - i]
        if len(left[left > 0]) > 0 and size[2]==-1:
            size[2] = i
        if len(right[right > 0]) > 0 and size[3]==-1:
            size[3] = width - i - 1
        if size[2] != -1 and size[3] != -1:
            break
    return size

def JudgeOneNumber(img_array):
    edge=[-1,-1]
    width=len(img_array[0])
    for i in range(width):
        left = img_array[:, i]
        right = img_array[:, width - 1 - i]
        if len(left[left > 0]) > 0 and edge[0]==-1:
            edge[0] = i
        if len(right[right > 0]) > 0 and edge[1]==-1:
            edge[1] = width - i - 1
        if edge[0] != -1 and edge[1] != -1:
            break
    for j in range(edge[0],edge[1]+1):
        border=img_array[:,j]
        if len(border[border>0])==0:
            return False
    return True

def SplitPicture(img_array,img_list):
    if JudgeOneNumber(img_array):
        img_list.append(img_array)
        return img_list
    width=len(img_array[0])
    for i in range(width):
        left_border=img_array[:,i]
        right_border=img_array[:,i+1]
        if len(left_border[left_border>0])>0 and len(right_border[right_border>0])==0:
            break
    return_array=img_array[:,0:i+1]
    img_list.append(return_array)
    new_array=img_array[:,i+1:]
    return SplitPicture(new_array,img_list)

#读取图片,包括图片灰度化、剪裁、压缩
def GetCutZip(imagename):
    img = Image.open(imagename)
    img_array = np.array(img)
    #RGB图灰度化
    if img_array.ndim == 3:
        img_array = rgb2gray(img_array)

    #提高数字与背景的对比度
    img_array=Normalization(img_array)
    #白底黑字图像转化为黑底白字
    arr1=(img_array>=0.9)
    arr0=(img_array<=0.1)
    if arr1.sum()> arr0.sum():
        img_array = 1 - img_array

    img = Image.fromarray(np.uint8(img_array))
    #消除部分噪音,便于提取数字
    img_array[img_array>0.7]=1
    img_array[img_array<0.4]=0
    img_list = SplitPicture(img_array, [])
    final_list=[]
    for img_array in img_list:
        edge = JudgeEdge(img_array)
        cut_array = img_array[edge[0]:edge[1] + 1, edge[2]:edge[3] + 1]
        cut_img = Image.fromarray(np.uint8(cut_array * 255))
        if cut_img.size[0]<=cut_img.size[1]:
            zip_img = cut_img.resize((20 * cut_img.size[0] // cut_img.size[1], 20), Image.ANTIALIAS)
        else:
            zip_img =cut_img.resize((20,20*cut_img.size[1]//cut_img.size[0]),Image.ANTIALIAS)
        zip_img_array = np.array(zip_img)
        final_array = np.zeros((28, 28))
        height = len(zip_img_array)
        width = len(zip_img_array[0])
        high = (28 - height) // 2
        left = (28 - width) // 2
        final_array[high:high + height, left:left + width] = zip_img_array
        final_array=Normalization(final_array)
        final_list.append(final_array)
        
    return final_list

def recognize(src):
    nn = NeuralNetwork([784,250,10], 'logistic')
    img_list=GetCutZip(src)
    final_result=''
    for img_array in img_list:
        img_array=img_array.flatten()
        result_list=nn.predict(img_array)
        result=np.argmax(result_list)
        final_result=final_result+str(result)
    return final_result

if __name__=="__main__":
    img_path = input("请输入图片路径:\n")
    final_result = recognize(img_path)
    print("识别的最终结果是:"+final_result)

7.画板

PaintBoard.py

from PyQt5.QtWidgets import QWidget
from PyQt5.Qt import QPixmap, QPainter, QPoint, QPaintEvent, QMouseEvent, QPen, \
    QColor, QSize
from PyQt5.QtCore import Qt


class PaintBoard(QWidget):

    def __init__(self, Parent=None):
        '''
        Constructor
        '''
        super().__init__(Parent)

        self.__InitData()  # 先初始化数据,再初始化界面
        self.__InitView()
    def __InitData(self):

        self.__size = QSize(480, 460)

        # 新建QPixmap作为画板,尺寸为__size
        self.__board = QPixmap(self.__size)
        self.__board.fill(Qt.white)  # 用白色填充画板

        self.__IsEmpty = True  # 默认为空画板
        self.EraserMode = False  # 默认为禁用橡皮擦模式

        self.__lastPos = QPoint(0, 0)  # 上一次鼠标位置
        self.__currentPos = QPoint(0, 0)  # 当前的鼠标位置

        self.__painter = QPainter()  # 新建绘图工具

        self.__thickness = 10  # 默认画笔粗细为10px
        self.__penColor = QColor("black")  # 设置默认画笔颜色为黑色
        self.__colorList = QColor.colorNames()  # 获取颜色列表

    def __InitView(self):
        # 设置界面的尺寸为__size
        self.setFixedSize(self.__size)

    def Clear(self):
        # 清空画板
        self.__board.fill(Qt.white)
        self.update()
        self.__IsEmpty = True

    def ChangePenColor(self, color="black"):
        # 改变画笔颜色
        self.__penColor = QColor(color)

    def ChangePenThickness(self, thickness=10):
        # 改变画笔粗细
        self.__thickness = thickness

    def IsEmpty(self):
        # 返回画板是否为空
        return self.__IsEmpty

    def GetContentAsQImage(self):
        # 获取画板内容(返回QImage)
        image = self.__board.toImage()
        return image

    def paintEvent(self, paintEvent):
        # 绘图事件
        # 绘图时必须使用QPainter的实例,此处为__painter
        # 绘图在begin()函数与end()函数间进行
        # begin(param)的参数要指定绘图设备,即把图画在哪里
        # drawPixmap用于绘制QPixmap类型的对象
        self.__painter.begin(self)
        # 0,0为绘图的左上角起点的坐标,__board即要绘制的图
        self.__painter.drawPixmap(0, 0, self.__board)
        self.__painter.end()

    def mousePressEvent(self, mouseEvent):
        # 鼠标按下时,获取鼠标的当前位置保存为上一次位置
        self.__currentPos = mouseEvent.pos()
        self.__lastPos = self.__currentPos

    def mouseMoveEvent(self, mouseEvent):
        # 鼠标移动时,更新当前位置,并在上一个位置和当前位置间画线
        self.__currentPos = mouseEvent.pos()
        self.__painter.begin(self.__board)

        if self.EraserMode == False:
            # 非橡皮擦模式
            self.__painter.setPen(QPen(self.__penColor, self.__thickness))  # 设置画笔颜色,粗细
        else:
            # 橡皮擦模式下画笔为纯白色,粗细为10
            self.__painter.setPen(QPen(Qt.white, 10))

        # 画线
        self.__painter.drawLine(self.__lastPos, self.__currentPos)
        self.__painter.end()
        self.__lastPos = self.__currentPos

        self.update()  # 更新显示

    def mouseReleaseEvent(self, mouseEvent):
        self.__IsEmpty = False  # 画板不再为空

8.交互界面

MainWidget.py

from PyQt5.Qt import *
from PyQt5.QtWidgets import QHBoxLayout, QVBoxLayout, QPushButton, QSplitter, \
    QComboBox, QLabel, QSpinBox, QFileDialog
from PaintBoard import PaintBoard

from recognize import recognize

class MainWidget(QWidget):

    def __init__(self, Parent=None):

        super().__init__(Parent)

        self.__InitData()  # 先初始化数据,再初始化界面
        self.__InitView()

    def __InitData(self):
        '''
                  初始化成员变量
        '''
        self.__paintBoard = PaintBoard(self)
        # 获取颜色列表(字符串类型)
        self.__colorList = QColor.colorNames()

    def __InitView(self):
        '''
                  初始化界面
        '''
        self.setFixedSize(640, 480)
        self.setWindowTitle("手写数字识别系统")

        # 新建一个水平布局作为本窗体的主布局
        main_layout = QHBoxLayout(self)
        # 设置主布局内边距以及控件间距为10px
        main_layout.setSpacing(10)

        # 在主界面左侧放置画板
        main_layout.addWidget(self.__paintBoard)

        # 新建垂直子布局用于放置按键
        sub_layout = QVBoxLayout()

        # 设置此子布局和内部控件的间距为10px
        sub_layout.setContentsMargins(10, 10, 10, 10)

        self.__label_School = QLabel(self)
        self.__label_School.setText("踏雪寻梅")
        self.__label_School.setAlignment(Qt.AlignCenter)
        self.__label_School.setFont(QFont("楷体", 12, QFont.Bold))
        sub_layout.addWidget(self.__label_School)

        self.__label_faculty = QLabel(self)
        self.__label_faculty.setText("张熤")
        self.__label_faculty.setAlignment(Qt.AlignCenter)
        self.__label_faculty.setFont(QFont("楷体", 12, QFont.Bold))
        sub_layout.addWidget(self.__label_faculty)

        self.__label_Name = QLabel(self)
        self.__label_Name.setText("(〃'▽'〃)")
        self.__label_Name.setAlignment(Qt.AlignCenter)
        self.__label_Name.setFont(QFont("楷体", 12, QFont.Bold))
        sub_layout.addWidget(self.__label_Name)

        splitter = QSplitter(self)  # 占位符
        sub_layout.addWidget(splitter)

        self.__btn_Clear = QPushButton("清空画板")
        self.__btn_Clear.setParent(self)  # 设置父对象为本界面

        # 将按键按下信号与画板清空函数相关联
        self.__btn_Clear.clicked.connect(self.__paintBoard.Clear)
        sub_layout.addWidget(self.__btn_Clear)

        self.__btn_Quit = QPushButton("退出")
        self.__btn_Quit.setParent(self)  # 设置父对象为本界面
        self.__btn_Quit.clicked.connect(self.Quit)
        sub_layout.addWidget(self.__btn_Quit)

        self.__btn_Save = QPushButton("保存作品")
        self.__btn_Save.setParent(self)
        self.__btn_Save.clicked.connect(self.on_btn_Save_Clicked)
        sub_layout.addWidget(self.__btn_Save)

        self.__cbtn_Eraser = QCheckBox("  使用橡皮擦")
        self.__cbtn_Eraser.setParent(self)
        self.__cbtn_Eraser.clicked.connect(self.on_cbtn_Eraser_clicked)
        sub_layout.addWidget(self.__cbtn_Eraser)

        self.__btn_Recognize = QPushButton("识别")
        self.__btn_Recognize.setParent(self)  # 设置父对象为本界面
        self.__btn_Recognize.clicked.connect(self.on_recognize_clicked)
        sub_layout.addWidget(self.__btn_Recognize)

        self.__label_result = QLabel(self)
        self.__label_result.setText("识别结果:")
        sub_layout.addWidget(self.__label_result)

        self.__label_rec_result = QLabel('',self)
        self.__label_rec_result.setAlignment(Qt.AlignCenter)
        self.__label_rec_result.setFont(QFont("Roman times", 18, QFont.Bold))
        self.__label_rec_result.setStyleSheet("color:red")
        sub_layout.addWidget(self.__label_rec_result)

        splitter = QSplitter(self)  # 占位符
        sub_layout.addWidget(splitter)

        self.__label_penThickness = QLabel(self)
        self.__label_penThickness.setText("画笔粗细")
        self.__label_penThickness.setFixedHeight(20)
        sub_layout.addWidget(self.__label_penThickness)

        self.__spinBox_penThickness = QSpinBox(self)
        self.__spinBox_penThickness.setMaximum(20)
        self.__spinBox_penThickness.setMinimum(2)
        self.__spinBox_penThickness.setValue(10)  # 默认粗细为10
        self.__spinBox_penThickness.setSingleStep(2)  # 最小变化值为2
        self.__spinBox_penThickness.valueChanged.connect(
            self.on_PenThicknessChange)  # 关联spinBox值变化信号和函数on_PenThicknessChange
        sub_layout.addWidget(self.__spinBox_penThickness)

        self.__label_penColor = QLabel(self)
        self.__label_penColor.setText("画笔颜色")
        self.__label_penColor.setFixedHeight(20)
        sub_layout.addWidget(self.__label_penColor)

        self.__comboBox_penColor = QComboBox(self)
        self.__fillColorList(self.__comboBox_penColor)  # 用各种颜色填充下拉列表
        self.__comboBox_penColor.currentIndexChanged.connect(self.on_PenColorChange)  # 关联下拉列表的当前索引变更信号与函数on_PenColorChange
        sub_layout.addWidget(self.__comboBox_penColor)

        main_layout.addLayout(sub_layout)  # 将子布局加入主布局


    def __fillColorList(self, comboBox):
        index_black = 0
        index = 0
        for color in self.__colorList:
            if color == "black":
                index_black = index
            index += 1
            pix = QPixmap(70, 20)
            pix.fill(QColor(color))
            comboBox.addItem(QIcon(pix), None)
            comboBox.setIconSize(QSize(70, 20))
            comboBox.setSizeAdjustPolicy(QComboBox.AdjustToContents)

        comboBox.setCurrentIndex(index_black)


    def on_PenColorChange(self):
        color_index = self.__comboBox_penColor.currentIndex()
        color_str = self.__colorList[color_index]
        self.__paintBoard.ChangePenColor(color_str)


    def on_PenThicknessChange(self):
        penThickness = self.__spinBox_penThickness.value()
        self.__paintBoard.ChangePenThickness(penThickness)


    def on_btn_Save_Clicked(self):
        savePath = QFileDialog.getSaveFileName(self, 'Save Your Paint', '.\\', '*.png')
        print(savePath)
        if savePath[0] == "":
            print("Save cancel")
            return
        image = self.__paintBoard.GetContentAsQImage()
        image.save(savePath[0])


    def on_cbtn_Eraser_clicked(self):
        if self.__cbtn_Eraser.isChecked():
            self.__paintBoard.EraserMode = True  # 进入橡皮擦模式
        else:
            self.__paintBoard.EraserMode = False  # 退出橡皮擦模式

    def on_recognize_clicked(self):
        savePath = './tmp/image.png'
        image = self.__paintBoard.GetContentAsQImage()
        image.save(savePath)
        print(savePath)

        predict = recognize(savePath)
        print(predict)
        self.__label_rec_result.setText(predict)

    def Quit(self):
        self.close()

8.效果展示

手写数字识别:Python+BP神经网络
(`・ω・´) →项目文件已上传…

相关标签: 手写数字识别