File size: 159 Bytes
803ef9e
 
 
 
 
 
 
1
2
3
4
5
6
7
8
import torch.nn.functional as F


def norm_mse_loss(x0, x1):
    x0 = F.normalize(x0)
    x1 = F.normalize(x1)
    return 2 - 2 * (x0 * x1).sum(dim=-1).mean()