欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页  >  IT编程

pytorch数据扩增

程序员文章站 2022-05-19 11:30:06
from torchvision import datasets,transformsimport matplotlib.pylab as pltimport torchpath2data = "./data"# loading MNIST training datasettrain_data = datasets.MNIST(path2data,train=True,download=False)# define transformationsdata_transform = trans...
from torchvision import datasets,transforms
import matplotlib.pylab as plt
import torch

path2data = "./data"
# loading MNIST training dataset
train_data = datasets.MNIST(path2data,train=True,download=False)

# define transformations
data_transform = transforms.Compose([transforms.RandomHorizontalFlip(p=1),
                                     transforms.RandomVerticalFlip(p=1),
                                     transforms.ToTensor(),])

# get a sample image from training dataset
img = train_data[1][0]
print(train_data[1][0])
img_tr = data_transform(img)
img_tr_np = img_tr.numpy()

# show original and transformed images
plt.subplot(121)
plt.imshow(img,cmap="gray")
plt.title("original")
plt.subplot(122)
plt.imshow(img_tr_np[0],cmap="gray")
plt.title("transformed")
plt.show()

data_transform = transforms.Compose([transforms.RandomHorizontalFlip(1),
                                     transforms.RandomVerticalFlip(1),
                                     transforms.ToTensor()])

train_data = datasets.MNIST(path2data,train=True,download=False,transform=data_transform)

结果:

<PIL.Image.Image image mode=L size=28x28 at 0x1D9911C50B8>

pytorch数据扩增

本文地址:https://blog.csdn.net/qq_28368377/article/details/107427965

相关标签: pytorch