Spaces:
Running
Running
from typing import List, Dict, Optional | |
import torch | |
from PIL.Image import Image | |
from torch.utils.data import DataLoader | |
from datasets import Dataset | |
from transformers import AutoModel, AutoProcessor | |
MODEL_NAME = "Marqo/marqo-fashionCLIP" | |
class FashionCLIPEncoder: | |
def __init__(self): | |
self.processor = AutoProcessor.from_pretrained( | |
MODEL_NAME, trust_remote_code=True | |
) | |
self.model = AutoModel.from_pretrained(MODEL_NAME, trust_remote_code=True) | |
self.model.eval() | |
self.device = self.model.device | |
def encode_images( | |
self, images: List[Image], batch_size: Optional[int] = None | |
) -> List[List[float]]: | |
if batch_size is None: | |
batch_size = len(images) | |
def transform_fn(el: Dict): | |
return self.processor( | |
images=[content for content in el["image"]], return_tensors="pt" | |
) | |
dataset = Dataset.from_dict({"image": images}) | |
dataset.set_format("torch") | |
dataset.set_transform(transform_fn) | |
dataloader = DataLoader(dataset, batch_size=batch_size) | |
image_embeddings = [] | |
with torch.no_grad(): | |
for batch in dataloader: | |
batch = {k: v.to(self.device) for k, v in batch.items()} | |
embeddings = self._encode_images(batch) | |
image_embeddings.extend(embeddings) | |
return image_embeddings | |
def encode_text( | |
self, text: List[str], batch_size: Optional[int] = None | |
) -> List[List[float]]: | |
if batch_size is None: | |
batch_size = len(text) | |
def transform_fn(el: Dict): | |
kwargs = { | |
"padding": "max_length", | |
"return_tensors": "pt", | |
"truncation": True, | |
} | |
return self.processor(text=el["text"], **kwargs) | |
dataset = Dataset.from_dict({"text": text}) | |
dataset = dataset.map( | |
function=transform_fn, batched=True, remove_columns=["text"] | |
) | |
dataset.set_format("torch") | |
dataloader = DataLoader(dataset, batch_size=batch_size) | |
text_embeddings = [] | |
with torch.no_grad(): | |
for batch in dataloader: | |
batch = {k: v.to(self.device) for k, v in batch.items()} | |
embeddings = self._encode_text(batch) | |
text_embeddings.extend(embeddings) | |
return text_embeddings | |
def _encode_images(self, batch: Dict) -> List: | |
return self.model.get_image_features(**batch).detach().cpu().numpy().tolist() | |
def _encode_text(self, batch: Dict) -> List: | |
return self.model.get_text_features(**batch).detach().cpu().numpy().tolist() |