Satyajithchary commited on
Commit
97d2179
·
verified ·
1 Parent(s): d68e573

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +56 -76
app.py CHANGED
@@ -1,65 +1,29 @@
1
  import streamlit as st
2
  import cv2
3
  import numpy as np
4
- from PIL import Image
5
- import torch
6
-
7
-
8
- !git clone --recursive https://github.com/frank-xwang/UnSAM.git
9
- !python -m pip install 'git+https://github.com/MaureenZOU/detectron2-xyz.git'
10
- %cd UnSAM
11
- !python -m pip install -r requirements.txt
12
-
13
- # uncomment the following lines if you want to run with GPU
14
- %cd whole_image_segmentation/mask2former/modeling/pixel_decoder/ops
15
- !sh make.sh
16
- # clone and install Mask2Former
17
- !git clone https://github.com/facebookresearch/Mask2Former.git
18
- %cd Mask2Former
19
- !pip install -U opencv-python
20
- !pip install git+https://github.com/cocodataset/panopticapi.git
21
- !pip install -r requirements.txt
22
- %cd mask2former/modeling/pixel_decoder/ops
23
- !python setup.py build install
24
- %cd ../../../../
25
- %cd /kaggle/working/Mask2Former
26
- #%cd /kaggle/working/UnSAM/whole_image_segmentation/mask2former/modeling/pixel_decoder/ops/Mask2Former
27
-
28
- #%cd /kaggle/working/UnSAM/whole_image_segmentation/mask2former
29
- import detectron2
30
- from detectron2.utils.logger import setup_logger
31
- setup_logger()
32
- setup_logger(name="mask2former")
33
-
34
- # import some common libraries
35
- import numpy as np
36
- import cv2
37
  import torch
38
-
39
-
40
-
41
-
42
- from detectron2.engine import DefaultPredictor
43
  from detectron2.config import get_cfg
44
  from detectron2.projects.deeplab import add_deeplab_config
45
- from detectron2.utils.colormap import random_color
46
  from mask2former import add_maskformer2_config
47
- from tqdm import tqdm
 
48
 
49
- def setup_predictor(config_file, weights_path, device='cpu'):
 
50
  cfg = get_cfg()
51
  cfg.set_new_allowed(True)
52
  add_deeplab_config(cfg)
53
  add_maskformer2_config(cfg)
54
- cfg.merge_from_file(config_file)
55
  cfg.MODEL.WEIGHTS = weights_path
56
- cfg.MODEL.DEVICE = device
57
- predictor = DefaultPredictor(cfg)
58
- return predictor
59
 
60
  def area(mask):
61
- if mask.size == 0:
62
- return 0
63
  return np.count_nonzero(mask) / mask.size
64
 
65
  def vis_mask(input, mask, mask_color):
@@ -71,7 +35,7 @@ def vis_mask(input, mask, mask_color):
71
  def show_image(I, pool):
72
  already_painted = np.zeros(np.array(I).shape[:2])
73
  input = I.copy()
74
- for mask in tqdm(pool):
75
  already_painted += mask.astype(np.uint8)
76
  overlap = (already_painted == 2)
77
  if np.sum(overlap) != 0:
@@ -79,40 +43,56 @@ def show_image(I, pool):
79
  already_painted -= overlap
80
  input = vis_mask(input, mask, random_color(rgb=True))
81
  return input
 
82
 
83
- # Load UnSAM and UnSAM+ predictors
84
- unsam_predictor = setup_predictor(
85
- "/kaggle/working/UnSAM/whole_image_segmentation/configs/maskformer2_R50_bs16_50ep.yaml",
86
- "/kaggle/working/Mask2Former/unsam_sa1b_4perc_ckpt_200k.pth"
87
- )
88
- unsam_plus_predictor = setup_predictor(
89
- "/kaggle/working/UnSAM/whole_image_segmentation/configs/maskformer2_R50_bs16_50ep.yaml",
90
- "/kaggle/working/Mask2Former/unsam_plus_sa1b_1perc_ckpt_50k.pth"
91
- )
92
 
93
- st.title("Image Segmentation with UnSAM and UnSAM+")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
94
 
95
- # Upload image
96
- uploaded_file = st.file_uploader("Choose an image...", type="png")
97
 
98
  if uploaded_file is not None:
99
- # Read the image
100
- image = np.array(Image.open(uploaded_file))
101
 
102
- # Display the original image
103
- st.image(image, caption='Original Image', use_column_width=True)
104
 
105
- # Run predictions for UnSAM+
106
- unsam_plus_outputs = unsam_plus_predictor(image)['instances']
107
- unsam_plus_masks = [mask.cpu().numpy() for mask in unsam_plus_outputs.pred_masks]
108
- sorted_unsam_plus_masks = sorted(unsam_plus_masks, key=lambda m: area(m), reverse=True)
109
- unsam_plus_image = show_image(image, sorted_unsam_plus_masks)
110
 
111
- # Run predictions for UnSAM
112
- unsam_outputs = unsam_predictor(image)['instances']
113
- unsam_masks = [mask.cpu().numpy() for mask in unsam_outputs.pred_masks]
114
- sorted_unsam_masks = sorted(unsam_masks, key=lambda m: area(m), reverse=True)
115
- unsam_image = show_image(image, sorted_unsam_masks)
116
 
117
- # Display the images side by side
118
- st.image([image, unsam_plus_image, unsam_image], caption=['Original Image', 'UnSAM+ Output', 'UnSAM Output'], use_column_width=True)
 
 
 
 
 
 
1
  import streamlit as st
2
  import cv2
3
  import numpy as np
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
  import torch
5
+ from PIL import Image
 
 
 
 
6
  from detectron2.config import get_cfg
7
  from detectron2.projects.deeplab import add_deeplab_config
8
+ from detectron2.engine import DefaultPredictor
9
  from mask2former import add_maskformer2_config
10
+ from detectron2.utils.colormap import random_color
11
+ import os
12
 
13
+ @st.cache_resource
14
+ def setup_config(weights_path):
15
  cfg = get_cfg()
16
  cfg.set_new_allowed(True)
17
  add_deeplab_config(cfg)
18
  add_maskformer2_config(cfg)
19
+ cfg.merge_from_file("configs/maskformer2_R50_bs16_50ep.yaml")
20
  cfg.MODEL.WEIGHTS = weights_path
21
+ cfg.MODEL.DEVICE = "cpu" # Use CPU for inference
22
+ cfg.freeze()
23
+ return cfg
24
 
25
  def area(mask):
26
+ if mask.size == 0: return 0
 
27
  return np.count_nonzero(mask) / mask.size
28
 
29
  def vis_mask(input, mask, mask_color):
 
35
  def show_image(I, pool):
36
  already_painted = np.zeros(np.array(I).shape[:2])
37
  input = I.copy()
38
+ for mask in pool:
39
  already_painted += mask.astype(np.uint8)
40
  overlap = (already_painted == 2)
41
  if np.sum(overlap) != 0:
 
43
  already_painted -= overlap
44
  input = vis_mask(input, mask, random_color(rgb=True))
45
  return input
46
+ import gdown
47
 
48
+ gdown.download("https://drive.google.com/uc?id=1sCZM5j2pQr34-scSEkgG7VmUaHJc8n4d", "unsam_plus_sa1b_1perc_ckpt_50k.pth", quiet=False)
49
+ gdown.download("https://drive.google.com/uc?id=1qUdZ2ELU_5SNTsmx3Q0wSA87u4SebiO4", "unsam_sa1b_4perc_ckpt_200k.pth", quiet=False)
 
 
 
 
 
 
 
50
 
51
+ @st.cache_data
52
+ def process_image(image, model_type):
53
+ if model_type == "UNSAM+":
54
+ weights_path = "unsam_plus_sa1b_1perc_ckpt_50k.pth"
55
+ else: # UNSAM
56
+ weights_path = "unsam_sa1b_4perc_ckpt_200k.pth"
57
+
58
+ cfg = setup_config(weights_path)
59
+ predictor = DefaultPredictor(cfg)
60
+
61
+ inputs = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
62
+ outputs = predictor(inputs)['instances']
63
+
64
+ masks = []
65
+ for score, mask in zip(outputs.scores, outputs.pred_masks):
66
+ if score < 0.5: continue
67
+ masks.append(mask.cpu().numpy())
68
+
69
+ sorted_masks = sorted(masks, key=lambda m: area(m), reverse=True)
70
+ result_image = show_image(np.array(image), sorted_masks)
71
+
72
+ return result_image
73
+
74
+ st.title("UNSAM and UNSAM+ Image Segmentation")
75
 
76
+ uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"])
 
77
 
78
  if uploaded_file is not None:
79
+ image = Image.open(uploaded_file)
 
80
 
81
+ col1, col2, col3 = st.columns(3)
 
82
 
83
+ with col1:
84
+ st.header("Original Image")
85
+ st.image(image, use_column_width=True)
 
 
86
 
87
+ with col2:
88
+ st.header("UNSAM+ Output")
89
+ unsam_plus_output = process_image(image, "UNSAM+")
90
+ st.image(unsam_plus_output, use_column_width=True)
 
91
 
92
+ with col3:
93
+ st.header("UNSAM Output")
94
+ unsam_output = process_image(image, "UNSAM")
95
+ st.image(unsam_output, use_column_width=True)
96
+
97
+ else:
98
+ st.write("Please upload an image to see the segmentation results.")