import torch from typing import Dict, Any, List from PIL import Image import base64 from io import BytesIO class EndpointHandler: """ A handler class for processing image data, generating embeddings using a specified model and processor. Attributes: model: The pre-trained model used for generating embeddings. processor: The pre-trained processor used to process images before model inference. device: The device (CPU or CUDA) used to run model inference. default_batch_size: The default batch size for processing images in batches. """ def __init__(self, path: str = "", default_batch_size: int = 4): """ Initializes the EndpointHandler with a specified model path and default batch size. Args: path (str): Path to the pre-trained model and processor. default_batch_size (int): Default batch size for image processing. """ from colpali_engine.models import ColQwen2, ColQwen2Processor self.model = ColQwen2.from_pretrained( path, torch_dtype=torch.bfloat16, ).eval() self.processor = ColQwen2Processor.from_pretrained(path) self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") self.model.to(self.device) self.default_batch_size = default_batch_size def _process_batch(self, images: List[Image.Image]) -> List[List[float]]: """ Processes a batch of images and generates embeddings. Args: images (List[Image.Image]): List of images to process. Returns: List[List[float]]: List of embeddings for each image. """ batch_images = self.processor.process_images(images) batch_images = {k: v.to(self.device) for k, v in batch_images.items()} with torch.no_grad(): image_embeddings = self.model(**batch_images) return image_embeddings.cpu().tolist() def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: """ Processes input data containing base64-encoded images, decodes them, and generates embeddings. Args: data (Dict[str, Any]): Dictionary containing input images and optional batch size. Returns: Dict[str, Any]: Dictionary containing generated embeddings or error messages. """ images_data = data.get("inputs", []) batch_size = data.get("batch_size", self.default_batch_size) if not images_data: return {"error": "No images provided in 'inputs'."} images = [] for img_data in images_data: if isinstance(img_data, str): try: image_bytes = base64.b64decode(img_data) image = Image.open(BytesIO(image_bytes)).convert("RGB") images.append(image) except Exception as e: return {"error": f"Invalid image data: {e}"} else: return {"error": "Images should be base64-encoded strings."} embeddings = [] for i in range(0, len(images), batch_size): batch_images = images[i : i + batch_size] batch_embeddings = self._process_batch(batch_images) embeddings.extend(batch_embeddings) return {"embeddings": embeddings}