|
|
|
import torch |
|
import torch.nn as nn |
|
from .BaseNetwork import BaseNetwork |
|
|
|
|
|
class Discriminator(BaseNetwork): |
|
def __init__(self, in_channels, conv_type, dist_cnum, use_sigmoid=False, use_spectral_norm=True, init_weights=True): |
|
""" |
|
|
|
Args: |
|
in_channels: The input channels of the discriminator |
|
use_sigmoid: Whether to use sigmoid for the base network (true for the nsgan) |
|
use_spectral_norm: The usage of the spectral norm: always be true for the stability of GAN |
|
init_weights: always be True |
|
""" |
|
super(Discriminator, self).__init__(conv_type) |
|
self.use_sigmoid = use_sigmoid |
|
nf = dist_cnum |
|
|
|
self.conv = nn.Sequential( |
|
spectral_norm( |
|
nn.Conv3d(in_channels=in_channels, out_channels=nf * 1, kernel_size=(3, 5, 5), stride=(1, 2, 2), |
|
padding=(1, 2, 2), |
|
bias=not use_spectral_norm), use_spectral_norm), |
|
nn.LeakyReLU(0.2, inplace=True), |
|
spectral_norm( |
|
nn.Conv3d(in_channels=nf * 1, out_channels=nf * 2, kernel_size=(3, 5, 5), stride=(1, 2, 2), |
|
padding=(1, 2, 2), |
|
bias=not use_spectral_norm), use_spectral_norm), |
|
nn.LeakyReLU(0.2, inplace=True), |
|
spectral_norm( |
|
nn.Conv3d(in_channels=nf * 2, out_channels=nf * 4, kernel_size=(3, 5, 5), stride=(1, 2, 2), |
|
padding=(1, 2, 2), |
|
bias=not use_spectral_norm), use_spectral_norm), |
|
nn.LeakyReLU(0.2, inplace=True), |
|
spectral_norm( |
|
nn.Conv3d(in_channels=nf * 4, out_channels=nf * 4, kernel_size=(3, 5, 5), stride=(1, 2, 2), |
|
padding=(1, 2, 2), |
|
bias=not use_spectral_norm), use_spectral_norm), |
|
nn.LeakyReLU(0.2, inplace=True), |
|
spectral_norm( |
|
nn.Conv3d(in_channels=nf * 4, out_channels=nf * 4, kernel_size=(3, 5, 5), stride=(1, 2, 2), |
|
padding=(1, 2, 2), |
|
bias=not use_spectral_norm), use_spectral_norm), |
|
nn.LeakyReLU(0.2, inplace=True), |
|
nn.Conv3d(in_channels=nf * 4, out_channels=nf * 4, kernel_size=(3, 5, 5), stride=(1, 2, 2), |
|
padding=(1, 2, 2)) |
|
) |
|
|
|
if init_weights: |
|
self.init_weights() |
|
|
|
def forward(self, xs, t): |
|
""" |
|
|
|
Args: |
|
xs: Input feature, with shape of [bt, c, h, w] |
|
|
|
Returns: The discriminative map from the GAN |
|
|
|
""" |
|
bt, c, h, w = xs.shape |
|
b = bt // t |
|
xs = xs.view(b, t, c, h, w).permute(0, 2, 1, 3, 4).contiguous() |
|
feat = self.conv(xs) |
|
if self.use_sigmoid: |
|
feat = torch.sigmoid(feat) |
|
out = torch.transpose(feat, 1, 2) |
|
return out |
|
|
|
|
|
def spectral_norm(module, mode=True): |
|
if mode: |
|
return nn.utils.spectral_norm(module) |
|
return module |
|
|