Pytorch入坑指南
程序员文章站
2022-07-14 20:24:26
...
Pytorch入坑指南
pytorch使用tensorboard
- 安装Pytorch、Torchvision、Tensorboard
pip install --upgrade torch torchvision
pip install tensorboard
- 运行代码测试
import torch
import torchvision
from torch.utils.tensorboard import SummaryWriter
from torchvision import datasets,transforms
# 设定输出路径,不写则默认
writer=SummaryWriter('/home/bobo/Download/tensorboardDir')
# 设定 数据预处理(归一化、增广等) 的步骤
transform=transforms.Compose([
transforms.ToTensor(), #归一化
transforms.Normalize((0.5,),(0.5,))]) #均值 方差
# 采用 内置数据集
trainset=datasets.MNIST('mnist_train',train=True,download=True,transform=transform)
#数据集加载器
trainloader=torch.utils.data.DataLoader(trainset,batch_size=64,shuffle=True)
# 采用 内置网络 False代表 不使用预训练模型(因为要改第一层,故 预训练权重不适用)
model=torchvision.models.resnet50(False)
# 由于数据集是灰度图,故修改网络输入,由RGB改为灰度图
model.conv1=torch.nn.Conv2d(in_channels=1,out_channels=64,kernel_size=7,stride=2,padding=3,bias=False)
#加载 训练数据
images,labels=next(iter(trainloader))
# 设定网格 将一个batch的图像 转化为 一张网格图像
grid=torchvision.utils.make_grid(images)
# 展示训练图像
writer.add_image('images',grid,0)
#展示 模型结构图
writer.add_graph(model,images)
#一定要加
writer.close()
- 启动tensorflow
tensorboard --logdir=/home/bobo/Download/tensorboardDir --port 6006
注意:
(1)Python代码及启动命令要明确指定路径,因为tensorboard会显示该文件夹下的所有内容。
(2) 若网页无法打开,可能是该端口占用。使用ps -a
查看,并使用kill -9 端口号
即可。
参考
PyTorch 自带 TensorBoard 使用教程
PyTorch 1.1 or 1.2 使用Tensorboard
上一篇: 剑指offer-面试题6