amaye15
commited on
Commit
·
54d3e5c
1
Parent(s):
b0a7877
clean up
Browse files- handler.py +0 -240
handler.py
CHANGED
@@ -166,243 +166,3 @@ class EndpointHandler:
|
|
166 |
return {"error": f"Error computing similarity scores: {e}"}
|
167 |
|
168 |
return {"image": image_embeddings, "text": text_embeddings, "scores": scores}
|
169 |
-
|
170 |
-
|
171 |
-
# import torch
|
172 |
-
# from typing import Dict, Any, List
|
173 |
-
# from PIL import Image
|
174 |
-
# import base64
|
175 |
-
# from io import BytesIO
|
176 |
-
# import logging
|
177 |
-
# from torch.utils.data import DataLoader, Dataset
|
178 |
-
# import threading
|
179 |
-
|
180 |
-
|
181 |
-
# class ImageDataset(Dataset):
|
182 |
-
# def __init__(self, images: List[Image.Image], processor):
|
183 |
-
# self.images = images
|
184 |
-
# self.processor = processor
|
185 |
-
|
186 |
-
# def __len__(self):
|
187 |
-
# return len(self.images)
|
188 |
-
|
189 |
-
# def __getitem__(self, idx):
|
190 |
-
# image = self.processor.process_images([self.images[idx]])
|
191 |
-
# return image
|
192 |
-
|
193 |
-
|
194 |
-
# class TextDataset(Dataset):
|
195 |
-
# def __init__(self, texts: List[str], processor):
|
196 |
-
# self.texts = texts
|
197 |
-
# self.processor = processor
|
198 |
-
|
199 |
-
# def __len__(self):
|
200 |
-
# return len(self.texts)
|
201 |
-
|
202 |
-
# def __getitem__(self, idx):
|
203 |
-
# text = self.processor.process_queries([self.texts[idx]])
|
204 |
-
# return text
|
205 |
-
|
206 |
-
|
207 |
-
# class EndpointHandler:
|
208 |
-
# """
|
209 |
-
# A handler class for processing image and text data, generating embeddings using a specified model and processor.
|
210 |
-
|
211 |
-
# Attributes:
|
212 |
-
# model: The pre-trained model used for generating embeddings.
|
213 |
-
# processor: The pre-trained processor used to process images and text before model inference.
|
214 |
-
# device: The device (CPU or CUDA) used to run model inference.
|
215 |
-
# default_batch_size: The default batch size for processing images and text in batches.
|
216 |
-
# """
|
217 |
-
|
218 |
-
# def __init__(self, path: str = "", default_batch_size: int = 4):
|
219 |
-
# """
|
220 |
-
# Initializes the EndpointHandler with a specified model path and default batch size.
|
221 |
-
# """
|
222 |
-
# # Initialize logging
|
223 |
-
# logging.basicConfig(level=logging.INFO)
|
224 |
-
# self.logger = logging.getLogger(__name__)
|
225 |
-
|
226 |
-
# from colpali_engine.models import ColQwen2, ColQwen2Processor
|
227 |
-
|
228 |
-
# self.logger.info("Initializing model and processor.")
|
229 |
-
# try:
|
230 |
-
# self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
231 |
-
|
232 |
-
# self.model = (
|
233 |
-
# ColQwen2.from_pretrained(
|
234 |
-
# path,
|
235 |
-
# torch_dtype=torch.bfloat16,
|
236 |
-
# device_map="auto",
|
237 |
-
# )
|
238 |
-
# .to(self.device)
|
239 |
-
# .eval()
|
240 |
-
# )
|
241 |
-
|
242 |
-
# self.processor = ColQwen2Processor.from_pretrained(path)
|
243 |
-
# self.default_batch_size = default_batch_size
|
244 |
-
# self.logger.info("Initialization complete.")
|
245 |
-
# except Exception as e:
|
246 |
-
# self.logger.error(f"Failed to initialize model or processor: {e}")
|
247 |
-
# raise
|
248 |
-
|
249 |
-
# def _process_image_embeddings(
|
250 |
-
# self, images: List[Image.Image], batch_size: int
|
251 |
-
# ) -> torch.Tensor:
|
252 |
-
# """
|
253 |
-
# Processes images and generates embeddings.
|
254 |
-
|
255 |
-
# Args:
|
256 |
-
# images (List[Image.Image]): List of images to process.
|
257 |
-
# batch_size (int): Batch size for processing images.
|
258 |
-
|
259 |
-
# Returns:
|
260 |
-
# torch.Tensor: Tensor containing embeddings for each image.
|
261 |
-
# """
|
262 |
-
# self.logger.debug(f"Processing {len(images)} images.")
|
263 |
-
# try:
|
264 |
-
# image_dataset = ImageDataset(images, self.processor)
|
265 |
-
# image_loader = DataLoader(
|
266 |
-
# image_dataset, batch_size=batch_size, num_workers=4, pin_memory=True
|
267 |
-
# )
|
268 |
-
|
269 |
-
# all_embeddings = []
|
270 |
-
# with torch.no_grad():
|
271 |
-
# for batch in image_loader:
|
272 |
-
# batch_images = batch[0].to(self.device, non_blocking=True)
|
273 |
-
# with torch.cuda.amp.autocast():
|
274 |
-
# embeddings = self.model(**batch_images)
|
275 |
-
# all_embeddings.append(embeddings)
|
276 |
-
# image_embeddings = torch.cat(all_embeddings, dim=0)
|
277 |
-
# self.logger.debug("Image processing complete.")
|
278 |
-
# return image_embeddings
|
279 |
-
# except Exception as e:
|
280 |
-
# self.logger.error(f"Error processing images: {e}")
|
281 |
-
# raise
|
282 |
-
|
283 |
-
# def _process_text_embeddings(
|
284 |
-
# self, texts: List[str], batch_size: int
|
285 |
-
# ) -> torch.Tensor:
|
286 |
-
# """
|
287 |
-
# Processes text queries and generates embeddings.
|
288 |
-
|
289 |
-
# Args:
|
290 |
-
# texts (List[str]): List of text queries to process.
|
291 |
-
# batch_size (int): Batch size for processing texts.
|
292 |
-
|
293 |
-
# Returns:
|
294 |
-
# torch.Tensor: Tensor containing embeddings for each text query.
|
295 |
-
# """
|
296 |
-
# self.logger.debug(f"Processing {len(texts)} text queries.")
|
297 |
-
# try:
|
298 |
-
# text_dataset = TextDataset(texts, self.processor)
|
299 |
-
# text_loader = DataLoader(
|
300 |
-
# text_dataset, batch_size=batch_size, num_workers=4, pin_memory=True
|
301 |
-
# )
|
302 |
-
|
303 |
-
# all_embeddings = []
|
304 |
-
# with torch.no_grad():
|
305 |
-
# for batch in text_loader:
|
306 |
-
# batch_texts = batch[0].to(self.device, non_blocking=True)
|
307 |
-
# with torch.amp.autocast():
|
308 |
-
# embeddings = self.model(**batch_texts)
|
309 |
-
# all_embeddings.append(embeddings)
|
310 |
-
# text_embeddings = torch.cat(all_embeddings, dim=0)
|
311 |
-
# self.logger.debug("Text processing complete.")
|
312 |
-
# return text_embeddings
|
313 |
-
# except Exception as e:
|
314 |
-
# self.logger.error(f"Error processing texts: {e}")
|
315 |
-
# raise
|
316 |
-
|
317 |
-
# def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
|
318 |
-
# """
|
319 |
-
# Processes input data containing base64-encoded images and text queries, decodes them, and generates embeddings.
|
320 |
-
|
321 |
-
# Args:
|
322 |
-
# data (Dict[str, Any]): Dictionary containing input images, text queries, and optional batch size.
|
323 |
-
|
324 |
-
# Returns:
|
325 |
-
# Dict[str, Any]: Dictionary containing generated embeddings for images and text or error messages.
|
326 |
-
# """
|
327 |
-
# images_data = data.get("image", [])
|
328 |
-
# text_data = data.get("text", [])
|
329 |
-
# batch_size = data.get("batch_size", self.default_batch_size)
|
330 |
-
|
331 |
-
# images = []
|
332 |
-
# if images_data:
|
333 |
-
# self.logger.info("Decoding images from base64.")
|
334 |
-
# for img_data in images_data:
|
335 |
-
# if isinstance(img_data, str):
|
336 |
-
# try:
|
337 |
-
# image_bytes = base64.b64decode(img_data)
|
338 |
-
# image = Image.open(BytesIO(image_bytes)).convert("RGB")
|
339 |
-
# images.append(image)
|
340 |
-
# except Exception as e:
|
341 |
-
# self.logger.error(f"Invalid image data: {e}")
|
342 |
-
# return {"error": f"Invalid image data: {e}"}
|
343 |
-
# else:
|
344 |
-
# self.logger.error("Images should be base64-encoded strings.")
|
345 |
-
# return {"error": "Images should be base64-encoded strings."}
|
346 |
-
|
347 |
-
# image_embeddings = None
|
348 |
-
# text_embeddings = None
|
349 |
-
# scores = None
|
350 |
-
|
351 |
-
# def process_images():
|
352 |
-
# nonlocal image_embeddings
|
353 |
-
# if images:
|
354 |
-
# self.logger.info("Processing image embeddings.")
|
355 |
-
# try:
|
356 |
-
# image_embeddings = self._process_image_embeddings(
|
357 |
-
# images, batch_size
|
358 |
-
# )
|
359 |
-
# except Exception as e:
|
360 |
-
# self.logger.error(f"Error generating image embeddings: {e}")
|
361 |
-
|
362 |
-
# def process_texts():
|
363 |
-
# nonlocal text_embeddings
|
364 |
-
# if text_data:
|
365 |
-
# self.logger.info("Processing text embeddings.")
|
366 |
-
# try:
|
367 |
-
# text_embeddings = self._process_text_embeddings(
|
368 |
-
# text_data, batch_size
|
369 |
-
# )
|
370 |
-
# except Exception as e:
|
371 |
-
# self.logger.error(f"Error generating text embeddings: {e}")
|
372 |
-
|
373 |
-
# # Process images and texts in parallel if both are present
|
374 |
-
# threads = []
|
375 |
-
# if images_data and text_data:
|
376 |
-
# image_thread = threading.Thread(target=process_images)
|
377 |
-
# text_thread = threading.Thread(target=process_texts)
|
378 |
-
# threads.extend([image_thread, text_thread])
|
379 |
-
# image_thread.start()
|
380 |
-
# text_thread.start()
|
381 |
-
# for thread in threads:
|
382 |
-
# thread.join()
|
383 |
-
# else:
|
384 |
-
# process_images()
|
385 |
-
# process_texts()
|
386 |
-
|
387 |
-
# # Compute similarity scores if both embeddings are available
|
388 |
-
# if image_embeddings is not None and text_embeddings is not None:
|
389 |
-
# self.logger.info("Computing similarity scores.")
|
390 |
-
# try:
|
391 |
-
# with torch.no_grad(), torch.amp.autocast():
|
392 |
-
# scores = self.processor.score_multi_vector(
|
393 |
-
# text_embeddings, image_embeddings
|
394 |
-
# )
|
395 |
-
# self.logger.info("Similarity scoring complete.")
|
396 |
-
# except Exception as e:
|
397 |
-
# self.logger.error(f"Error computing similarity scores: {e}")
|
398 |
-
# return {"error": f"Error computing similarity scores: {e}"}
|
399 |
-
|
400 |
-
# result = {}
|
401 |
-
# if image_embeddings is not None:
|
402 |
-
# result["image"] = image_embeddings.cpu().tolist()
|
403 |
-
# if text_embeddings is not None:
|
404 |
-
# result["text"] = text_embeddings.cpu().tolist()
|
405 |
-
# if scores is not None:
|
406 |
-
# result["scores"] = scores.cpu().tolist()
|
407 |
-
|
408 |
-
# return result
|
|
|
166 |
return {"error": f"Error computing similarity scores: {e}"}
|
167 |
|
168 |
return {"image": image_embeddings, "text": text_embeddings, "scores": scores}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|