欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页

pyqt5结合keras+tensorflow实现深度学习训练过程GUI界面编写

程序员文章站 2024-03-19 18:48:40
...

pyqt5结合keras+tensorflow实现深度学习训练过程GUI界面编写

主要实现功能

第一次写博客,主要是想记录下最近踩的坑。最近想做一个集深度学习训练过程与缺陷检测过程为一体的界面,但是中间遇到许多问题,其中解决耗时最长的问题就是如何将深度学习训练过程实时显示在GUI界面的Textbrowser上,实现Textbrowser作为控制台输出的功能。

直接上代码

这里放的是GUI运行核心代码,其他的代码我将上传到CSDN下载中,有需要的小伙伴可以去下载

import ctypes
import win32con
import sys
from PyQt5.QtWidgets import QMainWindow, QApplication, QDialog, QFileDialog, QMessageBox
from PyQt5 import QtCore, QtGui
from PyQt5.QtCore import QThread, pyqtSignal
from mainwindow import Ui_MainWindow
from Model_training import Ui_Dialog
from Detection import Ui_Dialog1
import global_var as gl
from model_train import model_training
from model_prediction import predict, prediction
import pandas as pd

class EmittingStream(QtCore.QObject):
    textWritten = QtCore.pyqtSignal(str)
    def write(self, text):
        self.textWritten.emit(str(text))
    def flush(self):  # real signature unknown; restored from __doc__
        """ flush(self) """
        pass


class MainUI(QMainWindow, Ui_MainWindow):
    def __init__(self):
        super(MainUI, self).__init__()
        self.setupUi(self)
        self.pushButton_Training.setStyleSheet('background:rgb(0, 255, 0)')
        self.pushButton_Detection.setStyleSheet('background:rgb(0, 255, 0)')
        self.pushButton_Tuichu.setStyleSheet('color:red')
        self.pushButton_Tuichu.clicked.connect(self.close)
        self.Exit.triggered.connect(self.close)
        # self.pushButton_Detection.clicked.connect()


class Training_Dialog(QDialog, Ui_Dialog):
    def __init__(self):
        super(Training_Dialog, self).__init__()
        self.setupUi(self)
        self.pushButton_training.setStyleSheet('background:rgb(0, 255, 0)')
        self.pushButton_validation.setStyleSheet('background:rgb(0, 255, 0)')
        self.pushButton_Start.setStyleSheet('background:rgb(0, 255, 0)')
        self.pushButton_Stop.setStyleSheet('background:rgb(255, 0, 0)')
        self.comboBox_BS.addItems(['2', '4', '8', '16', '32', '64', '128'])
        self.comboBox_EP.addItems(['1', '20', '50', '100'])
        self.comboBox_LR.addItems(['0.1', '0.01', '0.001', '0.0001', '0.00001'])
        self.radioButton_AlexNet.setChecked(True)
        self.pushButton_Start.setEnabled(False)
        self.pushButton_training.clicked.connect(self.openfile)
        self.pushButton_validation.clicked.connect(self.openfile1)
        sys.stdout = EmittingStream(textWritten=self.normalOutputWritten)
        sys.stder = EmittingStream(textWritten=self.normalOutputWritten)
        self.pushButton_Start.clicked.connect(self.run_training)
        self.pushButton_Stop.clicked.connect(self.stop_training)
        self.my_thread = MyThread()  # 实例化线程对象

    def hyper_para(self):
        Epoch = self.comboBox_EP.currentText()
        gl.Epoch = int(Epoch)
        print('迭代次数为 %d' % gl.Epoch)
        batch_size = self.comboBox_BS.currentText()
        gl.batch_size = int(batch_size)
        print('批量尺寸为 %d' % gl.batch_size)
        Learning_rate = self.comboBox_LR.currentText()
        gl.learning_rate = float(Learning_rate)
        print('学习率为 %f' % gl.learning_rate)

    def stop_training(self):
        self.my_thread.is_on = False
        ret = ctypes.windll.kernel32.TerminateThread(  # @UndefinedVariable
            self.my_thread.handle, 0)
        print('终止训练', self.my_thread.handle, ret)

    def openfile(self):
        directory = QFileDialog.getExistingDirectory(self, "请选择文件夹路径",
                                                     "F:/Deep_CarrotNet/Carrot_resize_clear_split_clear")
        gl.gl_str_i = directory
        if len(gl.gl_str_i1) == 0:
            QMessageBox.critical(self, '提示', '请选择正确文件夹')

        print('成功加载训练文件', '训练文件夹所在位置:%s' % gl.gl_str_i)

    def openfile1(self):
        directory = QFileDialog.getExistingDirectory(self, "请选择文件夹路径",
                                                     "F:/Deep_CarrotNet/Carrot_resize_clear_split_clear")
        gl.gl_str_i1 = directory
        if len(gl.gl_str_i1) == 0:
            QMessageBox.critical(self, '提示', '请选择正确文件夹')
        else:
            self.pushButton_Start.setEnabled(True)
        print('成功加载验证文件', '验证文件夹所在位置:%s' % gl.gl_str_i1)

    def i_count(self):
        if self.radioButton_CarrotNet.text() == 'CarrotNet':
            if self.radioButton_CarrotNet.isChecked() == True:
                gl.gl_int_i = 2
                print('model is CarrotNet')

            elif self.radioButton_AlexNet.text() == 'AlexNet':
                if self.radioButton_AlexNet.isChecked() == True:
                    gl.gl_int_i = 1
                    print('model is AlexNet')

    def run_training(self):
        self.pushButton_Start.setEnabled(False)
        self.textBrowser.clear()
        self.i_count()
        self.hyper_para()
        if gl.gl_str_i == 'one':
            QMessageBox.critical(self, '错误', '请加载训练图片')
            self.my_thread.is_on = False
        elif gl.gl_str_i1 == 'one':
            QMessageBox.critical(self, '错误', '请加载验证图片')
            self.my_thread.is_on = False
        else:
            self.my_thread.is_on = True
        self.my_thread.start()  # 启动线程
        self.pushButton_Start.setEnabled(True)

    def normalOutputWritten(self, text):
        """Append text to the QTextEdit."""
        # Maybe QTextEdit.append() works as well, but this is how I do it:
        cursor = self.textBrowser.textCursor()
        cursor.movePosition(QtGui.QTextCursor.End)
        cursor.insertText(text)
        self.textBrowser.setTextCursor(cursor)
        self.textBrowser.ensureCursorVisible()


class MyThread(QThread):  # 线程类
    # my_signal = pyqtSignal(str)  # 自定义信号对象。参数str就代表这个信号可以传一个字符串

    def __init__(self):
        super(MyThread, self).__init__()
        # self.count = 0
        self.is_on = True

    def run(self):  # 线程执行函数
        self.handle = ctypes.windll.kernel32.OpenThread(  # @UndefinedVariable
            win32con.PROCESS_ALL_ACCESS, False, int(QThread.currentThreadId()))
        while self.is_on:
            model_training(gl.gl_int_i, gl.gl_str_i, gl.gl_str_i1, gl.Epoch,
                           gl.batch_size, gl.learning_rate)
            self.is_on = False


class EmittingStream1(QtCore.QObject):
    textWritten = QtCore.pyqtSignal(str)

    def write(self, text):
        self.textWritten.emit(str(text))

    def flush(self):  # real signature unknown; restored from __doc__
        """ flush(self) """
        pass


class Detection_Dialog(QDialog, Ui_Dialog1):
    def __init__(self):
        super(Detection_Dialog, self).__init__()
        self.setupUi(self)
        self.pushButton_start_detection.setStyleSheet('background:rgb(0, 255, 0)')
        self.pushButton_model.setStyleSheet('background:rgb(255, 0, 0)')
        self.pushButton_picture.setStyleSheet('background:rgb(255, 0, 0)')
        self.radioButton.setChecked(True)
        self.pushButton_model.setEnabled(True)
        self.pushButton_picture.setEnabled(False)
        self.pushButton_start_detection.setEnabled(False)
        self.pushButton_exit.setEnabled(False)
        self.pushButton_save.setEnabled(False)

        sys.stdout = EmittingStream1(textWritten=self.normalOutputWritten1)
        sys.stder = EmittingStream1(textWritten=self.normalOutputWritten1)
        # print('请先选择逐批检测还是逐个检测')
        self.pushButton_model.clicked.connect(self.message)
        self.pushButton_model.clicked.connect(self.load_moad)
        self.pushButton_picture.clicked.connect(self.load_image)
        self.my_thread1 = My_Thread1()  # 实例化线程对象
        self.pushButton_start_detection.clicked.connect(self.detection)
        self. pushButton_save.clicked.connect(self.save_result)
        self.pushButton_exit.clicked.connect(self.close)

    def save_result(self):
        path = QFileDialog.getExistingDirectory(self, "请选择文件路径")
        data = pd.DataFrame(gl.Y)
        data.to_csv(path + '/' + 'detection_result.csv', index=True)


    def message(self):
        QMessageBox.question(self, '提示', '请先选择逐批检测还是逐个检测')
        self.pushButton_model.setEnabled(True)
        self.pushButton_picture.setEnabled(True)
        self.pushButton_start_detection.setEnabled(True)
        self.pushButton_exit.setEnabled(True)

    def load_image(self):
        if self.radioButton.text() == '逐批检测':
            if self.radioButton.isChecked() == True:
                directory1 = QFileDialog.getExistingDirectory(self, "请选择文件路径")
                gl.gl_str_i3 = directory1
                print('成功导入检测文件', '检测文件所在位置:%s' % gl.gl_str_i3)
            elif self.radioButton_2.text() == '逐个检测':
                if self.radioButton_2.isChecked() == True:
                    fname, _ = QFileDialog.getOpenFileName(self, '选择图片', 'c:\\', 'Image files(*.jpg *.gif *.png)')
                    gl.gl_str_i4 = fname
                    print('成功导入检测图片', '检测文件所在位置:%s' % gl.gl_str_i4)
                else:
                    print('请正确选择检测文件路径')


    def load_moad(self):
        directory = QFileDialog.getExistingDirectory(self, "请选择文件路径")
        gl.gl_str_i2 = directory
        print('成功加载模型', '模型所在位置:%s' % gl.gl_str_i2)

    def normalOutputWritten1(self, text):
        """Append text to the QTextEdit."""
        # Maybe QTextEdit.append() works as well, but this is how I do it:
        cursor1 = self.textBrowser1.textCursor()
        cursor1.movePosition(QtGui.QTextCursor.End)
        cursor1.insertText(text)
        self.textBrowser1.setTextCursor(cursor1)
        self.textBrowser1.ensureCursorVisible()

    def detection(self):
        self.pushButton_start_detection.setEnabled(False)
        if self.radioButton.text() == '逐批检测':
            if self.radioButton.isChecked() == True:
                gl.i = 0
            elif self.radioButton_2.text() == '逐个检测':
                if self.radioButton_2.isChecked() == True:
                    gl.i = 1

        self.my_thread1.start()  # 启动线程
        self.pushButton_start_detection.setEnabled(True)
        self.pushButton_save.setEnabled(True)


class My_Thread1(QThread):
    def __init__(self):
        super(My_Thread1, self).__init__()

    def run(self):  # 线程执行函数
        print('测试开始')
        if gl.i == 0:
            prediction(gl.gl_str_i2, gl.gl_str_i3)
        else:
            predict(gl.gl_str_i2, gl.gl_str_i4)


if __name__ == "__main__":
    app = QApplication(sys.argv)
    main = MainUI()
    Training = Training_Dialog()
    Detection = Detection_Dialog()
    main.pushButton_Training.clicked.connect(Training.show)
    main.pushButton_Detection.clicked.connect(Detection.show)
    main.pushButton_Tuichu.clicked.connect(Training.close)
    main.pushButton_Tuichu.clicked.connect(Detection.close)
    main.Exit.triggered.connect(Training.close)
    main.Exit.triggered.connect(Detection.close)
    main.show()
    sys.exit(app.exec_())

重要代码

这里是关于如何将深度学习训练过程实时显示到GUI的Textbrowser上

class EmittingStream(QtCore.QObject):
    textWritten = QtCore.pyqtSignal(str)

    def write(self, text):
        self.textWritten.emit(str(text))

    def flush(self):  # real signature unknown; restored from __doc__
        """ flush(self) """
        pass

一定要加上flush函数的定义,之前在CSDN上找了很久,都没有这行,导致GUI界面上的Textbrowsers只能输出深度学习训练过程的第一行,不能实现实时刷新的功能,加上这个定义就可以完美解决

sys.stdout = EmittingStream(textWritten=self.normalOutputWritten)
sys.stder = EmittingStream(textWritten=self.normalOutputWritten)
class MyThread(QThread):  # 线程类
    # my_signal = pyqtSignal(str)  # 自定义信号对象。参数str就代表这个信号可以传一个字符串

    def __init__(self):
        super(MyThread, self).__init__()
        # self.count = 0
        self.is_on = True

    def run(self):  # 线程执行函数
        self.handle = ctypes.windll.kernel32.OpenThread(  # @UndefinedVariable
            win32con.PROCESS_ALL_ACCESS, False, int(QThread.currentThreadId())) # 是为了后面结束进程使用
        while self.is_on:
            model_training(gl.gl_int_i, gl.gl_str_i, gl.gl_str_i1, gl.Epoch,
                           gl.batch_size, gl.learning_rate)
            self.is_on = False
def stop_training(self):
        self.my_thread.is_on = False
        ret = ctypes.windll.kernel32.TerminateThread(  # @UndefinedVariable
            self.my_thread.handle, 0)
        print('终止训练', self.my_thread.handle, ret)

self.handle = ctypes.windll.kernel32.OpenThread( # @UndefinedVariable
win32con.PROCESS_ALL_ACCESS, False, int(QThread.currentThreadId())) # 是为了后面结束进程使用
def stop_training(self): 终止训练过程

出错解决

用前面的代码理论上是可以实现实时显示深度学习训练过程的,但我在刚开始使用时,总会出现== finished with exit code -1073740791 (0xC0000409)==,当把 sys.stdout = EmittingStream1(textWritten=self.normalOutputWritten1)
sys.stder = EmittingStream1(textWritten=self.normalOutputWritten1)这两行注释掉时程序可以正常运行,只不过内容没有输出到textbrowser上。在网上搜了一大圈,也没有发现适合我程序的,最后才发现是keras的版本和Tensorflow的版本不匹配造成的,但是之前不在GUI内运行不报错,在GUI框架下运行就会报错,==最终选择Keras2.2.5,tensorflow1.14.0 ==解决了问题,但是运行程序时会出现一大串警告,不过不影响最终结果

最终效果

pyqt5结合keras+tensorflow实现深度学习训练过程GUI界面编写

未解决问题

我这个有两个子界面,每个子界面都有一个Textbrowser,而且都想达到实时刷新的效果,但是当同时使用时会出现两个Textbrowser内容相互干扰的现象。哪位大神知道如何玩解决的话,还望不吝赐教

补充说明

本界面还使用了全局变量实现不同函数之间的互相传值,具体方法是先建个global_var.py文件,将需要传值的参数预先定义。此后各个文件import使用就行了

# coding=utf-8
# 在别的文件使用方法:
# import global_var_model as gl
#  gl.gl_int_i += 4,可以通过访问和修改gl.gl_int_i来实现python的全局变量,或者叫静态变量访问
# gl.gl_int_i
import numpy as np
gl_int_i = 1  # 这里的gl_int_i是最常用的用于标记的全局变量
gl_str_i = 'one'
gl_str_i1 = 'one'
gl_str_i2 = 'one'
gl_str_i3 = 'one'
gl_str_i4 = 'one'
batch_size = 1
Epoch = 1
learning_rate = 0.1
i = 0
Y = np.array([])

写在最后

第一次写博客,语言也不怎么精炼,文学功底不行,希望大家将就着看,整个GUI的全部代码我将在后续上传到CSDN上。当然这篇博客也借鉴了很多前人的经验,在此表示感谢