| | ''' |
| | Copyright (C) 2019 NVIDIA Corporation. Ting-Chun Wang, Ming-Yu Liu, Jun-Yan Zhu. |
| | BSD License. All rights reserved. |
| | |
| | Redistribution and use in source and binary forms, with or without |
| | modification, are permitted provided that the following conditions are met: |
| | |
| | * Redistributions of source code must retain the above copyright notice, this |
| | list of conditions and the following disclaimer. |
| | |
| | * Redistributions in binary form must reproduce the above copyright notice, |
| | this list of conditions and the following disclaimer in the documentation |
| | and/or other materials provided with the distribution. |
| | |
| | THE AUTHOR DISCLAIMS ALL WARRANTIES WITH REGARD TO THIS SOFTWARE, INCLUDING ALL |
| | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR ANY PARTICULAR PURPOSE. |
| | IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY SPECIAL, INDIRECT OR CONSEQUENTIAL |
| | DAMAGES OR ANY DAMAGES WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, |
| | WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING |
| | OUT OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. |
| | ''' |
| | import torch |
| | import torch.nn as nn |
| | import functools |
| | import numpy as np |
| | import pytorch_lightning as pl |
| |
|
| |
|
| | |
| | |
| | |
| | def weights_init(m): |
| | classname = m.__class__.__name__ |
| | if classname.find('Conv') != -1: |
| | m.weight.data.normal_(0.0, 0.02) |
| | elif classname.find('BatchNorm2d') != -1: |
| | m.weight.data.normal_(1.0, 0.02) |
| | m.bias.data.fill_(0) |
| |
|
| |
|
| | def get_norm_layer(norm_type='instance'): |
| | if norm_type == 'batch': |
| | norm_layer = functools.partial(nn.BatchNorm2d, affine=True) |
| | elif norm_type == 'instance': |
| | norm_layer = functools.partial(nn.InstanceNorm2d, affine=False) |
| | else: |
| | raise NotImplementedError('normalization layer [%s] is not found' % |
| | norm_type) |
| | return norm_layer |
| |
|
| |
|
| | def define_G(input_nc, |
| | output_nc, |
| | ngf, |
| | netG, |
| | n_downsample_global=3, |
| | n_blocks_global=9, |
| | n_local_enhancers=1, |
| | n_blocks_local=3, |
| | norm='instance', |
| | gpu_ids=[], |
| | last_op=nn.Tanh()): |
| | norm_layer = get_norm_layer(norm_type=norm) |
| | if netG == 'global': |
| | netG = GlobalGenerator(input_nc, |
| | output_nc, |
| | ngf, |
| | n_downsample_global, |
| | n_blocks_global, |
| | norm_layer, |
| | last_op=last_op) |
| | elif netG == 'local': |
| | netG = LocalEnhancer(input_nc, output_nc, ngf, n_downsample_global, |
| | n_blocks_global, n_local_enhancers, |
| | n_blocks_local, norm_layer) |
| | elif netG == 'encoder': |
| | netG = Encoder(input_nc, output_nc, ngf, n_downsample_global, |
| | norm_layer) |
| | else: |
| | raise ('generator not implemented!') |
| | |
| | if len(gpu_ids) > 0: |
| | assert (torch.cuda.is_available()) |
| | netG.cuda(gpu_ids[0]) |
| | netG.apply(weights_init) |
| | return netG |
| |
|
| |
|
| | def print_network(net): |
| | if isinstance(net, list): |
| | net = net[0] |
| | num_params = 0 |
| | for param in net.parameters(): |
| | num_params += param.numel() |
| | print(net) |
| | print('Total number of parameters: %d' % num_params) |
| |
|
| |
|
| | |
| | |
| | |
| | class LocalEnhancer(pl.LightningModule): |
| |
|
| | def __init__(self, |
| | input_nc, |
| | output_nc, |
| | ngf=32, |
| | n_downsample_global=3, |
| | n_blocks_global=9, |
| | n_local_enhancers=1, |
| | n_blocks_local=3, |
| | norm_layer=nn.BatchNorm2d, |
| | padding_type='reflect'): |
| | super(LocalEnhancer, self).__init__() |
| | self.n_local_enhancers = n_local_enhancers |
| |
|
| | |
| | ngf_global = ngf * (2**n_local_enhancers) |
| | model_global = GlobalGenerator(input_nc, output_nc, ngf_global, |
| | n_downsample_global, n_blocks_global, |
| | norm_layer).model |
| | model_global = [model_global[i] for i in range(len(model_global) - 3) |
| | ] |
| | self.model = nn.Sequential(*model_global) |
| |
|
| | |
| | for n in range(1, n_local_enhancers + 1): |
| | |
| | ngf_global = ngf * (2**(n_local_enhancers - n)) |
| | model_downsample = [ |
| | nn.ReflectionPad2d(3), |
| | nn.Conv2d(input_nc, ngf_global, kernel_size=7, padding=0), |
| | norm_layer(ngf_global), |
| | nn.ReLU(True), |
| | nn.Conv2d(ngf_global, |
| | ngf_global * 2, |
| | kernel_size=3, |
| | stride=2, |
| | padding=1), |
| | norm_layer(ngf_global * 2), |
| | nn.ReLU(True) |
| | ] |
| | |
| | model_upsample = [] |
| | for i in range(n_blocks_local): |
| | model_upsample += [ |
| | ResnetBlock(ngf_global * 2, |
| | padding_type=padding_type, |
| | norm_layer=norm_layer) |
| | ] |
| |
|
| | |
| | model_upsample += [ |
| | nn.ConvTranspose2d(ngf_global * 2, |
| | ngf_global, |
| | kernel_size=3, |
| | stride=2, |
| | padding=1, |
| | output_padding=1), |
| | norm_layer(ngf_global), |
| | nn.ReLU(True) |
| | ] |
| |
|
| | |
| | if n == n_local_enhancers: |
| | model_upsample += [ |
| | nn.ReflectionPad2d(3), |
| | nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0), |
| | nn.Tanh() |
| | ] |
| |
|
| | setattr(self, 'model' + str(n) + '_1', |
| | nn.Sequential(*model_downsample)) |
| | setattr(self, 'model' + str(n) + '_2', |
| | nn.Sequential(*model_upsample)) |
| |
|
| | self.downsample = nn.AvgPool2d(3, |
| | stride=2, |
| | padding=[1, 1], |
| | count_include_pad=False) |
| |
|
| | def forward(self, input): |
| | |
| | input_downsampled = [input] |
| | for i in range(self.n_local_enhancers): |
| | input_downsampled.append(self.downsample(input_downsampled[-1])) |
| |
|
| | |
| | output_prev = self.model(input_downsampled[-1]) |
| | |
| | for n_local_enhancers in range(1, self.n_local_enhancers + 1): |
| | model_downsample = getattr(self, |
| | 'model' + str(n_local_enhancers) + '_1') |
| | model_upsample = getattr(self, |
| | 'model' + str(n_local_enhancers) + '_2') |
| | input_i = input_downsampled[self.n_local_enhancers - |
| | n_local_enhancers] |
| | output_prev = model_upsample( |
| | model_downsample(input_i) + output_prev) |
| | return output_prev |
| |
|
| |
|
| | class GlobalGenerator(pl.LightningModule): |
| |
|
| | def __init__(self, |
| | input_nc, |
| | output_nc, |
| | ngf=64, |
| | n_downsampling=3, |
| | n_blocks=9, |
| | norm_layer=nn.BatchNorm2d, |
| | padding_type='reflect', |
| | last_op=nn.Tanh()): |
| | assert (n_blocks >= 0) |
| | super(GlobalGenerator, self).__init__() |
| | activation = nn.ReLU(True) |
| |
|
| | model = [ |
| | nn.ReflectionPad2d(3), |
| | nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0), |
| | norm_layer(ngf), activation |
| | ] |
| | |
| | for i in range(n_downsampling): |
| | mult = 2**i |
| | model += [ |
| | nn.Conv2d(ngf * mult, |
| | ngf * mult * 2, |
| | kernel_size=3, |
| | stride=2, |
| | padding=1), |
| | norm_layer(ngf * mult * 2), activation |
| | ] |
| |
|
| | |
| | mult = 2**n_downsampling |
| | for i in range(n_blocks): |
| | model += [ |
| | ResnetBlock(ngf * mult, |
| | padding_type=padding_type, |
| | activation=activation, |
| | norm_layer=norm_layer) |
| | ] |
| |
|
| | |
| | for i in range(n_downsampling): |
| | mult = 2**(n_downsampling - i) |
| | model += [ |
| | nn.ConvTranspose2d(ngf * mult, |
| | int(ngf * mult / 2), |
| | kernel_size=3, |
| | stride=2, |
| | padding=1, |
| | output_padding=1), |
| | norm_layer(int(ngf * mult / 2)), activation |
| | ] |
| | model += [ |
| | nn.ReflectionPad2d(3), |
| | nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0) |
| | ] |
| | if last_op is not None: |
| | model += [last_op] |
| | self.model = nn.Sequential(*model) |
| |
|
| | def forward(self, input): |
| | return self.model(input) |
| |
|
| |
|
| | |
| | class ResnetBlock(pl.LightningModule): |
| |
|
| | def __init__(self, |
| | dim, |
| | padding_type, |
| | norm_layer, |
| | activation=nn.ReLU(True), |
| | use_dropout=False): |
| | super(ResnetBlock, self).__init__() |
| | self.conv_block = self.build_conv_block(dim, padding_type, norm_layer, |
| | activation, use_dropout) |
| |
|
| | def build_conv_block(self, dim, padding_type, norm_layer, activation, |
| | use_dropout): |
| | conv_block = [] |
| | p = 0 |
| | if padding_type == 'reflect': |
| | conv_block += [nn.ReflectionPad2d(1)] |
| | elif padding_type == 'replicate': |
| | conv_block += [nn.ReplicationPad2d(1)] |
| | elif padding_type == 'zero': |
| | p = 1 |
| | else: |
| | raise NotImplementedError('padding [%s] is not implemented' % |
| | padding_type) |
| |
|
| | conv_block += [ |
| | nn.Conv2d(dim, dim, kernel_size=3, padding=p), |
| | norm_layer(dim), activation |
| | ] |
| | if use_dropout: |
| | conv_block += [nn.Dropout(0.5)] |
| |
|
| | p = 0 |
| | if padding_type == 'reflect': |
| | conv_block += [nn.ReflectionPad2d(1)] |
| | elif padding_type == 'replicate': |
| | conv_block += [nn.ReplicationPad2d(1)] |
| | elif padding_type == 'zero': |
| | p = 1 |
| | else: |
| | raise NotImplementedError('padding [%s] is not implemented' % |
| | padding_type) |
| | conv_block += [ |
| | nn.Conv2d(dim, dim, kernel_size=3, padding=p), |
| | norm_layer(dim) |
| | ] |
| |
|
| | return nn.Sequential(*conv_block) |
| |
|
| | def forward(self, x): |
| | out = x + self.conv_block(x) |
| | return out |
| |
|
| |
|
| | class Encoder(pl.LightningModule): |
| |
|
| | def __init__(self, |
| | input_nc, |
| | output_nc, |
| | ngf=32, |
| | n_downsampling=4, |
| | norm_layer=nn.BatchNorm2d): |
| | super(Encoder, self).__init__() |
| | self.output_nc = output_nc |
| |
|
| | model = [ |
| | nn.ReflectionPad2d(3), |
| | nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0), |
| | norm_layer(ngf), |
| | nn.ReLU(True) |
| | ] |
| | |
| | for i in range(n_downsampling): |
| | mult = 2**i |
| | model += [ |
| | nn.Conv2d(ngf * mult, |
| | ngf * mult * 2, |
| | kernel_size=3, |
| | stride=2, |
| | padding=1), |
| | norm_layer(ngf * mult * 2), |
| | nn.ReLU(True) |
| | ] |
| |
|
| | |
| | for i in range(n_downsampling): |
| | mult = 2**(n_downsampling - i) |
| | model += [ |
| | nn.ConvTranspose2d(ngf * mult, |
| | int(ngf * mult / 2), |
| | kernel_size=3, |
| | stride=2, |
| | padding=1, |
| | output_padding=1), |
| | norm_layer(int(ngf * mult / 2)), |
| | nn.ReLU(True) |
| | ] |
| |
|
| | model += [ |
| | nn.ReflectionPad2d(3), |
| | nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0), |
| | nn.Tanh() |
| | ] |
| | self.model = nn.Sequential(*model) |
| |
|
| | def forward(self, input, inst): |
| | outputs = self.model(input) |
| |
|
| | |
| | outputs_mean = outputs.clone() |
| | inst_list = np.unique(inst.cpu().numpy().astype(int)) |
| | for i in inst_list: |
| | for b in range(input.size()[0]): |
| | indices = (inst[b:b + 1] == int(i)).nonzero() |
| | for j in range(self.output_nc): |
| | output_ins = outputs[indices[:, 0] + b, indices[:, 1] + j, |
| | indices[:, 2], indices[:, 3]] |
| | mean_feat = torch.mean(output_ins).expand_as(output_ins) |
| | outputs_mean[indices[:, 0] + b, indices[:, 1] + j, |
| | indices[:, 2], indices[:, 3]] = mean_feat |
| | return outputs_mean |
| |
|