amaye15
First model version
b54bf3c
raw
history blame
1.96 kB
import torch
from typing import Dict, Any
from PIL import Image
import base64
from io import BytesIO
class EndpointHandler:
def __init__(self, path: str = ""):
# Import your model and processor inside the class
from colpali_engine.models import ColQwen2, ColQwen2Processor
# Load the model and processor
self.model = ColQwen2.from_pretrained(
path,
torch_dtype=torch.bfloat16,
).eval()
self.processor = ColQwen2Processor.from_pretrained(path)
# Determine the device
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.model.to(self.device)
def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
# Extract images from the input data
images_data = data.get("inputs", [])
if not images_data:
return {"error": "No images provided in 'inputs'."}
# Process images
images = []
for img_data in images_data:
if isinstance(img_data, str):
try:
# Assume base64-encoded image
image_bytes = base64.b64decode(img_data)
image = Image.open(BytesIO(image_bytes)).convert("RGB")
images.append(image)
except Exception as e:
return {"error": f"Invalid image data: {e}"}
else:
return {"error": "Images should be base64-encoded strings."}
# Prepare inputs
batch_images = self.processor.process_images(images)
# Move tensors to the device
batch_images = {k: v.to(self.device) for k, v in batch_images.items()}
# Generate embeddings
with torch.no_grad():
image_embeddings = self.model(**batch_images)
# Convert embeddings to a list
embeddings_list = image_embeddings.cpu().tolist()
return {"embeddings": embeddings_list}