Image-Text-to-Text
sentence-transformers
Safetensors
Transformers
qwen2_vl
Qwen2-VL
conversational
cheesyFishes commited on
Commit
c01a17b
·
verified ·
1 Parent(s): 719ef6e

even more device handling

Browse files
Files changed (1) hide show
  1. custom_st.py +121 -5
custom_st.py CHANGED
@@ -4,6 +4,7 @@ import os
4
  import math
5
  from io import BytesIO
6
  from typing import Any, Dict, List, Literal, Optional, Union
 
7
 
8
  import requests
9
  import torch
@@ -121,27 +122,142 @@ class Transformer(nn.Module):
121
  image_data = base64.b64decode(data)
122
  return Image.open(BytesIO(image_data))
123
 
124
- def _process_input(self, texts: List[Union[str, Image.Image]]) -> tuple[List[str], List[Image.Image]]:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
125
  processed_texts = []
126
  processed_images = []
127
  dummy_image = Image.new('RGB', (56, 56))
128
 
129
  for sample in texts:
130
  if isinstance(sample, str):
131
- processed_texts.append(self.query_prompt % sample)
132
- processed_images.append(dummy_image)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
133
  elif isinstance(sample, Image.Image):
134
  processed_texts.append(self.document_prompt)
135
  processed_images.append(self._resize_image(sample))
 
 
 
 
 
 
 
 
 
136
 
137
  return processed_texts, processed_images
138
 
139
  def forward(self, features: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
140
- cache_position = torch.arange(0, features['input_ids'].shape[0])
141
  inputs = self.model.prepare_inputs_for_generation(
142
  **features, cache_position=cache_position, use_cache=False
143
  )
144
 
 
 
 
 
145
  with torch.no_grad():
146
  output = self.model(
147
  **inputs,
@@ -155,7 +271,7 @@ class Transformer(nn.Module):
155
  )
156
  return features
157
 
158
- def tokenize(self, texts: List[Union[str, Image.Image]], padding: str = 'longest') -> Dict[str, torch.Tensor]:
159
  processed_texts, processed_images = self._process_input(texts)
160
 
161
  return self.processor(
 
4
  import math
5
  from io import BytesIO
6
  from typing import Any, Dict, List, Literal, Optional, Union
7
+ from urllib.parse import urlparse
8
 
9
  import requests
10
  import torch
 
122
  image_data = base64.b64decode(data)
123
  return Image.open(BytesIO(image_data))
124
 
125
+ @staticmethod
126
+ def _is_valid_url(url: str) -> bool:
127
+ try:
128
+ result = urlparse(url)
129
+ # Check if scheme and netloc are present and scheme is http/https
130
+ return all([result.scheme in ('http', 'https'), result.netloc])
131
+ except Exception:
132
+ return False
133
+
134
+ @staticmethod
135
+ def _is_safe_path(path: str) -> bool:
136
+ try:
137
+ # Convert to absolute path and normalize
138
+ abs_path = os.path.abspath(os.path.normpath(path))
139
+ # Check if file exists and is a regular file (not a directory or special file)
140
+ return os.path.isfile(abs_path)
141
+ except Exception:
142
+ return False
143
+
144
+ @staticmethod
145
+ def _load_image_from_url(url: str) -> Image.Image:
146
+ try:
147
+ response = requests.get(
148
+ url,
149
+ stream=True,
150
+ timeout=10, # Add timeout
151
+ headers={'User-Agent': 'Mozilla/5.0'} # Add user agent
152
+ )
153
+ response.raise_for_status()
154
+
155
+ # Check content type
156
+ content_type = response.headers.get('content-type', '')
157
+ if not content_type.startswith('image/'):
158
+ raise ValueError(f"Invalid content type: {content_type}")
159
+
160
+ # Limit file size (e.g., 10MB)
161
+ content = BytesIO()
162
+ size = 0
163
+ max_size = 10 * 1024 * 1024 # 10MB
164
+
165
+ for chunk in response.iter_content(chunk_size=8192):
166
+ size += len(chunk)
167
+ if size > max_size:
168
+ raise ValueError("File too large")
169
+ content.write(chunk)
170
+
171
+ content.seek(0)
172
+ return Image.open(content)
173
+ except Exception as e:
174
+ raise ValueError(f"Failed to load image from URL: {str(e)}")
175
+
176
+ @staticmethod
177
+ def _load_image_from_path(image_path: str) -> Image.Image:
178
+ try:
179
+ # Convert to absolute path and normalize
180
+ abs_path = os.path.abspath(os.path.normpath(image_path))
181
+
182
+ # Check file size before loading
183
+ file_size = os.path.getsize(abs_path)
184
+ max_size = 10 * 1024 * 1024 # 10MB
185
+ if file_size > max_size:
186
+ raise ValueError("File too large")
187
+
188
+ with Image.open(abs_path) as img:
189
+ # Make a copy to ensure file handle is closed
190
+ return img.copy()
191
+ except Exception as e:
192
+ raise ValueError(f"Failed to load image from path: {str(e)}")
193
+
194
+ @staticmethod
195
+ def _load_image_from_bytes(image_bytes: bytes) -> Image.Image:
196
+ try:
197
+ # Check size
198
+ if len(image_bytes) > 10 * 1024 * 1024: # 10MB
199
+ raise ValueError("Image data too large")
200
+
201
+ return Image.open(BytesIO(image_bytes))
202
+ except Exception as e:
203
+ raise ValueError(f"Failed to load image from bytes: {str(e)}")
204
+
205
+ def _process_input(self, texts: List[Union[str, Image.Image, bytes]]) -> tuple[List[str], List[Image.Image]]:
206
  processed_texts = []
207
  processed_images = []
208
  dummy_image = Image.new('RGB', (56, 56))
209
 
210
  for sample in texts:
211
  if isinstance(sample, str):
212
+ # Check if the string is a valid URL
213
+ if self._is_valid_url(sample):
214
+ try:
215
+ img = self._load_image_from_url(sample)
216
+ processed_texts.append(self.document_prompt)
217
+ processed_images.append(self._resize_image(img))
218
+ except Exception as e:
219
+ # If URL loading fails, treat as regular text
220
+ processed_texts.append(self.query_prompt % sample)
221
+ processed_images.append(dummy_image)
222
+ # Check if the string is a valid file path
223
+ elif self._is_safe_path(sample):
224
+ try:
225
+ img = self._load_image_from_path(sample)
226
+ processed_texts.append(self.document_prompt)
227
+ processed_images.append(self._resize_image(img))
228
+ except Exception as e:
229
+ # If image loading fails, treat as regular text
230
+ processed_texts.append(self.query_prompt % sample)
231
+ processed_images.append(dummy_image)
232
+ else:
233
+ # Regular text query
234
+ processed_texts.append(self.query_prompt % sample)
235
+ processed_images.append(dummy_image)
236
  elif isinstance(sample, Image.Image):
237
  processed_texts.append(self.document_prompt)
238
  processed_images.append(self._resize_image(sample))
239
+ elif isinstance(sample, bytes):
240
+ try:
241
+ img = self._load_image_from_bytes(sample)
242
+ processed_texts.append(self.document_prompt)
243
+ processed_images.append(self._resize_image(img))
244
+ except Exception as e:
245
+ # If bytes can't be converted to image, use dummy
246
+ processed_texts.append(self.document_prompt)
247
+ processed_images.append(dummy_image)
248
 
249
  return processed_texts, processed_images
250
 
251
  def forward(self, features: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
252
+ cache_position = torch.arange(0, features['input_ids'].shape[1])
253
  inputs = self.model.prepare_inputs_for_generation(
254
  **features, cache_position=cache_position, use_cache=False
255
  )
256
 
257
+ # ensure inputs are on the same device as the model
258
+ device = next(self.model.parameters()).device
259
+ inputs = {k: v.to(device) for k, v in inputs.items() if isinstance(v, torch.Tensor)}
260
+
261
  with torch.no_grad():
262
  output = self.model(
263
  **inputs,
 
271
  )
272
  return features
273
 
274
+ def tokenize(self, texts: List[Union[str, Image.Image, bytes]], padding: str = 'longest') -> Dict[str, torch.Tensor]:
275
  processed_texts, processed_images = self._process_input(texts)
276
 
277
  return self.processor(