pyqt5结合keras+tensorflow实现深度学习训练过程GUI界面编写
主要实现功能
第一次写博客,主要是想记录下最近踩的坑。最近想做一个集深度学习训练过程与缺陷检测过程为一体的界面,但是中间遇到许多问题,其中解决耗时最长的问题就是如何将深度学习训练过程实时显示在GUI界面的Textbrowser上,实现Textbrowser作为控制台输出的功能。
直接上代码
这里放的是GUI运行核心代码,其他的代码我将上传到CSDN下载中,有需要的小伙伴可以去下载
import ctypes
import win32con
import sys
from PyQt5.QtWidgets import QMainWindow, QApplication, QDialog, QFileDialog, QMessageBox
from PyQt5 import QtCore, QtGui
from PyQt5.QtCore import QThread, pyqtSignal
from mainwindow import Ui_MainWindow
from Model_training import Ui_Dialog
from Detection import Ui_Dialog1
import global_var as gl
from model_train import model_training
from model_prediction import predict, prediction
import pandas as pd
class EmittingStream(QtCore.QObject):
textWritten = QtCore.pyqtSignal(str)
def write(self, text):
self.textWritten.emit(str(text))
def flush(self): # real signature unknown; restored from __doc__
""" flush(self) """
pass
class MainUI(QMainWindow, Ui_MainWindow):
def __init__(self):
super(MainUI, self).__init__()
self.setupUi(self)
self.pushButton_Training.setStyleSheet('background:rgb(0, 255, 0)')
self.pushButton_Detection.setStyleSheet('background:rgb(0, 255, 0)')
self.pushButton_Tuichu.setStyleSheet('color:red')
self.pushButton_Tuichu.clicked.connect(self.close)
self.Exit.triggered.connect(self.close)
# self.pushButton_Detection.clicked.connect()
class Training_Dialog(QDialog, Ui_Dialog):
def __init__(self):
super(Training_Dialog, self).__init__()
self.setupUi(self)
self.pushButton_training.setStyleSheet('background:rgb(0, 255, 0)')
self.pushButton_validation.setStyleSheet('background:rgb(0, 255, 0)')
self.pushButton_Start.setStyleSheet('background:rgb(0, 255, 0)')
self.pushButton_Stop.setStyleSheet('background:rgb(255, 0, 0)')
self.comboBox_BS.addItems(['2', '4', '8', '16', '32', '64', '128'])
self.comboBox_EP.addItems(['1', '20', '50', '100'])
self.comboBox_LR.addItems(['0.1', '0.01', '0.001', '0.0001', '0.00001'])
self.radioButton_AlexNet.setChecked(True)
self.pushButton_Start.setEnabled(False)
self.pushButton_training.clicked.connect(self.openfile)
self.pushButton_validation.clicked.connect(self.openfile1)
sys.stdout = EmittingStream(textWritten=self.normalOutputWritten)
sys.stder = EmittingStream(textWritten=self.normalOutputWritten)
self.pushButton_Start.clicked.connect(self.run_training)
self.pushButton_Stop.clicked.connect(self.stop_training)
self.my_thread = MyThread() # 实例化线程对象
def hyper_para(self):
Epoch = self.comboBox_EP.currentText()
gl.Epoch = int(Epoch)
print('迭代次数为 %d' % gl.Epoch)
batch_size = self.comboBox_BS.currentText()
gl.batch_size = int(batch_size)
print('批量尺寸为 %d' % gl.batch_size)
Learning_rate = self.comboBox_LR.currentText()
gl.learning_rate = float(Learning_rate)
print('学习率为 %f' % gl.learning_rate)
def stop_training(self):
self.my_thread.is_on = False
ret = ctypes.windll.kernel32.TerminateThread( # @UndefinedVariable
self.my_thread.handle, 0)
print('终止训练', self.my_thread.handle, ret)
def openfile(self):
directory = QFileDialog.getExistingDirectory(self, "请选择文件夹路径",
"F:/Deep_CarrotNet/Carrot_resize_clear_split_clear")
gl.gl_str_i = directory
if len(gl.gl_str_i1) == 0:
QMessageBox.critical(self, '提示', '请选择正确文件夹')
print('成功加载训练文件', '训练文件夹所在位置:%s' % gl.gl_str_i)
def openfile1(self):
directory = QFileDialog.getExistingDirectory(self, "请选择文件夹路径",
"F:/Deep_CarrotNet/Carrot_resize_clear_split_clear")
gl.gl_str_i1 = directory
if len(gl.gl_str_i1) == 0:
QMessageBox.critical(self, '提示', '请选择正确文件夹')
else:
self.pushButton_Start.setEnabled(True)
print('成功加载验证文件', '验证文件夹所在位置:%s' % gl.gl_str_i1)
def i_count(self):
if self.radioButton_CarrotNet.text() == 'CarrotNet':
if self.radioButton_CarrotNet.isChecked() == True:
gl.gl_int_i = 2
print('model is CarrotNet')
elif self.radioButton_AlexNet.text() == 'AlexNet':
if self.radioButton_AlexNet.isChecked() == True:
gl.gl_int_i = 1
print('model is AlexNet')
def run_training(self):
self.pushButton_Start.setEnabled(False)
self.textBrowser.clear()
self.i_count()
self.hyper_para()
if gl.gl_str_i == 'one':
QMessageBox.critical(self, '错误', '请加载训练图片')
self.my_thread.is_on = False
elif gl.gl_str_i1 == 'one':
QMessageBox.critical(self, '错误', '请加载验证图片')
self.my_thread.is_on = False
else:
self.my_thread.is_on = True
self.my_thread.start() # 启动线程
self.pushButton_Start.setEnabled(True)
def normalOutputWritten(self, text):
"""Append text to the QTextEdit."""
# Maybe QTextEdit.append() works as well, but this is how I do it:
cursor = self.textBrowser.textCursor()
cursor.movePosition(QtGui.QTextCursor.End)
cursor.insertText(text)
self.textBrowser.setTextCursor(cursor)
self.textBrowser.ensureCursorVisible()
class MyThread(QThread): # 线程类
# my_signal = pyqtSignal(str) # 自定义信号对象。参数str就代表这个信号可以传一个字符串
def __init__(self):
super(MyThread, self).__init__()
# self.count = 0
self.is_on = True
def run(self): # 线程执行函数
self.handle = ctypes.windll.kernel32.OpenThread( # @UndefinedVariable
win32con.PROCESS_ALL_ACCESS, False, int(QThread.currentThreadId()))
while self.is_on:
model_training(gl.gl_int_i, gl.gl_str_i, gl.gl_str_i1, gl.Epoch,
gl.batch_size, gl.learning_rate)
self.is_on = False
class EmittingStream1(QtCore.QObject):
textWritten = QtCore.pyqtSignal(str)
def write(self, text):
self.textWritten.emit(str(text))
def flush(self): # real signature unknown; restored from __doc__
""" flush(self) """
pass
class Detection_Dialog(QDialog, Ui_Dialog1):
def __init__(self):
super(Detection_Dialog, self).__init__()
self.setupUi(self)
self.pushButton_start_detection.setStyleSheet('background:rgb(0, 255, 0)')
self.pushButton_model.setStyleSheet('background:rgb(255, 0, 0)')
self.pushButton_picture.setStyleSheet('background:rgb(255, 0, 0)')
self.radioButton.setChecked(True)
self.pushButton_model.setEnabled(True)
self.pushButton_picture.setEnabled(False)
self.pushButton_start_detection.setEnabled(False)
self.pushButton_exit.setEnabled(False)
self.pushButton_save.setEnabled(False)
sys.stdout = EmittingStream1(textWritten=self.normalOutputWritten1)
sys.stder = EmittingStream1(textWritten=self.normalOutputWritten1)
# print('请先选择逐批检测还是逐个检测')
self.pushButton_model.clicked.connect(self.message)
self.pushButton_model.clicked.connect(self.load_moad)
self.pushButton_picture.clicked.connect(self.load_image)
self.my_thread1 = My_Thread1() # 实例化线程对象
self.pushButton_start_detection.clicked.connect(self.detection)
self. pushButton_save.clicked.connect(self.save_result)
self.pushButton_exit.clicked.connect(self.close)
def save_result(self):
path = QFileDialog.getExistingDirectory(self, "请选择文件路径")
data = pd.DataFrame(gl.Y)
data.to_csv(path + '/' + 'detection_result.csv', index=True)
def message(self):
QMessageBox.question(self, '提示', '请先选择逐批检测还是逐个检测')
self.pushButton_model.setEnabled(True)
self.pushButton_picture.setEnabled(True)
self.pushButton_start_detection.setEnabled(True)
self.pushButton_exit.setEnabled(True)
def load_image(self):
if self.radioButton.text() == '逐批检测':
if self.radioButton.isChecked() == True:
directory1 = QFileDialog.getExistingDirectory(self, "请选择文件路径")
gl.gl_str_i3 = directory1
print('成功导入检测文件', '检测文件所在位置:%s' % gl.gl_str_i3)
elif self.radioButton_2.text() == '逐个检测':
if self.radioButton_2.isChecked() == True:
fname, _ = QFileDialog.getOpenFileName(self, '选择图片', 'c:\\', 'Image files(*.jpg *.gif *.png)')
gl.gl_str_i4 = fname
print('成功导入检测图片', '检测文件所在位置:%s' % gl.gl_str_i4)
else:
print('请正确选择检测文件路径')
def load_moad(self):
directory = QFileDialog.getExistingDirectory(self, "请选择文件路径")
gl.gl_str_i2 = directory
print('成功加载模型', '模型所在位置:%s' % gl.gl_str_i2)
def normalOutputWritten1(self, text):
"""Append text to the QTextEdit."""
# Maybe QTextEdit.append() works as well, but this is how I do it:
cursor1 = self.textBrowser1.textCursor()
cursor1.movePosition(QtGui.QTextCursor.End)
cursor1.insertText(text)
self.textBrowser1.setTextCursor(cursor1)
self.textBrowser1.ensureCursorVisible()
def detection(self):
self.pushButton_start_detection.setEnabled(False)
if self.radioButton.text() == '逐批检测':
if self.radioButton.isChecked() == True:
gl.i = 0
elif self.radioButton_2.text() == '逐个检测':
if self.radioButton_2.isChecked() == True:
gl.i = 1
self.my_thread1.start() # 启动线程
self.pushButton_start_detection.setEnabled(True)
self.pushButton_save.setEnabled(True)
class My_Thread1(QThread):
def __init__(self):
super(My_Thread1, self).__init__()
def run(self): # 线程执行函数
print('测试开始')
if gl.i == 0:
prediction(gl.gl_str_i2, gl.gl_str_i3)
else:
predict(gl.gl_str_i2, gl.gl_str_i4)
if __name__ == "__main__":
app = QApplication(sys.argv)
main = MainUI()
Training = Training_Dialog()
Detection = Detection_Dialog()
main.pushButton_Training.clicked.connect(Training.show)
main.pushButton_Detection.clicked.connect(Detection.show)
main.pushButton_Tuichu.clicked.connect(Training.close)
main.pushButton_Tuichu.clicked.connect(Detection.close)
main.Exit.triggered.connect(Training.close)
main.Exit.triggered.connect(Detection.close)
main.show()
sys.exit(app.exec_())
重要代码
这里是关于如何将深度学习训练过程实时显示到GUI的Textbrowser上
class EmittingStream(QtCore.QObject):
textWritten = QtCore.pyqtSignal(str)
def write(self, text):
self.textWritten.emit(str(text))
def flush(self): # real signature unknown; restored from __doc__
""" flush(self) """
pass
一定要加上flush函数的定义,之前在CSDN上找了很久,都没有这行,导致GUI界面上的Textbrowsers只能输出深度学习训练过程的第一行,不能实现实时刷新的功能,加上这个定义就可以完美解决
sys.stdout = EmittingStream(textWritten=self.normalOutputWritten)
sys.stder = EmittingStream(textWritten=self.normalOutputWritten)
class MyThread(QThread): # 线程类
# my_signal = pyqtSignal(str) # 自定义信号对象。参数str就代表这个信号可以传一个字符串
def __init__(self):
super(MyThread, self).__init__()
# self.count = 0
self.is_on = True
def run(self): # 线程执行函数
self.handle = ctypes.windll.kernel32.OpenThread( # @UndefinedVariable
win32con.PROCESS_ALL_ACCESS, False, int(QThread.currentThreadId())) # 是为了后面结束进程使用
while self.is_on:
model_training(gl.gl_int_i, gl.gl_str_i, gl.gl_str_i1, gl.Epoch,
gl.batch_size, gl.learning_rate)
self.is_on = False
def stop_training(self):
self.my_thread.is_on = False
ret = ctypes.windll.kernel32.TerminateThread( # @UndefinedVariable
self.my_thread.handle, 0)
print('终止训练', self.my_thread.handle, ret)
self.handle = ctypes.windll.kernel32.OpenThread( # @UndefinedVariable
win32con.PROCESS_ALL_ACCESS, False, int(QThread.currentThreadId())) # 是为了后面结束进程使用
def stop_training(self): 终止训练过程
出错解决
用前面的代码理论上是可以实现实时显示深度学习训练过程的,但我在刚开始使用时,总会出现== finished with exit code -1073740791 (0xC0000409)==,当把 sys.stdout = EmittingStream1(textWritten=self.normalOutputWritten1)
sys.stder = EmittingStream1(textWritten=self.normalOutputWritten1)这两行注释掉时程序可以正常运行,只不过内容没有输出到textbrowser上。在网上搜了一大圈,也没有发现适合我程序的,最后才发现是keras的版本和Tensorflow的版本不匹配造成的,但是之前不在GUI内运行不报错,在GUI框架下运行就会报错,==最终选择Keras2.2.5,tensorflow1.14.0 ==解决了问题,但是运行程序时会出现一大串警告,不过不影响最终结果
最终效果
未解决问题
我这个有两个子界面,每个子界面都有一个Textbrowser,而且都想达到实时刷新的效果,但是当同时使用时会出现两个Textbrowser内容相互干扰的现象。哪位大神知道如何玩解决的话,还望不吝赐教
补充说明
本界面还使用了全局变量实现不同函数之间的互相传值,具体方法是先建个global_var.py文件,将需要传值的参数预先定义。此后各个文件import使用就行了
# coding=utf-8
# 在别的文件使用方法:
# import global_var_model as gl
# gl.gl_int_i += 4,可以通过访问和修改gl.gl_int_i来实现python的全局变量,或者叫静态变量访问
# gl.gl_int_i
import numpy as np
gl_int_i = 1 # 这里的gl_int_i是最常用的用于标记的全局变量
gl_str_i = 'one'
gl_str_i1 = 'one'
gl_str_i2 = 'one'
gl_str_i3 = 'one'
gl_str_i4 = 'one'
batch_size = 1
Epoch = 1
learning_rate = 0.1
i = 0
Y = np.array([])
写在最后
第一次写博客,语言也不怎么精炼,文学功底不行,希望大家将就着看,整个GUI的全部代码我将在后续上传到CSDN上。当然这篇博客也借鉴了很多前人的经验,在此表示感谢