Tobias Cornille commited on
Commit
94040eb
·
1 Parent(s): d197a83

Add Segments.ai output to Gradio

Browse files
Files changed (1) hide show
  1. app.py +88 -36
app.py CHANGED
@@ -18,21 +18,21 @@ if not os.path.exists("./sam_vit_h_4b8939.pth"):
18
  )
19
  print(f"wget sam_vit_h_4b8939.pth result = {result}")
20
 
21
- import gradio as gr
22
 
23
  import argparse
24
  import random
25
  import warnings
 
 
26
 
 
27
  import numpy as np
28
- import matplotlib.pyplot as plt
29
  import torch
30
  from torch import nn
31
  import torch.nn.functional as F
32
  from scipy import ndimage
33
  from PIL import Image
34
  from huggingface_hub import hf_hub_download
35
- from segments.export import colorize
36
  from segments.utils import bitmap2file
37
 
38
  # Grounding DINO
@@ -262,6 +262,28 @@ def sam_mask_from_points(predictor, image_array, points):
262
  return upsampled_pred
263
 
264
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
265
  def generate_panoptic_mask(
266
  image,
267
  thing_category_names_string,
@@ -271,26 +293,44 @@ def generate_panoptic_mask(
271
  segmentation_background_threshold=0.1,
272
  shrink_kernel_size=20,
273
  num_samples_factor=1000,
 
274
  ):
275
- # parse inputs
276
- thing_category_names = [
277
- thing_category_name.strip()
278
- for thing_category_name in thing_category_names_string.split(",")
279
- ]
280
- stuff_category_names = [
281
- stuff_category_name.strip()
282
- for stuff_category_name in stuff_category_names_string.split(",")
283
- ]
284
- category_names = thing_category_names + stuff_category_names
285
- category_name_to_id = {
286
- category_name: i for i, category_name in enumerate(category_names)
287
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
288
 
289
  image = image.convert("RGB")
290
  image_array = np.asarray(image)
291
 
292
  # detect boxes for "thing" categories using Grounding DINO
293
- thing_boxes, category_ids = dino_detection(
294
  dino_model,
295
  image,
296
  image_array,
@@ -360,14 +400,21 @@ def generate_panoptic_mask(
360
  panoptic_names = (
361
  ["background"]
362
  + stuff_category_names
363
- + [category_names[category_id] for category_id in category_ids]
364
  )
365
  subsection_label_pairs = [
366
  (panoptic_bool_masks[i], panoptic_name)
367
  for i, panoptic_name in enumerate(panoptic_names)
368
  ]
369
 
370
- return (image_array, subsection_label_pairs)
 
 
 
 
 
 
 
371
 
372
 
373
  config_file = "GroundingDINO/groundingdino/config/GroundingDINO_SwinT_OGC.py"
@@ -465,9 +512,27 @@ if __name__ == "__main__":
465
  value=1000,
466
  step=1,
467
  )
 
 
 
468
 
469
  with gr.Column():
470
  annotated_image = gr.AnnotatedImage()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
471
 
472
  examples = gr.Examples(
473
  examples=[
@@ -475,21 +540,11 @@ if __name__ == "__main__":
475
  "a2d2.png",
476
  "car, bus, person",
477
  "road, sky, buildings, sidewalk",
478
- 0.3,
479
- 0.25,
480
- 0.1,
481
- 20,
482
- 1000,
483
  ],
484
  [
485
  "bxl.png",
486
  "car, tram, motorcycle, person",
487
  "road, buildings, sky",
488
- 0.3,
489
- 0.25,
490
- 0.1,
491
- 20,
492
- 1000,
493
  ],
494
  ],
495
  fn=generate_panoptic_mask,
@@ -497,13 +552,8 @@ if __name__ == "__main__":
497
  input_image,
498
  thing_category_names_string,
499
  stuff_category_names_string,
500
- box_threshold,
501
- text_threshold,
502
- segmentation_background_threshold,
503
- shrink_kernel_size,
504
- num_samples_factor,
505
  ],
506
- outputs=[annotated_image],
507
  cache_examples=True,
508
  )
509
 
@@ -518,8 +568,10 @@ if __name__ == "__main__":
518
  segmentation_background_threshold,
519
  shrink_kernel_size,
520
  num_samples_factor,
 
521
  ],
522
- outputs=[annotated_image],
 
523
  )
524
 
525
  block.launch(server_name="0.0.0.0", debug=args.debug, share=args.share)
 
18
  )
19
  print(f"wget sam_vit_h_4b8939.pth result = {result}")
20
 
 
21
 
22
  import argparse
23
  import random
24
  import warnings
25
+ import json
26
+ import tempfile
27
 
28
+ import gradio as gr
29
  import numpy as np
 
30
  import torch
31
  from torch import nn
32
  import torch.nn.functional as F
33
  from scipy import ndimage
34
  from PIL import Image
35
  from huggingface_hub import hf_hub_download
 
36
  from segments.utils import bitmap2file
37
 
38
  # Grounding DINO
 
262
  return upsampled_pred
263
 
264
 
265
+ def inds_to_segments_format(
266
+ panoptic_inds, thing_category_ids, stuff_category_ids, output_file
267
+ ):
268
+ panoptic_inds_array = panoptic_inds.numpy().astype(np.uint32)
269
+ bitmap_file = bitmap2file(panoptic_inds_array, is_segmentation_bitmap=True)
270
+ output_file.write(bitmap_file)
271
+
272
+ unique_inds = np.unique(panoptic_inds_array)
273
+ stuff_annotations = [
274
+ {"id": i + 1, "category_id": stuff_category_id}
275
+ for i, stuff_category_id in enumerate(stuff_category_ids)
276
+ if i in unique_inds
277
+ ]
278
+ thing_annotations = [
279
+ {"id": len(stuff_category_ids) + 1 + i, "category_id": thing_category_id}
280
+ for i, thing_category_id in enumerate(thing_category_ids)
281
+ ]
282
+ annotations = stuff_annotations + thing_annotations
283
+
284
+ return annotations
285
+
286
+
287
  def generate_panoptic_mask(
288
  image,
289
  thing_category_names_string,
 
293
  segmentation_background_threshold=0.1,
294
  shrink_kernel_size=20,
295
  num_samples_factor=1000,
296
+ task_attributes_json=None,
297
  ):
298
+ if task_attributes_json is not None:
299
+ task_attributes = json.loads(task_attributes_json)
300
+ categories = task_attributes["categories"]
301
+ category_name_to_id = {
302
+ category["name"]: category["id"] for category in categories
303
+ }
304
+ # split the categories into "stuff" categories (regions w/o instances)
305
+ # and "thing" categories (objects/instances)
306
+ stuff_categories = [
307
+ category for category in categories if not category["has_instances"]
308
+ ]
309
+ thing_categories = [
310
+ category for category in categories if category["has_instances"]
311
+ ]
312
+ stuff_category_names = [category["name"] for category in stuff_categories]
313
+ thing_category_names = [category["name"] for category in thing_categories]
314
+ else:
315
+ # parse inputs
316
+ thing_category_names = [
317
+ thing_category_name.strip()
318
+ for thing_category_name in thing_category_names_string.split(",")
319
+ ]
320
+ stuff_category_names = [
321
+ stuff_category_name.strip()
322
+ for stuff_category_name in stuff_category_names_string.split(",")
323
+ ]
324
+ category_names = thing_category_names + stuff_category_names
325
+ category_name_to_id = {
326
+ category_name: i for i, category_name in enumerate(category_names)
327
+ }
328
 
329
  image = image.convert("RGB")
330
  image_array = np.asarray(image)
331
 
332
  # detect boxes for "thing" categories using Grounding DINO
333
+ thing_boxes, thing_category_ids = dino_detection(
334
  dino_model,
335
  image,
336
  image_array,
 
400
  panoptic_names = (
401
  ["background"]
402
  + stuff_category_names
403
+ + [category_names[category_id] for category_id in thing_category_ids]
404
  )
405
  subsection_label_pairs = [
406
  (panoptic_bool_masks[i], panoptic_name)
407
  for i, panoptic_name in enumerate(panoptic_names)
408
  ]
409
 
410
+ temp_file = tempfile.NamedTemporaryFile(suffix=".png")
411
+ stuff_category_ids = [category_name_to_id[name] for name in stuff_category_names]
412
+ annotations = inds_to_segments_format(
413
+ panoptic_inds, thing_category_ids, stuff_category_ids, temp_file
414
+ )
415
+ annotations_json = json.dumps(annotations)
416
+
417
+ return (image_array, subsection_label_pairs), temp_file.name, annotations_json
418
 
419
 
420
  config_file = "GroundingDINO/groundingdino/config/GroundingDINO_SwinT_OGC.py"
 
512
  value=1000,
513
  step=1,
514
  )
515
+ task_attributes_json = gr.Textbox(
516
+ label="Task attributes JSON",
517
+ )
518
 
519
  with gr.Column():
520
  annotated_image = gr.AnnotatedImage()
521
+ with gr.Accordion("Segmentation bitmap", open=False):
522
+ segmentation_bitmap_text = gr.Markdown(
523
+ """
524
+ The segmentation bitmap is a 32-bit RGBA png image which contains the segmentation masks.
525
+ The alpha channel is set to 255, and the remaining 24-bit values in the RGB channels correspond to the object ids in the annotations list.
526
+ Unlabeled regions have a value of 0.
527
+ Because of the large dynamic range, these png images may appear black in an image viewer.
528
+ """
529
+ )
530
+ segmentation_bitmap = gr.Image(
531
+ type="filepath", label="Segmentation bitmap"
532
+ )
533
+ annotations_json = gr.Textbox(
534
+ label="Annotations JSON",
535
+ )
536
 
537
  examples = gr.Examples(
538
  examples=[
 
540
  "a2d2.png",
541
  "car, bus, person",
542
  "road, sky, buildings, sidewalk",
 
 
 
 
 
543
  ],
544
  [
545
  "bxl.png",
546
  "car, tram, motorcycle, person",
547
  "road, buildings, sky",
 
 
 
 
 
548
  ],
549
  ],
550
  fn=generate_panoptic_mask,
 
552
  input_image,
553
  thing_category_names_string,
554
  stuff_category_names_string,
 
 
 
 
 
555
  ],
556
+ outputs=[annotated_image, segmentation_bitmap, annotations_json],
557
  cache_examples=True,
558
  )
559
 
 
568
  segmentation_background_threshold,
569
  shrink_kernel_size,
570
  num_samples_factor,
571
+ task_attributes_json,
572
  ],
573
+ outputs=[annotated_image, segmentation_bitmap, annotations_json],
574
+ api_name="segment",
575
  )
576
 
577
  block.launch(server_name="0.0.0.0", debug=args.debug, share=args.share)