欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页

tensorboardX pytorch 入门实践

程序员文章站 2022-07-06 12:38:05
...

以 pytorch cifar-10 代码为例

for epoch in range(5):  # 循环遍历数据集多次
    for i, data in enumerate(trainloader, 0):
        # 得到输入数据
        inputs, labels = data
        # 包装数据
        inputs, labels = Variable(inputs), Variable(labels) 
        # 梯度清零
        optimizer.zero_grad()
        # net()为网络模型,通过模型得到输出
        outputs = net(inputs)
        # 计算损失
        loss = criterion(outputs, labels)
        # 反向传播
        loss.backward()
        # 参数优化
        optimizer.step()
        loss_value = loss.data[0]
    # 预设batchsize=100    
    outputs = torch.cat((outputs.data, torch.ones(len(outputs), 1)), 1)
    inputs = inputs.to(torch.device("cpu"))
    # outputs.size():[100,11]///labels.size():[100]///inputs.size():[100,3,32,32]
    writer.add_embedding(outputs, metadata=labels.data, label_img=inputs.data, global_step=epoch)
    writer.add_embedding(outputs, metadata=labels.data, global_step=epoch)
    writer.add_embedding(outputs, label_img=inputs.data, global_step=epoch)
    writer.add_scalar('loss_value',loss_value,epoch)
writer.add_graph(net,(inputs,))
print('Finished Training')
writer.close()

writer.add_embedding(outputs, metadata=labels.data, label_img=inputs.data, global_step=epoch)
tensorboardX pytorch 入门实践
writer.add_embedding(outputs, metadata=labels.data, global_step=epoch)
tensorboardX pytorch 入门实践
writer.add_embedding(outputs, label_img=inputs.data, global_step=epoch)
tensorboardX pytorch 入门实践
writer.add_scalar(‘loss_value’,loss_value,epoch)
tensorboardX pytorch 入门实践
writer.add_graph(net,(inputs,))
tensorboardX pytorch 入门实践

writer.add_*源代码:https://github.com/lanpa/tensorboard-pytorch/blob/master/tensorboardX/writer.py
tensorboardX官方指导文档:http://tensorboard-pytorch.readthedocs.io/en/latest/tutorial_zh.html
tensorboardX官方github:https://github.com/lanpa/tensorboard-pytorch
参考代码段详见: