muqtasid87 commited on
Commit
fad0fe2
·
verified ·
1 Parent(s): dd534ec

Update app_qwen.py

Browse files
Files changed (1) hide show
  1. app_qwen.py +12 -10
app_qwen.py CHANGED
@@ -9,19 +9,20 @@ import time
9
  import os
10
 
11
 
12
-
13
  @st.cache_resource
14
  def load_model():
15
  """Load the model and processor (cached to prevent reloading)"""
 
16
  model = Qwen2VLForConditionalGeneration.from_pretrained(
17
  "Qwen/Qwen2-VL-2B-Instruct-GPTQ-Int4",
18
- torch_dtype=torch.bfloat16,
19
- device_map="auto"
20
- ).eval()
21
  processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-2B-Instruct-GPTQ-Int4")
22
- return model, processor
 
23
 
24
- def process_image(image, prompt, model, processor):
25
  """Process the image and return the model's output"""
26
  start_time = time.time()
27
 
@@ -36,7 +37,7 @@ def process_image(image, prompt, model, processor):
36
  ]
37
 
38
  text_prompt = processor.apply_chat_template(conversation, add_generation_prompt=True)
39
- inputs = processor(text=[text_prompt], images=[image], padding=True, return_tensors="pt").to("cpu")
40
 
41
  output_ids = model.generate(**inputs, max_new_tokens=100)
42
  generated_ids = [output_ids[len(input_ids):] for input_ids, output_ids in zip(inputs.input_ids, output_ids)]
@@ -45,13 +46,14 @@ def process_image(image, prompt, model, processor):
45
  inference_time = time.time() - start_time
46
  return output_text[0].strip(), inference_time
47
 
 
48
  def main():
49
  # Compact header
50
  st.markdown("<h1 style='font-size: 24px;'>🔍 Image Analysis with Qwen2-VL</h1>", unsafe_allow_html=True)
51
 
52
  # Load model and processor
53
  with st.spinner("Loading model... This might take a minute."):
54
- model, processor = load_model()
55
 
56
  # Initialize session state
57
  if 'selected_image' not in st.session_state:
@@ -99,7 +101,7 @@ def main():
99
  if analyze_button and image_source:
100
  with st.spinner("Analyzing..."):
101
  try:
102
- result, inference_time = process_image(image, prompt, model, processor)
103
  st.session_state.result = result
104
  st.session_state.inference_time = inference_time
105
  except Exception as e:
@@ -144,4 +146,4 @@ def main():
144
  st.error("No example images found in the 'images' directory")
145
 
146
  if __name__ == "__main__":
147
- main()
 
9
  import os
10
 
11
 
 
12
  @st.cache_resource
13
  def load_model():
14
  """Load the model and processor (cached to prevent reloading)"""
15
+ device = "cuda" if torch.cuda.is_available() else "cpu"
16
  model = Qwen2VLForConditionalGeneration.from_pretrained(
17
  "Qwen/Qwen2-VL-2B-Instruct-GPTQ-Int4",
18
+ torch_dtype=torch.bfloat16 if device == "cuda" else torch.float32,
19
+ device_map="auto" if device == "cuda" else None
20
+ ).eval().to(device)
21
  processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-2B-Instruct-GPTQ-Int4")
22
+ return model, processor, device
23
+
24
 
25
+ def process_image(image, prompt, model, processor, device):
26
  """Process the image and return the model's output"""
27
  start_time = time.time()
28
 
 
37
  ]
38
 
39
  text_prompt = processor.apply_chat_template(conversation, add_generation_prompt=True)
40
+ inputs = processor(text=[text_prompt], images=[image], padding=True, return_tensors="pt").to(device)
41
 
42
  output_ids = model.generate(**inputs, max_new_tokens=100)
43
  generated_ids = [output_ids[len(input_ids):] for input_ids, output_ids in zip(inputs.input_ids, output_ids)]
 
46
  inference_time = time.time() - start_time
47
  return output_text[0].strip(), inference_time
48
 
49
+
50
  def main():
51
  # Compact header
52
  st.markdown("<h1 style='font-size: 24px;'>🔍 Image Analysis with Qwen2-VL</h1>", unsafe_allow_html=True)
53
 
54
  # Load model and processor
55
  with st.spinner("Loading model... This might take a minute."):
56
+ model, processor, device = load_model()
57
 
58
  # Initialize session state
59
  if 'selected_image' not in st.session_state:
 
101
  if analyze_button and image_source:
102
  with st.spinner("Analyzing..."):
103
  try:
104
+ result, inference_time = process_image(image, prompt, model, processor, device)
105
  st.session_state.result = result
106
  st.session_state.inference_time = inference_time
107
  except Exception as e:
 
146
  st.error("No example images found in the 'images' directory")
147
 
148
  if __name__ == "__main__":
149
+ main()