我的第一个tensorflow程序
程序员文章站
2022-04-03 22:50:28
...
第一个tensorflow是从网上抄来的,但是还是爬了个大坑,在预测文件中的图片转换为28*28尺寸的时候用PIL一直报错(原作者的代码),后来改用cv2模块resize问题就解决了,这是一个关于数字识别的程序,程序能够在一张只有一个0-9数字的图片中准确识别出数字是多少,准确率高达99+%,然后我用PyQt5封装了一下,使其可视化。环境为Python3+tensorflow2.0+PyQt5,首先创建一个python project,然后往里面添加文件夹,然后在v4_cnn目录下创建三个文件,mainUI.py,predict.py,train.py三个文件 ,ckpt文件目录是没有的,运行程序后生成的,想要运行该程序,必须先运行训练代码,train.py文件,然后再运行主UI文件maiUI.py文件。
训练文件train.py,代码如下:
import os
import tensorflow as tf
from tensorflow.keras import datasets, layers, models
import numpy as np
y1 = [0, 0.8, 0.1, 0.1, 0, 0, 0, 0, 0, 0]
y2 = [0, 0.1, 0.1, 0.1, 0.5, 0, 0.2, 0, 0, 0]
np.argmax(y1) # 1
np.argmax(y2) # 4
class CNN(object):
def __init__(self):
model = models.Sequential()
# 第1层卷积,卷积核大小为3*3,32个,28*28为待训练图片的大小
model.add(layers.Conv2D(32, (3, 3), activation='relu', input_shape=(28, 28, 1)))
model.add(layers.MaxPooling2D((2, 2)))
# 第2层卷积,卷积核大小为3*3,64个
model.add(layers.Conv2D(64, (3, 3), activation='relu'))
model.add(layers.MaxPooling2D((2, 2)))
# 第3层卷积,卷积核大小为3*3,64个
model.add(layers.Conv2D(64, (3, 3), activation='relu'))
model.add(layers.Flatten())
model.add(layers.Dense(64, activation='relu'))
model.add(layers.Dense(10, activation='softmax'))
model.summary()
self.model = model
class DataSource(object):
def __init__(self):
# mnist数据集存储的位置,如何不存在将自动下载
data_path = os.path.abspath(os.path.dirname(__file__)) + '/../data_set_tf2/mnist.npz'
(train_images, train_labels), (test_images, test_labels) = datasets.mnist.load_data(path=data_path)
# 6万张训练图片,1万张测试图片
train_images = train_images.reshape((60000, 28, 28, 1))
test_images = test_images.reshape((10000, 28, 28, 1))
# 像素值映射到 0 - 1 之间
train_images, test_images = train_images / 255.0, test_images / 255.0
self.train_images, self.train_labels = train_images, train_labels
self.test_images, self.test_labels = test_images, test_labels
class Train:
def __init__(self):
self.cnn = CNN()
self.data = DataSource()
def train(self):
check_path = './ckpt/cp-{epoch:04d}.ckpt'
# period 每隔5epoch保存一次
save_model_cb = tf.keras.callbacks.ModelCheckpoint(check_path, save_weights_only=True, verbose=1, period=5)
self.cnn.model.compile(optimizer='adam',
loss='sparse_categorical_crossentropy',
metrics=['accuracy'])
self.cnn.model.fit(self.data.train_images, self.data.train_labels, epochs=5, callbacks=[save_model_cb])
test_loss, test_acc = self.cnn.model.evaluate(self.data.test_images, self.data.test_labels)
print("准确率: %.4f,共测试了%d张图片 " % (test_acc, len(self.data.test_labels)))
if __name__ == "__main__":
app = Train()
app.train()
预测文件predict.py,代码如下:
import tensorflow as tf
from PIL import Image
import numpy as np
from v4_cnn.train import CNN
import cv2
class Predict(object):
def __init__(self):
latest = tf.train.latest_checkpoint('./ckpt')
self.cnn = CNN()
# 恢复网络权重
self.cnn.model.load_weights(latest)
def predict(self, image_path):
# 以黑白方式读取图片
img = Image.open(image_path).convert('L') #爬了个大坑
img = np.asarray(img)
img = cv2.resize(img,(28,28))
flatten_img = np.reshape(img, (28, 28, 1))
x = np.array([1 - flatten_img])
# API refer: https://keras.io/models/model/
self.y = self.cnn.model.predict(x)
主UI文件mainUI.py,代码如下:
import cv2
import sys
from PyQt5 import QtGui
from PyQt5.QtCore import Qt
from PyQt5.QtWidgets import QApplication, QWidget, QLabel, QHBoxLayout, QMainWindow, QDockWidget, QPushButton, \
QVBoxLayout, QTextEdit, QFileDialog
from v4_cnn.predict import Predict
import numpy as np
class QPixmapDemo(QMainWindow):
def __init__(self):
super().__init__()
self.txt = {0:'0', 1:'1', 2:'2', 3:'3', 4:'4', 5:'5', 6:'6', 7:'7', 8:'8', 9:'9'}
self.setWindowTitle('picture')
self.wgt = QWidget()
# self.wgt.resize(600, 500)
self.imgLabel = QLabel()
self.imgLabel.resize(600, 600) # 设置label的大小,图片会适配label的大小
self.hbox = QHBoxLayout()
self.hbox.addWidget(self.imgLabel)
self.wgt.setLayout(self.hbox)
self.setCentralWidget(self.wgt)
self.docker = docker(self)
self.addDockWidget(Qt.LeftDockWidgetArea,self.docker)
self.docker.btn_openFile.clicked.connect(self.openFile)
self.docker.btn_startDiscern.clicked.connect(self.start)
self.resize(800,600)
def openFile(self):
self.file, filetype = QFileDialog.getOpenFileName(self,
"选择只有一个数字的图片",
"./",
"All Files (*);;Text Files (*.txt)")
if self.file is not None:
self.setImage(self.file)
def start(self):
discern = Predict()
discern.predict(self.file)
num = np.argmax(discern.y[0])
self.docker.texEdit.setText(str(num))
def setImage(self, file):
img = cv2.imread(file) # opencv读取图片
img2 = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) # opencv读取的bgr格式图片转换成rgb格式
_image = QtGui.QImage(img2[:], img2.shape[1], img2.shape[0], img2.shape[1] * 3,
QtGui.QImage.Format_RGB888) # pyqt5转换成自己能放的图片格式
jpg_out = QtGui.QPixmap(_image).scaled(self.imgLabel.width(), self.imgLabel.height()) # 设置图片大小
self.imgLabel.setPixmap(jpg_out) # 设置图片显示
class docker(QDockWidget):
def __init__(self, parent):
super().__init__(parent)
self.btn_openFile = QPushButton('打开图片')
self.btn_startDiscern = QPushButton('开始识别')
self.texEdit = QTextEdit()
self.vbox = QVBoxLayout()
self.vbox.addWidget(self.btn_openFile)
self.vbox.addWidget(self.btn_startDiscern)
self.vbox.addWidget(self.texEdit)
self.wgt = QWidget()
self.wgt.setLayout(self.vbox)
self.setWidget(self.wgt)
if __name__ == '__main__':
app = QApplication(sys.argv)
win = QPixmapDemo()
win.show()
sys.exit(app.exec_())
运行结果:
上一篇: 如何用c语言输出100到200之间的素数
下一篇: windows怎么安装python
推荐阅读
-
我的第一个python web开发框架(28)——定制ORM(五)
-
我的第一个爬虫,爬取北京地区短租房信息
-
PyCharm的设置方法和第一个Python程序的建立
-
使用Python的Flask框架来搭建第一个Web应用程序
-
我的第一个netcore2.2 api项目搭建(三)
-
我的第一个python web开发框架(32)——接口代码重构
-
我的第一个python web开发框架(36)——后台菜单管理功能
-
微信小程序授权 获取用户的openid和session_key【后端使用java语言编写】,我写的是get方式,目的是测试能否获取到微信服务器中的数据,后期我会写上post请求方式。
-
java基础------环境变量的配置及编写第一个程序
-
微信小程序订阅消息,我踩过的坑都在这里了!