chongzhou commited on
Commit
b88e069
·
1 Parent(s): 3937e24

add result display panels for binary mask and cropped image

Browse files
Files changed (1) hide show
  1. app.py +43 -21
app.py CHANGED
@@ -9,6 +9,7 @@ from PIL import ImageDraw
9
  from utils.tools_gradio import fast_process
10
  import copy
11
  import argparse
 
12
 
13
  # Use ONNX to speed up the inference.
14
  ENABLE_ONNX = False
@@ -108,7 +109,7 @@ def reset(session_state):
108
  session_state['ori_image'] = None
109
  session_state['image_with_prompt'] = None
110
  session_state['feature'] = None
111
- return None, session_state
112
 
113
 
114
  def reset_all(session_state):
@@ -118,7 +119,7 @@ def reset_all(session_state):
118
  session_state['ori_image'] = None
119
  session_state['image_with_prompt'] = None
120
  session_state['feature'] = None
121
- return None, None, session_state
122
 
123
 
124
  def clear(session_state):
@@ -126,7 +127,7 @@ def clear(session_state):
126
  session_state['label_list'] = []
127
  session_state['box_list'] = []
128
  session_state['image_with_prompt'] = copy.deepcopy(session_state['ori_image'])
129
- return session_state['ori_image'], session_state
130
 
131
 
132
  def on_image_upload(
@@ -150,7 +151,7 @@ def on_image_upload(
150
  nd_image = np.array(image)
151
  session_state['feature'] = predictor.set_image(nd_image)
152
 
153
- return image, session_state
154
 
155
 
156
  def convert_box(xyxy):
@@ -229,7 +230,11 @@ def segment_with_points(
229
  withContours=withContours,
230
  )
231
 
232
- return seg, session_state
 
 
 
 
233
 
234
 
235
  def segment_with_box(
@@ -300,13 +305,23 @@ def segment_with_box(
300
  use_retina=use_retina,
301
  withContours=withContours,
302
  )
303
- return seg, session_state
304
- return image, session_state
 
 
 
 
305
 
306
 
307
  img_p = gr.Image(label="Input with points", type="pil")
308
  img_b = gr.Image(label="Input with box", type="pil")
309
 
 
 
 
 
 
 
310
  with gr.Blocks(css=css, title="EdgeSAM") as demo:
311
  session_state = gr.State({
312
  'coord_list': [],
@@ -339,7 +354,8 @@ with gr.Blocks(css=css, title="EdgeSAM") as demo:
339
  clear_btn_p = gr.Button("Clear", variant="secondary")
340
  reset_btn_p = gr.Button("Reset", variant="secondary")
341
  with gr.Row():
342
- gr.Markdown(description_p)
 
343
 
344
  with gr.Row():
345
  with gr.Column():
@@ -347,11 +363,13 @@ with gr.Blocks(css=css, title="EdgeSAM") as demo:
347
  gr.Examples(
348
  examples=examples,
349
  inputs=[img_p, session_state],
350
- outputs=[img_p, session_state],
351
  examples_per_page=8,
352
  fn=on_image_upload,
353
  run_on_click=True
354
  )
 
 
355
 
356
  with gr.Tab("Box mode") as tab_b:
357
  # Images
@@ -362,7 +380,9 @@ with gr.Blocks(css=css, title="EdgeSAM") as demo:
362
  with gr.Column():
363
  clear_btn_b = gr.Button("Clear", variant="secondary")
364
  reset_btn_b = gr.Button("Reset", variant="secondary")
365
- gr.Markdown(description_b)
 
 
366
 
367
  with gr.Row():
368
  with gr.Column():
@@ -370,30 +390,32 @@ with gr.Blocks(css=css, title="EdgeSAM") as demo:
370
  gr.Examples(
371
  examples=examples,
372
  inputs=[img_b, session_state],
373
- outputs=[img_b, session_state],
374
  examples_per_page=8,
375
  fn=on_image_upload,
376
  run_on_click=True
377
  )
 
 
378
 
379
  with gr.Row():
380
  with gr.Column(scale=1):
381
  gr.Markdown(
382
  "<center><img src='https://visitor-badge.laobi.icu/badge?page_id=chongzhou/edgesam' alt='visitors'></center>")
383
 
384
- img_p.upload(on_image_upload, [img_p, session_state], [img_p, session_state])
385
- img_p.select(segment_with_points, [add_or_remove, session_state], [img_p, session_state])
386
 
387
- clear_btn_p.click(clear, [session_state], [img_p, session_state])
388
- reset_btn_p.click(reset, [session_state], [img_p, session_state])
389
- tab_p.select(fn=reset_all, inputs=[session_state], outputs=[img_p, img_b, session_state])
390
 
391
- img_b.upload(on_image_upload, [img_b, session_state], [img_b, session_state])
392
- img_b.select(segment_with_box, [session_state], [img_b, session_state])
393
 
394
- clear_btn_b.click(clear, [session_state], [img_b, session_state])
395
- reset_btn_b.click(reset, [session_state], [img_b, session_state])
396
- tab_b.select(fn=reset_all, inputs=[session_state], outputs=[img_p, img_b, session_state])
397
 
398
  demo.queue()
399
  # demo.launch(server_name=args.server_name, server_port=args.port)
 
9
  from utils.tools_gradio import fast_process
10
  import copy
11
  import argparse
12
+ from PIL import Image
13
 
14
  # Use ONNX to speed up the inference.
15
  ENABLE_ONNX = False
 
109
  session_state['ori_image'] = None
110
  session_state['image_with_prompt'] = None
111
  session_state['feature'] = None
112
+ return None, None, None, session_state
113
 
114
 
115
  def reset_all(session_state):
 
119
  session_state['ori_image'] = None
120
  session_state['image_with_prompt'] = None
121
  session_state['feature'] = None
122
+ return None, None, None, None, None, None, session_state
123
 
124
 
125
  def clear(session_state):
 
127
  session_state['label_list'] = []
128
  session_state['box_list'] = []
129
  session_state['image_with_prompt'] = copy.deepcopy(session_state['ori_image'])
130
+ return session_state['ori_image'], None, None, session_state
131
 
132
 
133
  def on_image_upload(
 
151
  nd_image = np.array(image)
152
  session_state['feature'] = predictor.set_image(nd_image)
153
 
154
+ return image, None, None, session_state
155
 
156
 
157
  def convert_box(xyxy):
 
230
  withContours=withContours,
231
  )
232
 
233
+ binary_mask = np.where(annotations[0] > 0.5, 255, 0).astype(np.uint8)
234
+ mask = Image.fromarray(binary_mask)
235
+ binary_mask = np.expand_dims(binary_mask, axis=2)
236
+ crop = Image.fromarray(np.concatenate((session_state['ori_image'], binary_mask), axis=2), "RGBA")
237
+ return seg, mask, crop, session_state
238
 
239
 
240
  def segment_with_box(
 
305
  use_retina=use_retina,
306
  withContours=withContours,
307
  )
308
+ binary_mask = np.where(annotations[0] > 0.5, 255, 0).astype(np.uint8)
309
+ mask = Image.fromarray(binary_mask)
310
+ binary_mask = np.expand_dims(binary_mask, axis=2)
311
+ crop = Image.fromarray(np.concatenate((session_state['ori_image'], binary_mask), axis=2), "RGBA")
312
+ return seg, mask, crop, session_state
313
+ return image, None, None, session_state
314
 
315
 
316
  img_p = gr.Image(label="Input with points", type="pil")
317
  img_b = gr.Image(label="Input with box", type="pil")
318
 
319
+ mask_p = gr.Image(label="Mask", type="pil", interactive=False)
320
+ crop_p = gr.Image(label="Cropped image", type="pil", interactive=False)
321
+
322
+ mask_b = gr.Image(label="Mask", type="pil", interactive=False)
323
+ crop_b = gr.Image(label="Cropped image", type="pil", interactive=False)
324
+
325
  with gr.Blocks(css=css, title="EdgeSAM") as demo:
326
  session_state = gr.State({
327
  'coord_list': [],
 
354
  clear_btn_p = gr.Button("Clear", variant="secondary")
355
  reset_btn_p = gr.Button("Reset", variant="secondary")
356
  with gr.Row():
357
+ mask_p.render()
358
+ crop_p.render()
359
 
360
  with gr.Row():
361
  with gr.Column():
 
363
  gr.Examples(
364
  examples=examples,
365
  inputs=[img_p, session_state],
366
+ outputs=[img_p, mask_p, crop_p, session_state],
367
  examples_per_page=8,
368
  fn=on_image_upload,
369
  run_on_click=True
370
  )
371
+ with gr.Column():
372
+ gr.Markdown(description_p)
373
 
374
  with gr.Tab("Box mode") as tab_b:
375
  # Images
 
380
  with gr.Column():
381
  clear_btn_b = gr.Button("Clear", variant="secondary")
382
  reset_btn_b = gr.Button("Reset", variant="secondary")
383
+ with gr.Row():
384
+ mask_b.render()
385
+ crop_b.render()
386
 
387
  with gr.Row():
388
  with gr.Column():
 
390
  gr.Examples(
391
  examples=examples,
392
  inputs=[img_b, session_state],
393
+ outputs=[img_b, mask_b, crop_b, session_state],
394
  examples_per_page=8,
395
  fn=on_image_upload,
396
  run_on_click=True
397
  )
398
+ with gr.Column():
399
+ gr.Markdown(description_b)
400
 
401
  with gr.Row():
402
  with gr.Column(scale=1):
403
  gr.Markdown(
404
  "<center><img src='https://visitor-badge.laobi.icu/badge?page_id=chongzhou/edgesam' alt='visitors'></center>")
405
 
406
+ img_p.upload(on_image_upload, [img_p, session_state], [img_p, mask_p, crop_p, session_state])
407
+ img_p.select(segment_with_points, [add_or_remove, session_state], [img_p, mask_p, crop_p, session_state])
408
 
409
+ clear_btn_p.click(clear, [session_state], [img_p, mask_p, crop_p, session_state])
410
+ reset_btn_p.click(reset, [session_state], [img_p, mask_p, crop_p, session_state])
411
+ tab_p.select(fn=reset_all, inputs=[session_state], outputs=[img_p, mask_p, crop_p, img_b, mask_b, crop_b, session_state])
412
 
413
+ img_b.upload(on_image_upload, [img_b, session_state], [img_b, mask_b, crop_b, session_state])
414
+ img_b.select(segment_with_box, [session_state], [img_b, mask_b, crop_b, session_state])
415
 
416
+ clear_btn_b.click(clear, [session_state], [img_b, mask_b, crop_b, session_state])
417
+ reset_btn_b.click(reset, [session_state], [img_b, mask_b, crop_b, session_state])
418
+ tab_b.select(fn=reset_all, inputs=[session_state], outputs=[img_p, mask_p, crop_p, img_b, mask_b, crop_b, session_state])
419
 
420
  demo.queue()
421
  # demo.launch(server_name=args.server_name, server_port=args.port)