Spaces:
Runtime error
Runtime error
JeffLiang
commited on
Commit
·
ee2b9bc
1
Parent(s):
d8659bc
update ovseg + sam
Browse files
open_vocab_seg/modeling/clip_adapter/utils.py
CHANGED
@@ -63,7 +63,7 @@ def crop_with_mask(
|
|
63 |
[image.new_full((1, b - t, r - l), fill_value=val) for val in fill]
|
64 |
)
|
65 |
# return image[:, t:b, l:r], mask[None, t:b, l:r]
|
66 |
-
return image[:, t:b, l:r] * mask[None, t:b, l:r] + (
|
67 |
|
68 |
|
69 |
def build_clip_model(model: str, mask_prompt_depth: int = 0, frozen: bool = True):
|
|
|
63 |
[image.new_full((1, b - t, r - l), fill_value=val) for val in fill]
|
64 |
)
|
65 |
# return image[:, t:b, l:r], mask[None, t:b, l:r]
|
66 |
+
return image[:, t:b, l:r] * mask[None, t:b, l:r] + (~ mask[None, t:b, l:r]) * new_image, mask[None, t:b, l:r]
|
67 |
|
68 |
|
69 |
def build_clip_model(model: str, mask_prompt_depth: int = 0, frozen: bool = True):
|
open_vocab_seg/utils/__init__.py
CHANGED
@@ -2,4 +2,4 @@
|
|
2 |
# Copyright (c) Meta Platforms, Inc. All Rights Reserved
|
3 |
|
4 |
from .events import setup_wandb, WandbWriter
|
5 |
-
from .predictor import VisualizationDemo
|
|
|
2 |
# Copyright (c) Meta Platforms, Inc. All Rights Reserved
|
3 |
|
4 |
from .events import setup_wandb, WandbWriter
|
5 |
+
from .predictor import VisualizationDemo, SAMVisualizationDemo
|
open_vocab_seg/utils/predictor.py
CHANGED
@@ -3,11 +3,19 @@
|
|
3 |
|
4 |
import numpy as np
|
5 |
import torch
|
|
|
|
|
6 |
|
7 |
from detectron2.data import MetadataCatalog
|
|
|
8 |
from detectron2.engine.defaults import DefaultPredictor
|
9 |
from detectron2.utils.visualizer import ColorMode, Visualizer
|
|
|
10 |
|
|
|
|
|
|
|
|
|
11 |
|
12 |
class OVSegPredictor(DefaultPredictor):
|
13 |
def __init__(self, cfg):
|
@@ -129,4 +137,89 @@ class VisualizationDemo(object):
|
|
129 |
else:
|
130 |
raise NotImplementedError
|
131 |
|
132 |
-
return predictions, vis_output
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
3 |
|
4 |
import numpy as np
|
5 |
import torch
|
6 |
+
from torch.nn import functional as F
|
7 |
+
import cv2
|
8 |
|
9 |
from detectron2.data import MetadataCatalog
|
10 |
+
from detectron2.structures import BitMasks
|
11 |
from detectron2.engine.defaults import DefaultPredictor
|
12 |
from detectron2.utils.visualizer import ColorMode, Visualizer
|
13 |
+
from detectron2.modeling.postprocessing import sem_seg_postprocess
|
14 |
|
15 |
+
import open_clip
|
16 |
+
from segment_anything import SamAutomaticMaskGenerator, sam_model_registry
|
17 |
+
from open_vocab_seg.modeling.clip_adapter.adapter import PIXEL_MEAN, PIXEL_STD
|
18 |
+
from open_vocab_seg.modeling.clip_adapter.utils import crop_with_mask
|
19 |
|
20 |
class OVSegPredictor(DefaultPredictor):
|
21 |
def __init__(self, cfg):
|
|
|
137 |
else:
|
138 |
raise NotImplementedError
|
139 |
|
140 |
+
return predictions, vis_output
|
141 |
+
|
142 |
+
class SAMVisualizationDemo(object):
|
143 |
+
def __init__(self, cfg, granularity, sam_path, ovsegclip_path, instance_mode=ColorMode.IMAGE, parallel=False):
|
144 |
+
self.metadata = MetadataCatalog.get(
|
145 |
+
cfg.DATASETS.TEST[0] if len(cfg.DATASETS.TEST) else "__unused"
|
146 |
+
)
|
147 |
+
|
148 |
+
self.cpu_device = torch.device("cpu")
|
149 |
+
self.instance_mode = instance_mode
|
150 |
+
|
151 |
+
self.parallel = parallel
|
152 |
+
self.granularity = granularity
|
153 |
+
sam = sam_model_registry["vit_h"](checkpoint=sam_path)
|
154 |
+
self.predictor = SamAutomaticMaskGenerator(sam)
|
155 |
+
self.clip_model, _, _ = open_clip.create_model_and_transforms('ViT-L-14', pretrained=ovsegclip_path)
|
156 |
+
self.clip_model.cuda()
|
157 |
+
|
158 |
+
def run_on_image(self, image, class_names):
|
159 |
+
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
160 |
+
visualizer = OVSegVisualizer(image, self.metadata, instance_mode=self.instance_mode, class_names=class_names)
|
161 |
+
|
162 |
+
masks = self.predictor.generate(image)
|
163 |
+
pred_masks = [masks[i]['segmentation'][None,:,:] for i in range(len(masks))]
|
164 |
+
pred_masks = np.row_stack(pred_masks)
|
165 |
+
pred_masks = BitMasks(pred_masks)
|
166 |
+
bboxes = pred_masks.get_bounding_boxes()
|
167 |
+
|
168 |
+
mask_fill = [255.0 * c for c in PIXEL_MEAN]
|
169 |
+
|
170 |
+
image = torch.as_tensor(image.astype("float32").transpose(2, 0, 1))
|
171 |
+
|
172 |
+
regions = []
|
173 |
+
for bbox, mask in zip(bboxes, pred_masks):
|
174 |
+
region, _ = crop_with_mask(
|
175 |
+
image,
|
176 |
+
mask,
|
177 |
+
bbox,
|
178 |
+
fill=mask_fill,
|
179 |
+
)
|
180 |
+
regions.append(region.unsqueeze(0))
|
181 |
+
regions = [F.interpolate(r.to(torch.float), size=(224, 224), mode="bicubic") for r in regions]
|
182 |
+
|
183 |
+
pixel_mean = torch.tensor(PIXEL_MEAN).reshape(1, -1, 1, 1)
|
184 |
+
pixel_std = torch.tensor(PIXEL_STD).reshape(1, -1, 1, 1)
|
185 |
+
imgs = [(r/255.0 - pixel_mean) / pixel_std for r in regions]
|
186 |
+
imgs = torch.cat(imgs)
|
187 |
+
if len(class_names) == 1:
|
188 |
+
class_names.append('others')
|
189 |
+
txts = [f'a photo of {cls_name}' for cls_name in class_names]
|
190 |
+
text = open_clip.tokenize(txts)
|
191 |
+
|
192 |
+
with torch.no_grad(), torch.cuda.amp.autocast():
|
193 |
+
image_features = self.clip_model.encode_image(imgs)
|
194 |
+
text_features = self.clip_model.encode_text(text)
|
195 |
+
image_features /= image_features.norm(dim=-1, keepdim=True)
|
196 |
+
text_features /= text_features.norm(dim=-1, keepdim=True)
|
197 |
+
|
198 |
+
class_preds = (100.0 * image_features @ text_features.T).softmax(dim=-1)
|
199 |
+
select_cls = torch.zeros_like(class_preds)
|
200 |
+
|
201 |
+
max_scores, select_mask = torch.max(class_preds, dim=0)
|
202 |
+
if len(class_names) == 2 and class_names[-1] == 'others':
|
203 |
+
select_mask = select_mask[:-1]
|
204 |
+
if self.granularity < 1:
|
205 |
+
thr_scores = max_scores * self.granularity
|
206 |
+
select_mask = []
|
207 |
+
for i, thr in enumerate(thr_scores):
|
208 |
+
cls_pred = class_preds[:,i]
|
209 |
+
locs = torch.where(cls_pred > thr)
|
210 |
+
select_mask.extend(locs[0].tolist())
|
211 |
+
for idx in select_mask:
|
212 |
+
select_cls[idx] = class_preds[idx]
|
213 |
+
semseg = torch.einsum("qc,qhw->chw", select_cls, pred_masks.tensor.float())
|
214 |
+
|
215 |
+
r = semseg
|
216 |
+
blank_area = (r[0] == 0)
|
217 |
+
pred_mask = r.argmax(dim=0).to('cpu')
|
218 |
+
pred_mask[blank_area] = 255
|
219 |
+
pred_mask = np.array(pred_mask, dtype=np.int)
|
220 |
+
|
221 |
+
vis_output = visualizer.draw_sem_seg(
|
222 |
+
pred_mask
|
223 |
+
)
|
224 |
+
|
225 |
+
return None, vis_output
|
requirements.txt
CHANGED
@@ -19,4 +19,7 @@ torchvision==0.11.2+cu113
|
|
19 |
|
20 |
# Detectron
|
21 |
--find-links https://dl.fbaipublicfiles.com/detectron2/wheels/cu113/torch1.10/index.html
|
22 |
-
detectron2
|
|
|
|
|
|
|
|
19 |
|
20 |
# Detectron
|
21 |
--find-links https://dl.fbaipublicfiles.com/detectron2/wheels/cu113/torch1.10/index.html
|
22 |
+
detectron2
|
23 |
+
|
24 |
+
# Segment-anything
|
25 |
+
git+https://github.com/facebookresearch/segment-anything.git
|