JeffLiang commited on
Commit
ee2b9bc
·
1 Parent(s): d8659bc

update ovseg + sam

Browse files
open_vocab_seg/modeling/clip_adapter/utils.py CHANGED
@@ -63,7 +63,7 @@ def crop_with_mask(
63
  [image.new_full((1, b - t, r - l), fill_value=val) for val in fill]
64
  )
65
  # return image[:, t:b, l:r], mask[None, t:b, l:r]
66
- return image[:, t:b, l:r] * mask[None, t:b, l:r] + (1 - mask[None, t:b, l:r]) * new_image, mask[None, t:b, l:r]
67
 
68
 
69
  def build_clip_model(model: str, mask_prompt_depth: int = 0, frozen: bool = True):
 
63
  [image.new_full((1, b - t, r - l), fill_value=val) for val in fill]
64
  )
65
  # return image[:, t:b, l:r], mask[None, t:b, l:r]
66
+ return image[:, t:b, l:r] * mask[None, t:b, l:r] + (~ mask[None, t:b, l:r]) * new_image, mask[None, t:b, l:r]
67
 
68
 
69
  def build_clip_model(model: str, mask_prompt_depth: int = 0, frozen: bool = True):
open_vocab_seg/utils/__init__.py CHANGED
@@ -2,4 +2,4 @@
2
  # Copyright (c) Meta Platforms, Inc. All Rights Reserved
3
 
4
  from .events import setup_wandb, WandbWriter
5
- from .predictor import VisualizationDemo
 
2
  # Copyright (c) Meta Platforms, Inc. All Rights Reserved
3
 
4
  from .events import setup_wandb, WandbWriter
5
+ from .predictor import VisualizationDemo, SAMVisualizationDemo
open_vocab_seg/utils/predictor.py CHANGED
@@ -3,11 +3,19 @@
3
 
4
  import numpy as np
5
  import torch
 
 
6
 
7
  from detectron2.data import MetadataCatalog
 
8
  from detectron2.engine.defaults import DefaultPredictor
9
  from detectron2.utils.visualizer import ColorMode, Visualizer
 
10
 
 
 
 
 
11
 
12
  class OVSegPredictor(DefaultPredictor):
13
  def __init__(self, cfg):
@@ -129,4 +137,89 @@ class VisualizationDemo(object):
129
  else:
130
  raise NotImplementedError
131
 
132
- return predictions, vis_output
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
 
4
  import numpy as np
5
  import torch
6
+ from torch.nn import functional as F
7
+ import cv2
8
 
9
  from detectron2.data import MetadataCatalog
10
+ from detectron2.structures import BitMasks
11
  from detectron2.engine.defaults import DefaultPredictor
12
  from detectron2.utils.visualizer import ColorMode, Visualizer
13
+ from detectron2.modeling.postprocessing import sem_seg_postprocess
14
 
15
+ import open_clip
16
+ from segment_anything import SamAutomaticMaskGenerator, sam_model_registry
17
+ from open_vocab_seg.modeling.clip_adapter.adapter import PIXEL_MEAN, PIXEL_STD
18
+ from open_vocab_seg.modeling.clip_adapter.utils import crop_with_mask
19
 
20
  class OVSegPredictor(DefaultPredictor):
21
  def __init__(self, cfg):
 
137
  else:
138
  raise NotImplementedError
139
 
140
+ return predictions, vis_output
141
+
142
+ class SAMVisualizationDemo(object):
143
+ def __init__(self, cfg, granularity, sam_path, ovsegclip_path, instance_mode=ColorMode.IMAGE, parallel=False):
144
+ self.metadata = MetadataCatalog.get(
145
+ cfg.DATASETS.TEST[0] if len(cfg.DATASETS.TEST) else "__unused"
146
+ )
147
+
148
+ self.cpu_device = torch.device("cpu")
149
+ self.instance_mode = instance_mode
150
+
151
+ self.parallel = parallel
152
+ self.granularity = granularity
153
+ sam = sam_model_registry["vit_h"](checkpoint=sam_path)
154
+ self.predictor = SamAutomaticMaskGenerator(sam)
155
+ self.clip_model, _, _ = open_clip.create_model_and_transforms('ViT-L-14', pretrained=ovsegclip_path)
156
+ self.clip_model.cuda()
157
+
158
+ def run_on_image(self, image, class_names):
159
+ image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
160
+ visualizer = OVSegVisualizer(image, self.metadata, instance_mode=self.instance_mode, class_names=class_names)
161
+
162
+ masks = self.predictor.generate(image)
163
+ pred_masks = [masks[i]['segmentation'][None,:,:] for i in range(len(masks))]
164
+ pred_masks = np.row_stack(pred_masks)
165
+ pred_masks = BitMasks(pred_masks)
166
+ bboxes = pred_masks.get_bounding_boxes()
167
+
168
+ mask_fill = [255.0 * c for c in PIXEL_MEAN]
169
+
170
+ image = torch.as_tensor(image.astype("float32").transpose(2, 0, 1))
171
+
172
+ regions = []
173
+ for bbox, mask in zip(bboxes, pred_masks):
174
+ region, _ = crop_with_mask(
175
+ image,
176
+ mask,
177
+ bbox,
178
+ fill=mask_fill,
179
+ )
180
+ regions.append(region.unsqueeze(0))
181
+ regions = [F.interpolate(r.to(torch.float), size=(224, 224), mode="bicubic") for r in regions]
182
+
183
+ pixel_mean = torch.tensor(PIXEL_MEAN).reshape(1, -1, 1, 1)
184
+ pixel_std = torch.tensor(PIXEL_STD).reshape(1, -1, 1, 1)
185
+ imgs = [(r/255.0 - pixel_mean) / pixel_std for r in regions]
186
+ imgs = torch.cat(imgs)
187
+ if len(class_names) == 1:
188
+ class_names.append('others')
189
+ txts = [f'a photo of {cls_name}' for cls_name in class_names]
190
+ text = open_clip.tokenize(txts)
191
+
192
+ with torch.no_grad(), torch.cuda.amp.autocast():
193
+ image_features = self.clip_model.encode_image(imgs)
194
+ text_features = self.clip_model.encode_text(text)
195
+ image_features /= image_features.norm(dim=-1, keepdim=True)
196
+ text_features /= text_features.norm(dim=-1, keepdim=True)
197
+
198
+ class_preds = (100.0 * image_features @ text_features.T).softmax(dim=-1)
199
+ select_cls = torch.zeros_like(class_preds)
200
+
201
+ max_scores, select_mask = torch.max(class_preds, dim=0)
202
+ if len(class_names) == 2 and class_names[-1] == 'others':
203
+ select_mask = select_mask[:-1]
204
+ if self.granularity < 1:
205
+ thr_scores = max_scores * self.granularity
206
+ select_mask = []
207
+ for i, thr in enumerate(thr_scores):
208
+ cls_pred = class_preds[:,i]
209
+ locs = torch.where(cls_pred > thr)
210
+ select_mask.extend(locs[0].tolist())
211
+ for idx in select_mask:
212
+ select_cls[idx] = class_preds[idx]
213
+ semseg = torch.einsum("qc,qhw->chw", select_cls, pred_masks.tensor.float())
214
+
215
+ r = semseg
216
+ blank_area = (r[0] == 0)
217
+ pred_mask = r.argmax(dim=0).to('cpu')
218
+ pred_mask[blank_area] = 255
219
+ pred_mask = np.array(pred_mask, dtype=np.int)
220
+
221
+ vis_output = visualizer.draw_sem_seg(
222
+ pred_mask
223
+ )
224
+
225
+ return None, vis_output
requirements.txt CHANGED
@@ -19,4 +19,7 @@ torchvision==0.11.2+cu113
19
 
20
  # Detectron
21
  --find-links https://dl.fbaipublicfiles.com/detectron2/wheels/cu113/torch1.10/index.html
22
- detectron2
 
 
 
 
19
 
20
  # Detectron
21
  --find-links https://dl.fbaipublicfiles.com/detectron2/wheels/cu113/torch1.10/index.html
22
+ detectron2
23
+
24
+ # Segment-anything
25
+ git+https://github.com/facebookresearch/segment-anything.git