|
from models.BaseNetwork import BaseNetwork |
|
from models.transformer_base.ffn_base import FusionFeedForward |
|
from models.transformer_base.attention_flow import SWMHSA_depthGlobalWindowConcatLN_qkFlow_reweightFlow |
|
from models.transformer_base.attention_base import TMHSA |
|
|
|
import torch |
|
import torch.nn as nn |
|
from functools import reduce |
|
import torch.nn.functional as F |
|
|
|
|
|
class Model(nn.Module): |
|
def __init__(self, config): |
|
super(Model, self).__init__() |
|
self.net = FGT(config['tw'], config['sw'], config['gd'], config['input_resolution'], config['in_channel'], |
|
config['cnum'], config['flow_inChannel'], config['flow_cnum'], config['frame_hidden'], |
|
config['flow_hidden'], config['PASSMASK'], |
|
config['numBlocks'], config['kernel_size'], config['stride'], config['padding'], |
|
config['num_head'], config['conv_type'], config['norm'], |
|
config['use_bias'], config['ape'], |
|
config['mlp_ratio'], config['drop'], config['init_weights']) |
|
|
|
def forward(self, frames, flows, masks): |
|
ret = self.net(frames, flows, masks) |
|
return ret |
|
|
|
|
|
class Encoder(nn.Module): |
|
def __init__(self, in_channels): |
|
super(Encoder, self).__init__() |
|
self.group = [1, 2, 4, 8, 1] |
|
self.layers = nn.ModuleList([ |
|
nn.Conv2d(in_channels, 64, kernel_size=3, stride=2, padding=1), |
|
nn.LeakyReLU(0.2, inplace=True), |
|
nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1), |
|
nn.LeakyReLU(0.2, inplace=True), |
|
nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1), |
|
nn.LeakyReLU(0.2, inplace=True), |
|
nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1), |
|
nn.LeakyReLU(0.2, inplace=True), |
|
nn.Conv2d(256, 384, kernel_size=3, stride=1, padding=1, groups=1), |
|
nn.LeakyReLU(0.2, inplace=True), |
|
nn.Conv2d(640, 512, kernel_size=3, stride=1, padding=1, groups=2), |
|
nn.LeakyReLU(0.2, inplace=True), |
|
nn.Conv2d(768, 384, kernel_size=3, stride=1, padding=1, groups=4), |
|
nn.LeakyReLU(0.2, inplace=True), |
|
nn.Conv2d(640, 256, kernel_size=3, stride=1, padding=1, groups=8), |
|
nn.LeakyReLU(0.2, inplace=True), |
|
nn.Conv2d(512, 128, kernel_size=3, stride=1, padding=1, groups=1), |
|
nn.LeakyReLU(0.2, inplace=True) |
|
]) |
|
|
|
def forward(self, x): |
|
bt, c, h, w = x.size() |
|
h, w = h // 4, w // 4 |
|
out = x |
|
for i, layer in enumerate(self.layers): |
|
if i == 8: |
|
x0 = out |
|
if i > 8 and i % 2 == 0: |
|
g = self.group[(i - 8) // 2] |
|
x = x0.view(bt, g, -1, h, w) |
|
o = out.view(bt, g, -1, h, w) |
|
out = torch.cat([x, o], 2).view(bt, -1, h, w) |
|
out = layer(out) |
|
return out |
|
|
|
|
|
class AddPosEmb(nn.Module): |
|
def __init__(self, h, w, in_channels, out_channels): |
|
super(AddPosEmb, self).__init__() |
|
self.proj = nn.Conv2d(in_channels, out_channels, 3, 1, 1, bias=True, groups=out_channels) |
|
self.h, self.w = h, w |
|
|
|
def forward(self, x, h=0, w=0): |
|
B, N, C = x.shape |
|
if h == 0 and w == 0: |
|
assert N == self.h * self.w, 'Wrong input size' |
|
else: |
|
assert N == h * w, 'Wrong input size during inference' |
|
feat_token = x |
|
if h == 0 and w == 0: |
|
cnn_feat = feat_token.transpose(1, 2).view(B, C, self.h, self.w) |
|
else: |
|
cnn_feat = feat_token.transpose(1, 2).view(B, C, h, w) |
|
x = self.proj(cnn_feat) + cnn_feat |
|
x = x.flatten(2).transpose(1, 2) |
|
return x |
|
|
|
|
|
class Vec2Patch(nn.Module): |
|
def __init__(self, channel, hidden, output_size, kernel_size, stride, padding): |
|
super(Vec2Patch, self).__init__() |
|
self.relu = nn.LeakyReLU(0.2, inplace=True) |
|
c_out = reduce((lambda x, y: x * y), kernel_size) * channel |
|
self.embedding = nn.Linear(hidden, c_out) |
|
self.restore = nn.Fold(output_size=output_size, kernel_size=kernel_size, stride=stride, padding=padding) |
|
self.kernel_size = kernel_size |
|
self.stride = stride |
|
self.padding = padding |
|
|
|
def forward(self, x, output_h=0, output_w=0): |
|
feat = self.embedding(x) |
|
feat = feat.permute(0, 2, 1) |
|
if output_h != 0 or output_w != 0: |
|
feat = F.fold(feat, output_size=(output_h, output_w), kernel_size=self.kernel_size, stride=self.stride, |
|
padding=self.padding) |
|
else: |
|
feat = self.restore(feat) |
|
return feat |
|
|
|
|
|
class TemporalTransformer(nn.Module): |
|
def __init__(self, token_size, frame_hidden, num_heads, t_groupSize, mlp_ratio, dropout, n_vecs, |
|
t2t_params): |
|
super(TemporalTransformer, self).__init__() |
|
self.attention = TMHSA(token_size=token_size, group_size=t_groupSize, d_model=frame_hidden, head=num_heads, |
|
p=dropout) |
|
self.ffn = FusionFeedForward(frame_hidden, mlp_ratio, n_vecs, t2t_params, p=dropout) |
|
self.norm1 = nn.LayerNorm(frame_hidden) |
|
self.norm2 = nn.LayerNorm(frame_hidden) |
|
self.dropout = nn.Dropout(p=dropout) |
|
|
|
def forward(self, x, t, h, w, output_size): |
|
token_size = h * w |
|
s = self.norm1(x) |
|
x = x + self.dropout(self.attention(s, t, h, w)) |
|
y = self.norm2(x) |
|
x = x + self.ffn(y, token_size, output_size[0], output_size[1]) |
|
return x |
|
|
|
|
|
class SpatialTransformer(nn.Module): |
|
def __init__(self, token_size, frame_hidden, flow_hidden, num_heads, s_windowSize, g_downSize, mlp_ratio, |
|
dropout, n_vecs, t2t_params): |
|
super(SpatialTransformer, self).__init__() |
|
self.attention = SWMHSA_depthGlobalWindowConcatLN_qkFlow_reweightFlow(token_size=token_size, window_size=s_windowSize, |
|
kernel_size=g_downSize, d_model=frame_hidden, |
|
flow_dModel=flow_hidden, head=num_heads, p=dropout) |
|
self.ffn = FusionFeedForward(frame_hidden, mlp_ratio, n_vecs, t2t_params, p=dropout) |
|
self.norm = nn.LayerNorm(frame_hidden) |
|
self.dropout = nn.Dropout(p=dropout) |
|
|
|
def forward(self, x, f, t, h, w, output_size): |
|
token_size = h * w |
|
x = x + self.dropout(self.attention(x, f, t, h, w)) |
|
y = self.norm(x) |
|
x = x + self.ffn(y, token_size, output_size[0], output_size[1]) |
|
return x |
|
|
|
|
|
class TransformerBlock(nn.Module): |
|
def __init__(self, token_size, frame_hidden, flow_hidden, num_heads, t_groupSize, s_windowSize, g_downSize, |
|
mlp_ratio, |
|
dropout, n_vecs, |
|
t2t_params): |
|
super(TransformerBlock, self).__init__() |
|
self.t_transformer = TemporalTransformer(token_size=token_size, frame_hidden=frame_hidden, num_heads=num_heads, |
|
t_groupSize=t_groupSize, mlp_ratio=mlp_ratio, |
|
dropout=dropout, n_vecs=n_vecs, |
|
t2t_params=t2t_params) |
|
self.s_transformer = SpatialTransformer(token_size=token_size, frame_hidden=frame_hidden, |
|
flow_hidden=flow_hidden, num_heads=num_heads, s_windowSize=s_windowSize, |
|
g_downSize=g_downSize, mlp_ratio=mlp_ratio, |
|
dropout=dropout, n_vecs=n_vecs, t2t_params=t2t_params) |
|
|
|
def forward(self, inputs): |
|
x, f, t = inputs['x'], inputs['f'], inputs['t'] |
|
h, w = inputs['h'], inputs['w'] |
|
output_size = inputs['output_size'] |
|
x = self.t_transformer(x, t, h, w, output_size) |
|
x = self.s_transformer(x, f, t, h, w, output_size) |
|
return {'x': x, 'f': f, 't': t, 'h': h, 'w': w, 'output_size': output_size} |
|
|
|
|
|
class Decoder(BaseNetwork): |
|
def __init__(self, conv_type, in_channels, out_channels, use_bias, norm=None): |
|
super(Decoder, self).__init__(conv_type) |
|
self.layer1 = self.DeconvBlock(in_channels, in_channels, kernel_size=3, padding=1, norm=norm, |
|
bias=use_bias) |
|
self.layer2 = self.ConvBlock(in_channels, in_channels // 2, kernel_size=3, stride=1, padding=1, norm=norm, |
|
bias=use_bias) |
|
self.layer3 = self.DeconvBlock(in_channels // 2, in_channels // 2, kernel_size=3, padding=1, norm=norm, |
|
bias=use_bias) |
|
self.final = self.ConvBlock(in_channels // 2, out_channels, kernel_size=3, stride=1, padding=1, norm=norm, |
|
bias=use_bias, activation=None) |
|
|
|
def forward(self, features): |
|
feat1 = self.layer1(features) |
|
feat2 = self.layer2(feat1) |
|
feat3 = self.layer3(feat2) |
|
output = self.final(feat3) |
|
return output |
|
|
|
|
|
class FGT(BaseNetwork): |
|
def __init__(self, t_groupSize, s_windowSize, g_downSize, input_resolution, in_channels, cnum, flow_inChannel, |
|
flow_cnum, |
|
frame_hidden, flow_hidden, passmask, numBlocks, kernel_size, stride, padding, num_heads, conv_type, |
|
norm, use_bias, ape, mlp_ratio=4, drop=0, init_weights=True): |
|
super(FGT, self).__init__(conv_type) |
|
self.in_channels = in_channels |
|
self.passmask = passmask |
|
self.ape = ape |
|
self.frame_endoder = Encoder(in_channels) |
|
self.flow_encoder = nn.Sequential( |
|
nn.ReplicationPad2d(2), |
|
self.ConvBlock(flow_inChannel, flow_cnum, kernel_size=5, stride=1, padding=0, bias=use_bias, norm=norm), |
|
self.ConvBlock(flow_cnum, flow_cnum * 2, kernel_size=3, stride=2, padding=1, bias=use_bias, norm=norm), |
|
self.ConvBlock(flow_cnum * 2, flow_cnum * 2, kernel_size=3, stride=1, padding=1, bias=use_bias, norm=norm), |
|
self.ConvBlock(flow_cnum * 2, flow_cnum * 2, kernel_size=3, stride=2, padding=1, bias=use_bias, norm=norm) |
|
) |
|
|
|
self.patch2vec = nn.Conv2d(cnum * 2, frame_hidden, kernel_size=kernel_size, stride=stride, padding=padding) |
|
self.f_patch2vec = nn.Conv2d(flow_cnum * 2, flow_hidden, kernel_size=kernel_size, stride=stride, |
|
padding=padding) |
|
|
|
n_vecs = 1 |
|
token_size = [] |
|
output_shape = (input_resolution[0] // 4, input_resolution[1] // 4) |
|
for i, d in enumerate(kernel_size): |
|
token_nums = int((output_shape[i] + 2 * padding[i] - kernel_size[i]) / stride[i] + 1) |
|
n_vecs *= token_nums |
|
token_size.append(token_nums) |
|
|
|
if self.ape: |
|
self.add_pos_emb = AddPosEmb(token_size[0], token_size[1], frame_hidden, frame_hidden) |
|
self.token_size = token_size |
|
|
|
blocks = [] |
|
t2t_params = {'kernel_size': kernel_size, 'stride': stride, 'padding': padding, 'output_size': output_shape} |
|
for i in range(numBlocks // 2 - 1): |
|
layer = TransformerBlock(token_size, frame_hidden, flow_hidden, num_heads, t_groupSize, s_windowSize, |
|
g_downSize, mlp_ratio, drop, n_vecs, t2t_params) |
|
blocks.append(layer) |
|
self.first_t_transformer = TemporalTransformer(token_size, frame_hidden, num_heads, t_groupSize, mlp_ratio, |
|
drop, n_vecs, t2t_params) |
|
self.first_s_transformer = SpatialTransformer(token_size, frame_hidden, flow_hidden, num_heads, s_windowSize, |
|
g_downSize, mlp_ratio, drop, n_vecs, t2t_params) |
|
self.transformer = nn.Sequential(*blocks) |
|
|
|
self.vec2patch = Vec2Patch(cnum * 2, frame_hidden, output_shape, kernel_size, stride, padding) |
|
|
|
self.decoder = Decoder(conv_type, cnum * 2, 3, use_bias, norm) |
|
|
|
if init_weights: |
|
self.init_weights() |
|
|
|
def forward(self, masked_frames, flows, masks): |
|
b, t, c, h, w = masked_frames.shape |
|
cf = flows.shape[2] |
|
output_shape = (h // 4, w // 4) |
|
if self.passmask: |
|
inputs = torch.cat((masked_frames, masks), dim=2) |
|
else: |
|
inputs = masked_frames |
|
inputs = inputs.view(b * t, self.in_channels, h, w) |
|
flows = flows.view(b * t, cf, h, w) |
|
enc_feats = self.frame_endoder(inputs) |
|
flow_feats = self.flow_encoder(flows) |
|
trans_feat = self.patch2vec(enc_feats) |
|
flow_patches = self.f_patch2vec(flow_feats) |
|
_, c, h, w = trans_feat.shape |
|
cf = flow_patches.shape[1] |
|
if h != self.token_size[0] or w != self.token_size[1]: |
|
new_h, new_w = h, w |
|
else: |
|
new_h, new_w = 0, 0 |
|
output_shape = (0, 0) |
|
trans_feat = trans_feat.view(b * t, c, -1).permute(0, 2, 1) |
|
flow_patches = flow_patches.view(b * t, cf, -1).permute(0, 2, 1) |
|
trans_feat = self.first_t_transformer(trans_feat, t, new_h, new_w, output_shape) |
|
trans_feat = self.add_pos_emb(trans_feat, new_h, new_w) |
|
trans_feat = self.first_s_transformer(trans_feat, flow_patches, t, new_h, new_w, output_shape) |
|
inputs_trans_feat = {'x': trans_feat, 'f': flow_patches, 't': t, 'h': new_h, 'w': new_w, |
|
'output_size': output_shape} |
|
trans_feat = self.transformer(inputs_trans_feat)['x'] |
|
trans_feat = self.vec2patch(trans_feat, output_shape[0], output_shape[1]) |
|
enc_feats = enc_feats + trans_feat |
|
|
|
output = self.decoder(enc_feats) |
|
output = torch.tanh(output) |
|
return output |
|
|
|
|