|
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
|
|
from ..utils import xavier_init
|
|
from .registry import UPSAMPLE_LAYERS
|
|
|
|
UPSAMPLE_LAYERS.register_module('nearest', module=nn.Upsample)
|
|
UPSAMPLE_LAYERS.register_module('bilinear', module=nn.Upsample)
|
|
|
|
|
|
@UPSAMPLE_LAYERS.register_module(name='pixel_shuffle')
|
|
class PixelShufflePack(nn.Module):
|
|
"""Pixel Shuffle upsample layer.
|
|
|
|
This module packs `F.pixel_shuffle()` and a nn.Conv2d module together to
|
|
achieve a simple upsampling with pixel shuffle.
|
|
|
|
Args:
|
|
in_channels (int): Number of input channels.
|
|
out_channels (int): Number of output channels.
|
|
scale_factor (int): Upsample ratio.
|
|
upsample_kernel (int): Kernel size of the conv layer to expand the
|
|
channels.
|
|
"""
|
|
|
|
def __init__(self, in_channels, out_channels, scale_factor,
|
|
upsample_kernel):
|
|
super(PixelShufflePack, self).__init__()
|
|
self.in_channels = in_channels
|
|
self.out_channels = out_channels
|
|
self.scale_factor = scale_factor
|
|
self.upsample_kernel = upsample_kernel
|
|
self.upsample_conv = nn.Conv2d(
|
|
self.in_channels,
|
|
self.out_channels * scale_factor * scale_factor,
|
|
self.upsample_kernel,
|
|
padding=(self.upsample_kernel - 1) // 2)
|
|
self.init_weights()
|
|
|
|
def init_weights(self):
|
|
xavier_init(self.upsample_conv, distribution='uniform')
|
|
|
|
def forward(self, x):
|
|
x = self.upsample_conv(x)
|
|
x = F.pixel_shuffle(x, self.scale_factor)
|
|
return x
|
|
|
|
|
|
def build_upsample_layer(cfg, *args, **kwargs):
|
|
"""Build upsample layer.
|
|
|
|
Args:
|
|
cfg (dict): The upsample layer config, which should contain:
|
|
|
|
- type (str): Layer type.
|
|
- scale_factor (int): Upsample ratio, which is not applicable to
|
|
deconv.
|
|
- layer args: Args needed to instantiate a upsample layer.
|
|
args (argument list): Arguments passed to the ``__init__``
|
|
method of the corresponding conv layer.
|
|
kwargs (keyword arguments): Keyword arguments passed to the
|
|
``__init__`` method of the corresponding conv layer.
|
|
|
|
Returns:
|
|
nn.Module: Created upsample layer.
|
|
"""
|
|
if not isinstance(cfg, dict):
|
|
raise TypeError(f'cfg must be a dict, but got {type(cfg)}')
|
|
if 'type' not in cfg:
|
|
raise KeyError(
|
|
f'the cfg dict must contain the key "type", but got {cfg}')
|
|
cfg_ = cfg.copy()
|
|
|
|
layer_type = cfg_.pop('type')
|
|
if layer_type not in UPSAMPLE_LAYERS:
|
|
raise KeyError(f'Unrecognized upsample type {layer_type}')
|
|
else:
|
|
upsample = UPSAMPLE_LAYERS.get(layer_type)
|
|
|
|
if upsample is nn.Upsample:
|
|
cfg_['mode'] = layer_type
|
|
layer = upsample(*args, **kwargs, **cfg_)
|
|
return layer
|
|
|