欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页

Tensorflow/Keras 多线程同时训练多个模型

程序员文章站 2022-03-03 14:07:30
...

研究了很久,终于实现了多线程同时训练多个模型。

核心在于要理解TF里的Graph和Session。

废话不多说,直接上代码,看完代码就懂了!


class MyModel(object):

    def __init__(self):

        self.model1Thread = None
        self.model2Thread = None

        self.model1_graph = tf.Graph()
        self.model1_sess = tf.Session(graph=self.model1_graph)
      
        self.model2_graph = tf.Graph()
        self.model2_sess = tf.Session(graph=self.model2_graph)
        
        self.build_model1()
        self.build_model2()
    
    def build_model1(self):
        with self.model1_sess.as_default():
            with self.model1_graph.as_default():

                model = Model(inputs=[xxx], outputs=[xxx])

                model.compile()

                model._make_predict_function()
                return model

    def build_model2(self):
        with self.model2_sess.as_default():
            with self.model2_graph.as_default():

                model = Model(inputs=[xxx], outputs=[xxx])
                model.compile()

                model._make_predict_function()
                return model
        

    def predict(self):
    
        with self.model1_sess.as_default():
            with self.model1_graph.as_default():
                self.model1.predict([xxxx])

                

    def learn(self):
        
        self.model1Thread = threading.Thread(target=self.learn_model1,args=())
        self.model1Thread.setDaemon(True)
        self.model1Thread.start()

        self.model2Thread = threading.Thread(target=self.learn_model2,args=())
        self.model2Thread.setDaemon(True)
        self.model2Thread.start()

        self.generation = self.generation + 1
        self.flush_log()
    
    def learn_model1(self):
        with self.model1_sess.as_default():
            with self.model1_graph.as_default():
                self.model1.fit([xxxx], [xxxx])

            
    def learn_model2(self, obs,reward):
        with self.model2_sess.as_default():
            with self.model2_graph.as_default():
                self.model2.fit([xxx], [xxx])
                  
    
    def save_weights(self):
        with self.model1_sess.as_default():
            with self.model1_graph.as_default():
                self.model1.save_weights()

        with self.model2_sess.as_default():
            with self.model2_graph.as_default():
                self.model2.save_weights()

    def load_weights(self):
        with self.model1_sess.as_default():
            with self.model1_graph.as_default():
                self.model1.load_weights()

        with self.model2_sess.as_default():
            with self.model2_graph.as_default():
                self.model2.load_weights()