欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页

Pytorch MaxUnpooling 反最大池化操作(上采样)

程序员文章站 2022-04-30 22:41:50
...
import torch
from torch import nn
from torchvision import transforms
from PIL import Image

img = Image.open('images/train/0_1_pre_disaster.tif')

img_tensor = transforms.ToTensor()(img)
img_tensor = img_tensor.unsqueeze(0)         # MaxUnpool2d()需要输入一个4维矩阵 B * C * H * W

# 上采样
max_pool = nn.MaxPool2d(kernel_size=(2, 2), stride=(2, 2), return_indices=True, ceil_mode=True)
img_pool, indices = max_pool(img_tensor)

# 下采样
img_unpool = torch.rand_like(img_pool, dtype=torch.float)   # 输入图像的大小和上采样的大小保持一致
max_unpool = nn.MaxUnpool2d((2, 2), stride=(2, 2))   
img_unpool = max_unpool(img_unpool, indices)

img_show = transforms.ToPILImage()(img_unpool[0, ...])      # 显示反池化后的图像
img_show.show()

运行结果:

Pytorch MaxUnpooling 反最大池化操作(上采样)

相关标签: 实验室