from typing import List import gradio as gr import spaces import supervision as sv import torch from PIL import Image from utils.florence import load_florence_model, run_florence_inference, \ FLORENCE_OPEN_VOCABULARY_DETECTION_TASK from utils.sam import load_sam_image_model, run_sam_inference DEVICE = torch.device("cuda") # DEVICE = torch.device("cpu") torch.autocast(device_type="cuda", dtype=torch.bfloat16).__enter__() if torch.cuda.get_device_properties(0).major >= 8: torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cudnn.allow_tf32 = True FLORENCE_MODEL, FLORENCE_PROCESSOR = load_florence_model(device=DEVICE) SAM_IMAGE_MODEL = load_sam_image_model(device=DEVICE) @spaces.GPU @torch.inference_mode() @torch.autocast(device_type="cuda", dtype=torch.bfloat16) def process_image( image_input, text_input ) -> List[Image]: if not image_input: gr.Info("Please upload an image.") return [] if not text_input: gr.Info("Please enter a text prompt.") return [] texts = [prompt.strip() for prompt in text_input.split(",")] detections_list = [] for text in texts: _, result = run_florence_inference( model=FLORENCE_MODEL, processor=FLORENCE_PROCESSOR, device=DEVICE, image=image_input, task=FLORENCE_OPEN_VOCABULARY_DETECTION_TASK, text=text ) detections = sv.Detections.from_lmm( lmm=sv.LMM.FLORENCE_2, result=result, resolution_wh=image_input.size ) detections = run_sam_inference(SAM_IMAGE_MODEL, image_input, detections) detections_list.append(detections) detections = sv.Detections.merge(detections_list) detections = run_sam_inference(SAM_IMAGE_MODEL, image_input, detections) return [ Image.fromarray(mask.astype("uint8") * 255) for mask in detections.mask ] with gr.Blocks() as demo: with gr.Row(): with gr.Column(): image_input_component = gr.Image( type='pil', label='Upload image') text_input_component = gr.Textbox( label='Text prompt', placeholder='Enter comma separated text prompts') submit_button_component = gr.Button( value='Submit', variant='primary') with gr.Column(): gallery_output_component = gr.Gallery(label='Output masks') submit_button_component.click( fn=process_image, inputs=[ image_input_component, text_input_component ], outputs=[ gallery_output_component, ] ) text_input_component.submit( fn=process_image, inputs=[ image_input_component, text_input_component ], outputs=[ gallery_output_component, ] ) demo.launch(debug=False, show_error=True)