leonelhs commited on
Commit
05d99b1
·
1 Parent(s): 4877588

update interface

Browse files
Files changed (4) hide show
  1. .gitignore +1 -1
  2. app.py +87 -49
  3. cutter.py +98 -0
  4. requirements.txt +2 -1
.gitignore CHANGED
@@ -1,3 +1,3 @@
1
  .idea/
2
  __pycache__/
3
- playground.py
 
1
  .idea/
2
  __pycache__/
3
+ playground.py
app.py CHANGED
@@ -1,8 +1,9 @@
1
- import PIL.Image
 
2
  import gradio as gr
3
- from PIL import ImageColor
4
- from rembg import new_session, remove
5
 
 
6
  from utils import *
7
 
8
  remove_bg_models = {
@@ -14,40 +15,53 @@ remove_bg_models = {
14
  model_choices = keys(remove_bg_models)
15
 
16
 
17
- def alpha_matting(state):
18
- if state:
19
- return 270, 20, 11
20
- return 0, 0, 0
21
-
22
 
23
- def predict(image,
24
- session,
25
- matting,
26
- only_mask,
27
- post_process_mask,
28
- foreground_threshold,
29
- background_threshold,
30
- matting_erode_size,
31
- new_bg_color,
32
- bg_color,
33
- transparency):
34
  session = new_session(remove_bg_models[session])
35
 
36
- if new_bg_color:
37
- r, g, b, _ = ImageColor.getcolor(bg_color, "RGBA")
38
- bg_color = r, g, b, transparency
39
- else:
40
- bg_color = None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
 
42
- return remove(data=image,
43
- session=session,
44
- alpha_matting=matting,
45
- only_mask=only_mask,
46
- post_process_mask=post_process_mask,
47
- bgcolor=bg_color,
48
- alpha_matting_foreground_threshold=foreground_threshold,
49
- alpha_matting_background_threshold=background_threshold,
50
- alpha_matting_erode_size=matting_erode_size)
 
 
 
 
51
 
52
 
53
  footer = r"""
@@ -59,33 +73,57 @@ Demo based on <a href='https://github.com/danielgatis/rembg'>Rembg</a>
59
  """
60
 
61
  with gr.Blocks(title="Remove background") as app:
 
 
 
62
  gr.HTML("<center><h1>Remove Background Tool</h1></center>")
63
  with gr.Row(equal_height=False):
64
  with gr.Column():
65
- input_img = gr.Image(type="numpy", label="Input image")
66
  drp_models = gr.Dropdown(choices=model_choices, label="Model Segment", value="U2NET")
67
  with gr.Row():
68
- chk_alm = gr.Checkbox(label="Alpha Matting", value=False)
69
- chk_psm = gr.Checkbox(label="Post process mask", value=False)
70
- chk_msk = gr.Checkbox(label="Only Mask", value=False)
71
- sld_aft = gr.Slider(0, 300, value=0, step=1, label="Alpha matting foreground threshold")
72
- sld_amb = gr.Slider(0, 50, value=0, step=1, label="Alpha matting background threshold")
73
- sld_aes = gr.Slider(0, 20, value=0, step=1, label="Alpha matting erode size")
 
74
  with gr.Box():
75
  with gr.Row():
76
- chk_col = gr.Checkbox(label="Change background color", value=False)
77
- color = gr.ColorPicker(label="Pick a new color")
78
- trans = gr.Number(label="Transparency level", value=255, precision=0, minimum=0, maximum=255)
79
  run_btn = gr.Button(value="Remove background", variant="primary")
80
  with gr.Column():
81
- output_img = gr.Image(type="pil", label="result")
82
- gr.ClearButton(components=[input_img, output_img])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
83
 
84
- chk_alm.change(alpha_matting, inputs=[chk_alm], outputs=[sld_aft, sld_amb, sld_aes])
85
 
86
- run_btn.click(predict, [input_img, drp_models,
87
- chk_alm, chk_msk, chk_psm,
88
- sld_aft, sld_amb, sld_aes, chk_col, color, trans], [output_img])
89
 
90
  with gr.Row():
91
  gr.HTML(footer)
 
1
+ import logging
2
+
3
  import gradio as gr
4
+ from rembg import new_session
 
5
 
6
+ from cutter import remove, make_label
7
  from utils import *
8
 
9
  remove_bg_models = {
 
15
  model_choices = keys(remove_bg_models)
16
 
17
 
18
+ def predict(image, session, smoot, matting, bg_color):
 
 
 
 
19
 
 
 
 
 
 
 
 
 
 
 
 
20
  session = new_session(remove_bg_models[session])
21
 
22
+ try:
23
+ return remove(session, image, smoot, matting, bg_color)
24
+ except ValueError as err:
25
+ logging.error(err)
26
+ return make_label(str(err)), None
27
+
28
+
29
+ def change_show_mask(chk_state):
30
+ return gr.Image.update(visible=chk_state)
31
+
32
+
33
+ def change_include_matting(chk_state):
34
+ return gr.Box.update(visible=chk_state), (0, 0, 0), 0, 0, 0
35
+
36
+
37
+ def change_foreground_threshold(fg_value, value):
38
+ fg, bg, erode = value
39
+ return fg_value, bg, erode
40
+
41
+
42
+ def change_background_threshold(bg_value, value):
43
+ fg, bg, erode = value
44
+ return fg, bg_value, erode
45
+
46
+
47
+ def change_erode_size(erode_value, value):
48
+ fg, bg, erode = value
49
+ return fg, bg, erode_value
50
+
51
 
52
+ def set_dominant_color(chk_state):
53
+ return chk_state, gr.ColorPicker.update(value=False, visible=not chk_state)
54
+
55
+
56
+ def change_picker_color(picker, dominant):
57
+ if not dominant:
58
+ return picker
59
+ return dominant
60
+
61
+
62
+ def change_background_mode(chk_state):
63
+ return gr.ColorPicker.update(value=False, visible=chk_state), \
64
+ gr.Checkbox.update(value=False, visible=chk_state)
65
 
66
 
67
  footer = r"""
 
73
  """
74
 
75
  with gr.Blocks(title="Remove background") as app:
76
+ color_state = gr.State(value=False)
77
+ matting_state = gr.State(value=(0, 0, 0))
78
+
79
  gr.HTML("<center><h1>Remove Background Tool</h1></center>")
80
  with gr.Row(equal_height=False):
81
  with gr.Column():
82
+ input_img = gr.Image(type="pil", label="Input image")
83
  drp_models = gr.Dropdown(choices=model_choices, label="Model Segment", value="U2NET")
84
  with gr.Row():
85
+ chk_include_matting = gr.Checkbox(label="Matting", value=False)
86
+ chk_smoot_mask = gr.Checkbox(label="Smoot Mask", value=False)
87
+ chk_show_mask = gr.Checkbox(label="Show Mask", value=False)
88
+ with gr.Box(visible=False) as slider_matting:
89
+ slr_fg_threshold = gr.Slider(0, 300, value=270, step=1, label="Alpha matting foreground threshold")
90
+ slr_bg_threshold = gr.Slider(0, 50, value=20, step=1, label="Alpha matting background threshold")
91
+ slr_erode_size = gr.Slider(0, 20, value=11, step=1, label="Alpha matting erode size")
92
  with gr.Box():
93
  with gr.Row():
94
+ chk_change_color = gr.Checkbox(label="Change background color", value=False)
95
+ pkr_color = gr.ColorPicker(label="Pick a new color", visible=False)
96
+ chk_dominant = gr.Checkbox(label="Use dominant color", value=False, visible=False)
97
  run_btn = gr.Button(value="Remove background", variant="primary")
98
  with gr.Column():
99
+ output_img = gr.Image(type="pil", label="Image Result")
100
+ mask_img = gr.Image(type="pil", label="Image Mask", visible=False)
101
+ gr.ClearButton(components=[input_img, output_img, mask_img])
102
+
103
+ chk_include_matting.change(change_include_matting, inputs=[chk_include_matting],
104
+ outputs=[slider_matting, matting_state,
105
+ slr_fg_threshold, slr_bg_threshold, slr_erode_size])
106
+
107
+ slr_fg_threshold.change(change_foreground_threshold, inputs=[slr_fg_threshold, matting_state],
108
+ outputs=[matting_state])
109
+
110
+ slr_bg_threshold.change(change_background_threshold, inputs=[slr_bg_threshold, matting_state],
111
+ outputs=[matting_state])
112
+
113
+ slr_erode_size.change(change_erode_size, inputs=[slr_erode_size, matting_state],
114
+ outputs=[matting_state])
115
+
116
+ chk_show_mask.change(change_show_mask, inputs=[chk_show_mask], outputs=[mask_img])
117
+
118
+ chk_change_color.change(change_background_mode, inputs=[chk_change_color],
119
+ outputs=[pkr_color, chk_dominant])
120
+
121
+ pkr_color.change(change_picker_color, inputs=[pkr_color, chk_dominant], outputs=[color_state])
122
 
123
+ chk_dominant.change(set_dominant_color, inputs=[chk_dominant], outputs=[color_state, pkr_color])
124
 
125
+ run_btn.click(predict, inputs=[input_img, drp_models, chk_smoot_mask, matting_state, color_state],
126
+ outputs=[output_img, mask_img])
 
127
 
128
  with gr.Row():
129
  gr.HTML(footer)
cutter.py ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import PIL
2
+ import numpy as np
3
+ from PIL import Image, ImageColor, ImageDraw
4
+ from PIL.Image import Image as PILImage
5
+ from pymatting.alpha.estimate_alpha_cf import estimate_alpha_cf
6
+ from pymatting.foreground.estimate_foreground_ml import estimate_foreground_ml
7
+ from pymatting.util.util import stack_images
8
+ from rembg.bg import post_process, naive_cutout, apply_background_color
9
+ from scipy.ndimage import binary_erosion
10
+
11
+
12
+ def alpha_matting_cutout(img: PILImage, trimap: np.ndarray) -> PILImage:
13
+ if img.mode == "RGBA" or img.mode == "CMYK":
14
+ img = img.convert("RGB")
15
+
16
+ img = np.asarray(img)
17
+
18
+ img_normalized = img / 255.0
19
+ trimap_normalized = trimap / 255.0
20
+
21
+ alpha = estimate_alpha_cf(img_normalized, trimap_normalized)
22
+ foreground = estimate_foreground_ml(img_normalized, alpha)
23
+ cutout = stack_images(foreground, alpha)
24
+
25
+ cutout = np.clip(cutout * 255, 0, 255).astype(np.uint8)
26
+ return Image.fromarray(cutout)
27
+
28
+
29
+ def generate_trimap(
30
+ mask: PILImage,
31
+ foreground_threshold: int,
32
+ background_threshold: int,
33
+ erode_structure_size: int,
34
+ ) -> np.ndarray:
35
+ mask = np.asarray(mask)
36
+
37
+ is_foreground = mask > foreground_threshold
38
+ is_background = mask < background_threshold
39
+
40
+ structure = None
41
+ if erode_structure_size > 0:
42
+ structure = np.ones(
43
+ (erode_structure_size, erode_structure_size), dtype=np.uint8
44
+ )
45
+
46
+ is_foreground = binary_erosion(is_foreground, structure=structure)
47
+ is_background = binary_erosion(is_background, structure=structure, border_value=1)
48
+
49
+ trimap = np.full(mask.shape, dtype=np.uint8, fill_value=128)
50
+ trimap[is_foreground] = 255
51
+ trimap[is_background] = 0
52
+
53
+ return trimap
54
+
55
+
56
+ def get_background_dominant_color(img: PILImage, mask: PILImage) -> tuple:
57
+ negative_img = img.copy()
58
+ negative_mask = PIL.ImageOps.invert(mask)
59
+ negative_img.putalpha(negative_mask)
60
+ negative_img = negative_img.resize((1, 1))
61
+ r, g, b, a = negative_img.getpixel((0, 0))
62
+ return r, g, b, 255
63
+
64
+
65
+ def remove(session, img: PILImage, smoot: bool, matting: tuple, color) -> (PILImage, PILImage):
66
+ mask = session.predict(img)[0]
67
+
68
+ if smoot:
69
+ mask = PIL.Image.fromarray(post_process(np.array(mask)))
70
+
71
+ fg_t, bg_t, erode = matting
72
+
73
+ if fg_t > 0 or bg_t > 0 or erode > 0:
74
+ mask = generate_trimap(mask, *matting)
75
+ try:
76
+ cutout = alpha_matting_cutout(img, mask)
77
+ mask = PIL.Image.fromarray(mask)
78
+ except ValueError as err:
79
+ raise err
80
+ else:
81
+ cutout = naive_cutout(img, mask)
82
+
83
+ if color is True:
84
+ color = get_background_dominant_color(img, mask)
85
+ cutout = apply_background_color(cutout, color)
86
+ elif isinstance(color, str):
87
+ r, g, b = ImageColor.getcolor(color, "RGB")
88
+ cutout = apply_background_color(cutout, (r, g, b, 255))
89
+
90
+ return cutout, mask
91
+
92
+
93
+ def make_label(text, width=600, height=200, color="black") -> PILImage:
94
+ image = Image.new("RGB", (width, height), color)
95
+ draw = ImageDraw.Draw(image)
96
+ text_width, text_height = draw.textsize(text)
97
+ draw.text(((width-text_width)/2, height/2), text)
98
+ return image
requirements.txt CHANGED
@@ -1,3 +1,4 @@
1
  rembg~=2.0.47
2
  pillow~=9.5.0
3
- opencv-python-headless
 
 
1
  rembg~=2.0.47
2
  pillow~=9.5.0
3
+ pymatting
4
+ opencv-python-headless