amaye15 commited on
Commit
25a8604
·
1 Parent(s): f4f3a3e

Handler Updated - Text Embeddings - Added

Browse files
Files changed (1) hide show
  1. handler.py +149 -30
handler.py CHANGED
@@ -1,3 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import torch
2
  from typing import Dict, Any, List
3
  from PIL import Image
@@ -7,13 +99,13 @@ from io import BytesIO
7
 
8
  class EndpointHandler:
9
  """
10
- A handler class for processing image data, generating embeddings using a specified model and processor.
11
 
12
  Attributes:
13
  model: The pre-trained model used for generating embeddings.
14
- processor: The pre-trained processor used to process images before model inference.
15
  device: The device (CPU or CUDA) used to run model inference.
16
- default_batch_size: The default batch size for processing images in batches.
17
  """
18
 
19
  def __init__(self, path: str = "", default_batch_size: int = 4):
@@ -22,13 +114,16 @@ class EndpointHandler:
22
 
23
  Args:
24
  path (str): Path to the pre-trained model and processor.
25
- default_batch_size (int): Default batch size for image processing.
26
  """
27
  from colpali_engine.models import ColQwen2, ColQwen2Processor
28
 
29
  self.model = ColQwen2.from_pretrained(
30
  path,
31
  torch_dtype=torch.bfloat16,
 
 
 
32
  ).eval()
33
  self.processor = ColQwen2Processor.from_pretrained(path)
34
 
@@ -36,7 +131,7 @@ class EndpointHandler:
36
  self.model.to(self.device)
37
  self.default_batch_size = default_batch_size
38
 
39
- def _process_batch(self, images: List[Image.Image]) -> List[List[float]]:
40
  """
41
  Processes a batch of images and generates embeddings.
42
 
@@ -46,46 +141,70 @@ class EndpointHandler:
46
  Returns:
47
  List[List[float]]: List of embeddings for each image.
48
  """
49
- batch_images = self.processor.process_images(images)
50
- batch_images = {k: v.to(self.device) for k, v in batch_images.items()}
51
 
52
  with torch.no_grad():
53
  image_embeddings = self.model(**batch_images)
54
 
55
  return image_embeddings.cpu().tolist()
56
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
57
  def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
58
  """
59
- Processes input data containing base64-encoded images, decodes them, and generates embeddings.
60
 
61
  Args:
62
- data (Dict[str, Any]): Dictionary containing input images and optional batch size.
63
 
64
  Returns:
65
- Dict[str, Any]: Dictionary containing generated embeddings or error messages.
66
  """
67
- images_data = data.get("inputs", [])
 
68
  batch_size = data.get("batch_size", self.default_batch_size)
69
 
70
- if not images_data:
71
- return {"error": "No images provided in 'inputs'."}
72
-
73
  images = []
74
- for img_data in images_data:
75
- if isinstance(img_data, str):
76
- try:
77
- image_bytes = base64.b64decode(img_data)
78
- image = Image.open(BytesIO(image_bytes)).convert("RGB")
79
- images.append(image)
80
- except Exception as e:
81
- return {"error": f"Invalid image data: {e}"}
82
- else:
83
- return {"error": "Images should be base64-encoded strings."}
84
-
85
- embeddings = []
 
86
  for i in range(0, len(images), batch_size):
87
  batch_images = images[i : i + batch_size]
88
- batch_embeddings = self._process_batch(batch_images)
89
- embeddings.extend(batch_embeddings)
90
-
91
- return {"embeddings": embeddings}
 
 
 
 
 
 
 
 
 
1
+ # import torch
2
+ # from typing import Dict, Any, List
3
+ # from PIL import Image
4
+ # import base64
5
+ # from io import BytesIO
6
+
7
+
8
+ # class EndpointHandler:
9
+ # """
10
+ # A handler class for processing image data, generating embeddings using a specified model and processor.
11
+
12
+ # Attributes:
13
+ # model: The pre-trained model used for generating embeddings.
14
+ # processor: The pre-trained processor used to process images before model inference.
15
+ # device: The device (CPU or CUDA) used to run model inference.
16
+ # default_batch_size: The default batch size for processing images in batches.
17
+ # """
18
+
19
+ # def __init__(self, path: str = "", default_batch_size: int = 4):
20
+ # """
21
+ # Initializes the EndpointHandler with a specified model path and default batch size.
22
+
23
+ # Args:
24
+ # path (str): Path to the pre-trained model and processor.
25
+ # default_batch_size (int): Default batch size for image processing.
26
+ # """
27
+ # from colpali_engine.models import ColQwen2, ColQwen2Processor
28
+
29
+ # self.model = ColQwen2.from_pretrained(
30
+ # path,
31
+ # torch_dtype=torch.bfloat16,
32
+ # ).eval()
33
+ # self.processor = ColQwen2Processor.from_pretrained(path)
34
+
35
+ # self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
36
+ # self.model.to(self.device)
37
+ # self.default_batch_size = default_batch_size
38
+
39
+ # def _process_batch(self, images: List[Image.Image]) -> List[List[float]]:
40
+ # """
41
+ # Processes a batch of images and generates embeddings.
42
+
43
+ # Args:
44
+ # images (List[Image.Image]): List of images to process.
45
+
46
+ # Returns:
47
+ # List[List[float]]: List of embeddings for each image.
48
+ # """
49
+ # batch_images = self.processor.process_images(images)
50
+ # batch_images = {k: v.to(self.device) for k, v in batch_images.items()}
51
+
52
+ # with torch.no_grad():
53
+ # image_embeddings = self.model(**batch_images)
54
+
55
+ # return image_embeddings.cpu().tolist()
56
+
57
+ # def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
58
+ # """
59
+ # Processes input data containing base64-encoded images, decodes them, and generates embeddings.
60
+
61
+ # Args:
62
+ # data (Dict[str, Any]): Dictionary containing input images and optional batch size.
63
+
64
+ # Returns:
65
+ # Dict[str, Any]: Dictionary containing generated embeddings or error messages.
66
+ # """
67
+ # images_data = data.get("inputs", [])
68
+ # batch_size = data.get("batch_size", self.default_batch_size)
69
+
70
+ # if not images_data:
71
+ # return {"error": "No images provided in 'inputs'."}
72
+
73
+ # images = []
74
+ # for img_data in images_data:
75
+ # if isinstance(img_data, str):
76
+ # try:
77
+ # image_bytes = base64.b64decode(img_data)
78
+ # image = Image.open(BytesIO(image_bytes)).convert("RGB")
79
+ # images.append(image)
80
+ # except Exception as e:
81
+ # return {"error": f"Invalid image data: {e}"}
82
+ # else:
83
+ # return {"error": "Images should be base64-encoded strings."}
84
+
85
+ # embeddings = []
86
+ # for i in range(0, len(images), batch_size):
87
+ # batch_images = images[i : i + batch_size]
88
+ # batch_embeddings = self._process_batch(batch_images)
89
+ # embeddings.extend(batch_embeddings)
90
+
91
+ # return {"embeddings": embeddings}
92
+
93
  import torch
94
  from typing import Dict, Any, List
95
  from PIL import Image
 
99
 
100
  class EndpointHandler:
101
  """
102
+ A handler class for processing image and text data, generating embeddings using a specified model and processor.
103
 
104
  Attributes:
105
  model: The pre-trained model used for generating embeddings.
106
+ processor: The pre-trained processor used to process images and text before model inference.
107
  device: The device (CPU or CUDA) used to run model inference.
108
+ default_batch_size: The default batch size for processing images and text in batches.
109
  """
110
 
111
  def __init__(self, path: str = "", default_batch_size: int = 4):
 
114
 
115
  Args:
116
  path (str): Path to the pre-trained model and processor.
117
+ default_batch_size (int): Default batch size for processing images and text data.
118
  """
119
  from colpali_engine.models import ColQwen2, ColQwen2Processor
120
 
121
  self.model = ColQwen2.from_pretrained(
122
  path,
123
  torch_dtype=torch.bfloat16,
124
+ device_map=(
125
+ "cuda:0" if torch.cuda.is_available() else "cpu"
126
+ ), # Set device map based on availability
127
  ).eval()
128
  self.processor = ColQwen2Processor.from_pretrained(path)
129
 
 
131
  self.model.to(self.device)
132
  self.default_batch_size = default_batch_size
133
 
134
+ def _process_image_batch(self, images: List[Image.Image]) -> List[List[float]]:
135
  """
136
  Processes a batch of images and generates embeddings.
137
 
 
141
  Returns:
142
  List[List[float]]: List of embeddings for each image.
143
  """
144
+ batch_images = self.processor.process_images(images).to(self.device)
 
145
 
146
  with torch.no_grad():
147
  image_embeddings = self.model(**batch_images)
148
 
149
  return image_embeddings.cpu().tolist()
150
 
151
+ def _process_text_batch(self, texts: List[str]) -> List[List[float]]:
152
+ """
153
+ Processes a batch of text queries and generates embeddings.
154
+
155
+ Args:
156
+ texts (List[str]): List of text queries to process.
157
+
158
+ Returns:
159
+ List[List[float]]: List of embeddings for each text query.
160
+ """
161
+ batch_queries = self.processor.process_queries(texts).to(self.device)
162
+
163
+ with torch.no_grad():
164
+ query_embeddings = self.model(**batch_queries)
165
+
166
+ return query_embeddings.cpu().tolist()
167
+
168
  def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
169
  """
170
+ Processes input data containing base64-encoded images and text queries, decodes them, and generates embeddings.
171
 
172
  Args:
173
+ data (Dict[str, Any]): Dictionary containing input images, text queries, and optional batch size.
174
 
175
  Returns:
176
+ Dict[str, Any]: Dictionary containing generated embeddings for images and text or error messages.
177
  """
178
+ images_data = data.get("image", [])
179
+ text_data = data.get("text", [])
180
  batch_size = data.get("batch_size", self.default_batch_size)
181
 
182
+ # Decode and process images
 
 
183
  images = []
184
+ if images_data:
185
+ for img_data in images_data:
186
+ if isinstance(img_data, str):
187
+ try:
188
+ image_bytes = base64.b64decode(img_data)
189
+ image = Image.open(BytesIO(image_bytes)).convert("RGB")
190
+ images.append(image)
191
+ except Exception as e:
192
+ return {"error": f"Invalid image data: {e}"}
193
+ else:
194
+ return {"error": "Images should be base64-encoded strings."}
195
+
196
+ image_embeddings = []
197
  for i in range(0, len(images), batch_size):
198
  batch_images = images[i : i + batch_size]
199
+ batch_embeddings = self._process_image_batch(batch_images)
200
+ image_embeddings.extend(batch_embeddings)
201
+
202
+ # Process text data
203
+ text_embeddings = []
204
+ if text_data:
205
+ for i in range(0, len(text_data), batch_size):
206
+ batch_texts = text_data[i : i + batch_size]
207
+ batch_text_embeddings = self._process_text_batch(batch_texts)
208
+ text_embeddings.extend(batch_text_embeddings)
209
+
210
+ return {"image": image_embeddings, "text": text_embeddings}