import streamlit as st
from transformers import (
Qwen2VLForConditionalGeneration,
AutoProcessor
)
import torch
from PIL import Image
import time
import os
@st.cache_resource
def load_model():
"""Load the model and processor (cached to prevent reloading)"""
device = "cuda" if torch.cuda.is_available() else "cpu"
model = Qwen2VLForConditionalGeneration.from_pretrained(
"Qwen/Qwen2-VL-2B-Instruct-GPTQ-Int4",
torch_dtype=torch.bfloat16 if device == "cuda" else torch.float32,
device_map="auto" if device == "cuda" else None
).eval().to(device)
processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-2B-Instruct-GPTQ-Int4")
return model, processor, device
def process_image(image, prompt, model, processor, device):
"""Process the image and return the model's output"""
start_time = time.time()
conversation = [
{
"role": "user",
"content": [
{"type": "image"},
{"type": "text", "text": prompt},
],
},
]
text_prompt = processor.apply_chat_template(conversation, add_generation_prompt=True)
inputs = processor(text=[text_prompt], images=[image], padding=True, return_tensors="pt").to(device)
output_ids = model.generate(**inputs, max_new_tokens=100)
generated_ids = [output_ids[len(input_ids):] for input_ids, output_ids in zip(inputs.input_ids, output_ids)]
output_text = processor.batch_decode(generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True)
inference_time = time.time() - start_time
return output_text[0].strip(), inference_time
def main():
# Compact header
st.markdown("
🔍 Image Analysis with Qwen2-VL
", unsafe_allow_html=True)
# Load model and processor
with st.spinner("Loading model... This might take a minute."):
model, processor, device = load_model()
# Initialize session state
if 'selected_image' not in st.session_state:
st.session_state.selected_image = None
if 'result' not in st.session_state:
st.session_state.result = None
if 'inference_time' not in st.session_state:
st.session_state.inference_time = None
# Main content area
col1, col2, col3 = st.columns([1, 1.5, 1])
with col1:
# Input method selection
input_option = st.radio("Choose input method:", ["Use example image", "Upload image"], label_visibility="collapsed")
if input_option == "Upload image":
uploaded_file = st.file_uploader("Upload Image", type=["jpg", "jpeg", "png"], label_visibility="collapsed")
image_source = uploaded_file
if uploaded_file:
st.session_state.selected_image = uploaded_file
else:
image_source = st.session_state.selected_image
# Default prompt and analysis section
default_prompt = "What type of vehicle is this? Choose only from: car, pickup, bus, truck, motorbike, van. Answer only in one word."
prompt = st.text_area("Enter prompt:", value=default_prompt, height=100)
analyze_col1, analyze_col2 = st.columns([1, 2])
with analyze_col1:
analyze_button = st.button("Analyze Image", use_container_width=True, disabled=image_source is None)
# Display selected image and results
if image_source:
try:
if isinstance(image_source, str):
image = Image.open(image_source).convert("RGB")
else:
image = Image.open(image_source).convert("RGB")
st.image(image, caption="Selected Image", width=300)
except Exception as e:
st.error(f"Error loading image: {str(e)}")
# Analysis results
if analyze_button and image_source:
with st.spinner("Analyzing..."):
try:
result, inference_time = process_image(image, prompt, model, processor, device)
st.session_state.result = result
st.session_state.inference_time = inference_time
except Exception as e:
st.error(f"Error: {str(e)}")
if st.session_state.result:
st.success("Analysis Complete!")
st.markdown(f"**Result:**\n{st.session_state.result}")
st.markdown(f"*Inference time: {st.session_state.inference_time:.2f} seconds*")
# Example images section
if input_option == "Use example image":
st.markdown("### Example Images")
example_images = [f for f in os.listdir("images") if f.lower().endswith(('.jpg', '.jpeg', '.png'))]
if example_images:
# Create grid of images
cols = st.columns(4) # Adjust number of columns as needed
for idx, img_name in enumerate(example_images):
with cols[idx % 4]:
img_path = os.path.join("images", img_name)
img = Image.open(img_path)
img.thumbnail((150, 150))
# Make image clickable
if st.button(
"📷",
key=f"img_{idx}",
help=img_name,
use_container_width=True
):
st.session_state.selected_image = img_path
st.rerun()
# Display image with conditional styling
st.image(
img,
caption=img_name,
use_container_width=True,
)
else:
st.error("No example images found in the 'images' directory")
if __name__ == "__main__":
main()