Spaces:
No application file
No application file
File size: 614 Bytes
17dd133 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 |
import torch.nn as nn
from torchvision.models import vgg19
import config
# phi_5,4 5th conv layer before maxpooling but after activation
class VGGLoss(nn.Module):
def __init__(self):
super().__init__()
self.vgg = vgg19(pretrained=True).features[:36].eval().to(config.DEVICE)
self.loss = nn.MSELoss()
for param in self.vgg.parameters():
param.requires_grad = False
def forward(self, input, target):
vgg_input_features = self.vgg(input)
vgg_target_features = self.vgg(target)
return self.loss(vgg_input_features, vgg_target_features)
|