tensorboardX pytorch 入门实践
程序员文章站
2022-07-06 12:38:05
...
以 pytorch cifar-10 代码为例
for epoch in range(5): # 循环遍历数据集多次
for i, data in enumerate(trainloader, 0):
# 得到输入数据
inputs, labels = data
# 包装数据
inputs, labels = Variable(inputs), Variable(labels)
# 梯度清零
optimizer.zero_grad()
# net()为网络模型,通过模型得到输出
outputs = net(inputs)
# 计算损失
loss = criterion(outputs, labels)
# 反向传播
loss.backward()
# 参数优化
optimizer.step()
loss_value = loss.data[0]
# 预设batchsize=100
outputs = torch.cat((outputs.data, torch.ones(len(outputs), 1)), 1)
inputs = inputs.to(torch.device("cpu"))
# outputs.size():[100,11]///labels.size():[100]///inputs.size():[100,3,32,32]
writer.add_embedding(outputs, metadata=labels.data, label_img=inputs.data, global_step=epoch)
writer.add_embedding(outputs, metadata=labels.data, global_step=epoch)
writer.add_embedding(outputs, label_img=inputs.data, global_step=epoch)
writer.add_scalar('loss_value',loss_value,epoch)
writer.add_graph(net,(inputs,))
print('Finished Training')
writer.close()
writer.add_embedding(outputs, metadata=labels.data, label_img=inputs.data, global_step=epoch)
writer.add_embedding(outputs, metadata=labels.data, global_step=epoch)
writer.add_embedding(outputs, label_img=inputs.data, global_step=epoch)
writer.add_scalar(‘loss_value’,loss_value,epoch)
writer.add_graph(net,(inputs,))
writer.add_*源代码:https://github.com/lanpa/tensorboard-pytorch/blob/master/tensorboardX/writer.py
tensorboardX官方指导文档:http://tensorboard-pytorch.readthedocs.io/en/latest/tutorial_zh.html
tensorboardX官方github:https://github.com/lanpa/tensorboard-pytorch
参考代码段详见: