yunyangx commited on
Commit
bc658ea
·
1 Parent(s): 99469b5

make some updates

Browse files
Files changed (1) hide show
  1. app.py +424 -61
app.py CHANGED
@@ -1,78 +1,441 @@
 
 
 
1
  import gradio as gr
2
  import numpy as np
3
  import torch
 
4
  from torchvision.transforms import ToTensor
5
- from PIL import Image
6
-
7
- # loading EfficientSAM model
8
- model_path = "efficientsam_s_cpu.jit"
9
- with open(model_path, "rb") as f:
10
- model = torch.jit.load(f)
11
-
12
- # getting mask using points
13
- def get_sam_mask_using_points(img_tensor, pts_sampled, model):
14
- pts_sampled = torch.reshape(torch.tensor(pts_sampled), [1, 1, -1, 2])
15
- max_num_pts = pts_sampled.shape[2]
16
- pts_labels = torch.ones(1, 1, max_num_pts)
17
-
18
- predicted_logits, predicted_iou = model(
19
- img_tensor[None, ...],
20
- pts_sampled,
21
- pts_labels,
22
- )
23
- predicted_logits = predicted_logits.cpu()
24
- all_masks = torch.ge(torch.sigmoid(predicted_logits[0, 0, :, :, :]), 0.5).numpy()
25
- predicted_iou = predicted_iou[0, 0, ...].cpu().detach().numpy()
26
-
27
- max_predicted_iou = -1
28
- selected_mask_using_predicted_iou = None
29
- for m in range(all_masks.shape[0]):
30
- curr_predicted_iou = predicted_iou[m]
31
- if (
32
- curr_predicted_iou > max_predicted_iou
33
- or selected_mask_using_predicted_iou is None
34
- ):
35
- max_predicted_iou = curr_predicted_iou
36
- selected_mask_using_predicted_iou = all_masks[m]
37
- return selected_mask_using_predicted_iou
 
 
 
 
38
 
39
  # examples
40
- examples = [["examples/image1.jpg"], ["examples/image2.jpg"], ["examples/image3.jpg"], ["examples/image4.jpg"],
41
- ["examples/image5.jpg"], ["examples/image6.jpg"], ["examples/image7.jpg"], ["examples/image8.jpg"],
42
- ["examples/image9.jpg"], ["examples/image10.jpg"], ["examples/image11.jpg"], ["examples/image12.jpg"],
43
- ["examples/image13.jpg"], ["examples/image14.jpg"]]
 
 
 
 
 
 
 
 
 
 
 
 
44
 
 
45
 
46
- with gr.Blocks() as demo:
47
- with gr.Row():
48
- input_img = gr.Image(label="Input",height=512)
49
- output_img = gr.Image(label="Selected Segment",height=512)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
  with gr.Row():
52
- gr.Markdown("Try some of the examples below ⬇️")
53
- gr.Examples(examples=examples,
54
- inputs=[input_img])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
55
 
56
- def get_select_coords(img, evt: gr.SelectData):
57
- img_tensor = ToTensor()(img)
58
- _, H, W = img_tensor.shape
59
 
60
- visited_pixels = set()
61
- pixels_in_queue = set()
62
- pixels_in_segment = set()
 
 
63
 
64
- mask = get_sam_mask_using_points(img_tensor, [[evt.index[0], evt.index[1]]], model)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
65
 
66
- out = img.copy()
 
67
 
68
- out = out.astype(np.uint8)
69
- out *= mask[:,:,None]
70
- for pixel in pixels_in_segment:
71
- out[pixel[0], pixel[1]] = img[pixel[0], pixel[1]]
72
- print(out)
73
- return out
74
 
75
- input_img.select(get_select_coords, [input_img], output_img)
 
76
 
77
- if __name__ == "__main__":
78
- demo.launch()
 
1
+ import copy
2
+ import os # noqa
3
+
4
  import gradio as gr
5
  import numpy as np
6
  import torch
7
+ from PIL import ImageDraw
8
  from torchvision.transforms import ToTensor
9
+
10
+ from utils.tools import format_results, point_prompt
11
+ from utils.tools_gradio import fast_process
12
+
13
+ # Most of our demo code is from [FastSAM Demo](https://huggingface.co/spaces/An-619/FastSAM). Thanks for AN-619.
14
+
15
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
16
+
17
+ gpu_checkpoint_path = "efficientsam_s_gpu.jit"
18
+ cpu_checkpoint_path = "efficientsam_s_cpu.jit"
19
+
20
+ if torch.cuda.is_available():
21
+ model = torch.jit.load(gpu_checkpoint_path)
22
+ else:
23
+ model = torch.jit.load(cpu_checkpoint_path)
24
+ model.eval()
25
+
26
+ # Description
27
+ title = "<center><strong><font size='8'>Efficient Segment Anything(EfficientSAM)<font></strong></center>"
28
+
29
+ description_e = """This is a demo of [Efficient Segment Anything(EfficientSAM) Model](https://github.com/yformer/EfficientSAM).
30
+ """
31
+
32
+ description_p = """# Interactive Instance Segmentation
33
+ - Point-prompt instruction
34
+ <ol>
35
+ <li> Click on the left image (point input), visualizing the point on the right image </li>
36
+ <li> Click the button of Segment with Point Prompt </li>
37
+ </ol>
38
+ - Box-prompt instruction
39
+ <ol>
40
+ <li> Click on the left image (one point input), visualizing the point on the right image </li>
41
+ <li> Click on the left image (another point input), visualizing the point and the box on the right image</li>
42
+ <li> Click the button of Segment with Box Prompt </li>
43
+ </ol>
44
+ - Github [link](https://github.com/yformer/EfficientSAM)
45
+ """
46
 
47
  # examples
48
+ examples = [
49
+ ["examples/image1.jpg"],
50
+ ["examples/image2.jpg"],
51
+ ["examples/image3.jpg"],
52
+ ["examples/image4.jpg"],
53
+ ["examples/image5.jpg"],
54
+ ["examples/image6.jpg"],
55
+ ["examples/image7.jpg"],
56
+ ["examples/image8.jpg"],
57
+ ["examples/image9.jpg"],
58
+ ["examples/image10.jpg"],
59
+ ["examples/image11.jpg"],
60
+ ["examples/image12.jpg"],
61
+ ["examples/image13.jpg"],
62
+ ["examples/image14.jpg"],
63
+ ]
64
 
65
+ default_example = examples[0]
66
 
67
+ css = "h1 { text-align: center } .about { text-align: justify; padding-left: 10%; padding-right: 10%; }"
68
+
69
+
70
+ def segment_with_boxs(
71
+ image,
72
+ seg_image,
73
+ input_size=1024,
74
+ better_quality=False,
75
+ withContours=True,
76
+ use_retina=True,
77
+ mask_random_color=True,
78
+ ):
79
+ try:
80
+ global global_points
81
+ global global_point_label
82
+ if len(global_points) < 2:
83
+ return seg_image
84
+ print("Original Image : ", image.size)
85
+
86
+ input_size = int(input_size)
87
+ w, h = image.size
88
+ scale = input_size / max(w, h)
89
+ new_w = int(w * scale)
90
+ new_h = int(h * scale)
91
+ image = image.resize((new_w, new_h))
92
+
93
+ print("Scaled Image : ", image.size)
94
+ print("Scale : ", scale)
95
+
96
+ scaled_points = np.array(
97
+ [[int(x * scale) for x in point] for point in global_points]
98
+ )
99
+ scaled_points = scaled_points[:2]
100
+ scaled_point_label = np.array(global_point_label)[:2]
101
+
102
+ print(scaled_points, scaled_points is not None)
103
+ print(scaled_point_label, scaled_point_label is not None)
104
+
105
+ if scaled_points.size == 0 and scaled_point_label.size == 0:
106
+ print("No points selected")
107
+ return image
108
+
109
+ nd_image = np.array(image)
110
+ img_tensor = ToTensor()(nd_image)
111
+
112
+ print(img_tensor.shape)
113
+ pts_sampled = torch.reshape(torch.tensor(scaled_points), [1, 1, -1, 2])
114
+ pts_sampled = pts_sampled[:, :, :2, :]
115
+ pts_labels = torch.reshape(torch.tensor([2, 3]), [1, 1, 2])
116
+
117
+ predicted_logits, predicted_iou = model(
118
+ img_tensor[None, ...].to(device),
119
+ pts_sampled.to(device),
120
+ pts_labels.to(device),
121
+ )
122
+ predicted_logits = predicted_logits.cpu()
123
+ all_masks = torch.ge(
124
+ torch.sigmoid(predicted_logits[0, 0, :, :, :]), 0.5
125
+ ).numpy()
126
+ predicted_iou = predicted_iou[0, 0, ...].cpu().detach().numpy()
127
+
128
+ max_predicted_iou = -1
129
+ selected_mask_using_predicted_iou = None
130
+ selected_predicted_iou = None
131
+
132
+ for m in range(all_masks.shape[0]):
133
+ curr_predicted_iou = predicted_iou[m]
134
+ if (
135
+ curr_predicted_iou > max_predicted_iou
136
+ or selected_mask_using_predicted_iou is None
137
+ ):
138
+ max_predicted_iou = curr_predicted_iou
139
+ selected_mask_using_predicted_iou = all_masks[m : m + 1]
140
+ selected_predicted_iou = predicted_iou[m : m + 1]
141
+
142
+ results = format_results(
143
+ selected_mask_using_predicted_iou,
144
+ selected_predicted_iou,
145
+ predicted_logits,
146
+ 0,
147
+ )
148
+
149
+ annotations = results[0]["segmentation"]
150
+ annotations = np.array([annotations])
151
+ print(scaled_points.shape)
152
+ fig = fast_process(
153
+ annotations=annotations,
154
+ image=image,
155
+ device=device,
156
+ scale=(1024 // input_size),
157
+ better_quality=better_quality,
158
+ mask_random_color=mask_random_color,
159
+ use_retina=use_retina,
160
+ bbox=scaled_points.reshape([4]),
161
+ withContours=withContours,
162
+ )
163
+
164
+ global_points = []
165
+ global_point_label = []
166
+ # return fig, None
167
+ return fig
168
+ except:
169
+ return image
170
+
171
+
172
+ def segment_with_points(
173
+ image,
174
+ input_size=1024,
175
+ better_quality=False,
176
+ withContours=True,
177
+ use_retina=True,
178
+ mask_random_color=True,
179
+ ):
180
+ try:
181
+ global global_points
182
+ global global_point_label
183
+
184
+ print("Original Image : ", image.size)
185
+
186
+ input_size = int(input_size)
187
+ w, h = image.size
188
+ scale = input_size / max(w, h)
189
+ new_w = int(w * scale)
190
+ new_h = int(h * scale)
191
+ image = image.resize((new_w, new_h))
192
+
193
+ print("Scaled Image : ", image.size)
194
+ print("Scale : ", scale)
195
+
196
+ if global_points is None:
197
+ return image
198
+ if len(global_points) < 1:
199
+ return image
200
+ scaled_points = np.array(
201
+ [[int(x * scale) for x in point] for point in global_points]
202
+ )
203
+ scaled_point_label = np.array(global_point_label)
204
+
205
+ print(scaled_points, scaled_points is not None)
206
+ print(scaled_point_label, scaled_point_label is not None)
207
+
208
+ if scaled_points.size == 0 and scaled_point_label.size == 0:
209
+ print("No points selected")
210
+ return image
211
+
212
+ nd_image = np.array(image)
213
+ img_tensor = ToTensor()(nd_image)
214
+
215
+ print(img_tensor.shape)
216
+ pts_sampled = torch.reshape(torch.tensor(scaled_points), [1, 1, -1, 2])
217
+ pts_labels = torch.reshape(torch.tensor(global_point_label), [1, 1, -1])
218
+
219
+ predicted_logits, predicted_iou = model(
220
+ img_tensor[None, ...].to(device),
221
+ pts_sampled.to(device),
222
+ pts_labels.to(device),
223
+ )
224
+ predicted_logits = predicted_logits.cpu()
225
+ all_masks = torch.ge(
226
+ torch.sigmoid(predicted_logits[0, 0, :, :, :]), 0.5
227
+ ).numpy()
228
+ predicted_iou = predicted_iou[0, 0, ...].cpu().detach().numpy()
229
+
230
+ results = format_results(all_masks, predicted_iou, predicted_logits, 0)
231
 
232
+ annotations, _ = point_prompt(
233
+ results, scaled_points, scaled_point_label, new_h, new_w
234
+ )
235
+ annotations = np.array([annotations])
236
+
237
+ fig = fast_process(
238
+ annotations=annotations,
239
+ image=image,
240
+ device=device,
241
+ scale=(1024 // input_size),
242
+ better_quality=better_quality,
243
+ mask_random_color=mask_random_color,
244
+ points=scaled_points,
245
+ bbox=None,
246
+ use_retina=use_retina,
247
+ withContours=withContours,
248
+ )
249
+
250
+ global_points = []
251
+ global_point_label = []
252
+ # return fig, None
253
+ return fig
254
+ except:
255
+ return image
256
+
257
+
258
+ def get_points_with_draw(image, cond_image, evt: gr.SelectData):
259
+ global global_points
260
+ global global_point_label
261
+ if len(global_points) == 0:
262
+ image = copy.deepcopy(cond_image)
263
+ x, y = evt.index[0], evt.index[1]
264
+ label = "Add Mask"
265
+ point_radius, point_color = 15, (255, 255, 0) if label == "Add Mask" else (
266
+ 255,
267
+ 0,
268
+ 255,
269
+ )
270
+ global_points.append([x, y])
271
+ global_point_label.append(1 if label == "Add Mask" else 0)
272
+
273
+ print(x, y, label == "Add Mask")
274
+
275
+ if image is not None:
276
+ draw = ImageDraw.Draw(image)
277
+
278
+ draw.ellipse(
279
+ [
280
+ (x - point_radius, y - point_radius),
281
+ (x + point_radius, y + point_radius),
282
+ ],
283
+ fill=point_color,
284
+ )
285
+
286
+ return image
287
+
288
+
289
+ def get_points_with_draw_(image, cond_image, evt: gr.SelectData):
290
+ global global_points
291
+ global global_point_label
292
+ if len(global_points) == 0:
293
+ image = copy.deepcopy(cond_image)
294
+ if len(global_points) > 2:
295
+ return image
296
+ x, y = evt.index[0], evt.index[1]
297
+ label = "Add Mask"
298
+ point_radius, point_color = 15, (255, 255, 0) if label == "Add Mask" else (
299
+ 255,
300
+ 0,
301
+ 255,
302
+ )
303
+ global_points.append([x, y])
304
+ global_point_label.append(1 if label == "Add Mask" else 0)
305
+
306
+ print(x, y, label == "Add Mask")
307
+
308
+ if image is not None:
309
+ draw = ImageDraw.Draw(image)
310
+
311
+ draw.ellipse(
312
+ [
313
+ (x - point_radius, y - point_radius),
314
+ (x + point_radius, y + point_radius),
315
+ ],
316
+ fill=point_color,
317
+ )
318
+
319
+ if len(global_points) == 2:
320
+ x1, y1 = global_points[0]
321
+ x2, y2 = global_points[1]
322
+ if x1 < x2:
323
+ draw.rectangle([x1, y1, x2, y2], outline="red", width=5)
324
+ else:
325
+ draw.rectangle([x2, y2, x1, y1], outline="red", width=5)
326
+ global_points = global_points[::-1]
327
+ global_point_label = global_point_label[::-1]
328
+
329
+ return image
330
+
331
+
332
+ cond_img_p = gr.Image(label="Input with Point", value=default_example[0], type="pil")
333
+ cond_img_b = gr.Image(label="Input with Box", value=default_example[0], type="pil")
334
+
335
+ segm_img_p = gr.Image(
336
+ label="Segmented Image with Point-Prompt", interactive=False, type="pil"
337
+ )
338
+ segm_img_b = gr.Image(
339
+ label="Segmented Image with Box-Prompt", interactive=False, type="pil"
340
+ )
341
+
342
+ global_points = []
343
+ global_point_label = []
344
+
345
+ input_size_slider = gr.components.Slider(
346
+ minimum=512,
347
+ maximum=1024,
348
+ value=1024,
349
+ step=64,
350
+ label="Input_size",
351
+ info="Our model was trained on a size of 1024",
352
+ )
353
+
354
+ with gr.Blocks(css=css, title="Efficient SAM") as demo:
355
  with gr.Row():
356
+ with gr.Column(scale=1):
357
+ # Title
358
+ gr.Markdown(title)
359
+
360
+ with gr.Tab("Point mode"):
361
+ # Images
362
+ with gr.Row(variant="panel"):
363
+ with gr.Column(scale=1):
364
+ cond_img_p.render()
365
+
366
+ with gr.Column(scale=1):
367
+ segm_img_p.render()
368
+
369
+ # Submit & Clear
370
+ # ###
371
+ with gr.Row():
372
+ with gr.Column():
373
+
374
+ with gr.Column():
375
+ segment_btn_p = gr.Button(
376
+ "Segment with Point Prompt", variant="primary"
377
+ )
378
+ clear_btn_p = gr.Button("Clear", variant="secondary")
379
+
380
+ gr.Markdown("Try some of the examples below ⬇️")
381
+ gr.Examples(
382
+ examples=examples,
383
+ inputs=[cond_img_p],
384
+ examples_per_page=4,
385
+ )
386
 
387
+ with gr.Column():
388
+ # Description
389
+ gr.Markdown(description_p)
390
 
391
+ with gr.Tab("Box mode"):
392
+ # Images
393
+ with gr.Row(variant="panel"):
394
+ with gr.Column(scale=1):
395
+ cond_img_b.render()
396
 
397
+ with gr.Column(scale=1):
398
+ segm_img_b.render()
399
+
400
+ # Submit & Clear
401
+ with gr.Row():
402
+ with gr.Column():
403
+
404
+ with gr.Column():
405
+ segment_btn_b = gr.Button(
406
+ "Segment with Box Prompt", variant="primary"
407
+ )
408
+ clear_btn_b = gr.Button("Clear", variant="secondary")
409
+
410
+ gr.Markdown("Try some of the examples below ⬇️")
411
+ gr.Examples(
412
+ examples=examples,
413
+ inputs=[cond_img_b],
414
+ examples_per_page=4,
415
+ )
416
+
417
+ with gr.Column():
418
+ # Description
419
+ gr.Markdown(description_p)
420
+
421
+ cond_img_p.select(get_points_with_draw, [segm_img_p, cond_img_p], segm_img_p)
422
+
423
+ cond_img_b.select(get_points_with_draw_, [segm_img_b, cond_img_b], segm_img_b)
424
+
425
+ segment_btn_p.click(segment_with_points, inputs=[cond_img_p], outputs=segm_img_p)
426
+
427
+ segment_btn_b.click(
428
+ segment_with_boxs, inputs=[cond_img_b, segm_img_b], outputs=segm_img_b
429
+ )
430
 
431
+ def clear():
432
+ return None, None
433
 
434
+ def clear_text():
435
+ return None, None, None
 
 
 
 
436
 
437
+ clear_btn_p.click(clear, outputs=[cond_img_p, segm_img_p])
438
+ clear_btn_b.click(clear, outputs=[cond_img_b, segm_img_b])
439
 
440
+ demo.queue()
441
+ demo.launch(share=True)