import streamlit as st from transformers import ( Qwen2VLForConditionalGeneration, AutoModelForCausalLM, AutoProcessor ) import torch from PIL import Image import time import os import matplotlib.pyplot as plt import matplotlib.patches as patches import io import numpy as np @st.cache_resource def load_models(): """Load both models and processors""" # Load Qwen model qwen_model = Qwen2VLForConditionalGeneration.from_pretrained( "Qwen/Qwen2-VL-2B-Instruct-GPTQ-Int4", torch_dtype=torch.bfloat16, device_map="auto" ).eval() qwen_processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-2B-Instruct-GPTQ-Int4") # Load Florence model device = "cuda:0" if torch.cuda.is_available() else "cpu" torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32 florence_model = AutoModelForCausalLM.from_pretrained( "microsoft/Florence-2-large-ft", torch_dtype=torch_dtype, trust_remote_code=True ).to(device) florence_processor = AutoProcessor.from_pretrained( "microsoft/Florence-2-large-ft", trust_remote_code=True ) return qwen_model, qwen_processor, florence_model, florence_processor, device, torch_dtype def process_qwen(image, prompt, model, processor): """Process image with Qwen2-VL""" start_time = time.time() conversation = [ { "role": "user", "content": [ {"type": "image"}, {"type": "text", "text": prompt}, ], }, ] text_prompt = processor.apply_chat_template(conversation, add_generation_prompt=True) inputs = processor(text=[text_prompt], images=[image], padding=True, return_tensors="pt").to("cuda") output_ids = model.generate(**inputs, max_new_tokens=100) generated_ids = [output_ids[len(input_ids):] for input_ids, output_ids in zip(inputs.input_ids, output_ids)] output_text = processor.batch_decode(generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True) inference_time = time.time() - start_time return output_text[0].strip(), inference_time def draw_bounding_boxes(image, bboxes, labels): """Draw bounding boxes and labels on the image""" img_array = np.array(image) fig, ax = plt.subplots() ax.imshow(img_array) for bbox, label in zip(bboxes, labels): x, y, x2, y2 = bbox width = x2 - x height = y2 - y rect = patches.Rectangle( (x, y), width, height, linewidth=2, edgecolor='red', facecolor='none' ) ax.add_patch(rect) plt.text( x, y-5, label, color='red', fontsize=12, bbox=dict(facecolor='white', alpha=0.8, edgecolor='none', pad=0) ) plt.axis('off') buf = io.BytesIO() plt.savefig(buf, format='png', bbox_inches='tight', pad_inches=0) plt.close() buf.seek(0) return Image.open(buf) def process_florence(image, text_input, model, processor, device, torch_dtype): """Process image with Florence-2""" start_time = time.time() task_prompt = "" prompt = task_prompt + text_input if text_input else task_prompt inputs = processor( text=prompt, images=image, return_tensors="pt" ).to(device, torch_dtype) generated_ids = model.generate( input_ids=inputs["input_ids"], pixel_values=inputs["pixel_values"], max_new_tokens=2048, num_beams=3 ) generated_text = processor.batch_decode(generated_ids, skip_special_tokens=False)[0] parsed_answer = processor.post_process_generation( generated_text, task=task_prompt, image_size=(image.width, image.height) ) inference_time = time.time() - start_time result = parsed_answer[task_prompt] annotated_image = draw_bounding_boxes( image, result['bboxes'], result['labels'] ) return result, inference_time, annotated_image def main(): st.markdown("

🚗 Vehicle Analysis Pipeline

", unsafe_allow_html=True) # Load models with st.spinner("Loading models... This might take a minute."): qwen_model, qwen_processor, florence_model, florence_processor, device, torch_dtype = load_models() # Initialize session state if 'selected_image' not in st.session_state: st.session_state.selected_image = None if 'qwen_result' not in st.session_state: st.session_state.qwen_result = None if 'florence_result' not in st.session_state: st.session_state.florence_result = None if 'annotated_image' not in st.session_state: st.session_state.annotated_image = None # Image selection col1, col2 = st.columns([1, 2]) with col1: input_option = st.radio("Choose input method:", ["Use example image", "Upload image"], label_visibility="collapsed") if input_option == "Upload image": uploaded_file = st.file_uploader("Upload Image", type=["jpg", "jpeg", "png"], label_visibility="collapsed") image_source = uploaded_file if uploaded_file: st.session_state.selected_image = uploaded_file else: image_source = st.session_state.selected_image # Default prompt for Qwen default_prompt = "What type of vehicle is this? Choose only from: car, pickup, bus, truck, motorbike, van. Answer only in one word." prompt = st.text_area("Enter prompt for classification:", value=default_prompt, height=100) analyze_button = st.button("Analyze Image", use_container_width=True, disabled=image_source is None) # Display and process if image_source: try: if isinstance(image_source, str): image = Image.open(image_source).convert("RGB") else: image = Image.open(image_source).convert("RGB") with col2: st.image(image, caption="Selected Image", width=300) if analyze_button: # Step 1: Qwen Analysis with st.spinner("Step 1: Classifying vehicle type..."): qwen_result, qwen_time = process_qwen(image, prompt, qwen_model, qwen_processor) st.session_state.qwen_result = qwen_result # Step 2: Florence Analysis with st.spinner("Step 2: Detecting vehicle location..."): florence_result, florence_time, annotated_image = process_florence( image, f"Find the {qwen_result} in the image", florence_model, florence_processor, device, torch_dtype ) st.session_state.florence_result = florence_result st.session_state.annotated_image = annotated_image # Display results st.markdown("### Analysis Results") # Qwen results st.markdown("#### Step 1: Vehicle Classification") st.markdown(f"**Type:** {st.session_state.qwen_result}") st.markdown(f"*Classification time: {qwen_time:.2f} seconds*") # Florence results st.markdown("#### Step 2: Vehicle Detection") st.image(annotated_image, caption="Vehicle Detection Result", use_container_width=True) st.markdown(f"*Detection time: {florence_time:.2f} seconds*") st.markdown("**Raw Detection Data:**") st.json(florence_result) except Exception as e: st.error(f"Error processing image: {str(e)}") # Example images section if input_option == "Use example image": st.markdown("### Example Images") example_images = [f for f in os.listdir("images") if f.lower().endswith(('.jpg', '.jpeg', '.png'))] if example_images: cols = st.columns(4) for idx, img_name in enumerate(example_images): with cols[idx % 4]: img_path = os.path.join("images", img_name) img = Image.open(img_path) img.thumbnail((150, 150)) if st.button("📷", key=f"img_{idx}", help=img_name, use_container_width=True): st.session_state.selected_image = img_path st.rerun() st.image(img, caption=img_name, use_container_width=True) else: st.error("No example images found in the 'images' directory") if __name__ == "__main__": main()