pytorch数据扩增
程序员文章站
2022-05-19 11:30:06
from torchvision import datasets,transformsimport matplotlib.pylab as pltimport torchpath2data = "./data"# loading MNIST training datasettrain_data = datasets.MNIST(path2data,train=True,download=False)# define transformationsdata_transform = trans...
from torchvision import datasets,transforms
import matplotlib.pylab as plt
import torch
path2data = "./data"
# loading MNIST training dataset
train_data = datasets.MNIST(path2data,train=True,download=False)
# define transformations
data_transform = transforms.Compose([transforms.RandomHorizontalFlip(p=1),
transforms.RandomVerticalFlip(p=1),
transforms.ToTensor(),])
# get a sample image from training dataset
img = train_data[1][0]
print(train_data[1][0])
img_tr = data_transform(img)
img_tr_np = img_tr.numpy()
# show original and transformed images
plt.subplot(121)
plt.imshow(img,cmap="gray")
plt.title("original")
plt.subplot(122)
plt.imshow(img_tr_np[0],cmap="gray")
plt.title("transformed")
plt.show()
data_transform = transforms.Compose([transforms.RandomHorizontalFlip(1),
transforms.RandomVerticalFlip(1),
transforms.ToTensor()])
train_data = datasets.MNIST(path2data,train=True,download=False,transform=data_transform)
结果:
<PIL.Image.Image image mode=L size=28x28 at 0x1D9911C50B8>
本文地址:https://blog.csdn.net/qq_28368377/article/details/107427965
上一篇: 选做题 - ATM机