欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页  >  IT编程

如何在TensorFlow2.X中使用自定义训练循环的情况下在TensorBoard中绘制网络结构图(计算图)

程序员文章站 2024-02-06 18:49:58
遇到的问题很多小伙伴在使用TensorFlow2.x的时候会进行自定义的循环,也就是自己采用for循环来逐个Epoch循环;同时又想将此时的网络图绘制在TensorBoard中。这个时候问题就出现了:TensorBoard在2.0以后的版本中的的网络图是默认在model.fit之中自动绘制的;# 使用fit函数的时候会自动绘制网络计算图model.fit(trrain_dataset, epoch=10, ......)倘若想要自定义训练循环则又需要手动绘制网络图。# 自定义寻来你循环的时候,T...

遇到的问题

很多小伙伴在使用TensorFlow2.x的时候会进行自定义的循环,也就是自己采用for循环来逐个Epoch循环;同时又想将此时的网络图绘制在TensorBoard中。这个时候问题就出现了:TensorBoard在2.0以后的版本中的的网络图是默认在model.fit之中自动绘制的;

# 使用fit函数的时候会自动绘制网络计算图
model.fit(trrain_dataset, epoch=10, ......)

倘若想要自定义训练循环则又需要手动绘制网络图。

# 自定义寻来你循环的时候,TensorFlow不会帮助我们绘制网络计算图
for epooch in range(1, EPOCHS):
	SDG...
	LOSS...
	Record...

而网络上关于TensorFlow2.x绘制网络图的说明是少之又少,于是我决定写这篇博客来帮助大家来实现网络图的绘制。

如何在TensorFlow2.x中自己绘制网络图

直接给大家展示代码

import tensorflow as tf
from tensorflow.keras.layers import Dense, Flatten, Conv2D
from tensorflow.keras.datasets import mnist
from tensorflow.python.ops import summary_ops_v2  # 需要引入这个模块

logs_dir='你的自定义的日志目录'

# 你创建的模型
class ClassModel(tf.keras.Model):
    def __init__(self, ...):
        super(ClassModel, self).__init__()
        self.d1 = Dense(128, activation='relu')
        self.d2 = Dense(self.num_classes, activation='softmax')
        ... # 其他操作
        
    @tf.function # 需要使用tf.function
    def call(self, inputs):
        inputs = self.d1(inputs)
        output = self.d2(inputs)
        return output

# inputs可以是符合你输入数据形状的输入数据
inputs=training_dataset  
model=ClassModel()

# 开始创建网络计算图
graph_writer = tf.summary.create_file_writer(logdir=logs_dir)
with graph_writer.as_default():
    graph=model.call.get_concrete_function(inputs).graph
    summary_ops_v2.graph(graph.as_graph_def())
graph_writer.close()

通过这个流程,就可以构建出你的网络模型图了。
在这个过程中,有几点注意事项

  1. from tensorflow.python.ops import summary_ops_v2 需要引入这个模块
  2. 自定义模型中的call需要使用tf.function注解标注
  3. inputs可以为任何符合网络输入形状的数据,比如我的网络输入为(None, 32, 32, 3),那么我就可以令inputs=tf.ones((64, 32, 32, 3)),也就是说可以使用该数据跑通这个模型即可
  4. 使用tf.summary的FileWriter来进行绘制

绘制结果可以在TensorBoard的URL之中查看:
如何在TensorFlow2.X中使用自定义训练循环的情况下在TensorBoard中绘制网络结构图(计算图)

总结

其实这也是笔者找了很多文档都没发现,然后自己研究出来的方法。希望可以帮到大家。如果大家有任何问题,可以添加笔者QQ进行讨论:1574143668.
请大家在学习与工作的过程中不要忘记互联网创立的初衷——分享。

本文地址:https://blog.csdn.net/kiva12138/article/details/107072987