通过multiprocessing模块及时释放tensorflow的资源
程序员文章站
2022-07-04 12:02:23
...
在使用tf.data等模块时,tensorflow会产生内存泄露;当内存泄露发生时,我们期望及时保存checkpoint,返回相应的状态,然后重新启动tensorflow进行增量训练。
如果采用subprocess.call()方案在子进程中调用tensorflow,需要自行实现参数、结果的序列化和反序列化,比较麻烦。
本文给出一种通过multiprocessing模块在子进程中调用tensorflow的实现,传参数so easy
话不多说,上代码:
如果采用subprocess.call()方案在子进程中调用tensorflow,需要自行实现参数、结果的序列化和反序列化,比较麻烦。
本文给出一种通过multiprocessing模块在子进程中调用tensorflow的实现,传参数so easy
话不多说,上代码:
# coding=utf-8 ''' Created on Sep 18, 2018 @author: colinliang ''' from __future__ import absolute_import, division, print_function def run_tf(args, queue=None): print('\n\n------- beginning of tf process') print('args for tf: %s' % args) import tensorflow as tf sess = tf.Session() import psutil mem_start=psutil.virtual_memory().available batch=2000000 n=(args['epoch']+1) *batch with tf.device('/cpu:0'): v = tf.get_variable(name="tf_var", shape=[n], dtype=tf.float32, initializer=tf.random_uniform_initializer(-1, 1, 0, dtype=tf.float32)) sess.run(tf.global_variables_initializer()) # print( (mem_start-psutil.virtual_memory().available)/batch) if(mem_start-psutil.virtual_memory().available >batch*12): #内存检测,有内存泄露时及时退出 result={'exit code':-1} if(queue is not None): queue.put(result) return result import time time.sleep(10) r = sess.run(v[0]) print('sess: %s' % sess) sess.close() # tf.reset_default_graph() result={'first elem of tf var':r} if(queue is not None): queue.put(result) print('------- end of tf process') return result ##################################################### from multiprocessing import Process, Queue # 参考自https://*.com/questions/39758094/clearing-tensorflow-gpu-memory-after-model-execution for i in range(5): # Process的使用方法 https://docs.python.org/2/library/multiprocessing.html q = Queue() args = {'epoch':i} p = Process(target=run_tf, args=(args, q)) p.start() p.join() print("result: %s" % q.get())