HimankJ's picture
Upload 14 files
e0899e6 verified
raw
history blame
3.13 kB
from ultralytics import YOLO
import gradio as gr
import torch
from utils.tools_gradio import fast_process
from utils.tools import format_results, box_prompt, point_prompt, text_prompt
from PIL import ImageDraw
import numpy as np
model = YOLO('./weights/FastSAM-x.pt')
device = torch.device(
"cuda" if torch.cuda.is_available()
else "mps" if torch.backends.mps.is_available()
else "cpu"
)
def segment_everything(
input,
input_size=1024,
withContours=True,
iou_threshold=0.7,
conf_threshold=0.25,
better_quality=False,
use_retina=True,
text="",
wider=False,
mask_random_color=True,
):
input_size = int(input_size)
w, h = input.size
scale = input_size / max(w, h)
new_w = int(w * scale)
new_h = int(h * scale)
input = input.resize((new_w, new_h))
results = model(input,
device=device,
retina_masks=True,
iou=iou_threshold,
conf=conf_threshold,
imgsz=input_size,)
if len(text) > 0:
results = format_results(results[0], 0)
annotations, _ = text_prompt(results, text, input, device=device, wider=wider)
annotations = np.array([annotations])
else:
annotations = results[0].masks.data
fig = fast_process(annotations=annotations,
image=input,
device=device,
scale=(1024 // input_size),
better_quality=better_quality,
mask_random_color=mask_random_color,
bbox=None,
use_retina=use_retina,
withContours=withContours,)
return fig
title = "<center><strong><font size='8'>πŸƒ Fast Segment Anything πŸ€—</font></strong></center>"
description = """ # 🎯 Instructions for points mode """
examples = [["examples/sa_8776.jpg"], ["examples/sa_414.jpg"], ["examples/sa_1309.jpg"], ["examples/sa_11025.jpg"],
["examples/sa_561.jpg"], ["examples/sa_192.jpg"], ["examples/sa_10039.jpg"], ["examples/sa_862.jpg"]]
default_example = examples[0]
cond_img_e = gr.Image(label="Input", value=default_example[0], type='pil')
segm_img_e = gr.Image(label="Segmented Image", interactive=False, type='pil')
input_size_slider = gr.components.Slider(minimum=512,maximum=1024,value=1024,step=64,label='Input_size',info='Our model was trained on a size of 1024')
css = "h1 { text-align: center } .about { text-align: justify; padding-left: 10%; padding-right: 10%; }"
demo = gr.Interface(
segment_everything,
inputs = [
gr.Image(label="Input", value=default_example[0], type='pil'),
gr.components.Slider(minimum=512,maximum=1024,value=1024,step=64,label='Input_size',info='Our model was trained on a size of 1024'),
gr.Checkbox(value=True, label='withContours', info='draw the edges of the masks')
],
outputs = [
gr.Image(label="Segmented Image", interactive=False, type='pil')
],
title = title,
description = description,
examples = examples,
)
demo.launch()