|
|
|
|
|
import torch |
|
import torch.backends.cudnn as cudnn |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
import functools |
|
from torch.autograd import gradcheck |
|
from torch.autograd import Function |
|
from torch.autograd import Variable |
|
from torch.autograd import gradcheck |
|
from torch.autograd import Function |
|
import numpy as np |
|
|
|
|
|
def add_coordConv_channels(t): |
|
n,c,h,w=t.size() |
|
xx_channel=np.ones((h, w)) |
|
xx_range=np.array(range(h)) |
|
xx_range=np.expand_dims(xx_range,-1) |
|
xx_coord=xx_channel*xx_range |
|
yy_coord=xx_coord.transpose() |
|
|
|
xx_coord=xx_coord/(h-1) |
|
yy_coord=yy_coord/(h-1) |
|
xx_coord=xx_coord*2 - 1 |
|
yy_coord=yy_coord*2 - 1 |
|
xx_coord=torch.from_numpy(xx_coord).float() |
|
yy_coord=torch.from_numpy(yy_coord).float() |
|
|
|
if t.is_cuda: |
|
xx_coord=xx_coord.cuda() |
|
yy_coord=yy_coord.cuda() |
|
|
|
xx_coord=xx_coord.unsqueeze(0).unsqueeze(0).repeat(n,1,1,1) |
|
yy_coord=yy_coord.unsqueeze(0).unsqueeze(0).repeat(n,1,1,1) |
|
|
|
t_cc=torch.cat((t,xx_coord,yy_coord),dim=1) |
|
|
|
return t_cc |
|
|
|
|
|
|
|
class DenseBlockEncoder(nn.Module): |
|
def __init__(self, n_channels, n_convs, activation=nn.ReLU, args=[False]): |
|
super(DenseBlockEncoder, self).__init__() |
|
assert(n_convs > 0) |
|
|
|
self.n_channels = n_channels |
|
self.n_convs = n_convs |
|
self.layers = nn.ModuleList() |
|
for i in range(n_convs): |
|
self.layers.append(nn.Sequential( |
|
nn.BatchNorm2d(n_channels), |
|
activation(*args), |
|
nn.Conv2d(n_channels, n_channels, 3, stride=1, padding=1, bias=False),)) |
|
|
|
def forward(self, inputs): |
|
outputs = [] |
|
|
|
for i, layer in enumerate(self.layers): |
|
if i > 0: |
|
next_output = 0 |
|
for no in outputs: |
|
next_output = next_output + no |
|
outputs.append(next_output) |
|
else: |
|
outputs.append(layer(inputs)) |
|
return outputs[-1] |
|
|
|
|
|
class DenseBlockDecoder(nn.Module): |
|
def __init__(self, n_channels, n_convs, activation=nn.ReLU, args=[False]): |
|
super(DenseBlockDecoder, self).__init__() |
|
assert(n_convs > 0) |
|
|
|
self.n_channels = n_channels |
|
self.n_convs = n_convs |
|
self.layers = nn.ModuleList() |
|
for i in range(n_convs): |
|
self.layers.append(nn.Sequential( |
|
nn.BatchNorm2d(n_channels), |
|
activation(*args), |
|
nn.ConvTranspose2d(n_channels, n_channels, 3, stride=1, padding=1, bias=False),)) |
|
|
|
def forward(self, inputs): |
|
outputs = [] |
|
|
|
for i, layer in enumerate(self.layers): |
|
if i > 0: |
|
next_output = 0 |
|
for no in outputs: |
|
next_output = next_output + no |
|
outputs.append(next_output) |
|
else: |
|
outputs.append(layer(inputs)) |
|
return outputs[-1] |
|
|
|
class DenseTransitionBlockEncoder(nn.Module): |
|
def __init__(self, n_channels_in, n_channels_out, mp, activation=nn.ReLU, args=[False]): |
|
super(DenseTransitionBlockEncoder, self).__init__() |
|
self.n_channels_in = n_channels_in |
|
self.n_channels_out = n_channels_out |
|
self.mp = mp |
|
self.main = nn.Sequential( |
|
nn.BatchNorm2d(n_channels_in), |
|
activation(*args), |
|
nn.Conv2d(n_channels_in, n_channels_out, 1, stride=1, padding=0, bias=False), |
|
nn.MaxPool2d(mp), |
|
) |
|
def forward(self, inputs): |
|
|
|
return self.main(inputs) |
|
|
|
|
|
class DenseTransitionBlockDecoder(nn.Module): |
|
def __init__(self, n_channels_in, n_channels_out, activation=nn.ReLU, args=[False]): |
|
super(DenseTransitionBlockDecoder, self).__init__() |
|
self.n_channels_in = n_channels_in |
|
self.n_channels_out = n_channels_out |
|
self.main = nn.Sequential( |
|
nn.BatchNorm2d(n_channels_in), |
|
activation(*args), |
|
nn.ConvTranspose2d(n_channels_in, n_channels_out, 4, stride=2, padding=1, bias=False), |
|
) |
|
def forward(self, inputs): |
|
|
|
return self.main(inputs) |
|
|
|
|
|
class waspDenseEncoder128(nn.Module): |
|
def __init__(self, nc=1, ndf = 32, ndim = 128, activation=nn.LeakyReLU, args=[0.2, False], f_activation=nn.Tanh, f_args=[]): |
|
super(waspDenseEncoder128, self).__init__() |
|
self.ndim = ndim |
|
|
|
self.main = nn.Sequential( |
|
|
|
nn.BatchNorm2d(nc), |
|
nn.ReLU(True), |
|
nn.Conv2d(nc, ndf, 4, stride=2, padding=1), |
|
|
|
|
|
DenseBlockEncoder(ndf, 6), |
|
DenseTransitionBlockEncoder(ndf, ndf*2, 2, activation=activation, args=args), |
|
|
|
|
|
DenseBlockEncoder(ndf*2, 12), |
|
DenseTransitionBlockEncoder(ndf*2, ndf*4, 2, activation=activation, args=args), |
|
|
|
|
|
DenseBlockEncoder(ndf*4, 16), |
|
DenseTransitionBlockEncoder(ndf*4, ndf*8, 2, activation=activation, args=args), |
|
|
|
|
|
DenseBlockEncoder(ndf*8, 16), |
|
DenseTransitionBlockEncoder(ndf*8, ndf*8, 2, activation=activation, args=args), |
|
|
|
|
|
DenseBlockEncoder(ndf*8, 16), |
|
DenseTransitionBlockEncoder(ndf*8, ndim, 4, activation=activation, args=args), |
|
f_activation(*f_args), |
|
) |
|
|
|
def forward(self, input): |
|
input=add_coordConv_channels(input) |
|
output = self.main(input).view(-1,self.ndim) |
|
|
|
return output |
|
|
|
class waspDenseDecoder128(nn.Module): |
|
def __init__(self, nz=128, nc=1, ngf=32, lb=0, ub=1, activation=nn.ReLU, args=[False], f_activation=nn.Hardtanh, f_args=[]): |
|
super(waspDenseDecoder128, self).__init__() |
|
self.main = nn.Sequential( |
|
|
|
nn.BatchNorm2d(nz), |
|
activation(*args), |
|
nn.ConvTranspose2d(nz, ngf * 8, 4, 1, 0, bias=False), |
|
|
|
|
|
DenseBlockDecoder(ngf*8, 16), |
|
DenseTransitionBlockDecoder(ngf*8, ngf*8), |
|
|
|
|
|
DenseBlockDecoder(ngf*8, 16), |
|
DenseTransitionBlockDecoder(ngf*8, ngf*4), |
|
|
|
|
|
DenseBlockDecoder(ngf*4, 12), |
|
DenseTransitionBlockDecoder(ngf*4, ngf*2), |
|
|
|
|
|
DenseBlockDecoder(ngf*2, 6), |
|
DenseTransitionBlockDecoder(ngf*2, ngf), |
|
|
|
|
|
DenseBlockDecoder(ngf, 6), |
|
DenseTransitionBlockDecoder(ngf, ngf), |
|
|
|
|
|
nn.BatchNorm2d(ngf), |
|
activation(*args), |
|
nn.ConvTranspose2d(ngf, nc, 3, stride=1, padding=1, bias=False), |
|
f_activation(*f_args), |
|
) |
|
|
|
|
|
|
|
|
|
def forward(self, inputs): |
|
|
|
return self.main(inputs) |
|
|
|
|
|
|
|
|
|
class waspDenseEncoder512(nn.Module): |
|
def __init__(self, nc=1, ndf = 32, ndim = 128, activation=nn.LeakyReLU, args=[0.2, False], f_activation=nn.Tanh, f_args=[]): |
|
super(waspDenseEncoder512, self).__init__() |
|
self.ndim = ndim |
|
|
|
self.main = nn.Sequential( |
|
|
|
nn.BatchNorm2d(nc), |
|
nn.ReLU(True), |
|
nn.Conv2d(nc, ndf, 4, stride=2, padding=1), |
|
|
|
|
|
DenseBlockEncoder(ndf, 6), |
|
DenseTransitionBlockEncoder(ndf, ndf*2, 2, activation=activation, args=args), |
|
|
|
|
|
DenseBlockEncoder(ndf*2, 12), |
|
DenseTransitionBlockEncoder(ndf*2, ndf*4, 2, activation=activation, args=args), |
|
|
|
|
|
DenseBlockEncoder(ndf*4, 16), |
|
DenseTransitionBlockEncoder(ndf*4, ndf*8, 2, activation=activation, args=args), |
|
|
|
|
|
DenseBlockEncoder(ndf*8, 16), |
|
DenseTransitionBlockEncoder(ndf*8, ndf*8, 2, activation=activation, args=args), |
|
|
|
|
|
DenseBlockEncoder(ndf*8, 16), |
|
DenseTransitionBlockEncoder(ndf*8, ndf*8, 4, activation=activation, args=args), |
|
f_activation(*f_args), |
|
|
|
|
|
DenseBlockEncoder(ndf*8, 16), |
|
DenseTransitionBlockEncoder(ndf*8, ndim, 4, activation=activation, args=args), |
|
f_activation(*f_args), |
|
) |
|
|
|
def forward(self, input): |
|
input=add_coordConv_channels(input) |
|
output = self.main(input).view(-1,self.ndim) |
|
|
|
|
|
|
|
return output |
|
|
|
class waspDenseDecoder512(nn.Module): |
|
def __init__(self, nz=128, nc=1, ngf=32, lb=0, ub=1, activation=nn.ReLU, args=[False], f_activation=nn.Tanh, f_args=[]): |
|
super(waspDenseDecoder512, self).__init__() |
|
self.main = nn.Sequential( |
|
|
|
nn.BatchNorm2d(nz), |
|
activation(*args), |
|
nn.ConvTranspose2d(nz, ngf * 8, 4, 1, 0, bias=False), |
|
|
|
|
|
DenseBlockDecoder(ngf*8, 16), |
|
DenseTransitionBlockDecoder(ngf*8, ngf*8), |
|
|
|
|
|
DenseBlockDecoder(ngf*8, 16), |
|
DenseTransitionBlockDecoder(ngf*8, ngf*8), |
|
|
|
|
|
DenseBlockDecoder(ngf*8, 16), |
|
DenseTransitionBlockDecoder(ngf*8, ngf*4), |
|
|
|
|
|
DenseBlockDecoder(ngf*4, 12), |
|
DenseTransitionBlockDecoder(ngf*4, ngf*2), |
|
|
|
|
|
DenseBlockDecoder(ngf*2, 6), |
|
DenseTransitionBlockDecoder(ngf*2, ngf), |
|
|
|
|
|
DenseBlockDecoder(ngf, 6), |
|
DenseTransitionBlockDecoder(ngf, ngf), |
|
|
|
|
|
DenseBlockDecoder(ngf, 6), |
|
DenseTransitionBlockDecoder(ngf, ngf), |
|
|
|
|
|
nn.BatchNorm2d(ngf), |
|
activation(*args), |
|
nn.ConvTranspose2d(ngf, nc, 3, stride=1, padding=1, bias=False), |
|
f_activation(*f_args), |
|
) |
|
|
|
|
|
|
|
|
|
def forward(self, inputs): |
|
|
|
return self.main(inputs) |
|
|
|
|
|
class dnetccnl(nn.Module): |
|
|
|
|
|
|
|
|
|
|
|
def __init__(self, img_size=448, in_channels=3, out_channels=2, filters=32,fc_units=100): |
|
super(dnetccnl, self).__init__() |
|
self.nc=in_channels |
|
self.nf=filters |
|
self.ndim=img_size |
|
self.oc=out_channels |
|
self.fcu=fc_units |
|
|
|
self.encoder=waspDenseEncoder128(nc=self.nc+2,ndf=self.nf,ndim=self.ndim) |
|
self.decoder=waspDenseDecoder128(nz=self.ndim,nc=self.oc,ngf=self.nf) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def forward(self, inputs): |
|
|
|
encoded=self.encoder(inputs) |
|
encoded=encoded.unsqueeze(-1).unsqueeze(-1) |
|
decoded=self.decoder(encoded) |
|
|
|
|
|
|
|
|
|
return decoded |
|
|
|
class dnetccnl512(nn.Module): |
|
|
|
|
|
|
|
|
|
|
|
def __init__(self, img_size=448, in_channels=3, out_channels=2, filters=32,fc_units=100): |
|
super(dnetccnl512, self).__init__() |
|
self.nc=in_channels |
|
self.nf=filters |
|
self.ndim=img_size |
|
self.oc=out_channels |
|
self.fcu=fc_units |
|
|
|
self.encoder=waspDenseEncoder512(nc=self.nc+2,ndf=self.nf,ndim=self.ndim) |
|
self.decoder=waspDenseDecoder512(nz=self.ndim,nc=self.oc,ngf=self.nf) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def forward(self, inputs): |
|
|
|
encoded=self.encoder(inputs) |
|
encoded=encoded.unsqueeze(-1).unsqueeze(-1) |
|
decoded=self.decoder(encoded) |
|
|
|
|
|
|
|
|
|
return decoded |