Spaces:
Running
on
A10G
Running
on
A10G
# Copyright (c) Facebook, Inc. and its affiliates. | |
import numpy as np | |
from typing import List | |
import annotator.oneformer.pycocotools.mask as mask_util | |
from annotator.oneformer.detectron2.structures import Instances | |
from annotator.oneformer.detectron2.utils.visualizer import ( | |
ColorMode, | |
Visualizer, | |
_create_text_labels, | |
_PanopticPrediction, | |
) | |
from .colormap import random_color, random_colors | |
class _DetectedInstance: | |
""" | |
Used to store data about detected objects in video frame, | |
in order to transfer color to objects in the future frames. | |
Attributes: | |
label (int): | |
bbox (tuple[float]): | |
mask_rle (dict): | |
color (tuple[float]): RGB colors in range (0, 1) | |
ttl (int): time-to-live for the instance. For example, if ttl=2, | |
the instance color can be transferred to objects in the next two frames. | |
""" | |
__slots__ = ["label", "bbox", "mask_rle", "color", "ttl"] | |
def __init__(self, label, bbox, mask_rle, color, ttl): | |
self.label = label | |
self.bbox = bbox | |
self.mask_rle = mask_rle | |
self.color = color | |
self.ttl = ttl | |
class VideoVisualizer: | |
def __init__(self, metadata, instance_mode=ColorMode.IMAGE): | |
""" | |
Args: | |
metadata (MetadataCatalog): image metadata. | |
""" | |
self.metadata = metadata | |
self._old_instances = [] | |
assert instance_mode in [ | |
ColorMode.IMAGE, | |
ColorMode.IMAGE_BW, | |
], "Other mode not supported yet." | |
self._instance_mode = instance_mode | |
self._max_num_instances = self.metadata.get("max_num_instances", 74) | |
self._assigned_colors = {} | |
self._color_pool = random_colors(self._max_num_instances, rgb=True, maximum=1) | |
self._color_idx_set = set(range(len(self._color_pool))) | |
def draw_instance_predictions(self, frame, predictions): | |
""" | |
Draw instance-level prediction results on an image. | |
Args: | |
frame (ndarray): an RGB image of shape (H, W, C), in the range [0, 255]. | |
predictions (Instances): the output of an instance detection/segmentation | |
model. Following fields will be used to draw: | |
"pred_boxes", "pred_classes", "scores", "pred_masks" (or "pred_masks_rle"). | |
Returns: | |
output (VisImage): image object with visualizations. | |
""" | |
frame_visualizer = Visualizer(frame, self.metadata) | |
num_instances = len(predictions) | |
if num_instances == 0: | |
return frame_visualizer.output | |
boxes = predictions.pred_boxes.tensor.numpy() if predictions.has("pred_boxes") else None | |
scores = predictions.scores if predictions.has("scores") else None | |
classes = predictions.pred_classes.numpy() if predictions.has("pred_classes") else None | |
keypoints = predictions.pred_keypoints if predictions.has("pred_keypoints") else None | |
colors = predictions.COLOR if predictions.has("COLOR") else [None] * len(predictions) | |
periods = predictions.ID_period if predictions.has("ID_period") else None | |
period_threshold = self.metadata.get("period_threshold", 0) | |
visibilities = ( | |
[True] * len(predictions) | |
if periods is None | |
else [x > period_threshold for x in periods] | |
) | |
if predictions.has("pred_masks"): | |
masks = predictions.pred_masks | |
# mask IOU is not yet enabled | |
# masks_rles = mask_util.encode(np.asarray(masks.permute(1, 2, 0), order="F")) | |
# assert len(masks_rles) == num_instances | |
else: | |
masks = None | |
if not predictions.has("COLOR"): | |
if predictions.has("ID"): | |
colors = self._assign_colors_by_id(predictions) | |
else: | |
# ToDo: clean old assign color method and use a default tracker to assign id | |
detected = [ | |
_DetectedInstance(classes[i], boxes[i], mask_rle=None, color=colors[i], ttl=8) | |
for i in range(num_instances) | |
] | |
colors = self._assign_colors(detected) | |
labels = _create_text_labels(classes, scores, self.metadata.get("thing_classes", None)) | |
if self._instance_mode == ColorMode.IMAGE_BW: | |
# any() returns uint8 tensor | |
frame_visualizer.output.reset_image( | |
frame_visualizer._create_grayscale_image( | |
(masks.any(dim=0) > 0).numpy() if masks is not None else None | |
) | |
) | |
alpha = 0.3 | |
else: | |
alpha = 0.5 | |
labels = ( | |
None | |
if labels is None | |
else [y[0] for y in filter(lambda x: x[1], zip(labels, visibilities))] | |
) # noqa | |
assigned_colors = ( | |
None | |
if colors is None | |
else [y[0] for y in filter(lambda x: x[1], zip(colors, visibilities))] | |
) # noqa | |
frame_visualizer.overlay_instances( | |
boxes=None if masks is not None else boxes[visibilities], # boxes are a bit distracting | |
masks=None if masks is None else masks[visibilities], | |
labels=labels, | |
keypoints=None if keypoints is None else keypoints[visibilities], | |
assigned_colors=assigned_colors, | |
alpha=alpha, | |
) | |
return frame_visualizer.output | |
def draw_sem_seg(self, frame, sem_seg, area_threshold=None): | |
""" | |
Args: | |
sem_seg (ndarray or Tensor): semantic segmentation of shape (H, W), | |
each value is the integer label. | |
area_threshold (Optional[int]): only draw segmentations larger than the threshold | |
""" | |
# don't need to do anything special | |
frame_visualizer = Visualizer(frame, self.metadata) | |
frame_visualizer.draw_sem_seg(sem_seg, area_threshold=None) | |
return frame_visualizer.output | |
def draw_panoptic_seg_predictions( | |
self, frame, panoptic_seg, segments_info, area_threshold=None, alpha=0.5 | |
): | |
frame_visualizer = Visualizer(frame, self.metadata) | |
pred = _PanopticPrediction(panoptic_seg, segments_info, self.metadata) | |
if self._instance_mode == ColorMode.IMAGE_BW: | |
frame_visualizer.output.reset_image( | |
frame_visualizer._create_grayscale_image(pred.non_empty_mask()) | |
) | |
# draw mask for all semantic segments first i.e. "stuff" | |
for mask, sinfo in pred.semantic_masks(): | |
category_idx = sinfo["category_id"] | |
try: | |
mask_color = [x / 255 for x in self.metadata.stuff_colors[category_idx]] | |
except AttributeError: | |
mask_color = None | |
frame_visualizer.draw_binary_mask( | |
mask, | |
color=mask_color, | |
text=self.metadata.stuff_classes[category_idx], | |
alpha=alpha, | |
area_threshold=area_threshold, | |
) | |
all_instances = list(pred.instance_masks()) | |
if len(all_instances) == 0: | |
return frame_visualizer.output | |
# draw mask for all instances second | |
masks, sinfo = list(zip(*all_instances)) | |
num_instances = len(masks) | |
masks_rles = mask_util.encode( | |
np.asarray(np.asarray(masks).transpose(1, 2, 0), dtype=np.uint8, order="F") | |
) | |
assert len(masks_rles) == num_instances | |
category_ids = [x["category_id"] for x in sinfo] | |
detected = [ | |
_DetectedInstance(category_ids[i], bbox=None, mask_rle=masks_rles[i], color=None, ttl=8) | |
for i in range(num_instances) | |
] | |
colors = self._assign_colors(detected) | |
labels = [self.metadata.thing_classes[k] for k in category_ids] | |
frame_visualizer.overlay_instances( | |
boxes=None, | |
masks=masks, | |
labels=labels, | |
keypoints=None, | |
assigned_colors=colors, | |
alpha=alpha, | |
) | |
return frame_visualizer.output | |
def _assign_colors(self, instances): | |
""" | |
Naive tracking heuristics to assign same color to the same instance, | |
will update the internal state of tracked instances. | |
Returns: | |
list[tuple[float]]: list of colors. | |
""" | |
# Compute iou with either boxes or masks: | |
is_crowd = np.zeros((len(instances),), dtype=bool) | |
if instances[0].bbox is None: | |
assert instances[0].mask_rle is not None | |
# use mask iou only when box iou is None | |
# because box seems good enough | |
rles_old = [x.mask_rle for x in self._old_instances] | |
rles_new = [x.mask_rle for x in instances] | |
ious = mask_util.iou(rles_old, rles_new, is_crowd) | |
threshold = 0.5 | |
else: | |
boxes_old = [x.bbox for x in self._old_instances] | |
boxes_new = [x.bbox for x in instances] | |
ious = mask_util.iou(boxes_old, boxes_new, is_crowd) | |
threshold = 0.6 | |
if len(ious) == 0: | |
ious = np.zeros((len(self._old_instances), len(instances)), dtype="float32") | |
# Only allow matching instances of the same label: | |
for old_idx, old in enumerate(self._old_instances): | |
for new_idx, new in enumerate(instances): | |
if old.label != new.label: | |
ious[old_idx, new_idx] = 0 | |
matched_new_per_old = np.asarray(ious).argmax(axis=1) | |
max_iou_per_old = np.asarray(ious).max(axis=1) | |
# Try to find match for each old instance: | |
extra_instances = [] | |
for idx, inst in enumerate(self._old_instances): | |
if max_iou_per_old[idx] > threshold: | |
newidx = matched_new_per_old[idx] | |
if instances[newidx].color is None: | |
instances[newidx].color = inst.color | |
continue | |
# If an old instance does not match any new instances, | |
# keep it for the next frame in case it is just missed by the detector | |
inst.ttl -= 1 | |
if inst.ttl > 0: | |
extra_instances.append(inst) | |
# Assign random color to newly-detected instances: | |
for inst in instances: | |
if inst.color is None: | |
inst.color = random_color(rgb=True, maximum=1) | |
self._old_instances = instances[:] + extra_instances | |
return [d.color for d in instances] | |
def _assign_colors_by_id(self, instances: Instances) -> List: | |
colors = [] | |
untracked_ids = set(self._assigned_colors.keys()) | |
for id in instances.ID: | |
if id in self._assigned_colors: | |
colors.append(self._color_pool[self._assigned_colors[id]]) | |
untracked_ids.remove(id) | |
else: | |
assert ( | |
len(self._color_idx_set) >= 1 | |
), f"Number of id exceeded maximum, \ | |
max = {self._max_num_instances}" | |
idx = self._color_idx_set.pop() | |
color = self._color_pool[idx] | |
self._assigned_colors[id] = idx | |
colors.append(color) | |
for id in untracked_ids: | |
self._color_idx_set.add(self._assigned_colors[id]) | |
del self._assigned_colors[id] | |
return colors | |