hperkins commited on
Commit
acc9b5d
·
verified ·
1 Parent(s): d9d7db9

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +110 -57
handler.py CHANGED
@@ -1,75 +1,128 @@
1
- import json
2
- import torch
3
- from transformers import Qwen2VLForConditionalGeneration, AutoProcessor, pipeline, AutoImageProcessor
4
  from qwen_vl_utils import process_vision_info
 
 
 
 
 
 
 
 
5
 
6
- class EndpointHandler:
7
- def __init__(self, model_dir):
8
- # Setup device configuration
9
- self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
 
 
10
 
11
- try:
12
- self.model = Qwen2VLForConditionalGeneration.from_pretrained(
13
- model_dir,
14
- torch_dtype=torch.float16,
15
- device_map="auto"
16
- )
17
- self.model.to(self.device)
18
- except Exception as e:
19
- print(f"Error loading model: {e}")
20
- raise
21
 
22
- try:
23
- self.processor = AutoProcessor.from_pretrained(model_dir)
24
- self.image_processor = AutoImageProcessor.from_pretrained(model_dir) # Ensure you have the correct processor
25
- except Exception as e:
26
- print(f"Error loading processor: {e}")
27
- raise
28
 
29
- self.vqa_pipeline = pipeline(
30
- task="visual-question-answering",
31
- model=self.model,
32
- image_processor=self.image_processor, # Explicit image processor if needed
33
- device=0 if torch.cuda.is_available() else -1
34
  )
35
-
36
- def preprocess(self, request_data):
37
- messages = request_data.get('messages')
38
- if not messages:
39
- raise ValueError("Missing 'messages' in request data.")
40
-
41
  image_inputs, video_inputs = process_vision_info(messages)
42
- text = self.processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
43
 
44
  inputs = self.processor(
45
  text=[text],
46
  images=image_inputs,
47
  videos=video_inputs,
48
  padding=True,
49
- return_tensors="pt"
50
- ).to(self.device)
 
51
 
52
- return inputs
 
 
 
 
 
 
53
 
54
- def inference(self, inputs):
55
- with torch.no_grad():
56
- result = self.vqa_pipeline(
57
- images=inputs.get("images", None),
58
- videos=inputs.get("videos", None),
59
- question=inputs["text"]
60
- )
61
- return result
62
 
63
- def postprocess(self, inference_output):
64
- return json.dumps(inference_output)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
65
 
66
- def __call__(self, request):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
67
  try:
68
- request_data = json.loads(request)
69
- inputs = self.preprocess(request_data)
70
- result = self.inference(inputs)
71
- return self.postprocess(result)
72
- except Exception as e:
73
- error_message = f"Error: {str(e)}"
74
- print(error_message)
75
- return json.dumps({"error": error_message})
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, List, Any
2
+ from transformers import Qwen2VLForConditionalGeneration, AutoProcessor
3
+ from modelscope import snapshot_download
4
  from qwen_vl_utils import process_vision_info
5
+ import torch
6
+ import os
7
+ import base64
8
+ import io
9
+ from PIL import Image
10
+ import ffmpeg
11
+ import logging
12
+ import requests
13
 
14
+ class EndpointHandler():
15
+ def __init__(self, path=""):
16
+ self.model_dir = path
17
+ self.model = Qwen2VLForConditionalGeneration.from_pretrained(
18
+ self.model_dir, torch_dtype="auto", device_map="auto"
19
+ )
20
+ self.processor = AutoProcessor.from_pretrained(self.model_dir)
21
 
22
+ def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
23
+ """
24
+ data args:
25
+ inputs (str): The input text, including any image or video references.
26
+ max_new_tokens (int, optional): Maximum number of tokens to generate. Defaults to 128.
27
+ Return:
28
+ A dictionary containing the generated text.
29
+ """
30
+ inputs = data.get("inputs")
31
+ max_new_tokens = data.get("max_new_tokens", 128)
32
 
33
+ # Construct the messages list from the input string
34
+ messages = [{"role": "user", "content": self._parse_input(inputs)}]
 
 
 
 
35
 
36
+ text = self.processor.apply_chat_template(
37
+ messages, tokenize=False, add_generation_prompt=True
 
 
 
38
  )
 
 
 
 
 
 
39
  image_inputs, video_inputs = process_vision_info(messages)
 
40
 
41
  inputs = self.processor(
42
  text=[text],
43
  images=image_inputs,
44
  videos=video_inputs,
45
  padding=True,
46
+ return_tensors="pt",
47
+ )
48
+ inputs = inputs.to("cuda" if torch.cuda.is_available() else "cpu")
49
 
50
+ generated_ids = self.model.generate(**inputs, max_new_tokens=max_new_tokens)
51
+ generated_ids_trimmed = [
52
+ out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
53
+ ]
54
+ output_text = self.processor.batch_decode(
55
+ generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
56
+ )[0] # Return a single string
57
 
58
+ return {"generated_text": output_text}
 
 
 
 
 
 
 
59
 
60
+ def _parse_input(self, input_string):
61
+ """Parses the input string to identify image/video references and text."""
62
+ content = []
63
+ parts = input_string.split("<image>")
64
+ for i, part in enumerate(parts):
65
+ if i % 2 == 0: # Text part
66
+ content.append({"type": "text", "text": part.strip()})
67
+ else: # Image/video part
68
+ if part.startswith("video:"):
69
+ video_path = part.split("video:")[1].strip()
70
+ video_frames = self._extract_video_frames(video_path)
71
+ if video_frames:
72
+ content.append({"type": "video", "video": video_frames, "fps": 1}) # Add fps
73
+ else:
74
+ image = self._load_image(part.strip())
75
+ if image:
76
+ content.append({"type": "image", "image": image})
77
+ return content
78
 
79
+ def _load_image(self, image_data):
80
+ """Loads an image from a URL or base64 encoded string."""
81
+ if image_data.startswith("http"):
82
+ try:
83
+ image = Image.open(requests.get(image_data, stream=True).raw)
84
+ except Exception as e:
85
+ logging.error(f"Error loading image from URL: {e}")
86
+ return None
87
+ elif image_data.startswith("data:image"):
88
+ try:
89
+ image_data = image_data.split(",")[1]
90
+ image_bytes = base64.b64decode(image_data)
91
+ image = Image.open(io.BytesIO(image_bytes))
92
+ except Exception as e:
93
+ logging.error(f"Error loading image from base64: {e}")
94
+ return None
95
+ else:
96
+ logging.error("Invalid image data format. Must be URL or base64 encoded.")
97
+ return None
98
+ return image
99
+
100
+ def _extract_video_frames(self, video_path, fps=1):
101
+ """Extracts frames from a video at the specified FPS."""
102
  try:
103
+ probe = ffmpeg.probe(video_path)
104
+ video_stream = next((stream for stream in probe['streams'] if stream['codec_type'] == 'video'), None)
105
+ if not video_stream:
106
+ logging.error(f"No video stream found in {video_path}")
107
+ return None
108
+
109
+ width = int(video_stream['width'])
110
+ height = int(video_stream['height'])
111
+
112
+ out, _ = (
113
+ ffmpeg
114
+ .input(video_path)
115
+ .filter('fps', fps=fps)
116
+ .output('pipe:', format='rawvideo', pix_fmt='rgb24')
117
+ .run(capture_stdout=True)
118
+ )
119
+ frames = []
120
+ for i in range(0, len(out), width * height * 3):
121
+ frame_data = out[i:i + width * height * 3]
122
+ frame = Image.frombytes('RGB', (width, height), frame_data)
123
+ frames.append(frame)
124
+ return frames
125
+
126
+ except ffmpeg.Error as e:
127
+ logging.error(f"Error extracting video frames: {e}")
128
+ return None