File size: 618 Bytes
f6ca457
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32

import torch
from data_setup import gram_matrix

def content_loss(target, content):

  loss = torch.mean((target - content) ** 2)

  return loss


def style_loss(target_features, style_grams):

  loss = 0

  for layer in target_features:
    target_f = target_features[layer]
    target_gram = gram_matrix(target_f)
    style_gram = style_grams[layer]
    b,c,h,w = target_f.shape
    layer_loss = 0.2 * torch.mean((target_gram - style_gram) ** 2)
    loss += layer_loss/(c*h*w)

  return loss


def total_loss(content_loss, style_loss, alpha, beta):

  loss = alpha * content_loss + beta * style_loss

  return loss