欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页

Pytorch读取图片并显示

程序员文章站 2022-03-01 21:41:51
...
# -*- coding: utf-8 -*-
# @Time    : 18-3-15 下午6:43
# @Author  : zhwzhong
# @File    : model.py
# @Contact : [email protected]
# @Function:
from torchvision import transforms, datasets as ds
import torchvision as tv
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import numpy as np

transform = transforms.Compose(
    [
        transforms.ToTensor()
    ]
)
train_set = tv.datasets.ImageFolder(root='./', transform=transform)
data_loader = DataLoader(dataset=train_set)

to_pil_image = transforms.ToPILImage()

for image, label in data_loader:

    # 方法1:Image.show()
    # transforms.ToPILImage()中有一句
    # npimg = np.transpose(pic.numpy(), (1, 2, 0))
    # 因此pic只能是3-D Tensor,所以要用image[0]消去batch那一维
    img = to_pil_image(image[0])
    img.show()

    # 方法2:plt.imshow(ndarray)
    img = image[0]  # plt.imshow()只能接受3-D Tensor,所以也要用image[0]消去batch那一维
    img = img.numpy()  # FloatTensor转为ndarray
    img = np.transpose(img, (1, 2, 0))  # 把channel那一维放到最后

    # 显示图片
    plt.imshow(img)
    plt.show()

原文:https://blog.csdn.net/qq_34535410/article/details/79574327 

上一篇: 画图小技巧

下一篇: 3-matplotlib笔记