amaye15 commited on
Commit
54d3e5c
·
1 Parent(s): b0a7877
Files changed (1) hide show
  1. handler.py +0 -240
handler.py CHANGED
@@ -166,243 +166,3 @@ class EndpointHandler:
166
  return {"error": f"Error computing similarity scores: {e}"}
167
 
168
  return {"image": image_embeddings, "text": text_embeddings, "scores": scores}
169
-
170
-
171
- # import torch
172
- # from typing import Dict, Any, List
173
- # from PIL import Image
174
- # import base64
175
- # from io import BytesIO
176
- # import logging
177
- # from torch.utils.data import DataLoader, Dataset
178
- # import threading
179
-
180
-
181
- # class ImageDataset(Dataset):
182
- # def __init__(self, images: List[Image.Image], processor):
183
- # self.images = images
184
- # self.processor = processor
185
-
186
- # def __len__(self):
187
- # return len(self.images)
188
-
189
- # def __getitem__(self, idx):
190
- # image = self.processor.process_images([self.images[idx]])
191
- # return image
192
-
193
-
194
- # class TextDataset(Dataset):
195
- # def __init__(self, texts: List[str], processor):
196
- # self.texts = texts
197
- # self.processor = processor
198
-
199
- # def __len__(self):
200
- # return len(self.texts)
201
-
202
- # def __getitem__(self, idx):
203
- # text = self.processor.process_queries([self.texts[idx]])
204
- # return text
205
-
206
-
207
- # class EndpointHandler:
208
- # """
209
- # A handler class for processing image and text data, generating embeddings using a specified model and processor.
210
-
211
- # Attributes:
212
- # model: The pre-trained model used for generating embeddings.
213
- # processor: The pre-trained processor used to process images and text before model inference.
214
- # device: The device (CPU or CUDA) used to run model inference.
215
- # default_batch_size: The default batch size for processing images and text in batches.
216
- # """
217
-
218
- # def __init__(self, path: str = "", default_batch_size: int = 4):
219
- # """
220
- # Initializes the EndpointHandler with a specified model path and default batch size.
221
- # """
222
- # # Initialize logging
223
- # logging.basicConfig(level=logging.INFO)
224
- # self.logger = logging.getLogger(__name__)
225
-
226
- # from colpali_engine.models import ColQwen2, ColQwen2Processor
227
-
228
- # self.logger.info("Initializing model and processor.")
229
- # try:
230
- # self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
231
-
232
- # self.model = (
233
- # ColQwen2.from_pretrained(
234
- # path,
235
- # torch_dtype=torch.bfloat16,
236
- # device_map="auto",
237
- # )
238
- # .to(self.device)
239
- # .eval()
240
- # )
241
-
242
- # self.processor = ColQwen2Processor.from_pretrained(path)
243
- # self.default_batch_size = default_batch_size
244
- # self.logger.info("Initialization complete.")
245
- # except Exception as e:
246
- # self.logger.error(f"Failed to initialize model or processor: {e}")
247
- # raise
248
-
249
- # def _process_image_embeddings(
250
- # self, images: List[Image.Image], batch_size: int
251
- # ) -> torch.Tensor:
252
- # """
253
- # Processes images and generates embeddings.
254
-
255
- # Args:
256
- # images (List[Image.Image]): List of images to process.
257
- # batch_size (int): Batch size for processing images.
258
-
259
- # Returns:
260
- # torch.Tensor: Tensor containing embeddings for each image.
261
- # """
262
- # self.logger.debug(f"Processing {len(images)} images.")
263
- # try:
264
- # image_dataset = ImageDataset(images, self.processor)
265
- # image_loader = DataLoader(
266
- # image_dataset, batch_size=batch_size, num_workers=4, pin_memory=True
267
- # )
268
-
269
- # all_embeddings = []
270
- # with torch.no_grad():
271
- # for batch in image_loader:
272
- # batch_images = batch[0].to(self.device, non_blocking=True)
273
- # with torch.cuda.amp.autocast():
274
- # embeddings = self.model(**batch_images)
275
- # all_embeddings.append(embeddings)
276
- # image_embeddings = torch.cat(all_embeddings, dim=0)
277
- # self.logger.debug("Image processing complete.")
278
- # return image_embeddings
279
- # except Exception as e:
280
- # self.logger.error(f"Error processing images: {e}")
281
- # raise
282
-
283
- # def _process_text_embeddings(
284
- # self, texts: List[str], batch_size: int
285
- # ) -> torch.Tensor:
286
- # """
287
- # Processes text queries and generates embeddings.
288
-
289
- # Args:
290
- # texts (List[str]): List of text queries to process.
291
- # batch_size (int): Batch size for processing texts.
292
-
293
- # Returns:
294
- # torch.Tensor: Tensor containing embeddings for each text query.
295
- # """
296
- # self.logger.debug(f"Processing {len(texts)} text queries.")
297
- # try:
298
- # text_dataset = TextDataset(texts, self.processor)
299
- # text_loader = DataLoader(
300
- # text_dataset, batch_size=batch_size, num_workers=4, pin_memory=True
301
- # )
302
-
303
- # all_embeddings = []
304
- # with torch.no_grad():
305
- # for batch in text_loader:
306
- # batch_texts = batch[0].to(self.device, non_blocking=True)
307
- # with torch.amp.autocast():
308
- # embeddings = self.model(**batch_texts)
309
- # all_embeddings.append(embeddings)
310
- # text_embeddings = torch.cat(all_embeddings, dim=0)
311
- # self.logger.debug("Text processing complete.")
312
- # return text_embeddings
313
- # except Exception as e:
314
- # self.logger.error(f"Error processing texts: {e}")
315
- # raise
316
-
317
- # def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
318
- # """
319
- # Processes input data containing base64-encoded images and text queries, decodes them, and generates embeddings.
320
-
321
- # Args:
322
- # data (Dict[str, Any]): Dictionary containing input images, text queries, and optional batch size.
323
-
324
- # Returns:
325
- # Dict[str, Any]: Dictionary containing generated embeddings for images and text or error messages.
326
- # """
327
- # images_data = data.get("image", [])
328
- # text_data = data.get("text", [])
329
- # batch_size = data.get("batch_size", self.default_batch_size)
330
-
331
- # images = []
332
- # if images_data:
333
- # self.logger.info("Decoding images from base64.")
334
- # for img_data in images_data:
335
- # if isinstance(img_data, str):
336
- # try:
337
- # image_bytes = base64.b64decode(img_data)
338
- # image = Image.open(BytesIO(image_bytes)).convert("RGB")
339
- # images.append(image)
340
- # except Exception as e:
341
- # self.logger.error(f"Invalid image data: {e}")
342
- # return {"error": f"Invalid image data: {e}"}
343
- # else:
344
- # self.logger.error("Images should be base64-encoded strings.")
345
- # return {"error": "Images should be base64-encoded strings."}
346
-
347
- # image_embeddings = None
348
- # text_embeddings = None
349
- # scores = None
350
-
351
- # def process_images():
352
- # nonlocal image_embeddings
353
- # if images:
354
- # self.logger.info("Processing image embeddings.")
355
- # try:
356
- # image_embeddings = self._process_image_embeddings(
357
- # images, batch_size
358
- # )
359
- # except Exception as e:
360
- # self.logger.error(f"Error generating image embeddings: {e}")
361
-
362
- # def process_texts():
363
- # nonlocal text_embeddings
364
- # if text_data:
365
- # self.logger.info("Processing text embeddings.")
366
- # try:
367
- # text_embeddings = self._process_text_embeddings(
368
- # text_data, batch_size
369
- # )
370
- # except Exception as e:
371
- # self.logger.error(f"Error generating text embeddings: {e}")
372
-
373
- # # Process images and texts in parallel if both are present
374
- # threads = []
375
- # if images_data and text_data:
376
- # image_thread = threading.Thread(target=process_images)
377
- # text_thread = threading.Thread(target=process_texts)
378
- # threads.extend([image_thread, text_thread])
379
- # image_thread.start()
380
- # text_thread.start()
381
- # for thread in threads:
382
- # thread.join()
383
- # else:
384
- # process_images()
385
- # process_texts()
386
-
387
- # # Compute similarity scores if both embeddings are available
388
- # if image_embeddings is not None and text_embeddings is not None:
389
- # self.logger.info("Computing similarity scores.")
390
- # try:
391
- # with torch.no_grad(), torch.amp.autocast():
392
- # scores = self.processor.score_multi_vector(
393
- # text_embeddings, image_embeddings
394
- # )
395
- # self.logger.info("Similarity scoring complete.")
396
- # except Exception as e:
397
- # self.logger.error(f"Error computing similarity scores: {e}")
398
- # return {"error": f"Error computing similarity scores: {e}"}
399
-
400
- # result = {}
401
- # if image_embeddings is not None:
402
- # result["image"] = image_embeddings.cpu().tolist()
403
- # if text_embeddings is not None:
404
- # result["text"] = text_embeddings.cpu().tolist()
405
- # if scores is not None:
406
- # result["scores"] = scores.cpu().tolist()
407
-
408
- # return result
 
166
  return {"error": f"Error computing similarity scores: {e}"}
167
 
168
  return {"image": image_embeddings, "text": text_embeddings, "scores": scores}