amaye15 commited on
Commit
64262c3
·
1 Parent(s): dbabaf1

Logging added

Browse files
Files changed (1) hide show
  1. handler.py +150 -72
handler.py CHANGED
@@ -7,13 +7,13 @@
7
 
8
  # class EndpointHandler:
9
  # """
10
- # A handler class for processing image data, generating embeddings using a specified model and processor.
11
 
12
  # Attributes:
13
  # model: The pre-trained model used for generating embeddings.
14
- # processor: The pre-trained processor used to process images before model inference.
15
  # device: The device (CPU or CUDA) used to run model inference.
16
- # default_batch_size: The default batch size for processing images in batches.
17
  # """
18
 
19
  # def __init__(self, path: str = "", default_batch_size: int = 4):
@@ -22,13 +22,16 @@
22
 
23
  # Args:
24
  # path (str): Path to the pre-trained model and processor.
25
- # default_batch_size (int): Default batch size for image processing.
26
  # """
27
  # from colpali_engine.models import ColQwen2, ColQwen2Processor
28
 
29
  # self.model = ColQwen2.from_pretrained(
30
  # path,
31
  # torch_dtype=torch.bfloat16,
 
 
 
32
  # ).eval()
33
  # self.processor = ColQwen2Processor.from_pretrained(path)
34
 
@@ -36,7 +39,7 @@
36
  # self.model.to(self.device)
37
  # self.default_batch_size = default_batch_size
38
 
39
- # def _process_batch(self, images: List[Image.Image]) -> List[List[float]]:
40
  # """
41
  # Processes a batch of images and generates embeddings.
42
 
@@ -46,55 +49,97 @@
46
  # Returns:
47
  # List[List[float]]: List of embeddings for each image.
48
  # """
49
- # batch_images = self.processor.process_images(images)
50
- # batch_images = {k: v.to(self.device) for k, v in batch_images.items()}
51
 
52
  # with torch.no_grad():
53
  # image_embeddings = self.model(**batch_images)
54
 
55
  # return image_embeddings.cpu().tolist()
56
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
57
  # def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
58
  # """
59
- # Processes input data containing base64-encoded images, decodes them, and generates embeddings.
60
 
61
  # Args:
62
- # data (Dict[str, Any]): Dictionary containing input images and optional batch size.
63
 
64
  # Returns:
65
- # Dict[str, Any]: Dictionary containing generated embeddings or error messages.
66
  # """
67
- # images_data = data.get("inputs", [])
 
68
  # batch_size = data.get("batch_size", self.default_batch_size)
69
 
70
- # if not images_data:
71
- # return {"error": "No images provided in 'inputs'."}
72
-
73
  # images = []
74
- # for img_data in images_data:
75
- # if isinstance(img_data, str):
76
- # try:
77
- # image_bytes = base64.b64decode(img_data)
78
- # image = Image.open(BytesIO(image_bytes)).convert("RGB")
79
- # images.append(image)
80
- # except Exception as e:
81
- # return {"error": f"Invalid image data: {e}"}
82
- # else:
83
- # return {"error": "Images should be base64-encoded strings."}
84
-
85
- # embeddings = []
 
86
  # for i in range(0, len(images), batch_size):
87
  # batch_images = images[i : i + batch_size]
88
- # batch_embeddings = self._process_batch(batch_images)
89
- # embeddings.extend(batch_embeddings)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
90
 
91
- # return {"embeddings": embeddings}
92
 
93
  import torch
94
  from typing import Dict, Any, List
95
  from PIL import Image
96
  import base64
97
  from io import BytesIO
 
98
 
99
 
100
  class EndpointHandler:
@@ -116,20 +161,27 @@ class EndpointHandler:
116
  path (str): Path to the pre-trained model and processor.
117
  default_batch_size (int): Default batch size for processing images and text data.
118
  """
119
- from colpali_engine.models import ColQwen2, ColQwen2Processor
 
 
120
 
121
- self.model = ColQwen2.from_pretrained(
122
- path,
123
- torch_dtype=torch.bfloat16,
124
- device_map=(
125
- "cuda:0" if torch.cuda.is_available() else "cpu"
126
- ), # Set device map based on availability
127
- ).eval()
128
- self.processor = ColQwen2Processor.from_pretrained(path)
129
 
130
- self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
131
- self.model.to(self.device)
132
- self.default_batch_size = default_batch_size
 
 
 
 
 
 
 
 
 
 
 
 
133
 
134
  def _process_image_batch(self, images: List[Image.Image]) -> List[List[float]]:
135
  """
@@ -141,12 +193,16 @@ class EndpointHandler:
141
  Returns:
142
  List[List[float]]: List of embeddings for each image.
143
  """
144
- batch_images = self.processor.process_images(images).to(self.device)
145
-
146
- with torch.no_grad():
147
- image_embeddings = self.model(**batch_images)
148
-
149
- return image_embeddings.cpu().tolist()
 
 
 
 
150
 
151
  def _process_text_batch(self, texts: List[str]) -> List[List[float]]:
152
  """
@@ -158,12 +214,16 @@ class EndpointHandler:
158
  Returns:
159
  List[List[float]]: List of embeddings for each text query.
160
  """
161
- batch_queries = self.processor.process_queries(texts).to(self.device)
162
-
163
- with torch.no_grad():
164
- query_embeddings = self.model(**batch_queries)
165
-
166
- return query_embeddings.cpu().tolist()
 
 
 
 
167
 
168
  def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
169
  """
@@ -182,6 +242,7 @@ class EndpointHandler:
182
  # Decode and process images
183
  images = []
184
  if images_data:
 
185
  for img_data in images_data:
186
  if isinstance(img_data, str):
187
  try:
@@ -189,38 +250,55 @@ class EndpointHandler:
189
  image = Image.open(BytesIO(image_bytes)).convert("RGB")
190
  images.append(image)
191
  except Exception as e:
 
192
  return {"error": f"Invalid image data: {e}"}
193
  else:
 
194
  return {"error": "Images should be base64-encoded strings."}
195
 
196
  image_embeddings = []
197
- for i in range(0, len(images), batch_size):
198
- batch_images = images[i : i + batch_size]
199
- batch_embeddings = self._process_image_batch(batch_images)
200
- image_embeddings.extend(batch_embeddings)
 
 
 
 
 
 
201
 
202
  # Process text data
203
  text_embeddings = []
204
  if text_data:
205
- for i in range(0, len(text_data), batch_size):
206
- batch_texts = text_data[i : i + batch_size]
207
- batch_text_embeddings = self._process_text_batch(batch_texts)
208
- text_embeddings.extend(batch_text_embeddings)
 
 
 
 
 
209
 
210
  # Compute similarity scores if both image and text embeddings are available
211
  scores = []
212
  if image_embeddings and text_embeddings:
213
- # Convert embeddings to tensors for scoring
214
- image_embeddings_tensor = torch.tensor(image_embeddings).to(self.device)
215
- text_embeddings_tensor = torch.tensor(text_embeddings).to(self.device)
216
-
217
- with torch.no_grad():
218
- scores = (
219
- self.processor.score_multi_vector(
220
- text_embeddings_tensor, image_embeddings_tensor
 
 
 
221
  )
222
- .cpu()
223
- .tolist()
224
- )
 
225
 
226
  return {"image": image_embeddings, "text": text_embeddings, "scores": scores}
 
7
 
8
  # class EndpointHandler:
9
  # """
10
+ # A handler class for processing image and text data, generating embeddings using a specified model and processor.
11
 
12
  # Attributes:
13
  # model: The pre-trained model used for generating embeddings.
14
+ # processor: The pre-trained processor used to process images and text before model inference.
15
  # device: The device (CPU or CUDA) used to run model inference.
16
+ # default_batch_size: The default batch size for processing images and text in batches.
17
  # """
18
 
19
  # def __init__(self, path: str = "", default_batch_size: int = 4):
 
22
 
23
  # Args:
24
  # path (str): Path to the pre-trained model and processor.
25
+ # default_batch_size (int): Default batch size for processing images and text data.
26
  # """
27
  # from colpali_engine.models import ColQwen2, ColQwen2Processor
28
 
29
  # self.model = ColQwen2.from_pretrained(
30
  # path,
31
  # torch_dtype=torch.bfloat16,
32
+ # device_map=(
33
+ # "cuda:0" if torch.cuda.is_available() else "cpu"
34
+ # ), # Set device map based on availability
35
  # ).eval()
36
  # self.processor = ColQwen2Processor.from_pretrained(path)
37
 
 
39
  # self.model.to(self.device)
40
  # self.default_batch_size = default_batch_size
41
 
42
+ # def _process_image_batch(self, images: List[Image.Image]) -> List[List[float]]:
43
  # """
44
  # Processes a batch of images and generates embeddings.
45
 
 
49
  # Returns:
50
  # List[List[float]]: List of embeddings for each image.
51
  # """
52
+ # batch_images = self.processor.process_images(images).to(self.device)
 
53
 
54
  # with torch.no_grad():
55
  # image_embeddings = self.model(**batch_images)
56
 
57
  # return image_embeddings.cpu().tolist()
58
 
59
+ # def _process_text_batch(self, texts: List[str]) -> List[List[float]]:
60
+ # """
61
+ # Processes a batch of text queries and generates embeddings.
62
+
63
+ # Args:
64
+ # texts (List[str]): List of text queries to process.
65
+
66
+ # Returns:
67
+ # List[List[float]]: List of embeddings for each text query.
68
+ # """
69
+ # batch_queries = self.processor.process_queries(texts).to(self.device)
70
+
71
+ # with torch.no_grad():
72
+ # query_embeddings = self.model(**batch_queries)
73
+
74
+ # return query_embeddings.cpu().tolist()
75
+
76
  # def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
77
  # """
78
+ # Processes input data containing base64-encoded images and text queries, decodes them, and generates embeddings.
79
 
80
  # Args:
81
+ # data (Dict[str, Any]): Dictionary containing input images, text queries, and optional batch size.
82
 
83
  # Returns:
84
+ # Dict[str, Any]: Dictionary containing generated embeddings for images and text or error messages.
85
  # """
86
+ # images_data = data.get("image", [])
87
+ # text_data = data.get("text", [])
88
  # batch_size = data.get("batch_size", self.default_batch_size)
89
 
90
+ # # Decode and process images
 
 
91
  # images = []
92
+ # if images_data:
93
+ # for img_data in images_data:
94
+ # if isinstance(img_data, str):
95
+ # try:
96
+ # image_bytes = base64.b64decode(img_data)
97
+ # image = Image.open(BytesIO(image_bytes)).convert("RGB")
98
+ # images.append(image)
99
+ # except Exception as e:
100
+ # return {"error": f"Invalid image data: {e}"}
101
+ # else:
102
+ # return {"error": "Images should be base64-encoded strings."}
103
+
104
+ # image_embeddings = []
105
  # for i in range(0, len(images), batch_size):
106
  # batch_images = images[i : i + batch_size]
107
+ # batch_embeddings = self._process_image_batch(batch_images)
108
+ # image_embeddings.extend(batch_embeddings)
109
+
110
+ # # Process text data
111
+ # text_embeddings = []
112
+ # if text_data:
113
+ # for i in range(0, len(text_data), batch_size):
114
+ # batch_texts = text_data[i : i + batch_size]
115
+ # batch_text_embeddings = self._process_text_batch(batch_texts)
116
+ # text_embeddings.extend(batch_text_embeddings)
117
+
118
+ # # Compute similarity scores if both image and text embeddings are available
119
+ # scores = []
120
+ # if image_embeddings and text_embeddings:
121
+ # # Convert embeddings to tensors for scoring
122
+ # image_embeddings_tensor = torch.tensor(image_embeddings).to(self.device)
123
+ # text_embeddings_tensor = torch.tensor(text_embeddings).to(self.device)
124
+
125
+ # with torch.no_grad():
126
+ # scores = (
127
+ # self.processor.score_multi_vector(
128
+ # text_embeddings_tensor, image_embeddings_tensor
129
+ # )
130
+ # .cpu()
131
+ # .tolist()
132
+ # )
133
+
134
+ # return {"image": image_embeddings, "text": text_embeddings, "scores": scores}
135
 
 
136
 
137
  import torch
138
  from typing import Dict, Any, List
139
  from PIL import Image
140
  import base64
141
  from io import BytesIO
142
+ import logging
143
 
144
 
145
  class EndpointHandler:
 
161
  path (str): Path to the pre-trained model and processor.
162
  default_batch_size (int): Default batch size for processing images and text data.
163
  """
164
+ # Initialize logging
165
+ logging.basicConfig(level=logging.INFO)
166
+ self.logger = logging.getLogger(__name__)
167
 
168
+ from colpali_engine.models import ColQwen2, ColQwen2Processor
 
 
 
 
 
 
 
169
 
170
+ self.logger.info("Initializing model and processor.")
171
+ try:
172
+ self.model = ColQwen2.from_pretrained(
173
+ path,
174
+ torch_dtype=torch.bfloat16,
175
+ device_map=("cuda:0" if torch.cuda.is_available() else "cpu"),
176
+ ).eval()
177
+ self.processor = ColQwen2Processor.from_pretrained(path)
178
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
179
+ self.model.to(self.device)
180
+ self.default_batch_size = default_batch_size
181
+ self.logger.info("Initialization complete.")
182
+ except Exception as e:
183
+ self.logger.error(f"Failed to initialize model or processor: {e}")
184
+ raise
185
 
186
  def _process_image_batch(self, images: List[Image.Image]) -> List[List[float]]:
187
  """
 
193
  Returns:
194
  List[List[float]]: List of embeddings for each image.
195
  """
196
+ self.logger.debug(f"Processing batch of {len(images)} images.")
197
+ try:
198
+ batch_images = self.processor.process_images(images).to(self.device)
199
+ with torch.no_grad():
200
+ image_embeddings = self.model(**batch_images)
201
+ self.logger.debug("Image batch processing complete.")
202
+ return image_embeddings.cpu().tolist()
203
+ except Exception as e:
204
+ self.logger.error(f"Error processing image batch: {e}")
205
+ raise
206
 
207
  def _process_text_batch(self, texts: List[str]) -> List[List[float]]:
208
  """
 
214
  Returns:
215
  List[List[float]]: List of embeddings for each text query.
216
  """
217
+ self.logger.debug(f"Processing batch of {len(texts)} text queries.")
218
+ try:
219
+ batch_queries = self.processor.process_queries(texts).to(self.device)
220
+ with torch.no_grad():
221
+ query_embeddings = self.model(**batch_queries)
222
+ self.logger.debug("Text batch processing complete.")
223
+ return query_embeddings.cpu().tolist()
224
+ except Exception as e:
225
+ self.logger.error(f"Error processing text batch: {e}")
226
+ raise
227
 
228
  def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
229
  """
 
242
  # Decode and process images
243
  images = []
244
  if images_data:
245
+ self.logger.info("Decoding images from base64.")
246
  for img_data in images_data:
247
  if isinstance(img_data, str):
248
  try:
 
250
  image = Image.open(BytesIO(image_bytes)).convert("RGB")
251
  images.append(image)
252
  except Exception as e:
253
+ self.logger.error(f"Invalid image data: {e}")
254
  return {"error": f"Invalid image data: {e}"}
255
  else:
256
+ self.logger.error("Images should be base64-encoded strings.")
257
  return {"error": "Images should be base64-encoded strings."}
258
 
259
  image_embeddings = []
260
+ if images:
261
+ self.logger.info("Processing image embeddings.")
262
+ try:
263
+ for i in range(0, len(images), batch_size):
264
+ batch_images = images[i : i + batch_size]
265
+ batch_embeddings = self._process_image_batch(batch_images)
266
+ image_embeddings.extend(batch_embeddings)
267
+ except Exception as e:
268
+ self.logger.error(f"Error generating image embeddings: {e}")
269
+ return {"error": f"Error generating image embeddings: {e}"}
270
 
271
  # Process text data
272
  text_embeddings = []
273
  if text_data:
274
+ self.logger.info("Processing text embeddings.")
275
+ try:
276
+ for i in range(0, len(text_data), batch_size):
277
+ batch_texts = text_data[i : i + batch_size]
278
+ batch_text_embeddings = self._process_text_batch(batch_texts)
279
+ text_embeddings.extend(batch_text_embeddings)
280
+ except Exception as e:
281
+ self.logger.error(f"Error generating text embeddings: {e}")
282
+ return {"error": f"Error generating text embeddings: {e}"}
283
 
284
  # Compute similarity scores if both image and text embeddings are available
285
  scores = []
286
  if image_embeddings and text_embeddings:
287
+ self.logger.info("Computing similarity scores.")
288
+ try:
289
+ image_embeddings_tensor = torch.tensor(image_embeddings).to(self.device)
290
+ text_embeddings_tensor = torch.tensor(text_embeddings).to(self.device)
291
+ with torch.no_grad():
292
+ scores = (
293
+ self.processor.score_multi_vector(
294
+ text_embeddings_tensor, image_embeddings_tensor
295
+ )
296
+ .cpu()
297
+ .tolist()
298
  )
299
+ self.logger.info("Similarity scoring complete.")
300
+ except Exception as e:
301
+ self.logger.error(f"Error computing similarity scores: {e}")
302
+ return {"error": f"Error computing similarity scores: {e}"}
303
 
304
  return {"image": image_embeddings, "text": text_embeddings, "scores": scores}