File size: 4,141 Bytes
9d7268a bcbc05a 31e368b 9d7268a 24d96ab 31e368b bcbc05a 9d7268a bcbc05a 9d7268a bcbc05a 9d7268a bcbc05a 9d7268a bcbc05a 9d7268a 24d96ab 9d7268a 24d96ab 9d7268a bcbc05a c6fe3c5 bcbc05a 9d7268a bcbc05a 9d7268a bcbc05a 9d7268a 8dc3889 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 |
from PIL import Image
import torch
import torch.nn as nn
import torch.nn.functional as F
import transformers
from transformers import PreTrainedModel
from src import loss
from src import vision_model
from src.config import TinyCLIPConfig
from src.config import TinyCLIPTextConfig
from src.config import TinyCLIPVisionConfig
class Projection(nn.Module):
def __init__(self, d_in: int, d_out: int, p: float = 0.5) -> None:
super().__init__()
self.linear1 = nn.Linear(d_in, d_out, bias=False)
self.linear2 = nn.Linear(d_out, d_out, bias=False)
self.layer_norm = nn.LayerNorm(d_out)
self.drop = nn.Dropout(p)
def forward(self, x: torch.Tensor) -> torch.Tensor:
embed1 = self.linear1(x)
embed2 = self.drop(self.linear2(F.gelu(embed1)))
embeds = self.layer_norm(embed1 + embed2)
return embeds
def projection_layers(d_in: int, d_out: int, num_layers: int) -> nn.Module:
layers = []
for _ in range(num_layers - 1):
layers.extend([Projection(d_in, d_in), nn.GELU()])
layers += [Projection(d_in, d_out)]
return nn.Sequential(*layers)
def mean_pooling(
text_representation: torch.FloatTensor, attention_mask: torch.LongTensor
) -> torch.FloatTensor:
input_mask_expanded = attention_mask.unsqueeze(-1).expand(text_representation.size()).float()
return torch.sum(text_representation * input_mask_expanded, 1) / torch.clamp(
input_mask_expanded.sum(1), min=1e-9
) # type: ignore
class TinyCLIPTextEncoder(PreTrainedModel):
config_class = TinyCLIPTextConfig
def __init__(self, config: TinyCLIPTextConfig):
super().__init__(config)
self.base = transformers.AutoModel.from_pretrained(config.text_model)
self.cls_type = config.cls_type
self.projection = projection_layers(
self.base.config.hidden_size, config.embed_dims, config.projection_layers
)
def forward(self, x: dict[str, torch.Tensor]):
out = self.base(**x).last_hidden_state
if self.cls_type:
out = out[:, 0] # get CLS token output
else:
out = mean_pooling(out, x["attention_mask"]) # type: ignore
projected_vec = self.projection(out)
return F.normalize(projected_vec, dim=-1)
class TinyCLIPVisionEncoder(PreTrainedModel):
config_class = TinyCLIPVisionConfig
def __init__(self, config: TinyCLIPVisionConfig):
super().__init__(config)
base, num_features = vision_model.get_vision_base(config)
self.base = base
self.projection = projection_layers(
num_features, config.embed_dims, config.projection_layers
)
def forward(self, images: torch.Tensor):
projected_vec = self.projection(self.base(images))
return F.normalize(projected_vec, dim=-1)
class TinyCLIP(PreTrainedModel):
config_class = TinyCLIPConfig
def __init__(self, config: TinyCLIPConfig):
super().__init__(config)
self.text_encoder = TinyCLIPTextEncoder(config.text_config)
self.vision_encoder = TinyCLIPVisionEncoder(config.vision_config)
if config.freeze_text_base:
self.text_encoder.base.eval()
for param in self.text_encoder.parameters():
param.requires_grad = False
if config.freeze_vision_base:
self.vision_encoder.base.eval()
for param in self.vision_encoder.parameters():
param.requires_grad = False
self.loss_fn = loss.get_loss(config.loss_type)
def forward(
self,
text_input: dict[str, torch.Tensor],
vision_input: list[Image.Image],
return_loss: bool = False,
) -> dict[str, torch.Tensor]:
text_output = self.text_encoder(text_input)
vision_output = self.vision_encoder(vision_input)
out = {"text_output": text_output, "vision_output": vision_output}
if return_loss:
out["loss"] = self.loss_fn(vision_output, text_output)
return out
if __name__ == "__main__":
model = TinyCLIP(TinyCLIPConfig())
print(model)
print("Done!")
|