pytorch 带batch的tensor类型图像显示操作
程序员文章站
2022-06-18 10:37:21
项目场景pytorch训练时我们一般把数据集放到数据加载器里,然后分批拿出来训练。训练前我们一般还要看一下训练数据长啥样,也就是训练数据集可视化。那么如何显示dataloader里面带batch的te...
项目场景
pytorch训练时我们一般把数据集放到数据加载器里,然后分批拿出来训练。训练前我们一般还要看一下训练数据长啥样,也就是训练数据集可视化。
那么如何显示dataloader里面带batch的tensor类型的图像呢?
显示图像
绘图最常用的库就是matplotlib:
pip install matplotlib
显示图像会用到matplotlib.pyplot.imshow方法。查阅官方文档可知,该方法接收的图像的通道数要放到后面:
数据加载器中数据的维度是[b, c, h, w],我们每次只拿一个数据出来就是[c, h, w],而matplotlib.pyplot.imshow要求的输入维度是[h, w, c],所以我们需要交换一下数据维度,把通道数放到最后面,这里用到pytorch里面的permute方法(transpose方法也行,不过要交换两次,没这个方便,numpy中的transpose方法倒是可以一次交换完成)
用法示例如下:
>>> x = torch.randn(2, 3, 5) >>> x.size() torch.size([2, 3, 5]) >>> x.permute(1, 2, 0).size() torch.size([3, 5, 2])
代码示例
#%% 导入模块 import torch import matplotlib.pyplot as plt from torchvision.utils import make_grid from torch.utils.data import dataloader from torchvision import datasets, transforms #%% 下载数据集 train_file = datasets.mnist( root='./dataset/', train=true, transform=transforms.compose([ transforms.totensor(), transforms.normalize((0.1307,), (0.3081,)) ]), download=true ) #%% 制作数据加载器 train_loader = dataloader( dataset=train_file, batch_size=9, shuffle=true ) #%% 训练数据可视化 images, labels = next(iter(train_loader)) print(images.size()) # torch.size([9, 1, 28, 28]) plt.figure(figsize=(9, 9)) for i in range(9): plt.subplot(3, 3, i+1) plt.title(labels[i].item()) plt.imshow(images[i].permute(1, 2, 0), cmap='gray') plt.axis('off') plt.show()
这里以mnist数据集为例,演示一下显示效果。我这个代码其实还有一点小问题。数据增强的时候我不是进行标准化了嘛,就是在第7行代码:normalize((0.1307,), (0.3081,))。
所以,如果你想查看训练集的原始图像,还得反标准化。
标准化:image = (image-mean)/std
反标准化:image = image*std+mean
我拿imagenet中的一个蚂蚁和蜜蜂的子集做了一下实验,标准化前后的区别还是很明显的:
最终效果
补充:pil,plt显示tensor类型的图像
该方法针对显示dataloader读取的图像
pil 与plt中对应操作不同,但原理是一样的,我试过用下方代码image的方法在plt上show失败了,原因暂且不知。
# 方法1:image.show() # transforms.topilimage()中有一句 # npimg = np.transpose(pic.numpy(), (1, 2, 0)) # 因此pic只能是3-d tensor,所以要用image[0]消去batch那一维 img = transforms.topilimage(image[0]) img.show() # 方法2:plt.imshow(ndarray) img = image[0] # plt.imshow()只能接受3-d tensor,所以也要用image[0]消去batch那一维 img = img.numpy() # floattensor转为ndarray img = np.transpose(img, (1,2,0)) # 把channel那一维放到最后 # 显示图片 plt.imshow(img) plt.show() cnt += 1
以上为个人经验,希望能给大家一个参考,也希望大家多多支持。