Tensorflow/Keras 多线程同时训练多个模型
程序员文章站
2022-03-03 14:07:30
...
研究了很久,终于实现了多线程同时训练多个模型。
核心在于要理解TF里的Graph和Session。
废话不多说,直接上代码,看完代码就懂了!
class MyModel(object):
def __init__(self):
self.model1Thread = None
self.model2Thread = None
self.model1_graph = tf.Graph()
self.model1_sess = tf.Session(graph=self.model1_graph)
self.model2_graph = tf.Graph()
self.model2_sess = tf.Session(graph=self.model2_graph)
self.build_model1()
self.build_model2()
def build_model1(self):
with self.model1_sess.as_default():
with self.model1_graph.as_default():
model = Model(inputs=[xxx], outputs=[xxx])
model.compile()
model._make_predict_function()
return model
def build_model2(self):
with self.model2_sess.as_default():
with self.model2_graph.as_default():
model = Model(inputs=[xxx], outputs=[xxx])
model.compile()
model._make_predict_function()
return model
def predict(self):
with self.model1_sess.as_default():
with self.model1_graph.as_default():
self.model1.predict([xxxx])
def learn(self):
self.model1Thread = threading.Thread(target=self.learn_model1,args=())
self.model1Thread.setDaemon(True)
self.model1Thread.start()
self.model2Thread = threading.Thread(target=self.learn_model2,args=())
self.model2Thread.setDaemon(True)
self.model2Thread.start()
self.generation = self.generation + 1
self.flush_log()
def learn_model1(self):
with self.model1_sess.as_default():
with self.model1_graph.as_default():
self.model1.fit([xxxx], [xxxx])
def learn_model2(self, obs,reward):
with self.model2_sess.as_default():
with self.model2_graph.as_default():
self.model2.fit([xxx], [xxx])
def save_weights(self):
with self.model1_sess.as_default():
with self.model1_graph.as_default():
self.model1.save_weights()
with self.model2_sess.as_default():
with self.model2_graph.as_default():
self.model2.save_weights()
def load_weights(self):
with self.model1_sess.as_default():
with self.model1_graph.as_default():
self.model1.load_weights()
with self.model2_sess.as_default():
with self.model2_graph.as_default():
self.model2.load_weights()
上一篇: Java 全角、半角特殊符号转换