openFashionClip / encoder.py
im
Trying to faster model
82a72c7
raw
history blame
2.7 kB
from typing import List, Dict, Optional
import torch
from PIL.Image import Image
from torch.utils.data import DataLoader
from datasets import Dataset
from transformers import AutoModel, AutoProcessor
MODEL_NAME = "Marqo/marqo-fashionCLIP"
class FashionCLIPEncoder:
def __init__(self):
self.processor = AutoProcessor.from_pretrained(
MODEL_NAME, trust_remote_code=True
)
self.model = AutoModel.from_pretrained(MODEL_NAME, trust_remote_code=True)
self.model.eval()
self.device = self.model.device
def encode_images(
self, images: List[Image], batch_size: Optional[int] = None
) -> List[List[float]]:
if batch_size is None:
batch_size = len(images)
def transform_fn(el: Dict):
return self.processor(
images=[content for content in el["image"]], return_tensors="pt"
)
dataset = Dataset.from_dict({"image": images})
dataset.set_format("torch")
dataset.set_transform(transform_fn)
dataloader = DataLoader(dataset, batch_size=batch_size)
image_embeddings = []
with torch.no_grad():
for batch in dataloader:
batch = {k: v.to(self.device) for k, v in batch.items()}
embeddings = self._encode_images(batch)
image_embeddings.extend(embeddings)
return image_embeddings
def encode_text(
self, text: List[str], batch_size: Optional[int] = None
) -> List[List[float]]:
if batch_size is None:
batch_size = len(text)
def transform_fn(el: Dict):
kwargs = {
"padding": "max_length",
"return_tensors": "pt",
"truncation": True,
}
return self.processor(text=el["text"], **kwargs)
dataset = Dataset.from_dict({"text": text})
dataset = dataset.map(
function=transform_fn, batched=True, remove_columns=["text"]
)
dataset.set_format("torch")
dataloader = DataLoader(dataset, batch_size=batch_size)
text_embeddings = []
with torch.no_grad():
for batch in dataloader:
batch = {k: v.to(self.device) for k, v in batch.items()}
embeddings = self._encode_text(batch)
text_embeddings.extend(embeddings)
return text_embeddings
def _encode_images(self, batch: Dict) -> List:
return self.model.get_image_features(**batch).detach().cpu().numpy().tolist()
def _encode_text(self, batch: Dict) -> List:
return self.model.get_text_features(**batch).detach().cpu().numpy().tolist()