Spaces:
Running
on
Zero
Running
on
Zero
from typing import List | |
import gradio as gr | |
import spaces | |
import supervision as sv | |
import torch | |
from PIL import Image | |
from utils.florence import load_florence_model, run_florence_inference, \ | |
FLORENCE_OPEN_VOCABULARY_DETECTION_TASK | |
from utils.sam import load_sam_image_model, run_sam_inference | |
DEVICE = torch.device("cuda") | |
# DEVICE = torch.device("cpu") | |
torch.autocast(device_type="cuda", dtype=torch.bfloat16).__enter__() | |
if torch.cuda.get_device_properties(0).major >= 8: | |
torch.backends.cuda.matmul.allow_tf32 = True | |
torch.backends.cudnn.allow_tf32 = True | |
FLORENCE_MODEL, FLORENCE_PROCESSOR = load_florence_model(device=DEVICE) | |
SAM_IMAGE_MODEL = load_sam_image_model(device=DEVICE) | |
def process_image( | |
image_input, text_input | |
) -> List[Image]: | |
if not image_input: | |
gr.Info("Please upload an image.") | |
return [] | |
if not text_input: | |
gr.Info("Please enter a text prompt.") | |
return [] | |
texts = [prompt.strip() for prompt in text_input.split(",")] | |
detections_list = [] | |
for text in texts: | |
_, result = run_florence_inference( | |
model=FLORENCE_MODEL, | |
processor=FLORENCE_PROCESSOR, | |
device=DEVICE, | |
image=image_input, | |
task=FLORENCE_OPEN_VOCABULARY_DETECTION_TASK, | |
text=text | |
) | |
detections = sv.Detections.from_lmm( | |
lmm=sv.LMM.FLORENCE_2, | |
result=result, | |
resolution_wh=image_input.size | |
) | |
detections = run_sam_inference(SAM_IMAGE_MODEL, image_input, detections) | |
detections_list.append(detections) | |
detections = sv.Detections.merge(detections_list) | |
detections = run_sam_inference(SAM_IMAGE_MODEL, image_input, detections) | |
return [ | |
Image.fromarray(mask.astype("uint8") * 255) | |
for mask | |
in detections.mask | |
] | |
with gr.Blocks() as demo: | |
with gr.Row(): | |
with gr.Column(): | |
image_input_component = gr.Image( | |
type='pil', label='Upload image') | |
text_input_component = gr.Textbox( | |
label='Text prompt', | |
placeholder='Enter comma separated text prompts') | |
submit_button_component = gr.Button( | |
value='Submit', variant='primary') | |
with gr.Column(): | |
gallery_output_component = gr.Gallery(label='Output masks') | |
submit_button_component.click( | |
fn=process_image, | |
inputs=[ | |
image_input_component, | |
text_input_component | |
], | |
outputs=[ | |
gallery_output_component, | |
] | |
) | |
text_input_component.submit( | |
fn=process_image, | |
inputs=[ | |
image_input_component, | |
text_input_component | |
], | |
outputs=[ | |
gallery_output_component, | |
] | |
) | |
demo.launch(debug=False, show_error=True) | |