Spaces:
Running
Running
muqtasid87
commited on
Update app_qwen.py
Browse files- app_qwen.py +12 -10
app_qwen.py
CHANGED
@@ -9,19 +9,20 @@ import time
|
|
9 |
import os
|
10 |
|
11 |
|
12 |
-
|
13 |
@st.cache_resource
|
14 |
def load_model():
|
15 |
"""Load the model and processor (cached to prevent reloading)"""
|
|
|
16 |
model = Qwen2VLForConditionalGeneration.from_pretrained(
|
17 |
"Qwen/Qwen2-VL-2B-Instruct-GPTQ-Int4",
|
18 |
-
torch_dtype=torch.bfloat16,
|
19 |
-
device_map="auto"
|
20 |
-
).eval()
|
21 |
processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-2B-Instruct-GPTQ-Int4")
|
22 |
-
return model, processor
|
|
|
23 |
|
24 |
-
def process_image(image, prompt, model, processor):
|
25 |
"""Process the image and return the model's output"""
|
26 |
start_time = time.time()
|
27 |
|
@@ -36,7 +37,7 @@ def process_image(image, prompt, model, processor):
|
|
36 |
]
|
37 |
|
38 |
text_prompt = processor.apply_chat_template(conversation, add_generation_prompt=True)
|
39 |
-
inputs = processor(text=[text_prompt], images=[image], padding=True, return_tensors="pt").to(
|
40 |
|
41 |
output_ids = model.generate(**inputs, max_new_tokens=100)
|
42 |
generated_ids = [output_ids[len(input_ids):] for input_ids, output_ids in zip(inputs.input_ids, output_ids)]
|
@@ -45,13 +46,14 @@ def process_image(image, prompt, model, processor):
|
|
45 |
inference_time = time.time() - start_time
|
46 |
return output_text[0].strip(), inference_time
|
47 |
|
|
|
48 |
def main():
|
49 |
# Compact header
|
50 |
st.markdown("<h1 style='font-size: 24px;'>🔍 Image Analysis with Qwen2-VL</h1>", unsafe_allow_html=True)
|
51 |
|
52 |
# Load model and processor
|
53 |
with st.spinner("Loading model... This might take a minute."):
|
54 |
-
model, processor = load_model()
|
55 |
|
56 |
# Initialize session state
|
57 |
if 'selected_image' not in st.session_state:
|
@@ -99,7 +101,7 @@ def main():
|
|
99 |
if analyze_button and image_source:
|
100 |
with st.spinner("Analyzing..."):
|
101 |
try:
|
102 |
-
result, inference_time = process_image(image, prompt, model, processor)
|
103 |
st.session_state.result = result
|
104 |
st.session_state.inference_time = inference_time
|
105 |
except Exception as e:
|
@@ -144,4 +146,4 @@ def main():
|
|
144 |
st.error("No example images found in the 'images' directory")
|
145 |
|
146 |
if __name__ == "__main__":
|
147 |
-
main()
|
|
|
9 |
import os
|
10 |
|
11 |
|
|
|
12 |
@st.cache_resource
|
13 |
def load_model():
|
14 |
"""Load the model and processor (cached to prevent reloading)"""
|
15 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
16 |
model = Qwen2VLForConditionalGeneration.from_pretrained(
|
17 |
"Qwen/Qwen2-VL-2B-Instruct-GPTQ-Int4",
|
18 |
+
torch_dtype=torch.bfloat16 if device == "cuda" else torch.float32,
|
19 |
+
device_map="auto" if device == "cuda" else None
|
20 |
+
).eval().to(device)
|
21 |
processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-2B-Instruct-GPTQ-Int4")
|
22 |
+
return model, processor, device
|
23 |
+
|
24 |
|
25 |
+
def process_image(image, prompt, model, processor, device):
|
26 |
"""Process the image and return the model's output"""
|
27 |
start_time = time.time()
|
28 |
|
|
|
37 |
]
|
38 |
|
39 |
text_prompt = processor.apply_chat_template(conversation, add_generation_prompt=True)
|
40 |
+
inputs = processor(text=[text_prompt], images=[image], padding=True, return_tensors="pt").to(device)
|
41 |
|
42 |
output_ids = model.generate(**inputs, max_new_tokens=100)
|
43 |
generated_ids = [output_ids[len(input_ids):] for input_ids, output_ids in zip(inputs.input_ids, output_ids)]
|
|
|
46 |
inference_time = time.time() - start_time
|
47 |
return output_text[0].strip(), inference_time
|
48 |
|
49 |
+
|
50 |
def main():
|
51 |
# Compact header
|
52 |
st.markdown("<h1 style='font-size: 24px;'>🔍 Image Analysis with Qwen2-VL</h1>", unsafe_allow_html=True)
|
53 |
|
54 |
# Load model and processor
|
55 |
with st.spinner("Loading model... This might take a minute."):
|
56 |
+
model, processor, device = load_model()
|
57 |
|
58 |
# Initialize session state
|
59 |
if 'selected_image' not in st.session_state:
|
|
|
101 |
if analyze_button and image_source:
|
102 |
with st.spinner("Analyzing..."):
|
103 |
try:
|
104 |
+
result, inference_time = process_image(image, prompt, model, processor, device)
|
105 |
st.session_state.result = result
|
106 |
st.session_state.inference_time = inference_time
|
107 |
except Exception as e:
|
|
|
146 |
st.error("No example images found in the 'images' directory")
|
147 |
|
148 |
if __name__ == "__main__":
|
149 |
+
main()
|