File size: 4,681 Bytes
05d99b1
 
4877588
05d99b1
4877588
05d99b1
4877588
 
 
 
 
 
 
 
 
 
 
05d99b1
4877588
 
 
05d99b1
 
 
 
 
 
 
 
5aff4e7
05d99b1
 
 
5aff4e7
05d99b1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4877588
05d99b1
5aff4e7
05d99b1
 
 
 
 
 
 
 
 
5aff4e7
 
4877588
 
 
 
 
 
 
 
 
 
 
05d99b1
 
 
4877588
 
 
05d99b1
4877588
 
05d99b1
 
 
5aff4e7
05d99b1
 
 
5aff4e7
4877588
05d99b1
 
 
4877588
 
05d99b1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4877588
05d99b1
4877588
05d99b1
 
4877588
 
 
 
5aff4e7
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
import logging

import gradio as gr
from rembg import new_session

from cutter import remove, make_label
from utils import *

remove_bg_models = {
    "U2NET": "u2net",
    "U2NET Human Seg": "u2net_human_seg",
    "U2NET Cloth Seg": "u2net_cloth_seg"
}

model_choices = keys(remove_bg_models)


def predict(image, session, smoot, matting, bg_color):

    session = new_session(remove_bg_models[session])

    try:
        return remove(session, image, smoot, matting, bg_color)
    except ValueError as err:
        logging.error(err)
        return make_label(str(err)), None


def change_show_mask(chk_state):
    return gr.Image(visible=chk_state)


def change_include_matting(chk_state):
    return gr.Group(visible=chk_state), (0, 0, 0), 0, 0, 0


def change_foreground_threshold(fg_value, value):
    fg, bg, erode = value
    return fg_value, bg, erode


def change_background_threshold(bg_value, value):
    fg, bg, erode = value
    return fg, bg_value, erode


def change_erode_size(erode_value, value):
    fg, bg, erode = value
    return fg, bg, erode_value


def set_dominant_color(chk_state):
    return chk_state, gr.ColorPicker(visible=not chk_state)


def change_picker_color(picker, dominant):
    if not dominant:
        return picker
    return dominant


def change_background_mode(chk_state):
    return gr.ColorPicker(visible=chk_state), \
        gr.Checkbox(value=False, visible=chk_state)


footer = r"""
<center>
<b>
Demo based on <a href='https://github.com/danielgatis/rembg'>Rembg</a>
</b>
</center>
"""

with gr.Blocks(title="Remove background") as app:
    color_state = gr.State(value=False)
    matting_state = gr.State(value=(0, 0, 0))

    gr.HTML("<center><h1>Remove Background Tool</h1></center>")
    with gr.Row(equal_height=False):
        with gr.Column():
            input_img = gr.Image(type="pil", label="Input image")
            drp_models = gr.Dropdown(choices=model_choices, label="Model Segment", value="U2NET")
            with gr.Row():
                chk_include_matting = gr.Checkbox(label="Matting", value=False)
                chk_smoot_mask = gr.Checkbox(label="Smoot Mask", value=False)
                chk_show_mask = gr.Checkbox(label="Show Mask", value=False)
            with gr.Group(visible=False) as slider_matting:
                slr_fg_threshold = gr.Slider(0, 300, value=270, step=1, label="Alpha matting foreground threshold")
                slr_bg_threshold = gr.Slider(0, 50, value=20, step=1, label="Alpha matting background threshold")
                slr_erode_size = gr.Slider(0, 20, value=11, step=1, label="Alpha matting erode size")
            with gr.Group():
                with gr.Row():
                    chk_change_color = gr.Checkbox(label="Change background color", value=False)
                    pkr_color = gr.ColorPicker(label="Pick a new color", visible=False)
                    chk_dominant = gr.Checkbox(label="Use dominant color", value=False, visible=False)
            run_btn = gr.Button(value="Remove background", variant="primary")
        with gr.Column():
            output_img = gr.Image(type="pil", label="Image Result")
            mask_img = gr.Image(type="pil", label="Image Mask", visible=False)
            gr.ClearButton(components=[input_img, output_img, mask_img])

    chk_include_matting.change(change_include_matting, inputs=[chk_include_matting],
                               outputs=[slider_matting, matting_state,
                                        slr_fg_threshold, slr_bg_threshold, slr_erode_size])

    slr_fg_threshold.change(change_foreground_threshold, inputs=[slr_fg_threshold, matting_state],
                            outputs=[matting_state])

    slr_bg_threshold.change(change_background_threshold, inputs=[slr_bg_threshold, matting_state],
                            outputs=[matting_state])

    slr_erode_size.change(change_erode_size, inputs=[slr_erode_size, matting_state],
                          outputs=[matting_state])

    chk_show_mask.change(change_show_mask, inputs=[chk_show_mask], outputs=[mask_img])

    chk_change_color.change(change_background_mode, inputs=[chk_change_color],
                            outputs=[pkr_color, chk_dominant])

    pkr_color.change(change_picker_color, inputs=[pkr_color, chk_dominant], outputs=[color_state])

    chk_dominant.change(set_dominant_color, inputs=[chk_dominant], outputs=[color_state, pkr_color])

    run_btn.click(predict, inputs=[input_img, drp_models, chk_smoot_mask, matting_state, color_state],
                  outputs=[output_img, mask_img])

    with gr.Row():
        gr.HTML(footer)

app.queue()
app.launch(share=False, debug=True, show_error=True)