|
import os |
|
import sys |
|
|
|
import torch |
|
import torch.nn as nn |
|
import numpy as np |
|
|
|
|
|
class AlignLoss(nn.Module): |
|
def __init__(self, reduction='mean'): |
|
super().__init__() |
|
self.loss_fn = nn.L1Loss(reduction=reduction) |
|
|
|
def forward(self, frames, masks, aligned_vs, aligned_rs): |
|
""" |
|
|
|
:param frames: The original frames(GT) |
|
:param masks: Original masks |
|
:param aligned_vs: aligned visibility map from reference frame(List: B, C, T, H, W) |
|
:param aligned_rs: aligned reference frames(List: B, C, T, H, W) |
|
:return: |
|
""" |
|
try: |
|
B, C, T, H, W = frames.shape |
|
except ValueError: |
|
frames = frames.unsqueeze(2) |
|
masks = masks.unsqueeze(2) |
|
B, C, T, H, W = frames.shape |
|
loss = 0 |
|
for i in range(T): |
|
frame = frames[:, :, i] |
|
mask = masks[:, :, i] |
|
aligned_v = aligned_vs[i] |
|
aligned_r = aligned_rs[i] |
|
loss += self._singleFrameAlignLoss(frame, mask, aligned_v, aligned_r) |
|
return loss |
|
|
|
def _singleFrameAlignLoss(self, targetFrame, targetMask, aligned_v, aligned_r): |
|
""" |
|
|
|
:param targetFrame: targetFrame to be aligned-> B, C, H, W |
|
:param targetMask: the mask of target frames |
|
:param aligned_v: aligned visibility map from reference frame |
|
:param aligned_r: aligned reference frame-> B, C, T, H, W |
|
:return: |
|
""" |
|
targetVisibility = 1. - targetMask |
|
targetVisibility = targetVisibility.unsqueeze(2) |
|
targetFrame = targetFrame.unsqueeze(2) |
|
visibility_map = targetVisibility * aligned_v |
|
target_visibility = visibility_map * targetFrame |
|
reference_visibility = visibility_map * aligned_r |
|
loss = 0 |
|
for i in range(aligned_r.shape[2]): |
|
loss += self.loss_fn(target_visibility[:, :, i], reference_visibility[:, :, i]) |
|
return loss |
|
|
|
|
|
class HoleVisibleLoss(nn.Module): |
|
def __init__(self, reduction='mean'): |
|
super().__init__() |
|
self.loss_fn = nn.L1Loss(reduction=reduction) |
|
|
|
def forward(self, outputs, masks, GTs, c_masks): |
|
try: |
|
B, C, T, H, W = outputs.shape |
|
except ValueError: |
|
outputs = outputs.unsqueeze(2) |
|
masks = masks.unsqueeze(2) |
|
GTs = GTs.unsqueeze(2) |
|
c_masks = c_masks.unsqueeze(2) |
|
B, C, T, H, W = outputs.shape |
|
loss = 0 |
|
for i in range(T): |
|
loss += self._singleFrameHoleVisibleLoss(outputs[:, :, i], masks[:, :, i], c_masks[:, :, i], GTs[:, :, i]) |
|
return loss |
|
|
|
def _singleFrameHoleVisibleLoss(self, targetFrame, targetMask, c_mask, GT): |
|
return self.loss_fn(targetMask * c_mask * targetFrame, targetMask * c_mask * GT) |
|
|
|
|
|
class HoleInvisibleLoss(nn.Module): |
|
def __init__(self, reduction='mean'): |
|
super().__init__() |
|
self.loss_fn = nn.L1Loss(reduction=reduction) |
|
|
|
def forward(self, outputs, masks, GTs, c_masks): |
|
try: |
|
B, C, T, H, W = outputs.shape |
|
except ValueError: |
|
outputs = outputs.unsqueeze(2) |
|
masks = masks.unsqueeze(2) |
|
GTs = GTs.unsqueeze(2) |
|
c_masks = c_masks.unsqueeze(2) |
|
B, C, T, H, W = outputs.shape |
|
loss = 0 |
|
for i in range(T): |
|
loss += self._singleFrameHoleInvisibleLoss(outputs[:, :, i], masks[:, :, i], c_masks[:, :, i], GTs[:, :, i]) |
|
return loss |
|
|
|
def _singleFrameHoleInvisibleLoss(self, targetFrame, targetMask, c_mask, GT): |
|
return self.loss_fn(targetMask * (1. - c_mask) * targetFrame, targetMask * (1. - c_mask) * GT) |
|
|
|
|
|
class NonHoleLoss(nn.Module): |
|
def __init__(self, reduction='mean'): |
|
super().__init__() |
|
self.loss_fn = nn.L1Loss(reduction=reduction) |
|
|
|
def forward(self, outputs, masks, GTs): |
|
try: |
|
B, C, T, H, W = outputs.shape |
|
except ValueError: |
|
outputs = outputs.unsqueeze(2) |
|
masks = masks.unsqueeze(2) |
|
GTs = GTs.unsqueeze(2) |
|
B, C, T, H, W = outputs.shape |
|
loss = 0 |
|
for i in range(T): |
|
loss += self._singleNonHoleLoss(outputs[:, :, i], masks[:, :, i], GTs[:, :, i]) |
|
return loss |
|
|
|
def _singleNonHoleLoss(self, targetFrame, targetMask, GT): |
|
return self.loss_fn((1. - targetMask) * targetFrame, (1. - targetMask) * GT) |
|
|
|
|
|
class ReconLoss(nn.Module): |
|
def __init__(self, reduction='mean', masked=False): |
|
super().__init__() |
|
self.loss_fn = nn.L1Loss(reduction=reduction) |
|
self.masked = masked |
|
|
|
def forward(self, model_output, target, mask): |
|
outputs = model_output |
|
targets = target |
|
if self.masked: |
|
masks = mask |
|
return self.loss_fn(outputs * masks, targets * masks) |
|
else: |
|
return self.loss_fn(outputs, targets) |
|
|
|
|
|
class VGGLoss(nn.Module): |
|
def __init__(self, vgg): |
|
super().__init__() |
|
self.l1_loss = nn.L1Loss() |
|
self.vgg = vgg |
|
|
|
def vgg_loss(self, output, target): |
|
output_feature = self.vgg(output) |
|
target_feature = self.vgg(target) |
|
loss = ( |
|
self.l1_loss(output_feature.relu2_2, target_feature.relu2_2) |
|
+ self.l1_loss(output_feature.relu3_3, target_feature.relu3_3) |
|
+ self.l1_loss(output_feature.relu4_3, target_feature.relu4_3) |
|
) |
|
return loss |
|
|
|
def forward(self, data_input, model_output): |
|
targets = data_input |
|
outputs = model_output |
|
mean_image_loss = self.vgg_loss(outputs, targets) |
|
return mean_image_loss |
|
|
|
|
|
class StyleLoss(nn.Module): |
|
def __init__(self, vgg, original_channel_norm=True): |
|
super().__init__() |
|
self.l1_loss = nn.L1Loss() |
|
self.vgg = vgg |
|
self.original_channel_norm = original_channel_norm |
|
|
|
|
|
def gram_matrix(self, input): |
|
a, b, c, d = input.size() |
|
|
|
|
|
|
|
features = input.view(a * b, c * d) |
|
|
|
G = torch.mm(features, features.t()) |
|
|
|
|
|
|
|
return G.div(a * b * c * d) |
|
|
|
|
|
def style_loss(self, output, target): |
|
output_features = self.vgg(output) |
|
target_features = self.vgg(target) |
|
layers = ['relu2_2', 'relu3_3', 'relu4_3'] |
|
loss = 0 |
|
for i, layer in enumerate(layers): |
|
output_feature = getattr(output_features, layer) |
|
target_feature = getattr(target_features, layer) |
|
B, C_P, H, W = output_feature.shape |
|
output_gram_matrix = self.gram_matrix(output_feature) |
|
target_gram_matrix = self.gram_matrix(target_feature) |
|
if self.original_channel_norm: |
|
C_P_square_divider = 2 ** (i + 1) |
|
else: |
|
C_P_square_divider = C_P ** 2 |
|
assert C_P == 128 * 2 ** i |
|
loss += self.l1_loss(output_gram_matrix, target_gram_matrix) / C_P_square_divider |
|
return loss |
|
|
|
def forward(self, data_input, model_output): |
|
targets = data_input |
|
outputs = model_output |
|
mean_image_loss = self.style_loss(outputs, targets) |
|
return mean_image_loss |
|
|
|
|
|
class L1LossMaskedMean(nn.Module): |
|
def __init__(self): |
|
super().__init__() |
|
self.l1 = nn.L1Loss(reduction='sum') |
|
|
|
def forward(self, x, y, mask): |
|
masked = 1 - mask |
|
l1_sum = self.l1(x * masked, y * masked) |
|
return l1_sum / torch.sum(masked) |
|
|
|
|
|
class L2LossMaskedMean(nn.Module): |
|
def __init__(self, reduction='sum'): |
|
super().__init__() |
|
self.l2 = nn.MSELoss(reduction=reduction) |
|
|
|
def forward(self, x, y, mask): |
|
masked = 1 - mask |
|
l2_sum = self.l2(x * masked, y * masked) |
|
return l2_sum / torch.sum(masked) |
|
|
|
|
|
class ImcompleteVideoReconLoss(nn.Module): |
|
def __init__(self): |
|
super().__init__() |
|
self.loss_fn = L1LossMaskedMean() |
|
|
|
def forward(self, data_input, model_output): |
|
imcomplete_video = model_output['imcomplete_video'] |
|
targets = data_input['targets'] |
|
down_sampled_targets = nn.functional.interpolate( |
|
targets.transpose(1, 2), scale_factor=[1, 0.5, 0.5]) |
|
|
|
masks = data_input['masks'] |
|
down_sampled_masks = nn.functional.interpolate( |
|
masks.transpose(1, 2), scale_factor=[1, 0.5, 0.5]) |
|
return self.loss_fn( |
|
imcomplete_video, down_sampled_targets, |
|
down_sampled_masks |
|
) |
|
|
|
|
|
class CompleteFramesReconLoss(nn.Module): |
|
def __init__(self): |
|
super().__init__() |
|
self.loss_fn = L1LossMaskedMean() |
|
|
|
def forward(self, data_input, model_output): |
|
outputs = model_output['outputs'] |
|
targets = data_input['targets'] |
|
masks = data_input['masks'] |
|
return self.loss_fn(outputs, targets, masks) |
|
|
|
|
|
class AdversarialLoss(nn.Module): |
|
r""" |
|
Adversarial loss |
|
https://arxiv.org/abs/1711.10337 |
|
""" |
|
|
|
def __init__(self, type='nsgan', target_real_label=1.0, target_fake_label=0.0): |
|
r""" |
|
type = nsgan | lsgan | hinge |
|
""" |
|
super(AdversarialLoss, self).__init__() |
|
self.type = type |
|
self.register_buffer('real_label', torch.tensor(target_real_label)) |
|
self.register_buffer('fake_label', torch.tensor(target_fake_label)) |
|
|
|
if type == 'nsgan': |
|
self.criterion = nn.BCELoss() |
|
elif type == 'lsgan': |
|
self.criterion = nn.MSELoss() |
|
elif type == 'hinge': |
|
self.criterion = nn.ReLU() |
|
|
|
def __call__(self, outputs, is_real, is_disc=None): |
|
if self.type == 'hinge': |
|
if is_disc: |
|
if is_real: |
|
outputs = -outputs |
|
return self.criterion(1 + outputs).mean() |
|
else: |
|
return (-outputs).mean() |
|
else: |
|
labels = (self.real_label if is_real else self.fake_label).expand_as( |
|
outputs) |
|
loss = self.criterion(outputs, labels) |
|
return loss |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ValidLoss(nn.Module): |
|
def __init__(self): |
|
super(ValidLoss, self).__init__() |
|
self.loss_fn = nn.L1Loss(reduction='mean') |
|
|
|
def forward(self, model_output, target, mk): |
|
outputs = model_output |
|
targets = target |
|
return self.loss_fn(outputs * (1 - mk), targets * (1 - mk)) |
|
|
|
|
|
|
|
class TVLoss(nn.Module): |
|
def __init__(self): |
|
super(TVLoss, self).__init__() |
|
|
|
def forward(self, mask_input, model_output): |
|
|
|
outputs = model_output |
|
|
|
if len(mask_input.shape) == 4: |
|
mask_input = mask_input.unsqueeze(2) |
|
if len(outputs.shape) == 4: |
|
outputs = outputs.unsqueeze(2) |
|
|
|
outputs = outputs.permute((0, 2, 1, 3, 4)).contiguous() |
|
masks = mask_input.permute((0, 2, 1, 3, 4)).contiguous() |
|
|
|
B, L, C, H, W = outputs.shape |
|
x = outputs.view([B * L, C, H, W]) |
|
|
|
masks = masks.view([B * L, -1]) |
|
mask_areas = masks.sum(dim=1) |
|
|
|
h_x = x.size()[2] |
|
w_x = x.size()[3] |
|
h_tv = torch.pow((x[:, :, 1:, :] - x[:, :, :h_x - 1, :]), 2).sum(1).sum(1).sum(1) |
|
w_tv = torch.pow((x[:, :, :, 1:] - x[:, :, :, :w_x - 1]), 2).sum(1).sum(1).sum(1) |
|
return ((h_tv + w_tv) / mask_areas).mean() |
|
|
|
|
|
|
|
def show_images(image, name): |
|
import cv2 |
|
import numpy as np |
|
image = np.array(image) |
|
image[image > 0.5] = 255. |
|
image = image.transpose((1, 2, 0)) |
|
cv2.imwrite(name, image) |
|
|
|
|
|
if __name__ == '__main__': |
|
|
|
targetFrame = torch.ones(1, 3, 32, 32) |
|
GT = torch.ones(1, 3, 32, 32) |
|
GT += 1 |
|
mask = torch.zeros(1, 1, 32, 32) |
|
mask[:, :, 8:24, 8:24] = 1. |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
c_mask = torch.zeros(1, 1, 32, 32) |
|
c_mask[:, :, 8:16, 16:24] = 1. |
|
result1 = HoleVisibleLoss()(targetFrame, mask, GT, c_mask) |
|
result2 = HoleInvisibleLoss()(targetFrame, mask, GT, c_mask) |
|
result3 = NonHoleLoss()(targetFrame, mask, GT) |
|
print('vis: {}, invis: {}, gt: {}'.format(result1, result2, result3)) |
|
|