|
"""Layers used for up-sampling or down-sampling images. |
|
|
|
Many functions are ported from https://github.com/NVlabs/stylegan2. |
|
""" |
|
|
|
import torch.nn as nn |
|
import torch |
|
import torch.nn.functional as F |
|
import numpy as np |
|
from .op import upfirdn2d |
|
|
|
|
|
|
|
def get_weight(module, |
|
shape, |
|
weight_var='weight', |
|
kernel_init=None): |
|
"""Get/create weight tensor for a convolution or fully-connected layer.""" |
|
|
|
return module.param(weight_var, kernel_init, shape) |
|
|
|
|
|
class Conv2d(nn.Module): |
|
"""Conv2d layer with optimal upsampling and downsampling (StyleGAN2).""" |
|
|
|
def __init__(self, in_ch, out_ch, kernel, up=False, down=False, |
|
resample_kernel=(1, 3, 3, 1), |
|
use_bias=True, |
|
kernel_init=None): |
|
super().__init__() |
|
assert not (up and down) |
|
assert kernel >= 1 and kernel % 2 == 1 |
|
self.weight = nn.Parameter(torch.zeros(out_ch, in_ch, kernel, kernel)) |
|
if kernel_init is not None: |
|
self.weight.data = kernel_init(self.weight.data.shape) |
|
if use_bias: |
|
self.bias = nn.Parameter(torch.zeros(out_ch)) |
|
|
|
self.up = up |
|
self.down = down |
|
self.resample_kernel = resample_kernel |
|
self.kernel = kernel |
|
self.use_bias = use_bias |
|
|
|
def forward(self, x): |
|
if self.up: |
|
x = upsample_conv_2d(x, self.weight, k=self.resample_kernel) |
|
elif self.down: |
|
x = conv_downsample_2d(x, self.weight, k=self.resample_kernel) |
|
else: |
|
x = F.conv2d(x, self.weight, stride=1, padding=self.kernel // 2) |
|
|
|
if self.use_bias: |
|
x = x + self.bias.reshape(1, -1, 1, 1) |
|
|
|
return x |
|
|
|
|
|
def naive_upsample_2d(x, factor=2): |
|
_N, C, H, W = x.shape |
|
x = torch.reshape(x, (-1, C, H, 1, W, 1)) |
|
x = x.repeat(1, 1, 1, factor, 1, factor) |
|
return torch.reshape(x, (-1, C, H * factor, W * factor)) |
|
|
|
|
|
def naive_downsample_2d(x, factor=2): |
|
_N, C, H, W = x.shape |
|
x = torch.reshape(x, (-1, C, H // factor, factor, W // factor, factor)) |
|
return torch.mean(x, dim=(3, 5)) |
|
|
|
|
|
def upsample_conv_2d(x, w, k=None, factor=2, gain=1): |
|
"""Fused `upsample_2d()` followed by `tf.nn.conv2d()`. |
|
|
|
Padding is performed only once at the beginning, not between the |
|
operations. |
|
The fused op is considerably more efficient than performing the same |
|
calculation |
|
using standard TensorFlow ops. It supports gradients of arbitrary order. |
|
Args: |
|
x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, |
|
C]`. |
|
w: Weight tensor of the shape `[filterH, filterW, inChannels, |
|
outChannels]`. Grouped convolution can be performed by `inChannels = |
|
x.shape[0] // numGroups`. |
|
k: FIR filter of the shape `[firH, firW]` or `[firN]` |
|
(separable). The default is `[1] * factor`, which corresponds to |
|
nearest-neighbor upsampling. |
|
factor: Integer upsampling factor (default: 2). |
|
gain: Scaling factor for signal magnitude (default: 1.0). |
|
|
|
Returns: |
|
Tensor of the shape `[N, C, H * factor, W * factor]` or |
|
`[N, H * factor, W * factor, C]`, and same datatype as `x`. |
|
""" |
|
|
|
assert isinstance(factor, int) and factor >= 1 |
|
|
|
|
|
assert len(w.shape) == 4 |
|
convH = w.shape[2] |
|
convW = w.shape[3] |
|
inC = w.shape[1] |
|
outC = w.shape[0] |
|
|
|
assert convW == convH |
|
|
|
|
|
if k is None: |
|
k = [1] * factor |
|
k = _setup_kernel(k) * (gain * (factor ** 2)) |
|
p = (k.shape[0] - factor) - (convW - 1) |
|
|
|
stride = (factor, factor) |
|
|
|
|
|
stride = [1, 1, factor, factor] |
|
output_shape = ((_shape(x, 2) - 1) * factor + convH, (_shape(x, 3) - 1) * factor + convW) |
|
output_padding = (output_shape[0] - (_shape(x, 2) - 1) * stride[0] - convH, |
|
output_shape[1] - (_shape(x, 3) - 1) * stride[1] - convW) |
|
assert output_padding[0] >= 0 and output_padding[1] >= 0 |
|
num_groups = _shape(x, 1) // inC |
|
|
|
|
|
w = torch.reshape(w, (num_groups, -1, inC, convH, convW)) |
|
w = w[..., ::-1, ::-1].permute(0, 2, 1, 3, 4) |
|
w = torch.reshape(w, (num_groups * inC, -1, convH, convW)) |
|
|
|
x = F.conv_transpose2d(x, w, stride=stride, output_padding=output_padding, padding=0) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return upfirdn2d(x, torch.tensor(k, device=x.device), |
|
pad=((p + 1) // 2 + factor - 1, p // 2 + 1)) |
|
|
|
|
|
def conv_downsample_2d(x, w, k=None, factor=2, gain=1): |
|
"""Fused `tf.nn.conv2d()` followed by `downsample_2d()`. |
|
|
|
Padding is performed only once at the beginning, not between the operations. |
|
The fused op is considerably more efficient than performing the same |
|
calculation |
|
using standard TensorFlow ops. It supports gradients of arbitrary order. |
|
Args: |
|
x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, |
|
C]`. |
|
w: Weight tensor of the shape `[filterH, filterW, inChannels, |
|
outChannels]`. Grouped convolution can be performed by `inChannels = |
|
x.shape[0] // numGroups`. |
|
k: FIR filter of the shape `[firH, firW]` or `[firN]` |
|
(separable). The default is `[1] * factor`, which corresponds to |
|
average pooling. |
|
factor: Integer downsampling factor (default: 2). |
|
gain: Scaling factor for signal magnitude (default: 1.0). |
|
|
|
Returns: |
|
Tensor of the shape `[N, C, H // factor, W // factor]` or |
|
`[N, H // factor, W // factor, C]`, and same datatype as `x`. |
|
""" |
|
|
|
assert isinstance(factor, int) and factor >= 1 |
|
_outC, _inC, convH, convW = w.shape |
|
assert convW == convH |
|
if k is None: |
|
k = [1] * factor |
|
k = _setup_kernel(k) * gain |
|
p = (k.shape[0] - factor) + (convW - 1) |
|
s = [factor, factor] |
|
x = upfirdn2d(x, torch.tensor(k, device=x.device), |
|
pad=((p + 1) // 2, p // 2)) |
|
return F.conv2d(x, w, stride=s, padding=0) |
|
|
|
|
|
def _setup_kernel(k): |
|
k = np.asarray(k, dtype=np.float32) |
|
if k.ndim == 1: |
|
k = np.outer(k, k) |
|
k /= np.sum(k) |
|
assert k.ndim == 2 |
|
assert k.shape[0] == k.shape[1] |
|
return k |
|
|
|
|
|
def _shape(x, dim): |
|
return x.shape[dim] |
|
|
|
|
|
def upsample_2d(x, k=None, factor=2, gain=1): |
|
r"""Upsample a batch of 2D images with the given filter. |
|
|
|
Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]` |
|
and upsamples each image with the given filter. The filter is normalized so |
|
that |
|
if the input pixels are constant, they will be scaled by the specified |
|
`gain`. |
|
Pixels outside the image are assumed to be zero, and the filter is padded |
|
with |
|
zeros so that its shape is a multiple of the upsampling factor. |
|
Args: |
|
x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, |
|
C]`. |
|
k: FIR filter of the shape `[firH, firW]` or `[firN]` |
|
(separable). The default is `[1] * factor`, which corresponds to |
|
nearest-neighbor upsampling. |
|
factor: Integer upsampling factor (default: 2). |
|
gain: Scaling factor for signal magnitude (default: 1.0). |
|
|
|
Returns: |
|
Tensor of the shape `[N, C, H * factor, W * factor]` |
|
""" |
|
assert isinstance(factor, int) and factor >= 1 |
|
if k is None: |
|
k = [1] * factor |
|
k = _setup_kernel(k) * (gain * (factor ** 2)) |
|
p = k.shape[0] - factor |
|
return upfirdn2d(x, torch.tensor(k, device=x.device), |
|
up=factor, pad=((p + 1) // 2 + factor - 1, p // 2)) |
|
|
|
|
|
def downsample_2d(x, k=None, factor=2, gain=1): |
|
r"""Downsample a batch of 2D images with the given filter. |
|
|
|
Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]` |
|
and downsamples each image with the given filter. The filter is normalized |
|
so that |
|
if the input pixels are constant, they will be scaled by the specified |
|
`gain`. |
|
Pixels outside the image are assumed to be zero, and the filter is padded |
|
with |
|
zeros so that its shape is a multiple of the downsampling factor. |
|
Args: |
|
x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, |
|
C]`. |
|
k: FIR filter of the shape `[firH, firW]` or `[firN]` |
|
(separable). The default is `[1] * factor`, which corresponds to |
|
average pooling. |
|
factor: Integer downsampling factor (default: 2). |
|
gain: Scaling factor for signal magnitude (default: 1.0). |
|
|
|
Returns: |
|
Tensor of the shape `[N, C, H // factor, W // factor]` |
|
""" |
|
|
|
assert isinstance(factor, int) and factor >= 1 |
|
if k is None: |
|
k = [1] * factor |
|
k = _setup_kernel(k) * gain |
|
p = k.shape[0] - factor |
|
return upfirdn2d(x, torch.tensor(k, device=x.device), |
|
down=factor, pad=((p + 1) // 2, p // 2)) |
|
|