import torch import torch.nn as nn import torch.nn.functional as F import math networks = ['BaseNetwork', 'Discriminator', 'ASPP'] # Base model borrows from PEN-NET # https://github.com/researchmm/PEN-Net-for-Inpainting class BaseNetwork(nn.Module): def __init__(self): super(BaseNetwork, self).__init__() def print_network(self): if isinstance(self, list): self = self[0] num_params = 0 for param in self.parameters(): num_params += param.numel() print('Network [%s] was created. Total number of parameters: %.1f million. ' 'To see the architecture, do print(network).' % (type(self).__name__, num_params / 1000000)) def init_weights(self, init_type='normal', gain=0.02): ''' initialize network's weights init_type: normal | xavier | kaiming | orthogonal https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/9451e70673400885567d08a9e97ade2524c700d0/models/networks.py#L39 ''' def init_func(m): classname = m.__class__.__name__ if classname.find('InstanceNorm2d') != -1: if hasattr(m, 'weight') and m.weight is not None: nn.init.constant_(m.weight.data, 1.0) if hasattr(m, 'bias') and m.bias is not None: nn.init.constant_(m.bias.data, 0.0) elif hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1): if init_type == 'normal': nn.init.normal_(m.weight.data, 0.0, gain) elif init_type == 'xavier': nn.init.xavier_normal_(m.weight.data, gain=gain) elif init_type == 'xavier_uniform': nn.init.xavier_uniform_(m.weight.data, gain=1.0) elif init_type == 'kaiming': nn.init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') elif init_type == 'orthogonal': nn.init.orthogonal_(m.weight.data, gain=gain) elif init_type == 'none': # uses pytorch's default init method m.reset_parameters() else: raise NotImplementedError('initialization method [%s] is not implemented' % init_type) if hasattr(m, 'bias') and m.bias is not None: nn.init.constant_(m.bias.data, 0.0) self.apply(init_func) # propagate to children for m in self.children(): if hasattr(m, 'init_weights'): m.init_weights(init_type, gain) # temporal patch gan: from Free-form Video Inpainting with 3D Gated Convolution and Temporal PatchGAN in 2019 ICCV # todo: debug this model class Discriminator(BaseNetwork): def __init__(self, in_channels=3, use_sigmoid=False, use_spectral_norm=True, init_weights=True): super(Discriminator, self).__init__() self.use_sigmoid = use_sigmoid nf = 64 self.conv = nn.Sequential( DisBuildingBlock(in_channel=in_channels, out_channel=nf * 1, kernel_size=(3, 5, 5), stride=(1, 2, 2), padding=1, use_spectral_norm=use_spectral_norm), # nn.InstanceNorm2d(64, track_running_stats=False), nn.LeakyReLU(0.2, inplace=True), DisBuildingBlock(in_channel=nf * 1, out_channel=nf * 2, kernel_size=(3, 5, 5), stride=(1, 2, 2), padding=(1, 2, 2), use_spectral_norm=use_spectral_norm), # nn.InstanceNorm2d(128, track_running_stats=False), nn.LeakyReLU(0.2, inplace=True), DisBuildingBlock(in_channel=nf * 2, out_channel=nf * 4, kernel_size=(3, 5, 5), stride=(1, 2, 2), padding=(1, 2, 2), use_spectral_norm=use_spectral_norm), # nn.InstanceNorm2d(256, track_running_stats=False), nn.LeakyReLU(0.2, inplace=True), DisBuildingBlock(in_channel=nf * 4, out_channel=nf * 4, kernel_size=(3, 5, 5), stride=(1, 2, 2), padding=(1, 2, 2), use_spectral_norm=use_spectral_norm), # nn.InstanceNorm2d(256, track_running_stats=False), nn.LeakyReLU(0.2, inplace=True), DisBuildingBlock(in_channel=nf * 4, out_channel=nf * 4, kernel_size=(3, 5, 5), stride=(1, 2, 2), padding=(1, 2, 2), use_spectral_norm=use_spectral_norm), # nn.InstanceNorm2d(256, track_running_stats=False), nn.LeakyReLU(0.2, inplace=True), nn.Conv3d(nf * 4, nf * 4, kernel_size=(3, 5, 5), stride=(1, 2, 2), padding=(1, 2, 2)) ) if init_weights: self.init_weights() def forward(self, xs): # B, C, T, H, W = xs.shape feat = self.conv(xs) if self.use_sigmoid: feat = torch.sigmoid(feat) return feat class DisBuildingBlock(nn.Module): def __init__(self, in_channel, out_channel, kernel_size, stride, padding, use_spectral_norm): super(DisBuildingBlock, self).__init__() self.block = self._getBlock(in_channel, out_channel, kernel_size, stride, padding, use_spectral_norm) def _getBlock(self, in_channel, out_channel, kernel_size, stride, padding, use_spectral_norm): feature_conv = nn.Conv3d(in_channels=in_channel, out_channels=out_channel, kernel_size=kernel_size, stride=stride, padding=padding, bias=not use_spectral_norm) if use_spectral_norm: feature_conv = nn.utils.spectral_norm(feature_conv) return feature_conv def forward(self, inputs): out = self.block(inputs) return out class ASPP(nn.Module): def __init__(self, input_channels, output_channels, rate=[1, 2, 4, 8]): super(ASPP, self).__init__() self.input_channels = input_channels self.output_channels = output_channels self.rate = rate for i in range(len(rate)): self.__setattr__('conv{}'.format(str(i).zfill(2)), nn.Sequential( nn.Conv2d(input_channels, output_channels // len(rate), kernel_size=3, dilation=rate[i], padding=rate[i]), nn.LeakyReLU(negative_slope=0.2, inplace=True) )) def forward(self, inputs): inputs = inputs / len(self.rate) tmp = [] for i in range(len(self.rate)): tmp.append(self.__getattr__('conv{}'.format(str(i).zfill(2)))(inputs)) output = torch.cat(tmp, dim=1) return output class GatedConv2dWithActivation(torch.nn.Module): """ Gated Convlution layer with activation (default activation:LeakyReLU) Params: same as conv2d Input: The feature from last layer "I" Output:\phi(f(I))*\sigmoid(g(I)) """ def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True, batch_norm=False, activation=torch.nn.LeakyReLU(0.2, inplace=True)): super(GatedConv2dWithActivation, self).__init__() self.batch_norm = batch_norm self.activation = activation self.conv2d = torch.nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias) self.mask_conv2d = torch.nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias) self.batch_norm2d = torch.nn.BatchNorm2d(out_channels) self.sigmoid = torch.nn.Sigmoid() for m in self.modules(): if isinstance(m, nn.Conv2d): nn.init.kaiming_normal_(m.weight) def gated(self, mask): return self.sigmoid(mask) def forward(self, inputs): x = self.conv2d(inputs) mask = self.mask_conv2d(inputs) if self.activation is not None: x = self.activation(x) * self.gated(mask) else: x = x * self.gated(mask) if self.batch_norm: return self.batch_norm2d(x) else: return x class GatedDeConv2dWithActivation(torch.nn.Module): """ Gated DeConvlution layer with activation (default activation:LeakyReLU) resize + conv Params: same as conv2d Input: The feature from last layer "I" Output:\phi(f(I))*\sigmoid(g(I)) """ def __init__(self, scale_factor, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True, batch_norm=False, activation=torch.nn.LeakyReLU(0.2, inplace=True)): super(GatedDeConv2dWithActivation, self).__init__() self.conv2d = GatedConv2dWithActivation(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias, batch_norm, activation) self.scale_factor = scale_factor def forward(self, inputs): # print(input.size()) x = F.interpolate(inputs, scale_factor=self.scale_factor) return self.conv2d(x) class SNGatedConv2dWithActivation(torch.nn.Module): """ Gated Convolution with spetral normalization """ def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True, batch_norm=False, activation=torch.nn.LeakyReLU(0.2, inplace=True)): super(SNGatedConv2dWithActivation, self).__init__() self.conv2d = torch.nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias) self.mask_conv2d = torch.nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias) self.activation = activation self.batch_norm = batch_norm self.batch_norm2d = torch.nn.BatchNorm2d(out_channels) self.sigmoid = torch.nn.Sigmoid() self.conv2d = torch.nn.utils.spectral_norm(self.conv2d) self.mask_conv2d = torch.nn.utils.spectral_norm(self.mask_conv2d) for m in self.modules(): if isinstance(m, nn.Conv2d): nn.init.kaiming_normal_(m.weight) def gated(self, mask): return self.sigmoid(mask) def forward(self, inputs): x = self.conv2d(inputs) mask = self.mask_conv2d(inputs) if self.activation is not None: x = self.activation(x) * self.gated(mask) else: x = x * self.gated(mask) if self.batch_norm: return self.batch_norm2d(x) else: return x class SNGatedDeConv2dWithActivation(torch.nn.Module): """ Gated DeConvlution layer with activation (default activation:LeakyReLU) resize + conv Params: same as conv2d Input: The feature from last layer "I" Output:\phi(f(I))*\sigmoid(g(I)) """ def __init__(self, scale_factor, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True, batch_norm=False, activation=torch.nn.LeakyReLU(0.2, inplace=True)): super(SNGatedDeConv2dWithActivation, self).__init__() self.conv2d = SNGatedConv2dWithActivation(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias, batch_norm, activation) self.scale_factor = scale_factor def forward(self, inputs): x = F.interpolate(inputs, scale_factor=2) return self.conv2d(x) class GatedConv3d(nn.Module): def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True, activation=nn.LeakyReLU(0.2, inplace=True)): super(GatedConv3d, self).__init__() self.input_conv = nn.Conv3d(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias) self.gating_conv = nn.Conv3d(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias) self.activation = activation def forward(self, inputs): feature = self.input_conv(inputs) if self.activation: feature = self.activation(feature) gating = torch.sigmoid(self.gating_conv(inputs)) return feature * gating class GatedDeconv3d(nn.Module): def __init__(self, in_channels, out_channels, kernel_size, stride, padding, scale_factor, dilation=1, groups=1, bias=True, activation=nn.LeakyReLU(0.2, inplace=True)): super().__init__() self.scale_factor = scale_factor self.deconv = GatedConv3d(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias, activation) def forward(self, inputs): inputs = F.interpolate(inputs, scale_factor=(1, self.scale_factor, self.scale_factor)) return self.deconv(inputs) def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.): return _no_grad_trunc_normal_(tensor, mean, std, a, b) def _no_grad_trunc_normal_(tensor, mean, std, a, b): # Cut & paste from PyTorch official master until it's in a few official releases - RW # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf def norm_cdf(x): # Computes standard normal cumulative distribution function return (1. + math.erf(x / math.sqrt(2.))) / 2. with torch.no_grad(): # Values are generated by using a truncated uniform distribution and # then using the inverse CDF for the normal distribution. # Get upper and lower cdf values l = norm_cdf((a - mean) / std) u = norm_cdf((b - mean) / std) # Uniformly fill tensor with values from [l, u], then translate to # [2l-1, 2u-1]. tensor.uniform_(2 * l - 1, 2 * u - 1) # Use inverse cdf transform for normal distribution to get truncated # standard normal tensor.erfinv_() # Transform to proper mean, std tensor.mul_(std * math.sqrt(2.)) tensor.add_(mean) # Clamp to ensure it's in the proper range tensor.clamp_(min=a, max=b) return tensor