Spaces:
Runtime error
Runtime error
File size: 8,125 Bytes
c3a1897 eb902b3 c3a1897 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 |
from transformers import (CLIPProcessor, CLIPModel, AutoProcessor, CLIPSegForImageSegmentation,
OneFormerProcessor, OneFormerForUniversalSegmentation,
BlipProcessor, BlipForConditionalGeneration)
import torch
import mmcv
import torch.nn.functional as F
import numpy as np
import spacy
from PIL import Image
import pycocotools.mask as maskUtils
from models.segment_models.configs.ade20k_id2label import CONFIG as CONFIG_ADE20K_ID2LABEL
from models.segment_models.configs.coco_id2label import CONFIG as CONFIG_COCO_ID2LABEL
# from mmdet.core.visualization.image import imshow_det_bboxes # comment this line if you don't use mmdet
nlp = spacy.load('en_core_web_sm')
class SemanticSegment():
def __init__(self, device):
self.device = device
self.model_init()
def model_init(self):
self.init_clip()
self.init_oneformer_ade20k()
self.init_oneformer_coco()
self.init_blip()
self.init_clipseg()
def init_clip(self):
model_name = "openai/clip-vit-large-patch14"
self.clip_processor = CLIPProcessor.from_pretrained(model_name)
self.clip_model = CLIPModel.from_pretrained(model_name).to(self.device)
def init_oneformer_ade20k(self):
model_name = "shi-labs/oneformer_ade20k_swin_large"
self.oneformer_ade20k_processor = OneFormerProcessor.from_pretrained(model_name)
self.oneformer_ade20k_model = OneFormerForUniversalSegmentation.from_pretrained(model_name).to(self.device)
def init_oneformer_coco(self):
model_name = "shi-labs/oneformer_coco_swin_large"
self.oneformer_coco_processor = OneFormerProcessor.from_pretrained(model_name)
self.oneformer_coco_model = OneFormerForUniversalSegmentation.from_pretrained(model_name).to(self.device)
def init_blip(self):
model_name = "Salesforce/blip-image-captioning-large"
self.blip_processor = BlipProcessor.from_pretrained(model_name)
self.blip_model = BlipForConditionalGeneration.from_pretrained(model_name).to(self.device)
def init_clipseg(self):
model_name = "CIDAS/clipseg-rd64-refined"
self.clipseg_processor = AutoProcessor.from_pretrained(model_name)
self.clipseg_model = CLIPSegForImageSegmentation.from_pretrained(model_name).to(self.device)
self.clipseg_processor.image_processor.do_resize = False
@staticmethod
def get_noun_phrases(text):
doc = nlp(text)
return [chunk.text for chunk in doc.noun_chunks]
def open_vocabulary_classification_blip(self, raw_image):
captioning_inputs = self.blip_processor(raw_image, return_tensors="pt").to(self.device)
out = self.blip_model.generate(**captioning_inputs)
caption = self.blip_processor.decode(out[0], skip_special_tokens=True)
return SemanticSegment.get_noun_phrases(caption)
def oneformer_segmentation(self, image, processor, model):
inputs = processor(images=image, task_inputs=["semantic"], return_tensors="pt").to(self.device)
outputs = model(**inputs)
predicted_semantic_map = processor.post_process_semantic_segmentation(
outputs, target_sizes=[image.size[::-1]])[0]
return predicted_semantic_map
def clip_classification(self, image, class_list, top_k):
inputs = self.clip_processor(text=class_list, images=image, return_tensors="pt", padding=True).to(self.device)
outputs = self.clip_model(**inputs)
logits_per_image = outputs.logits_per_image
probs = logits_per_image.softmax(dim=1)
if top_k == 1:
return class_list[probs.argmax().item()]
else:
top_k_indices = probs.topk(top_k, dim=1).indices[0]
return [class_list[index] for index in top_k_indices]
def clipseg_segmentation(self, image, class_list):
inputs = self.clipseg_processor(
text=class_list, images=[image] * len(class_list),
padding=True, return_tensors="pt").to(self.device)
h, w = inputs['pixel_values'].shape[-2:]
fixed_scale = (512, 512)
inputs['pixel_values'] = F.interpolate(
inputs['pixel_values'],
size=fixed_scale,
mode='bilinear',
align_corners=False)
outputs = self.clipseg_model(**inputs)
logits = F.interpolate(outputs.logits[None], size=(h, w), mode='bilinear', align_corners=False)[0]
return logits
def semantic_class_w_mask(self, img_src, anns, out_file_name="output/test.json", scale_small=1.2, scale_large=1.6):
"""
generate class name for each mask
:param img_src: image path
:param anns: coco annotations, the same as return dict besides "class_name" and "class_proposals"
:param out_file_name: output file name
:param scale_small: scale small
:param scale_large: scale large
:return: dict('segmentation', 'area', 'bbox', 'predicted_iou', 'point_coords', 'stability_score', 'crop_box', "class_name", "class_proposals"})
"""
img = mmcv.imread(img_src)
oneformer_coco_seg = self.oneformer_segmentation(Image.fromarray(img), self.oneformer_coco_processor, self.oneformer_coco_model)
oneformer_ade20k_seg = self.oneformer_segmentation(Image.fromarray(img), self.oneformer_ade20k_processor, self.oneformer_ade20k_model)
bitmasks, class_names = [], []
for ann in anns:
# for ann in anns['annotations']:
valid_mask = torch.tensor((ann['segmentation'])).bool()
# valid_mask = torch.tensor(maskUtils.decode(ann['segmentation'])).bool()
coco_propose_classes_ids = oneformer_coco_seg[valid_mask]
ade20k_propose_classes_ids = oneformer_ade20k_seg[valid_mask]
top_k_coco_propose_classes_ids = torch.bincount(coco_propose_classes_ids.flatten()).topk(1).indices
top_k_ade20k_propose_classes_ids = torch.bincount(ade20k_propose_classes_ids.flatten()).topk(1).indices
local_class_names = {CONFIG_ADE20K_ID2LABEL['id2label'][str(class_id.item())] for class_id in top_k_ade20k_propose_classes_ids}
local_class_names.update({CONFIG_COCO_ID2LABEL['refined_id2label'][str(class_id.item())] for class_id in top_k_coco_propose_classes_ids})
bbox = ann['bbox']
patch_small = mmcv.imcrop(img, np.array([bbox[0], bbox[1], bbox[0] + bbox[2], bbox[1] + bbox[3]]), scale=scale_small)
patch_large = mmcv.imcrop(img, np.array([bbox[0], bbox[1], bbox[0] + bbox[2], bbox[1] + bbox[3]]), scale=scale_large)
op_class_list = self.open_vocabulary_classification_blip(patch_large)
local_class_list = list(local_class_names.union(op_class_list))
top_k = min(len(local_class_list), 3)
mask_categories = self.clip_classification(patch_small, local_class_list, top_k)
class_ids_patch_large = self.clipseg_segmentation(patch_large, mask_categories).argmax(0)
valid_mask_large_crop = mmcv.imcrop(valid_mask.numpy(), np.array([bbox[0], bbox[1], bbox[0] + bbox[2], bbox[1] + bbox[3]]), scale=
scale_large)
top_1_patch_large = torch.bincount(class_ids_patch_large[torch.tensor(valid_mask_large_crop)].flatten()).topk(1).indices
top_1_mask_category = mask_categories[top_1_patch_large.item()]
ann['class_name'] = str(top_1_mask_category)
ann['class_proposals'] = mask_categories
class_names.append(ann['class_name'])
# bitmasks.append(maskUtils.decode(ann['segmentation']))
bitmasks.append((ann['segmentation']))
# mmcv.dump(anns, out_file_name)
return anns
# below for visualization
# imshow_det_bboxes(img,
# bboxes=None,
# labels=np.arange(len(bitmasks)),
# segms=np.stack(bitmasks),
# class_names=class_names,
# font_size=25,
# show=False,
# out_file='output/result2.png') |