sachin commited on
Commit
1ccc202
·
1 Parent(s): 9d7268a

Added clip loss functions

Browse files
Files changed (1) hide show
  1. src/loss.py +95 -0
src/loss.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ import torch.nn.functional as F
4
+
5
+
6
+ def contrastive_loss(logits, dim):
7
+ neg_ce = torch.diag(F.log_softmax(logits, dim=dim))
8
+ return -neg_ce.mean()
9
+
10
+
11
+ def contrastive_sigmoid_loss(logits):
12
+ return F.binary_cross_entropy_with_logits(logits, torch.eye(len(logits)), reduction="mean")
13
+
14
+
15
+ class CLIPLoss(nn.Module):
16
+ def __init__(self, logit_temperature: float = -1.0):
17
+ super().__init__()
18
+ self.logit_temperature = nn.Parameter(logit_temperature)
19
+
20
+ def forward(self, image_features: torch.Tensor, text_features: torch.Tensor):
21
+ temperature = self.logit_temperature.sigmoid()
22
+ similarity_matrix = image_features @ text_features.T
23
+ caption_loss = contrastive_loss(similarity_matrix / temperature, dim=0)
24
+ image_loss = contrastive_loss(similarity_matrix / temperature, dim=1)
25
+
26
+ return 0.5 * (caption_loss + image_loss)
27
+
28
+
29
+ class CyCLIP(nn.Module):
30
+ def __init__(self, logit_temperature: float = -1.0):
31
+ super().__init__()
32
+ self.logit_temperature = nn.Parameter(logit_temperature)
33
+ self.lambda_1: float = 1.0
34
+ self.lambda_2: float = 1.0
35
+
36
+ def forward(self, image_features: torch.Tensor, text_features: torch.Tensor):
37
+ temperature = self.logit_temperature.sigmoid()
38
+ similarity_matrix = image_features @ text_features.T
39
+ caption_loss = contrastive_loss(similarity_matrix / temperature, dim=0)
40
+ image_loss = contrastive_loss(similarity_matrix / temperature, dim=1)
41
+
42
+ symmetry_loss = F.mse_loss(similarity_matrix, similarity_matrix.T)
43
+ modality_difference_loss = F.mse_loss(
44
+ image_features @ image_features.T, text_features @ text_features.T
45
+ )
46
+
47
+ return (
48
+ 0.5 * (caption_loss + image_loss)
49
+ + self.lambda_1 * symmetry_loss
50
+ + self.lambda_2 * modality_difference_loss
51
+ )
52
+
53
+
54
+ class SigLIPLoss(nn.Module):
55
+ def __init__(self, logit_temperature: float = -1.0):
56
+ super().__init__()
57
+ self.logit_temperature = nn.Parameter(logit_temperature)
58
+
59
+ def forward(self, image_features: torch.Tensor, text_features: torch.Tensor):
60
+ temperature = self.logit_temperature.sigmoid()
61
+ similarity_matrix = image_features @ text_features.T
62
+ return contrastive_sigmoid_loss(similarity_matrix / temperature)
63
+
64
+
65
+ class CySigLIPLoss(nn.Module):
66
+ def __init__(self, logit_temperature: float = -1.0):
67
+ super().__init__()
68
+ self.logit_temperature = nn.Parameter(logit_temperature)
69
+ self.lambda_1: float = 1.0
70
+ self.lambda_2: float = 1.0
71
+
72
+ def forward(self, image_features: torch.Tensor, text_features: torch.Tensor):
73
+ temperature = self.logit_temperature.sigmoid()
74
+ similarity_matrix = image_features @ text_features.T
75
+ loss = contrastive_sigmoid_loss(similarity_matrix / temperature)
76
+
77
+ symmetry_loss = F.mse_loss(similarity_matrix, similarity_matrix.T)
78
+ modality_difference_loss = F.mse_loss(
79
+ image_features @ image_features.T, text_features @ text_features.T
80
+ )
81
+
82
+ return loss + self.lambda_1 * symmetry_loss + self.lambda_2 * modality_difference_loss
83
+
84
+
85
+ def get_loss(loss_type: str):
86
+ loss_functions = {
87
+ "clip": CLIPLoss(),
88
+ "cyclip": CyCLIP(),
89
+ "sigmoid": SigLIPLoss(),
90
+ "cyclic_sigmoid": CySigLIPLoss(),
91
+ }
92
+ if loss_type in loss_functions:
93
+ return loss_functions[loss_type]
94
+ else:
95
+ raise ValueError("Invalid loss type")