File size: 3,129 Bytes
e0899e6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
from ultralytics import YOLO
import gradio as gr
import torch
from utils.tools_gradio import fast_process
from utils.tools import format_results, box_prompt, point_prompt, text_prompt
from PIL import ImageDraw
import numpy as np

model = YOLO('./weights/FastSAM-x.pt')

device = torch.device(
    "cuda" if torch.cuda.is_available()
    else "mps" if torch.backends.mps.is_available()
    else "cpu"
)

def segment_everything(
    input,
    input_size=1024, 
    withContours=True,
    iou_threshold=0.7,
    conf_threshold=0.25,
    better_quality=False,
    use_retina=True,
    text="",
    wider=False,
    mask_random_color=True,
):
    input_size = int(input_size)  
    w, h = input.size
    scale = input_size / max(w, h)
    new_w = int(w * scale)
    new_h = int(h * scale)
    input = input.resize((new_w, new_h))

    results = model(input,
                    device=device,
                    retina_masks=True,
                    iou=iou_threshold,
                    conf=conf_threshold,
                    imgsz=input_size,)

    if len(text) > 0:
        results = format_results(results[0], 0)
        annotations, _ = text_prompt(results, text, input, device=device, wider=wider)
        annotations = np.array([annotations])
    else:
        annotations = results[0].masks.data
    
    fig = fast_process(annotations=annotations,
                       image=input,
                       device=device,
                       scale=(1024 // input_size),
                       better_quality=better_quality,
                       mask_random_color=mask_random_color,
                       bbox=None,
                       use_retina=use_retina,
                       withContours=withContours,)
    return fig

title = "<center><strong><font size='8'>πŸƒ Fast Segment Anything πŸ€—</font></strong></center>"
description = """ # 🎯 Instructions for points mode """
examples = [["examples/sa_8776.jpg"], ["examples/sa_414.jpg"], ["examples/sa_1309.jpg"], ["examples/sa_11025.jpg"],
            ["examples/sa_561.jpg"], ["examples/sa_192.jpg"], ["examples/sa_10039.jpg"], ["examples/sa_862.jpg"]]
default_example = examples[0]

cond_img_e = gr.Image(label="Input", value=default_example[0], type='pil')
segm_img_e = gr.Image(label="Segmented Image", interactive=False, type='pil')

input_size_slider = gr.components.Slider(minimum=512,maximum=1024,value=1024,step=64,label='Input_size',info='Our model was trained on a size of 1024')

css = "h1 { text-align: center } .about { text-align: justify; padding-left: 10%; padding-right: 10%; }"

demo = gr.Interface(
    segment_everything,
    inputs = [
        gr.Image(label="Input", value=default_example[0], type='pil'),
        gr.components.Slider(minimum=512,maximum=1024,value=1024,step=64,label='Input_size',info='Our model was trained on a size of 1024'),
        gr.Checkbox(value=True, label='withContours', info='draw the edges of the masks')
    ],
    outputs = [
        gr.Image(label="Segmented Image", interactive=False, type='pil')
    ],
    title = title,
    description = description,
    examples = examples,
)
demo.launch()