SkalskiP's picture
masking API
b32b0a3
raw
history blame
2.97 kB
from typing import List
import gradio as gr
import spaces
import supervision as sv
import torch
from PIL import Image
from utils.florence import load_florence_model, run_florence_inference, \
FLORENCE_OPEN_VOCABULARY_DETECTION_TASK
from utils.sam import load_sam_image_model, run_sam_inference
DEVICE = torch.device("cuda")
# DEVICE = torch.device("cpu")
torch.autocast(device_type="cuda", dtype=torch.bfloat16).__enter__()
if torch.cuda.get_device_properties(0).major >= 8:
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
FLORENCE_MODEL, FLORENCE_PROCESSOR = load_florence_model(device=DEVICE)
SAM_IMAGE_MODEL = load_sam_image_model(device=DEVICE)
@spaces.GPU
@torch.inference_mode()
@torch.autocast(device_type="cuda", dtype=torch.bfloat16)
def process_image(
image_input, text_input
) -> List[Image]:
if not image_input:
gr.Info("Please upload an image.")
return []
if not text_input:
gr.Info("Please enter a text prompt.")
return []
texts = [prompt.strip() for prompt in text_input.split(",")]
detections_list = []
for text in texts:
_, result = run_florence_inference(
model=FLORENCE_MODEL,
processor=FLORENCE_PROCESSOR,
device=DEVICE,
image=image_input,
task=FLORENCE_OPEN_VOCABULARY_DETECTION_TASK,
text=text
)
detections = sv.Detections.from_lmm(
lmm=sv.LMM.FLORENCE_2,
result=result,
resolution_wh=image_input.size
)
detections = run_sam_inference(SAM_IMAGE_MODEL, image_input, detections)
detections_list.append(detections)
detections = sv.Detections.merge(detections_list)
detections = run_sam_inference(SAM_IMAGE_MODEL, image_input, detections)
return [
Image.fromarray(mask.astype("uint8") * 255)
for mask
in detections.mask
]
with gr.Blocks() as demo:
with gr.Row():
with gr.Column():
image_input_component = gr.Image(
type='pil', label='Upload image')
text_input_component = gr.Textbox(
label='Text prompt',
placeholder='Enter comma separated text prompts')
submit_button_component = gr.Button(
value='Submit', variant='primary')
with gr.Column():
gallery_output_component = gr.Gallery(label='Output masks')
submit_button_component.click(
fn=process_image,
inputs=[
image_input_component,
text_input_component
],
outputs=[
gallery_output_component,
]
)
text_input_component.submit(
fn=process_image,
inputs=[
image_input_component,
text_input_component
],
outputs=[
gallery_output_component,
]
)
demo.launch(debug=False, show_error=True)