Pytorch MaxUnpooling 反最大池化操作(上采样)
程序员文章站
2022-04-30 22:41:50
...
import torch
from torch import nn
from torchvision import transforms
from PIL import Image
img = Image.open('images/train/0_1_pre_disaster.tif')
img_tensor = transforms.ToTensor()(img)
img_tensor = img_tensor.unsqueeze(0) # MaxUnpool2d()需要输入一个4维矩阵 B * C * H * W
# 上采样
max_pool = nn.MaxPool2d(kernel_size=(2, 2), stride=(2, 2), return_indices=True, ceil_mode=True)
img_pool, indices = max_pool(img_tensor)
# 下采样
img_unpool = torch.rand_like(img_pool, dtype=torch.float) # 输入图像的大小和上采样的大小保持一致
max_unpool = nn.MaxUnpool2d((2, 2), stride=(2, 2))
img_unpool = max_unpool(img_unpool, indices)
img_show = transforms.ToPILImage()(img_unpool[0, ...]) # 显示反池化后的图像
img_show.show()
运行结果:
上一篇: 压缩字符串