RTVE / loss.py
ChandanaShastri's picture
Upload 6 files
17dd133 verified
raw
history blame contribute delete
614 Bytes
import torch.nn as nn
from torchvision.models import vgg19
import config
# phi_5,4 5th conv layer before maxpooling but after activation
class VGGLoss(nn.Module):
def __init__(self):
super().__init__()
self.vgg = vgg19(pretrained=True).features[:36].eval().to(config.DEVICE)
self.loss = nn.MSELoss()
for param in self.vgg.parameters():
param.requires_grad = False
def forward(self, input, target):
vgg_input_features = self.vgg(input)
vgg_target_features = self.vgg(target)
return self.loss(vgg_input_features, vgg_target_features)