Update processing_phi3_v.py

#30
by rageSpin - opened
Files changed (1) hide show
  1. processing_phi3_v.py +22 -4
processing_phi3_v.py CHANGED
@@ -56,7 +56,7 @@ logger = logging.get_logger(__name__)
56
  if is_vision_available():
57
  from PIL import Image
58
 
59
- import torch
60
  import torchvision
61
 
62
  def padding_336(b):
@@ -205,6 +205,22 @@ class Phi3VImageProcessor(BaseImageProcessor):
205
  num_img_tokens = int((new_height // 336 * new_width // 336 + 1) * 144 + 1 + (new_height // 336 + 1) * 12)
206
  return num_img_tokens
207
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
208
  def preprocess(
209
  self,
210
  images: ImageInput,
@@ -245,8 +261,8 @@ class Phi3VImageProcessor(BaseImageProcessor):
245
  "torch.Tensor, tf.Tensor or jax.ndarray."
246
  )
247
 
248
- if do_convert_rgb:
249
- images = [convert_to_rgb(image) for image in images]
250
 
251
  image_sizes = []
252
  img_processor = torchvision.transforms.Compose([
@@ -256,7 +272,9 @@ class Phi3VImageProcessor(BaseImageProcessor):
256
 
257
  # PIL images
258
  # HD_transform pad images to size of multiiply of 336, 336
259
- # convert to RGB first
 
 
260
  images = [image.convert('RGB') for image in images]
261
  elems = [HD_transform(im, hd_num = self.num_crops) for im in images]
262
  # tensor transform and normalize
 
56
  if is_vision_available():
57
  from PIL import Image
58
 
59
+ import torch # why is this imported twice?
60
  import torchvision
61
 
62
  def padding_336(b):
 
205
  num_img_tokens = int((new_height // 336 * new_width // 336 + 1) * 144 + 1 + (new_height // 336 + 1) * 12)
206
  return num_img_tokens
207
 
208
+ def convert_PIL(self, image):
209
+ """
210
+ Convert an image to a PIL Image object if it is not already one.
211
+
212
+ Args:
213
+ image (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`): The image to be converted. Can be a numpy array or a torch tensor or PIL object.
214
+
215
+ Returns:
216
+ A PIL Image object.
217
+ """
218
+ if not isinstance(image, Image.Image):
219
+ return torchvision.transforms.functional.to_pil_image(image)
220
+ else:
221
+ return image
222
+
223
+
224
  def preprocess(
225
  self,
226
  images: ImageInput,
 
261
  "torch.Tensor, tf.Tensor or jax.ndarray."
262
  )
263
 
264
+ # if do_convert_rgb:
265
+ # images = [convert_to_rgb(image) for image in images]
266
 
267
  image_sizes = []
268
  img_processor = torchvision.transforms.Compose([
 
272
 
273
  # PIL images
274
  # HD_transform pad images to size of multiiply of 336, 336
275
+ # check and convert if the images are in PIL format
276
+ images = [convert_PIL(image) for image in images]
277
+ # convert to RGB first (I think the argument "do_convert_rgb is useless, since it is forced here")
278
  images = [image.convert('RGB') for image in images]
279
  elems = [HD_transform(im, hd_num = self.num_crops) for im in images]
280
  # tensor transform and normalize