File size: 3,662 Bytes
e0899e6
 
 
 
 
627bd31
e0899e6
d6a309a
 
a1b30ce
e0899e6
ab5f7df
 
 
e0899e6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9a792e8
6d1bd90
e0899e6
 
 
 
 
 
 
 
9a792e8
 
 
 
 
 
 
019e8c5
 
9a792e8
627bd31
9a792e8
 
 
 
ab5f7df
9a792e8
 
6a4af22
9a792e8
 
 
a4d5dcd
 
9a792e8
6d1bd90
e0899e6
 
 
a00bec7
e0899e6
 
 
 
 
 
 
 
 
6a4af22
e0899e6
 
 
 
3ab1a99
6d1bd90
e0899e6
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
from ultralytics import YOLO
import gradio as gr
import torch
from utils.tools_gradio import fast_process
from utils.tools import format_results, box_prompt, point_prompt, text_prompt
from PIL import ImageDraw,Image
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.patches as patches
import io

import warnings
warnings.filterwarnings(action='ignore')

model = YOLO('./weights/FastSAM-x.pt')

device = torch.device(
    "cuda" if torch.cuda.is_available()
    else "mps" if torch.backends.mps.is_available()
    else "cpu"
)

def segment_everything(
    input,
    input_size=1024, 
    withContours=True,
    iou_threshold=0.7,
    conf_threshold=0.25,
    better_quality=False,
    use_retina=True,
    wider=False,
    mask_random_color=True,
):
    input_size = int(input_size)  
    w, h = input.size
    scale = input_size / max(w, h)
    new_w = int(w * scale)
    new_h = int(h * scale)
    input = input.resize((new_w, new_h))

    results = model(input,
                    device=device,
                    retina_masks=True,
                    iou=iou_threshold,
                    conf=conf_threshold,
                    imgsz=input_size,)

    
    annotations = results[0].masks.data
    segmented_img = fast_process(annotations=annotations,
                       image=input,
                       device=device,
                       scale=(1024 // input_size),
                       better_quality=better_quality,
                       mask_random_color=mask_random_color,
                       bbox=None,
                       use_retina=use_retina,
                       withContours=withContours,)

    bboxes = results[0].boxes.data
    areas = (bboxes[:, 2] - bboxes[:, 0]) * (bboxes[:, 3] - bboxes[:, 1])
    _, largest_indices = torch.topk(areas, 2)
    largest_boxes = bboxes[largest_indices]
    for i, box in enumerate(largest_boxes):
        print(f"Largest Box {i+1}: {box.tolist()}")
    print('-----------')
    
    fig, ax = plt.subplots(1)
    ax.imshow(input)
    for box in largest_boxes:
        x1, y1, x2, y2 = box[:4] 
        rect = patches.Rectangle((x1, y1), x2-x1, y2-y1, linewidth=2, edgecolor='r', facecolor='none')
        ax.add_patch(rect)

    ax.axis('off')
    buf = io.BytesIO()
    plt.savefig(buf, format='jpg', bbox_inches='tight', pad_inches=0)
    plt.close(fig)
    buf.seek(0)
    
    cropped_img = Image.open(buf).convert("RGBA")
    cropped_img = cropped_img.resize((1024, 682))
    
    return segmented_img, cropped_img

title = "<center><strong><font size='8'>πŸƒ Fast Segment Anything πŸ€—</font></strong></center>"
description = """ # 🎯 Instructions for points mode """
examples = [["examples/invoice3.jpeg"], ["examples/invoice2.jpeg"], ["examples/invoice1.jpeg"]]
default_example = examples[0]

input_size_slider = gr.components.Slider(minimum=512,maximum=1024,value=1024,step=64,label='Input_size',info='Our model was trained on a size of 1024')

css = "h1 { text-align: center } .about { text-align: justify; padding-left: 10%; padding-right: 10%; }"

demo = gr.Interface(
    segment_everything,
    inputs = [
        gr.Image(label="Input", type='pil'),
        gr.components.Slider(minimum=512,maximum=1024,value=1024,step=64,label='Input_size',info='Our model was trained on a size of 1024'),
        gr.Checkbox(value=True, label='withContours', info='draw the edges of the masks')
    ],
    outputs = [
        gr.Image(label="Segmented Image", interactive=False, type='pil'),
        gr.Image(label="Cropped Image", interactive=False, type='pil')
    ],
    title = title,
    description = description,
    examples = examples,
)
demo.launch()