欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页

通过multiprocessing模块及时释放tensorflow的资源

程序员文章站 2022-07-04 12:02:23
...
在使用tf.data等模块时,tensorflow会产生内存泄露;当内存泄露发生时,我们期望及时保存checkpoint,返回相应的状态,然后重新启动tensorflow进行增量训练。

如果采用subprocess.call()方案在子进程中调用tensorflow,需要自行实现参数、结果的序列化和反序列化,比较麻烦。
本文给出一种通过multiprocessing模块在子进程中调用tensorflow的实现,传参数so easy
话不多说,上代码:
# coding=utf-8
'''
Created on Sep 18, 2018

@author: colinliang
'''
from __future__ import absolute_import, division, print_function
def run_tf(args, queue=None):
    print('\n\n------- beginning of tf process')
    print('args for tf: %s' % args)
    import tensorflow as tf
    sess = tf.Session()
    
    import psutil
    mem_start=psutil.virtual_memory().available
    batch=2000000
    n=(args['epoch']+1) *batch
    with tf.device('/cpu:0'):
        v = tf.get_variable(name="tf_var", shape=[n], dtype=tf.float32, initializer=tf.random_uniform_initializer(-1, 1, 0, dtype=tf.float32))
    sess.run(tf.global_variables_initializer())
#     print( (mem_start-psutil.virtual_memory().available)/batch)
    if(mem_start-psutil.virtual_memory().available  >batch*12): #内存检测,有内存泄露时及时退出
        result={'exit code':-1}
        if(queue is not None):
            queue.put(result)
        return result
    
    import time 
    time.sleep(10)
    
    r = sess.run(v[0])
    print('sess: %s' % sess)
    sess.close()
#     tf.reset_default_graph()
    result={'first elem of tf var':r}
    if(queue is not None):
        queue.put(result)
    print('------- end of tf process')
    return result

#####################################################
from  multiprocessing import Process, Queue
# 参考自https://*.com/questions/39758094/clearing-tensorflow-gpu-memory-after-model-execution
for i in range(5):  # Process的使用方法 https://docs.python.org/2/library/multiprocessing.html
    q = Queue()
    args = {'epoch':i}
    p = Process(target=run_tf, args=(args, q))
    p.start()
    p.join()   
    print("result: %s" % q.get())

相关标签: python tensorflow