yunyangx commited on
Commit
16def94
·
1 Parent(s): 5abd2fe

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +22 -25
app.py CHANGED
@@ -69,15 +69,14 @@ css = "h1 { text-align: center } .about { text-align: justify; padding-left: 10%
69
 
70
  def segment_with_boxs(
71
  image,
72
- seg_image,
73
  input_size=1024,
74
  better_quality=False,
75
  withContours=True,
76
  use_retina=True,
77
  mask_random_color=True,
78
  ):
79
- global global_points
80
- global global_point_label
81
  if len(global_points) < 2:
82
  return seg_image
83
  print("Original Image : ", image.size)
@@ -157,19 +156,18 @@ def segment_with_boxs(
157
  global_points = []
158
  global_point_label = []
159
  # return fig, None
160
- return fig
161
 
162
 
163
  def segment_with_points(
164
- image,
165
  input_size=1024,
166
  better_quality=False,
167
  withContours=True,
168
  use_retina=True,
169
  mask_random_color=True,
170
  ):
171
- global global_points
172
- global global_point_label
173
 
174
  print("Original Image : ", image.size)
175
 
@@ -238,12 +236,11 @@ def segment_with_points(
238
  global_points = []
239
  global_point_label = []
240
  # return fig, None
241
- return fig
 
242
 
 
243
 
244
- def get_points_with_draw(image, cond_image, evt: gr.SelectData):
245
- global global_points
246
- global global_point_label
247
  if len(global_points) == 0:
248
  image = copy.deepcopy(cond_image)
249
  x, y = evt.index[0], evt.index[1]
@@ -266,11 +263,10 @@ def get_points_with_draw(image, cond_image, evt: gr.SelectData):
266
  fill=point_color,
267
  )
268
 
269
- return image
 
 
270
 
271
- def get_points_with_draw_(image, cond_image, evt: gr.SelectData):
272
- global global_points
273
- global global_point_label
274
  if len(global_points) == 0:
275
  image = copy.deepcopy(cond_image)
276
  if len(global_points) > 2:
@@ -319,7 +315,7 @@ def get_points_with_draw_(image, cond_image, evt: gr.SelectData):
319
  global_points[1][0] = x1
320
  global_points[1][1] = y1
321
 
322
- return image
323
 
324
 
325
  cond_img_p = gr.Image(label="Input with Point", value=default_example[0], type="pil")
@@ -332,6 +328,9 @@ segm_img_b = gr.Image(
332
  label="Segmented Image with Box-Prompt", interactive=False, type="pil"
333
  )
334
 
 
 
 
335
  input_size_slider = gr.components.Slider(
336
  minimum=512,
337
  maximum=1024,
@@ -342,8 +341,6 @@ input_size_slider = gr.components.Slider(
342
  )
343
 
344
  with gr.Blocks(css=css, title="Efficient SAM") as demo:
345
- global_points = []
346
- global_point_label = []
347
  with gr.Row():
348
  with gr.Column(scale=1):
349
  # Title
@@ -411,26 +408,26 @@ with gr.Blocks(css=css, title="Efficient SAM") as demo:
411
  # Description
412
  gr.Markdown(description_p)
413
 
414
- cond_img_p.select(get_points_with_draw, [segm_img_p, cond_img_p], segm_img_p)
415
 
416
- cond_img_b.select(get_points_with_draw_, [segm_img_b, cond_img_b], segm_img_b)
417
 
418
  segment_btn_p.click(
419
- segment_with_points, inputs=[cond_img_p], outputs=segm_img_p
420
  )
421
 
422
  segment_btn_b.click(
423
- segment_with_boxs, inputs=[cond_img_b, segm_img_b], outputs=segm_img_b
424
  )
425
 
426
  def clear():
427
- return None, None
428
 
429
  def clear_text():
430
  return None, None, None
431
 
432
- clear_btn_p.click(clear, outputs=[cond_img_p, segm_img_p])
433
- clear_btn_b.click(clear, outputs=[cond_img_b, segm_img_b])
434
 
435
  demo.queue()
436
  demo.launch()
 
69
 
70
  def segment_with_boxs(
71
  image,
72
+ seg_image, global_points, global_point_label,
73
  input_size=1024,
74
  better_quality=False,
75
  withContours=True,
76
  use_retina=True,
77
  mask_random_color=True,
78
  ):
79
+
 
80
  if len(global_points) < 2:
81
  return seg_image
82
  print("Original Image : ", image.size)
 
156
  global_points = []
157
  global_point_label = []
158
  # return fig, None
159
+ return fig, global_points, global_point_label
160
 
161
 
162
  def segment_with_points(
163
+ image, global_points, global_point_label,
164
  input_size=1024,
165
  better_quality=False,
166
  withContours=True,
167
  use_retina=True,
168
  mask_random_color=True,
169
  ):
170
+
 
171
 
172
  print("Original Image : ", image.size)
173
 
 
236
  global_points = []
237
  global_point_label = []
238
  # return fig, None
239
+ return fig, global_points, global_point_label
240
+
241
 
242
+ def get_points_with_draw(image, cond_image, global_points, global_point_label, evt: gr.SelectData):
243
 
 
 
 
244
  if len(global_points) == 0:
245
  image = copy.deepcopy(cond_image)
246
  x, y = evt.index[0], evt.index[1]
 
263
  fill=point_color,
264
  )
265
 
266
+ return image, global_points, global_point_label
267
+
268
+ def get_points_with_draw_(image, cond_image, global_points, global_point_label, evt: gr.SelectData):
269
 
 
 
 
270
  if len(global_points) == 0:
271
  image = copy.deepcopy(cond_image)
272
  if len(global_points) > 2:
 
315
  global_points[1][0] = x1
316
  global_points[1][1] = y1
317
 
318
+ return image, global_points, global_point_label
319
 
320
 
321
  cond_img_p = gr.Image(label="Input with Point", value=default_example[0], type="pil")
 
328
  label="Segmented Image with Box-Prompt", interactive=False, type="pil"
329
  )
330
 
331
+ global_points = gr.State([])
332
+ global_point_label = gr.State([])
333
+
334
  input_size_slider = gr.components.Slider(
335
  minimum=512,
336
  maximum=1024,
 
341
  )
342
 
343
  with gr.Blocks(css=css, title="Efficient SAM") as demo:
 
 
344
  with gr.Row():
345
  with gr.Column(scale=1):
346
  # Title
 
408
  # Description
409
  gr.Markdown(description_p)
410
 
411
+ cond_img_p.select(get_points_with_draw, [segm_img_p, cond_img_p, global_points, global_point_label], [segm_img_p, global_points, global_point_label])
412
 
413
+ cond_img_b.select(get_points_with_draw_, [segm_img_b, cond_img_b, global_points, global_point_label], [segm_img_b, global_points, global_point_label])
414
 
415
  segment_btn_p.click(
416
+ segment_with_points, inputs=[cond_img_p, global_points, global_point_label], outputs=[segm_img_p, global_points, global_point_label]
417
  )
418
 
419
  segment_btn_b.click(
420
+ segment_with_boxs, inputs=[cond_img_b, segm_img_b, global_points, global_point_label], outputs=[segm_img_b,global_points, global_point_label]
421
  )
422
 
423
  def clear():
424
+ return None, None, [], []
425
 
426
  def clear_text():
427
  return None, None, None
428
 
429
+ clear_btn_p.click(clear, outputs=[cond_img_p, segm_img_p, global_points, global_point_label])
430
+ clear_btn_b.click(clear, outputs=[cond_img_b, segm_img_b, global_points, global_point_label])
431
 
432
  demo.queue()
433
  demo.launch()