欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页

我的第一个tensorflow程序

程序员文章站 2022-04-03 22:50:28
...

第一个tensorflow是从网上抄来的,但是还是爬了个大坑,在预测文件中的图片转换为28*28尺寸的时候用PIL一直报错(原作者的代码),后来改用cv2模块resize问题就解决了,这是一个关于数字识别的程序,程序能够在一张只有一个0-9数字的图片中准确识别出数字是多少,准确率高达99+%,然后我用PyQt5封装了一下,使其可视化。环境为Python3+tensorflow2.0+PyQt5,首先创建一个python project,然后往里面添加文件夹,然后在v4_cnn目录下创建三个文件,mainUI.py,predict.py,train.py三个文件 ,ckpt文件目录是没有的,运行程序后生成的,想要运行该程序,必须先运行训练代码,train.py文件,然后再运行主UI文件maiUI.py文件。

我的第一个tensorflow程序

 

训练文件train.py,代码如下:

import os
import tensorflow as tf
from tensorflow.keras import datasets, layers, models
import numpy as np

y1 = [0, 0.8, 0.1, 0.1, 0, 0, 0, 0, 0, 0]
y2 = [0, 0.1, 0.1, 0.1, 0.5, 0, 0.2, 0, 0, 0]
np.argmax(y1) # 1
np.argmax(y2) # 4

class CNN(object):
    def __init__(self):
        model = models.Sequential()
        # 第1层卷积,卷积核大小为3*3,32个,28*28为待训练图片的大小
        model.add(layers.Conv2D(32, (3, 3), activation='relu', input_shape=(28, 28, 1)))
        model.add(layers.MaxPooling2D((2, 2)))
        # 第2层卷积,卷积核大小为3*3,64个
        model.add(layers.Conv2D(64, (3, 3), activation='relu'))
        model.add(layers.MaxPooling2D((2, 2)))
        # 第3层卷积,卷积核大小为3*3,64个
        model.add(layers.Conv2D(64, (3, 3), activation='relu'))
        model.add(layers.Flatten())
        model.add(layers.Dense(64, activation='relu'))
        model.add(layers.Dense(10, activation='softmax'))
        model.summary()
        self.model = model

class DataSource(object):
    def __init__(self):
        # mnist数据集存储的位置,如何不存在将自动下载
        data_path = os.path.abspath(os.path.dirname(__file__)) + '/../data_set_tf2/mnist.npz'
        (train_images, train_labels), (test_images, test_labels) = datasets.mnist.load_data(path=data_path)
        # 6万张训练图片,1万张测试图片
        train_images = train_images.reshape((60000, 28, 28, 1))
        test_images = test_images.reshape((10000, 28, 28, 1))
        # 像素值映射到 0 - 1 之间
        train_images, test_images = train_images / 255.0, test_images / 255.0

        self.train_images, self.train_labels = train_images, train_labels
        self.test_images, self.test_labels = test_images, test_labels

class Train:
    def __init__(self):
        self.cnn = CNN()
        self.data = DataSource()

    def train(self):
        check_path = './ckpt/cp-{epoch:04d}.ckpt'
        # period 每隔5epoch保存一次
        save_model_cb = tf.keras.callbacks.ModelCheckpoint(check_path, save_weights_only=True, verbose=1, period=5)
        self.cnn.model.compile(optimizer='adam',
                               loss='sparse_categorical_crossentropy',
                               metrics=['accuracy'])
        self.cnn.model.fit(self.data.train_images, self.data.train_labels, epochs=5, callbacks=[save_model_cb])
        test_loss, test_acc = self.cnn.model.evaluate(self.data.test_images, self.data.test_labels)
        print("准确率: %.4f,共测试了%d张图片 " % (test_acc, len(self.data.test_labels)))


if __name__ == "__main__":
    app = Train()
    app.train()

预测文件predict.py,代码如下:

import tensorflow as tf
from PIL import Image
import numpy as np
from v4_cnn.train import CNN
import cv2


class Predict(object):
    def __init__(self):
        latest = tf.train.latest_checkpoint('./ckpt')
        self.cnn = CNN()
        # 恢复网络权重
        self.cnn.model.load_weights(latest)

    def predict(self, image_path):
        # 以黑白方式读取图片
        img = Image.open(image_path).convert('L') #爬了个大坑
        img = np.asarray(img)
        img = cv2.resize(img,(28,28))
        flatten_img = np.reshape(img, (28, 28, 1))
        x = np.array([1 - flatten_img])
        # API refer: https://keras.io/models/model/
        self.y = self.cnn.model.predict(x)

主UI文件mainUI.py,代码如下:

import cv2
import sys
from PyQt5 import QtGui
from PyQt5.QtCore import Qt
from PyQt5.QtWidgets import QApplication, QWidget, QLabel, QHBoxLayout, QMainWindow, QDockWidget, QPushButton, \
    QVBoxLayout, QTextEdit, QFileDialog
from v4_cnn.predict import Predict
import numpy as np

class QPixmapDemo(QMainWindow):
    def __init__(self):
        super().__init__()
        self.txt = {0:'0', 1:'1', 2:'2', 3:'3', 4:'4', 5:'5', 6:'6', 7:'7', 8:'8', 9:'9'}
        self.setWindowTitle('picture')
        self.wgt = QWidget()
        # self.wgt.resize(600, 500)
        self.imgLabel = QLabel()
        self.imgLabel.resize(600, 600)  # 设置label的大小,图片会适配label的大小
        self.hbox = QHBoxLayout()
        self.hbox.addWidget(self.imgLabel)
        self.wgt.setLayout(self.hbox)
        self.setCentralWidget(self.wgt)
        self.docker = docker(self)
        self.addDockWidget(Qt.LeftDockWidgetArea,self.docker)
        self.docker.btn_openFile.clicked.connect(self.openFile)
        self.docker.btn_startDiscern.clicked.connect(self.start)
        self.resize(800,600)

    def openFile(self):
        self.file, filetype = QFileDialog.getOpenFileName(self,
                                                          "选择只有一个数字的图片",
                                                          "./",
                                                          "All Files (*);;Text Files (*.txt)")
        if self.file is not None:
            self.setImage(self.file)

    def start(self):
        discern = Predict()
        discern.predict(self.file)
        num = np.argmax(discern.y[0])
        self.docker.texEdit.setText(str(num))

    def setImage(self, file):
        img = cv2.imread(file)  # opencv读取图片
        img2 = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)  # opencv读取的bgr格式图片转换成rgb格式
        _image = QtGui.QImage(img2[:], img2.shape[1], img2.shape[0], img2.shape[1] * 3,
                              QtGui.QImage.Format_RGB888)  # pyqt5转换成自己能放的图片格式
        jpg_out = QtGui.QPixmap(_image).scaled(self.imgLabel.width(), self.imgLabel.height())  # 设置图片大小
        self.imgLabel.setPixmap(jpg_out)  # 设置图片显示


class docker(QDockWidget):
    def __init__(self, parent):
        super().__init__(parent)
        self.btn_openFile = QPushButton('打开图片')
        self.btn_startDiscern = QPushButton('开始识别')
        self.texEdit = QTextEdit()
        self.vbox = QVBoxLayout()
        self.vbox.addWidget(self.btn_openFile)
        self.vbox.addWidget(self.btn_startDiscern)
        self.vbox.addWidget(self.texEdit)
        self.wgt = QWidget()
        self.wgt.setLayout(self.vbox)
        self.setWidget(self.wgt)


if __name__ == '__main__':
    app = QApplication(sys.argv)
    win = QPixmapDemo()
    win.show()
    sys.exit(app.exec_())

运行结果:

我的第一个tensorflow程序