欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页

tensorflow在训练过程中出现内存溢出的问题

程序员文章站 2022-07-13 12:52:38
...

最近在使用perceptual_loss过程中,出现了模型随着训练占用的内存越来越大的问题,千万不要在遍历数据的过程中使用tf的相关操作,必要的操作放到模型中,如果在模型外执行,会不断增加新的节点,导致网络占用的内存随着训练的进行越来越大

 

tf2预训练好的模型加载放在__init__中,返回函数名。

 

今天在跑程序的时候,内存一个劲儿的涨。本地不行拿到服务器上去跑,62G内存分分钟干没了,不知道问题出在哪儿。经过在网上的一番查找,才弄清楚。一句话说:在迭代循环时,不能再包含任何张量的计算表达式,包括以tf.开头的函数(如tf.nn.embedding_lookup

如果你非得计算,请在循环体外面定义好表达式,在循环中直接run

举例:

import tensorflow as tf

a = tf.Variable(tf.truncated_normal(shape=[100,1000]),name='a')
b = tf.Variable(tf.truncated_normal(shape=[100,1000]),name='b')

with tf.Session() as sess:
    tf.global_variables_initializer().run()
    while True:
        print(sess.run(a+b))

可以看到,在循环体中出现了a+b 这个表达式,当你在运行程序的时候,内存会慢慢的增大(当然这个程序的增长速度还不足以导致崩掉)。原因是在Tensorflow的机制中,任何张量的计算表达式(函数操作)都会被作为节点添加到计算图中。如果循环中有表达式,那么计算图中就会被不停的加入几点,导致内存上升。

正确的做法应该是:(将表达式定义在外边)

import tensorflow as tf

a = tf.Variable(tf.truncated_normal(shape=[100,1000]),name='a')
b = tf.Variable(tf.truncated_normal(shape=[100,1000]),name='b')
z=a+b
with tf.Session() as sess:
    tf.global_variables_initializer().run()
    while True:
        print(sess.run(z))

同时TensorFlow也提供了一个办法来检查这个问题:

import tensorflow as tf

a = tf.Variable(tf.truncated_normal(shape=[100,1000]),name='a')
b = tf.Variable(tf.truncated_normal(shape=[100,1000]),name='b')

with tf.Session() as sess:
    tf.global_variables_initializer().run()
    while True:
        print(sess.run(a+b))
        sess.graph.finalize()

此时将报错:RuntimeError: Graph is finalized and cannot be modified.
sess.graph.finalize()这个函数告诉TensorFlow,计算图我已经定义完毕。所以当循环到第二次的时候就会报错。

再例如:

import tensorflow as tf

a = tf.Variable(tf.truncated_normal(shape=[2, 3]), name='a')
b = tf.Variable(tf.truncated_normal(shape=[2, 3]), name='b')

with tf.Session() as sess:
    tf.global_variables_initializer().run()
    sess.graph.finalize()
    c = tf.concat([a, b], axis=0)
    print(sess.run(c))

如上程序也会报错,因为tf.concat()会增加计算图中的节点,而在此之前,我已申明计算图定义完毕。这也证明,tf.开头的函数也将导致计算图中的节点增加。解决方法同上。

 

reference:

https://blog.csdn.net/qq_37876289/article/details/106157359

相关标签: tf2 tensorflow