muqtasid87 commited on
Commit
dd534ec
Β·
verified Β·
1 Parent(s): 9d3f1be

Update app_qwen.py

Browse files
Files changed (1) hide show
  1. app_qwen.py +146 -146
app_qwen.py CHANGED
@@ -1,147 +1,147 @@
1
- import streamlit as st
2
- from transformers import (
3
- Qwen2VLForConditionalGeneration,
4
- AutoProcessor
5
- )
6
- import torch
7
- from PIL import Image
8
- import time
9
- import os
10
-
11
-
12
-
13
- @st.cache_resource
14
- def load_model():
15
- """Load the model and processor (cached to prevent reloading)"""
16
- model = Qwen2VLForConditionalGeneration.from_pretrained(
17
- "Qwen/Qwen2-VL-2B-Instruct-GPTQ-Int4",
18
- torch_dtype=torch.bfloat16,
19
- device_map="auto"
20
- ).eval()
21
- processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-2B-Instruct-GPTQ-Int4")
22
- return model, processor
23
-
24
- def process_image(image, prompt, model, processor):
25
- """Process the image and return the model's output"""
26
- start_time = time.time()
27
-
28
- conversation = [
29
- {
30
- "role": "user",
31
- "content": [
32
- {"type": "image"},
33
- {"type": "text", "text": prompt},
34
- ],
35
- },
36
- ]
37
-
38
- text_prompt = processor.apply_chat_template(conversation, add_generation_prompt=True)
39
- inputs = processor(text=[text_prompt], images=[image], padding=True, return_tensors="pt").to("cuda")
40
-
41
- output_ids = model.generate(**inputs, max_new_tokens=100)
42
- generated_ids = [output_ids[len(input_ids):] for input_ids, output_ids in zip(inputs.input_ids, output_ids)]
43
- output_text = processor.batch_decode(generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True)
44
-
45
- inference_time = time.time() - start_time
46
- return output_text[0].strip(), inference_time
47
-
48
- def main():
49
- # Compact header
50
- st.markdown("<h1 style='font-size: 24px;'>πŸ” Image Analysis with Qwen2-VL</h1>", unsafe_allow_html=True)
51
-
52
- # Load model and processor
53
- with st.spinner("Loading model... This might take a minute."):
54
- model, processor = load_model()
55
-
56
- # Initialize session state
57
- if 'selected_image' not in st.session_state:
58
- st.session_state.selected_image = None
59
- if 'result' not in st.session_state:
60
- st.session_state.result = None
61
- if 'inference_time' not in st.session_state:
62
- st.session_state.inference_time = None
63
-
64
- # Main content area
65
- col1, col2, col3 = st.columns([1, 1.5, 1])
66
-
67
- with col1:
68
- # Input method selection
69
- input_option = st.radio("Choose input method:", ["Use example image", "Upload image"], label_visibility="collapsed")
70
-
71
- if input_option == "Upload image":
72
- uploaded_file = st.file_uploader("Upload Image", type=["jpg", "jpeg", "png"], label_visibility="collapsed")
73
- image_source = uploaded_file
74
- if uploaded_file:
75
- st.session_state.selected_image = uploaded_file
76
- else:
77
- image_source = st.session_state.selected_image
78
-
79
- # Default prompt and analysis section
80
- default_prompt = "What type of vehicle is this? Choose only from: car, pickup, bus, truck, motorbike, van. Answer only in one word."
81
- prompt = st.text_area("Enter prompt:", value=default_prompt, height=100)
82
-
83
- analyze_col1, analyze_col2 = st.columns([1, 2])
84
- with analyze_col1:
85
- analyze_button = st.button("Analyze Image", use_container_width=True, disabled=image_source is None)
86
-
87
- # Display selected image and results
88
- if image_source:
89
- try:
90
- if isinstance(image_source, str):
91
- image = Image.open(image_source).convert("RGB")
92
- else:
93
- image = Image.open(image_source).convert("RGB")
94
- st.image(image, caption="Selected Image", width=300)
95
- except Exception as e:
96
- st.error(f"Error loading image: {str(e)}")
97
-
98
- # Analysis results
99
- if analyze_button and image_source:
100
- with st.spinner("Analyzing..."):
101
- try:
102
- result, inference_time = process_image(image, prompt, model, processor)
103
- st.session_state.result = result
104
- st.session_state.inference_time = inference_time
105
- except Exception as e:
106
- st.error(f"Error: {str(e)}")
107
-
108
- if st.session_state.result:
109
- st.success("Analysis Complete!")
110
- st.markdown(f"**Result:**\n{st.session_state.result}")
111
- st.markdown(f"*Inference time: {st.session_state.inference_time:.2f} seconds*")
112
-
113
- # Example images section
114
- if input_option == "Use example image":
115
- st.markdown("### Example Images")
116
- example_images = [f for f in os.listdir("images") if f.lower().endswith(('.jpg', '.jpeg', '.png'))]
117
-
118
- if example_images:
119
- # Create grid of images
120
- cols = st.columns(4) # Adjust number of columns as needed
121
- for idx, img_name in enumerate(example_images):
122
- with cols[idx % 4]:
123
- img_path = os.path.join("images", img_name)
124
- img = Image.open(img_path)
125
- img.thumbnail((150, 150))
126
-
127
- # Make image clickable
128
- if st.button(
129
- "πŸ“·",
130
- key=f"img_{idx}",
131
- help=img_name,
132
- use_container_width=True
133
- ):
134
- st.session_state.selected_image = img_path
135
- st.rerun()
136
-
137
- # Display image with conditional styling
138
- st.image(
139
- img,
140
- caption=img_name,
141
- use_container_width=True,
142
- )
143
- else:
144
- st.error("No example images found in the 'images' directory")
145
-
146
- if __name__ == "__main__":
147
  main()
 
1
+ import streamlit as st
2
+ from transformers import (
3
+ Qwen2VLForConditionalGeneration,
4
+ AutoProcessor
5
+ )
6
+ import torch
7
+ from PIL import Image
8
+ import time
9
+ import os
10
+
11
+
12
+
13
+ @st.cache_resource
14
+ def load_model():
15
+ """Load the model and processor (cached to prevent reloading)"""
16
+ model = Qwen2VLForConditionalGeneration.from_pretrained(
17
+ "Qwen/Qwen2-VL-2B-Instruct-GPTQ-Int4",
18
+ torch_dtype=torch.bfloat16,
19
+ device_map="auto"
20
+ ).eval()
21
+ processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-2B-Instruct-GPTQ-Int4")
22
+ return model, processor
23
+
24
+ def process_image(image, prompt, model, processor):
25
+ """Process the image and return the model's output"""
26
+ start_time = time.time()
27
+
28
+ conversation = [
29
+ {
30
+ "role": "user",
31
+ "content": [
32
+ {"type": "image"},
33
+ {"type": "text", "text": prompt},
34
+ ],
35
+ },
36
+ ]
37
+
38
+ text_prompt = processor.apply_chat_template(conversation, add_generation_prompt=True)
39
+ inputs = processor(text=[text_prompt], images=[image], padding=True, return_tensors="pt").to("cpu")
40
+
41
+ output_ids = model.generate(**inputs, max_new_tokens=100)
42
+ generated_ids = [output_ids[len(input_ids):] for input_ids, output_ids in zip(inputs.input_ids, output_ids)]
43
+ output_text = processor.batch_decode(generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True)
44
+
45
+ inference_time = time.time() - start_time
46
+ return output_text[0].strip(), inference_time
47
+
48
+ def main():
49
+ # Compact header
50
+ st.markdown("<h1 style='font-size: 24px;'>πŸ” Image Analysis with Qwen2-VL</h1>", unsafe_allow_html=True)
51
+
52
+ # Load model and processor
53
+ with st.spinner("Loading model... This might take a minute."):
54
+ model, processor = load_model()
55
+
56
+ # Initialize session state
57
+ if 'selected_image' not in st.session_state:
58
+ st.session_state.selected_image = None
59
+ if 'result' not in st.session_state:
60
+ st.session_state.result = None
61
+ if 'inference_time' not in st.session_state:
62
+ st.session_state.inference_time = None
63
+
64
+ # Main content area
65
+ col1, col2, col3 = st.columns([1, 1.5, 1])
66
+
67
+ with col1:
68
+ # Input method selection
69
+ input_option = st.radio("Choose input method:", ["Use example image", "Upload image"], label_visibility="collapsed")
70
+
71
+ if input_option == "Upload image":
72
+ uploaded_file = st.file_uploader("Upload Image", type=["jpg", "jpeg", "png"], label_visibility="collapsed")
73
+ image_source = uploaded_file
74
+ if uploaded_file:
75
+ st.session_state.selected_image = uploaded_file
76
+ else:
77
+ image_source = st.session_state.selected_image
78
+
79
+ # Default prompt and analysis section
80
+ default_prompt = "What type of vehicle is this? Choose only from: car, pickup, bus, truck, motorbike, van. Answer only in one word."
81
+ prompt = st.text_area("Enter prompt:", value=default_prompt, height=100)
82
+
83
+ analyze_col1, analyze_col2 = st.columns([1, 2])
84
+ with analyze_col1:
85
+ analyze_button = st.button("Analyze Image", use_container_width=True, disabled=image_source is None)
86
+
87
+ # Display selected image and results
88
+ if image_source:
89
+ try:
90
+ if isinstance(image_source, str):
91
+ image = Image.open(image_source).convert("RGB")
92
+ else:
93
+ image = Image.open(image_source).convert("RGB")
94
+ st.image(image, caption="Selected Image", width=300)
95
+ except Exception as e:
96
+ st.error(f"Error loading image: {str(e)}")
97
+
98
+ # Analysis results
99
+ if analyze_button and image_source:
100
+ with st.spinner("Analyzing..."):
101
+ try:
102
+ result, inference_time = process_image(image, prompt, model, processor)
103
+ st.session_state.result = result
104
+ st.session_state.inference_time = inference_time
105
+ except Exception as e:
106
+ st.error(f"Error: {str(e)}")
107
+
108
+ if st.session_state.result:
109
+ st.success("Analysis Complete!")
110
+ st.markdown(f"**Result:**\n{st.session_state.result}")
111
+ st.markdown(f"*Inference time: {st.session_state.inference_time:.2f} seconds*")
112
+
113
+ # Example images section
114
+ if input_option == "Use example image":
115
+ st.markdown("### Example Images")
116
+ example_images = [f for f in os.listdir("images") if f.lower().endswith(('.jpg', '.jpeg', '.png'))]
117
+
118
+ if example_images:
119
+ # Create grid of images
120
+ cols = st.columns(4) # Adjust number of columns as needed
121
+ for idx, img_name in enumerate(example_images):
122
+ with cols[idx % 4]:
123
+ img_path = os.path.join("images", img_name)
124
+ img = Image.open(img_path)
125
+ img.thumbnail((150, 150))
126
+
127
+ # Make image clickable
128
+ if st.button(
129
+ "πŸ“·",
130
+ key=f"img_{idx}",
131
+ help=img_name,
132
+ use_container_width=True
133
+ ):
134
+ st.session_state.selected_image = img_path
135
+ st.rerun()
136
+
137
+ # Display image with conditional styling
138
+ st.image(
139
+ img,
140
+ caption=img_name,
141
+ use_container_width=True,
142
+ )
143
+ else:
144
+ st.error("No example images found in the 'images' directory")
145
+
146
+ if __name__ == "__main__":
147
  main()