欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页

keras中的tensorboard可视化

程序员文章站 2022-06-05 22:04:57
...
import tensorflow as tf
from keras.callbacks import Callback

class Mylosscallback(Callback):
    def __init__(self, log_dir):
        super(Callback, self).__init__()
        self.writer = tf.summary.FileWriter(log_dir)
        self.num = 0
    def on_train_begin(self, logs={}):
        self.losses = {'batch': [], 'epoch': []}
        self.accuracy = {'batch': [], 'epoch': []}
        self.val_loss = {'batch': [], 'epoch': []}
        self.val_acc = {'batch': [], 'epoch': []}
    def on_batch_end(self, batch, logs={}):
        self.num = self.num + 1
        self.losses = logs.get('loss')
        self.accuracy = logs.get('acc')
        self.val_loss = logs.get('val_loss')
        self.val_acc = logs.get('val_acc')
        print('debug success!!!')
        summary = tf.Summary()
        summary.value.add(tag='losses', simple_value=self.losses)
        summary.value.add(tag='accuracy', simple_value=self.accuracy)
        summary.value.add(tag='val_loss', simple_value=self.val_loss)
        summary.value.add(tag='val_acc', simple_value=self.val_acc)
        self.writer.add_summary(summary, self.num)
        self.writer.flush()

最后使用这个命令就可以使用tensorboard可视化自己想要的东西了,需要一个回调

model.fit(X_train, y_train, epochs=50, callbacks=[Mylosscallback(log_dir='./log')])
相关标签: AI-python python