|
import torch |
|
import torch.nn as nn |
|
import numpy as np |
|
import torch.nn.functional as F |
|
|
|
|
|
class VanillaConv(nn.Module): |
|
def __init__( |
|
self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, |
|
groups=1, bias=True, norm="SN", activation=nn.LeakyReLU(0.2, inplace=True) |
|
): |
|
|
|
super().__init__() |
|
if padding == -1: |
|
if isinstance(kernel_size, int): |
|
kernel_size = (kernel_size, kernel_size) |
|
if isinstance(dilation, int): |
|
dilation = (dilation, dilation) |
|
self.padding = tuple(((np.array(kernel_size) - 1) * np.array(dilation)) // 2) if padding == -1 else padding |
|
self.featureConv = nn.Conv2d( |
|
in_channels, out_channels, kernel_size, |
|
stride, self.padding, dilation, groups, bias) |
|
|
|
self.norm = norm |
|
if norm == "BN": |
|
self.norm_layer = nn.BatchNorm2d(out_channels) |
|
elif norm == "IN": |
|
self.norm_layer = nn.InstanceNorm2d(out_channels, track_running_stats=True) |
|
elif norm == "SN": |
|
self.norm = None |
|
self.featureConv = nn.utils.spectral_norm(self.featureConv) |
|
else: |
|
self.norm = None |
|
|
|
self.activation = activation |
|
|
|
def forward(self, xs): |
|
out = self.featureConv(xs) |
|
if self.activation: |
|
out = self.activation(out) |
|
if self.norm is not None: |
|
out = self.norm_layer(out) |
|
return out |
|
|
|
|
|
class VanillaDeconv(nn.Module): |
|
def __init__( |
|
self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, |
|
groups=1, bias=True, norm="SN", activation=nn.LeakyReLU(0.2, inplace=True), |
|
scale_factor=2 |
|
): |
|
super().__init__() |
|
self.conv = VanillaConv( |
|
in_channels, out_channels, kernel_size, stride, padding, dilation, |
|
groups, bias, norm, activation) |
|
self.scale_factor = scale_factor |
|
|
|
def forward(self, xs): |
|
xs_resized = F.interpolate(xs, scale_factor=self.scale_factor) |
|
return self.conv(xs_resized) |
|
|
|
|
|
class GatedConv(VanillaConv): |
|
def __init__( |
|
self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, |
|
groups=1, bias=True, norm="SN", activation=nn.LeakyReLU(0.2, inplace=True) |
|
): |
|
super().__init__( |
|
in_channels, out_channels, kernel_size, stride, padding, dilation, |
|
groups, bias, norm, activation |
|
) |
|
self.gatingConv = nn.Conv2d( |
|
in_channels, out_channels, kernel_size, |
|
stride, self.padding, dilation, groups, bias) |
|
if norm == 'SN': |
|
self.gatingConv = nn.utils.spectral_norm(self.gatingConv) |
|
self.sigmoid = nn.Sigmoid() |
|
self.store_gated_values = False |
|
|
|
def gated(self, mask): |
|
|
|
out = self.sigmoid(mask) |
|
if self.store_gated_values: |
|
self.gated_values = out.detach().cpu() |
|
return out |
|
|
|
def forward(self, xs): |
|
gating = self.gatingConv(xs) |
|
feature = self.featureConv(xs) |
|
if self.activation: |
|
feature = self.activation(feature) |
|
out = self.gated(gating) * feature |
|
if self.norm is not None: |
|
out = self.norm_layer(out) |
|
return out |
|
|
|
|
|
class GatedDeconv(VanillaDeconv): |
|
def __init__( |
|
self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, |
|
groups=1, bias=True, norm="SN", activation=nn.LeakyReLU(0.2, inplace=True), |
|
scale_factor=2 |
|
): |
|
super().__init__( |
|
in_channels, out_channels, kernel_size, stride, padding, dilation, |
|
groups, bias, norm, activation, scale_factor |
|
) |
|
self.conv = GatedConv( |
|
in_channels, out_channels, kernel_size, stride, padding, dilation, |
|
groups, bias, norm, activation) |
|
|
|
|
|
class PartialConv(VanillaConv): |
|
def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, |
|
groups=1, bias=True, norm="SN", activation=nn.LeakyReLU(0.2, inplace=True)): |
|
super().__init__( |
|
in_channels, out_channels, kernel_size, stride, padding, dilation, |
|
groups, bias, norm, activation |
|
) |
|
self.mask_sum_conv = nn.Conv2d(1, 1, kernel_size, |
|
stride, self.padding, dilation, groups, False) |
|
nn.init.constant_(self.mask_sum_conv.weight, 1.0) |
|
|
|
|
|
for param in self.mask_sum_conv.parameters(): |
|
param.requires_grad = False |
|
|
|
def forward(self, input_tuple): |
|
|
|
|
|
|
|
|
|
inp, mask = input_tuple |
|
|
|
|
|
|
|
output = self.featureConv(mask * inp) |
|
|
|
|
|
if self.featureConv.bias is not None: |
|
output_bias = self.featureConv.bias.view(1, -1, 1, 1) |
|
else: |
|
output_bias = torch.zeros([1, 1, 1, 1]).to(inp.device) |
|
|
|
|
|
with torch.no_grad(): |
|
mask_sum = self.mask_sum_conv(mask) |
|
|
|
|
|
no_update_holes = (mask_sum == 0) |
|
|
|
|
|
mask_sum_no_zero = mask_sum.masked_fill_(no_update_holes, 1.0) |
|
|
|
|
|
|
|
output = (output - output_bias) / mask_sum_no_zero + output_bias |
|
output = output.masked_fill_(no_update_holes, 0.0) |
|
|
|
|
|
new_mask = torch.ones_like(mask_sum) |
|
new_mask = new_mask.masked_fill_(no_update_holes, 0.0) |
|
|
|
if self.activation is not None: |
|
output = self.activation(output) |
|
if self.norm is not None: |
|
output = self.norm_layer(output) |
|
return output, new_mask |
|
|
|
|
|
class PartialDeconv(nn.Module): |
|
def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, |
|
groups=1, bias=True, norm="SN", activation=nn.LeakyReLU(0.2, inplace=True), |
|
scale_factor=2): |
|
super().__init__() |
|
self.conv = PartialConv( |
|
in_channels, out_channels, kernel_size, stride, padding, dilation, |
|
groups, bias, norm, activation) |
|
self.scale_factor = scale_factor |
|
|
|
def forward(self, input_tuple): |
|
inp, mask = input_tuple |
|
inp_resized = F.interpolate(inp, scale_factor=self.scale_factor) |
|
with torch.no_grad(): |
|
mask_resized = F.interpolate(mask, scale_factor=self.scale_factor) |
|
return self.conv((inp_resized, mask_resized)) |
|
|