File size: 3,110 Bytes
d4b77ac |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 |
# temporal patch GAN to maintain the temporal consecutive of the flows
import torch
import torch.nn as nn
from .BaseNetwork import BaseNetwork
class Discriminator(BaseNetwork):
def __init__(self, in_channels, conv_type, dist_cnum, use_sigmoid=False, use_spectral_norm=True, init_weights=True):
"""
Args:
in_channels: The input channels of the discriminator
use_sigmoid: Whether to use sigmoid for the base network (true for the nsgan)
use_spectral_norm: The usage of the spectral norm: always be true for the stability of GAN
init_weights: always be True
"""
super(Discriminator, self).__init__(conv_type)
self.use_sigmoid = use_sigmoid
nf = dist_cnum
self.conv = nn.Sequential(
spectral_norm(
nn.Conv3d(in_channels=in_channels, out_channels=nf * 1, kernel_size=(3, 5, 5), stride=(1, 2, 2),
padding=(1, 2, 2),
bias=not use_spectral_norm), use_spectral_norm),
nn.LeakyReLU(0.2, inplace=True),
spectral_norm(
nn.Conv3d(in_channels=nf * 1, out_channels=nf * 2, kernel_size=(3, 5, 5), stride=(1, 2, 2),
padding=(1, 2, 2),
bias=not use_spectral_norm), use_spectral_norm),
nn.LeakyReLU(0.2, inplace=True),
spectral_norm(
nn.Conv3d(in_channels=nf * 2, out_channels=nf * 4, kernel_size=(3, 5, 5), stride=(1, 2, 2),
padding=(1, 2, 2),
bias=not use_spectral_norm), use_spectral_norm),
nn.LeakyReLU(0.2, inplace=True),
spectral_norm(
nn.Conv3d(in_channels=nf * 4, out_channels=nf * 4, kernel_size=(3, 5, 5), stride=(1, 2, 2),
padding=(1, 2, 2),
bias=not use_spectral_norm), use_spectral_norm),
nn.LeakyReLU(0.2, inplace=True),
spectral_norm(
nn.Conv3d(in_channels=nf * 4, out_channels=nf * 4, kernel_size=(3, 5, 5), stride=(1, 2, 2),
padding=(1, 2, 2),
bias=not use_spectral_norm), use_spectral_norm),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv3d(in_channels=nf * 4, out_channels=nf * 4, kernel_size=(3, 5, 5), stride=(1, 2, 2),
padding=(1, 2, 2))
)
if init_weights:
self.init_weights()
def forward(self, xs, t):
"""
Args:
xs: Input feature, with shape of [bt, c, h, w]
Returns: The discriminative map from the GAN
"""
bt, c, h, w = xs.shape
b = bt // t
xs = xs.view(b, t, c, h, w).permute(0, 2, 1, 3, 4).contiguous()
feat = self.conv(xs)
if self.use_sigmoid:
feat = torch.sigmoid(feat)
out = torch.transpose(feat, 1, 2) # [b, t, c, h, w]
return out
def spectral_norm(module, mode=True):
if mode:
return nn.utils.spectral_norm(module)
return module
|