pix2pixHD代码解析
程序员文章站
2022-03-09 13:06:37
...
目录
导入数据集
-
train.py
第15行:from data.data_loader import CreateDataLoader
第30行:data_loader = CreateDataLoader(opt)
调用了data.data_loader
-
data_loader.py
: 定义类def CreateDataLoader(opt):
第四行data_loader = CustomDatasetDataLoader()
返回一个CustomDatasetDataLoader
类 -
custom_dataset_data_loader.py
:CustomDatasetDataLoader
第20行:self.dataset = CreateDataset(opt)
:调用custom_dataset_data_loader.py
中CreateDataset函数
第5行CreateDataset函数:dataset = AlignedDataset()
改AlignedDataset
中的数据集地址
class CustomDatasetDataLoader(BaseDataLoader):
def name(self):
return 'CustomDatasetDataLoader'
def initialize(self, opt):
BaseDataLoader.initialize(self, opt)
self.dataset = CreateDataset(opt)
self.dataloader = torch.utils.data.DataLoader(
self.dataset,
batch_size=opt.batchSize,
shuffle=not opt.serial_batches,
num_workers=int(opt.nThreads))
def load_data(self):
return self.dataloader
def __len__(self):
return min(len(self.dataset), self.opt.max_dataset_size)
class AlignedDataset(BaseDataset):
def initialize(self, opt):
self.opt = opt
self.root = opt.dataroot
### input A (label maps)
dir_A = '_A' if self.opt.label_nc == 0 else '_label'
self.dir_A = os.path.join(opt.dataroot, opt.phase + dir_A) # ./datasets/cityscapes/train_label
self.A_paths = sorted(make_dataset(self.dir_A))
### input B (real images)
if opt.isTrain or opt.use_encoded_image:
dir_B = '_B' if self.opt.label_nc == 0 else '_img'
self.dir_B = os.path.join(opt.dataroot, opt.phase + dir_B) # ./datasets/cityscapes/train_img
self.B_paths = sorted(make_dataset(self.dir_B))
### instance maps
if not opt.no_instance:
self.dir_inst = os.path.join(opt.dataroot, opt.phase + '_inst') # ./datasets/cityscapes/train_inst
self.inst_paths = sorted(make_dataset(self.dir_inst))
------------ Options -------------
batchSize: 1
beta1: 0.5
checkpoints_dir: ./checkpoints
continue_train: False
data_type: 32
dataroot: ./datasets/cityscapes/
debug: False
display_freq: 100
display_winsize: 512
feat_num: 3
fineSize: 512
fp16: False
gpu_ids: [0]
input_nc: 3
instance_feat: False
isTrain: True
label_feat: False
label_nc: 35
lambda_feat: 10.0
loadSize: 1024
load_features: False
load_pretrain:
local_rank: 0
lr: 0.0002
max_dataset_size: inf
model: pix2pixHD
nThreads: 2
n_blocks_global: 9
n_blocks_local: 3
n_clusters: 10
n_downsample_E: 4
n_downsample_global: 4
n_layers_D: 3
n_local_enhancers: 1
name: label2city_512p
ndf: 64
nef: 16
netG: global
ngf: 64
niter: 100
niter_decay: 100
niter_fix_global: 0
no_flip: False
no_ganFeat_loss: False
no_html: False
no_instance: False
no_lsgan: False
no_vgg_loss: False
norm: instance
num_D: 2
output_nc: 3
phase: train
pool_size: 0
print_freq: 100
resize_or_crop: scale_width
save_epoch_freq: 10
save_latest_freq: 1000
serial_batches: False
tf_log: False
use_dropout: False
verbose: False
which_epoch: latest
-------------- End ----------------
CustomDatasetDataLoader
dataset [AlignedDataset] was created
#training images = 8
GlobalGenerator(
(model): Sequential(
(0): ReflectionPad2d((3, 3, 3, 3))
(1): Conv2d(36, 64, kernel_size=(7, 7), stride=(1, 1))
(2): InstanceNorm2d(64, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
(3): ReLU(inplace)
(4): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
(5): InstanceNorm2d(128, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
(6): ReLU(inplace)
(7): Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
(8): InstanceNorm2d(256, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
(9): ReLU(inplace)
(10): Conv2d(256, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
(11): InstanceNorm2d(512, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
(12): ReLU(inplace)
(13): Conv2d(512, 1024, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
(14): InstanceNorm2d(1024, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
(15): ReLU(inplace)
(16): ResnetBlock(
(conv_block): Sequential(
(0): ReflectionPad2d((1, 1, 1, 1))
(1): Conv2d(1024, 1024, kernel_size=(3, 3), stride=(1, 1))
(2): InstanceNorm2d(1024, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
(3): ReLU(inplace)
(4): ReflectionPad2d((1, 1, 1, 1))
(5): Conv2d(1024, 1024, kernel_size=(3, 3), stride=(1, 1))
(6): InstanceNorm2d(1024, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
)
)
(17): ResnetBlock(
(conv_block): Sequential(
(0): ReflectionPad2d((1, 1, 1, 1))
(1): Conv2d(1024, 1024, kernel_size=(3, 3), stride=(1, 1))
(2): InstanceNorm2d(1024, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
(3): ReLU(inplace)
(4): ReflectionPad2d((1, 1, 1, 1))
(5): Conv2d(1024, 1024, kernel_size=(3, 3), stride=(1, 1))
(6): InstanceNorm2d(1024, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
)
)
(18): ResnetBlock(
(conv_block): Sequential(
(0): ReflectionPad2d((1, 1, 1, 1))
(1): Conv2d(1024, 1024, kernel_size=(3, 3), stride=(1, 1))
(2): InstanceNorm2d(1024, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
(3): ReLU(inplace)
(4): ReflectionPad2d((1, 1, 1, 1))
(5): Conv2d(1024, 1024, kernel_size=(3, 3), stride=(1, 1))
(6): InstanceNorm2d(1024, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
)
)
(19): ResnetBlock(
(conv_block): Sequential(
(0): ReflectionPad2d((1, 1, 1, 1))
(1): Conv2d(1024, 1024, kernel_size=(3, 3), stride=(1, 1))
(2): InstanceNorm2d(1024, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
(3): ReLU(inplace)
(4): ReflectionPad2d((1, 1, 1, 1))
(5): Conv2d(1024, 1024, kernel_size=(3, 3), stride=(1, 1))
(6): InstanceNorm2d(1024, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
)
)
(20): ResnetBlock(
(conv_block): Sequential(
(0): ReflectionPad2d((1, 1, 1, 1))
(1): Conv2d(1024, 1024, kernel_size=(3, 3), stride=(1, 1))
(2): InstanceNorm2d(1024, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
(3): ReLU(inplace)
(4): ReflectionPad2d((1, 1, 1, 1))
(5): Conv2d(1024, 1024, kernel_size=(3, 3), stride=(1, 1))
(6): InstanceNorm2d(1024, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
)
)
(21): ResnetBlock(
(conv_block): Sequential(
(0): ReflectionPad2d((1, 1, 1, 1))
(1): Conv2d(1024, 1024, kernel_size=(3, 3), stride=(1, 1))
(2): InstanceNorm2d(1024, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
(3): ReLU(inplace)
(4): ReflectionPad2d((1, 1, 1, 1))
(5): Conv2d(1024, 1024, kernel_size=(3, 3), stride=(1, 1))
(6): InstanceNorm2d(1024, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
)
)
(22): ResnetBlock(
(conv_block): Sequential(
(0): ReflectionPad2d((1, 1, 1, 1))
(1): Conv2d(1024, 1024, kernel_size=(3, 3), stride=(1, 1))
(2): InstanceNorm2d(1024, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
(3): ReLU(inplace)
(4): ReflectionPad2d((1, 1, 1, 1))
(5): Conv2d(1024, 1024, kernel_size=(3, 3), stride=(1, 1))
(6): InstanceNorm2d(1024, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
)
)
(23): ResnetBlock(
(conv_block): Sequential(
(0): ReflectionPad2d((1, 1, 1, 1))
(1): Conv2d(1024, 1024, kernel_size=(3, 3), stride=(1, 1))
(2): InstanceNorm2d(1024, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
(3): ReLU(inplace)
(4): ReflectionPad2d((1, 1, 1, 1))
(5): Conv2d(1024, 1024, kernel_size=(3, 3), stride=(1, 1))
(6): InstanceNorm2d(1024, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
)
)
(24): ResnetBlock(
(conv_block): Sequential(
(0): ReflectionPad2d((1, 1, 1, 1))
(1): Conv2d(1024, 1024, kernel_size=(3, 3), stride=(1, 1))
(2): InstanceNorm2d(1024, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
(3): ReLU(inplace)
(4): ReflectionPad2d((1, 1, 1, 1))
(5): Conv2d(1024, 1024, kernel_size=(3, 3), stride=(1, 1))
(6): InstanceNorm2d(1024, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
)
)
(25): ConvTranspose2d(1024, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), output_padding=(1, 1))
(26): InstanceNorm2d(512, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
(27): ReLU(inplace)
(28): ConvTranspose2d(512, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), output_padding=(1, 1))
(29): InstanceNorm2d(256, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
(30): ReLU(inplace)
(31): ConvTranspose2d(256, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), output_padding=(1, 1))
(32): InstanceNorm2d(128, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
(33): ReLU(inplace)
(34): ConvTranspose2d(128, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), output_padding=(1, 1))
(35): InstanceNorm2d(64, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
(36): ReLU(inplace)
(37): ReflectionPad2d((3, 3, 3, 3))
(38): Conv2d(64, 3, kernel_size=(7, 7), stride=(1, 1))
(39): Tanh()
)
)
MultiscaleDiscriminator(
(scale0_layer0): Sequential(
(0): Conv2d(39, 64, kernel_size=(4, 4), stride=(2, 2), padding=(2, 2))
(1): LeakyReLU(negative_slope=0.2, inplace)
)
(scale0_layer1): Sequential(
(0): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(2, 2))
(1): InstanceNorm2d(128, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
(2): LeakyReLU(negative_slope=0.2, inplace)
)
(scale0_layer2): Sequential(
(0): Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2), padding=(2, 2))
(1): InstanceNorm2d(256, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
(2): LeakyReLU(negative_slope=0.2, inplace)
)
(scale0_layer3): Sequential(
(0): Conv2d(256, 512, kernel_size=(4, 4), stride=(1, 1), padding=(2, 2))
(1): InstanceNorm2d(512, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
(2): LeakyReLU(negative_slope=0.2, inplace)
)
(scale0_layer4): Sequential(
(0): Conv2d(512, 1, kernel_size=(4, 4), stride=(1, 1), padding=(2, 2))
)
(scale1_layer0): Sequential(
(0): Conv2d(39, 64, kernel_size=(4, 4), stride=(2, 2), padding=(2, 2))
(1): LeakyReLU(negative_slope=0.2, inplace)
)
(scale1_layer1): Sequential(
(0): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(2, 2))
(1): InstanceNorm2d(128, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
(2): LeakyReLU(negative_slope=0.2, inplace)
)
(scale1_layer2): Sequential(
(0): Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2), padding=(2, 2))
(1): InstanceNorm2d(256, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
(2): LeakyReLU(negative_slope=0.2, inplace)
)
(scale1_layer3): Sequential(
(0): Conv2d(256, 512, kernel_size=(4, 4), stride=(1, 1), padding=(2, 2))
(1): InstanceNorm2d(512, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
(2): LeakyReLU(negative_slope=0.2, inplace)
)
(scale1_layer4): Sequential(
(0): Conv2d(512, 1, kernel_size=(4, 4), stride=(1, 1), padding=(2, 2))
)
(downsample): AvgPool2d(kernel_size=3, stride=2, padding=[1, 1])
)
Downloading: "https://download.pytorch.org/models/vgg19-dcbb9e9d.pth" to /home/yhr/.cache/torch/checkpoints/vgg19-dcbb9e9d.pth
100.0%
create web directory ./checkpoints/label2city_512p/web...
上一篇: GAN生成CIFAR10