File size: 3,584 Bytes
1ccc202 180681d 1ccc202 8dc3889 1ccc202 6d1b6c6 1ccc202 180681d 1ccc202 180681d 1ccc202 8dc3889 1ccc202 180681d 1ccc202 8dc3889 1ccc202 6d1b6c6 1ccc202 8dc3889 1ccc202 180681d 1ccc202 180681d 1ccc202 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 |
import torch
from torch import nn
import torch.nn.functional as F
def get_similarity_matrix(
image_features: torch.Tensor, text_features: torch.Tensor
) -> torch.Tensor:
return image_features @ text_features.T
def contrastive_loss(logits, dim):
neg_ce = torch.diag(F.log_softmax(logits, dim=dim))
return -neg_ce.mean()
def contrastive_sigmoid_loss(logits):
return F.binary_cross_entropy_with_logits(logits, torch.eye(len(logits)), reduction="mean")
class CLIPLoss(nn.Module):
def __init__(self, logit_temperature: float = -1.0):
super().__init__()
self.logit_temperature = nn.Parameter(torch.tensor(logit_temperature))
def forward(self, similarity_matrix: torch.Tensor, *args):
temperature = self.logit_temperature.sigmoid()
caption_loss = contrastive_loss(similarity_matrix / temperature, dim=0)
image_loss = contrastive_loss(similarity_matrix / temperature, dim=1)
return 0.5 * (caption_loss + image_loss)
class CyCLIPLoss(nn.Module):
def __init__(self, logit_temperature: float = -1.0):
super().__init__()
self.logit_temperature = nn.Parameter(torch.tensor(logit_temperature))
self.lambda_1: float = 1.0
self.lambda_2: float = 1.0
def forward(
self,
similarity_matrix: torch.Tensor,
image_features: torch.Tensor,
text_features: torch.Tensor,
):
temperature = self.logit_temperature.sigmoid()
caption_loss = contrastive_loss(similarity_matrix / temperature, dim=0)
image_loss = contrastive_loss(similarity_matrix / temperature, dim=1)
symmetry_loss = F.mse_loss(similarity_matrix, similarity_matrix.T)
modality_difference_loss = F.mse_loss(
image_features @ image_features.T, text_features @ text_features.T
)
return (
0.5 * (caption_loss + image_loss)
+ self.lambda_1 * symmetry_loss
+ self.lambda_2 * modality_difference_loss
)
class SigLIPLoss(nn.Module):
def __init__(self, logit_temperature: float = -1.0):
super().__init__()
self.logit_temperature = nn.Parameter(torch.tensor(logit_temperature))
def forward(self, similarity_matrix: torch.Tensor, *args):
temperature = self.logit_temperature.sigmoid()
return contrastive_sigmoid_loss(similarity_matrix / temperature)
class CySigLIPLoss(nn.Module):
def __init__(self, logit_temperature: float = -1.0):
super().__init__()
self.logit_temperature = nn.Parameter(torch.tensor(logit_temperature))
self.lambda_1: float = 1.0
self.lambda_2: float = 1.0
def forward(
self,
similarity_matrix: torch.Tensor,
image_features: torch.Tensor,
text_features: torch.Tensor,
):
temperature = self.logit_temperature.sigmoid()
loss = contrastive_sigmoid_loss(similarity_matrix / temperature)
symmetry_loss = F.mse_loss(similarity_matrix, similarity_matrix.T)
modality_difference_loss = F.mse_loss(
image_features @ image_features.T, text_features @ text_features.T
)
return loss + self.lambda_1 * symmetry_loss + self.lambda_2 * modality_difference_loss
def get_loss(loss_type: str):
loss_functions = {
"clip": CLIPLoss(),
"cyclip": CyCLIPLoss(),
"sigmoid": SigLIPLoss(),
"cyclic_sigmoid": CySigLIPLoss(),
}
if loss_type in loss_functions:
return loss_functions[loss_type]
else:
raise ValueError("Invalid loss type")
|