欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页

tensorflow:不要在session中定义运算

程序员文章站 2022-07-08 09:42:26
...

最近在做项目时,总是会有程序崩溃的问题,系统也没有任何提示。最后通过监控系统发现是内存溢出造成的。

追查下去,发现一段类似这样的代码,在session中调用tensorflow的api进行运算:

import tensorflow as tf
X = tf.constant([[1,2,3], [3,2,4]], dtype=tf.float32)
W = tf.constant([[1,1],[2,2],[3,3]], dtype=tf.float32)
bias = tf.constant([1, 2], dtype=tf.float32)
y = tf.nn.softmax(tf.matmul(X, W) + bias)

with tf.Session() as sess:

    for i in range(10):
        print(i)
        sess.run(tf.nn.softmax(tf.matmul(X, W) + bias))

    writer = tf.compat.v1.summary.FileWriter("./graph", sess.graph)
    writer.close()

使用tensorboard查看内存泄漏的原因:

tensorflow:不要在session中定义运算

将计算图展开为

tensorflow:不要在session中定义运算

当然,这里只是展开了softmax,其他节点也可以类似展开。

可以看到,在session中定义计算节点,存在一个很大的风险,就是会在计算图中产生新的图节点,如果像我这样使用for循环运算,那么节点数会无限增加,注意不仅仅是softmax节点在增加,其他计算节点也在增加,这样的开销会越来越大,直至程序崩溃。

为了解决这个问题,我们应该使用上面定义的y的等式,在进入session前就已经将计算图定义好,在session中直接调用,而不是重新搭建。