Spaces:
No application file
No application file
import torch.nn as nn | |
from torchvision.models import vgg19 | |
import config | |
# phi_5,4 5th conv layer before maxpooling but after activation | |
class VGGLoss(nn.Module): | |
def __init__(self): | |
super().__init__() | |
self.vgg = vgg19(pretrained=True).features[:36].eval().to(config.DEVICE) | |
self.loss = nn.MSELoss() | |
for param in self.vgg.parameters(): | |
param.requires_grad = False | |
def forward(self, input, target): | |
vgg_input_features = self.vgg(input) | |
vgg_target_features = self.vgg(target) | |
return self.loss(vgg_input_features, vgg_target_features) | |