Spaces:
Build error
Build error
File size: 29,225 Bytes
6155c0e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 |
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
# Note: This file has been barrowed from facebookresearch/slowfast repo. And it is used to add the bounding boxes and predictions to the frame.
# TODO: Migrate this into the core PyTorchVideo libarary.
from __future__ import annotations
import itertools
# import logging
from types import SimpleNamespace
from typing import Dict, List, Optional, Tuple, Union
import matplotlib.pyplot as plt
import numpy as np
import torch
from detectron2.utils.visualizer import Visualizer
# logger = logging.getLogger(__name__)
def _create_text_labels(
classes: List[int],
scores: List[float],
class_names: List[str],
ground_truth: bool = False,
) -> List[str]:
"""
Create text labels.
Args:
classes (list[int]): a list of class ids for each example.
scores (list[float] or None): list of scores for each example.
class_names (list[str]): a list of class names, ordered by their ids.
ground_truth (bool): whether the labels are ground truth.
Returns:
labels (list[str]): formatted text labels.
"""
try:
labels = [class_names.get(c, "n/a") for c in classes]
except IndexError:
# logger.error("Class indices get out of range: {}".format(classes))
return None
if ground_truth:
labels = ["[{}] {}".format("GT", label) for label in labels]
elif scores is not None:
assert len(classes) == len(scores)
labels = ["[{:.2f}] {}".format(s, label) for s, label in zip(scores, labels)]
return labels
class ImgVisualizer(Visualizer):
def __init__(
self, img_rgb: torch.Tensor, meta: Optional[SimpleNamespace] = None, **kwargs
) -> None:
"""
See https://github.com/facebookresearch/detectron2/blob/main/detectron2/utils/visualizer.py
for more details.
Args:
img_rgb: a tensor or numpy array of shape (H, W, C), where H and W correspond to
the height and width of the image respectively. C is the number of
color channels. The image is required to be in RGB format since that
is a requirement of the Matplotlib library. The image is also expected
to be in the range [0, 255].
meta (MetadataCatalog): image metadata.
See https://github.com/facebookresearch/detectron2/blob/81d5a87763bfc71a492b5be89b74179bd7492f6b/detectron2/data/catalog.py#L90
"""
super(ImgVisualizer, self).__init__(img_rgb, meta, **kwargs)
def draw_text(
self,
text: str,
position: List[int],
*,
font_size: Optional[int] = None,
color: str = "w",
horizontal_alignment: str = "center",
vertical_alignment: str = "bottom",
box_facecolor: str = "black",
alpha: float = 0.5,
) -> None:
"""
Draw text at the specified position.
Args:
text (str): the text to draw on image.
position (list of 2 ints): the x,y coordinate to place the text.
font_size (Optional[int]): font of the text. If not provided, a font size
proportional to the image width is calculated and used.
color (str): color of the text. Refer to `matplotlib.colors` for full list
of formats that are accepted.
horizontal_alignment (str): see `matplotlib.text.Text`.
vertical_alignment (str): see `matplotlib.text.Text`.
box_facecolor (str): color of the box wrapped around the text. Refer to
`matplotlib.colors` for full list of formats that are accepted.
alpha (float): transparency level of the box.
"""
if not font_size:
font_size = self._default_font_size
x, y = position
self.output.ax.text(
x,
y,
text,
size=font_size * self.output.scale,
family="monospace",
bbox={
"facecolor": box_facecolor,
"alpha": alpha,
"pad": 0.7,
"edgecolor": "none",
},
verticalalignment=vertical_alignment,
horizontalalignment=horizontal_alignment,
color=color,
zorder=10,
)
def draw_multiple_text(
self,
text_ls: List[str],
box_coordinate: torch.Tensor,
*,
top_corner: bool = True,
font_size: Optional[int] = None,
color: str = "w",
box_facecolors: str = "black",
alpha: float = 0.5,
) -> None:
"""
Draw a list of text labels for some bounding box on the image.
Args:
text_ls (list of strings): a list of text labels.
box_coordinate (tensor): shape (4,). The (x_left, y_top, x_right, y_bottom)
coordinates of the box.
top_corner (bool): If True, draw the text labels at (x_left, y_top) of the box.
Else, draw labels at (x_left, y_bottom).
font_size (Optional[int]): font of the text. If not provided, a font size
proportional to the image width is calculated and used.
color (str): color of the text. Refer to `matplotlib.colors` for full list
of formats that are accepted.
box_facecolors (str): colors of the box wrapped around the text. Refer to
`matplotlib.colors` for full list of formats that are accepted.
alpha (float): transparency level of the box.
"""
if not isinstance(box_facecolors, list):
box_facecolors = [box_facecolors] * len(text_ls)
assert len(box_facecolors) == len(
text_ls
), "Number of colors provided is not equal to the number of text labels."
if not font_size:
font_size = self._default_font_size
text_box_width = font_size + font_size // 2
# If the texts does not fit in the assigned location,
# we split the text and draw it in another place.
if top_corner:
num_text_split = self._align_y_top(
box_coordinate, len(text_ls), text_box_width
)
y_corner = 1
else:
num_text_split = len(text_ls) - self._align_y_bottom(
box_coordinate, len(text_ls), text_box_width
)
y_corner = 3
text_color_sorted = sorted(
zip(text_ls, box_facecolors), key=lambda x: x[0], reverse=True
)
if len(text_color_sorted) != 0:
text_ls, box_facecolors = zip(*text_color_sorted)
else:
text_ls, box_facecolors = [], []
text_ls, box_facecolors = list(text_ls), list(box_facecolors)
self.draw_multiple_text_upward(
text_ls[:num_text_split][::-1],
box_coordinate,
y_corner=y_corner,
font_size=font_size,
color=color,
box_facecolors=box_facecolors[:num_text_split][::-1],
alpha=alpha,
)
self.draw_multiple_text_downward(
text_ls[num_text_split:],
box_coordinate,
y_corner=y_corner,
font_size=font_size,
color=color,
box_facecolors=box_facecolors[num_text_split:],
alpha=alpha,
)
def draw_multiple_text_upward(
self,
text_ls: List[str],
box_coordinate: torch.Tensor,
*,
y_corner: int = 1,
font_size: Optional[int] = None,
color: str = "w",
box_facecolors: str = "black",
alpha: float = 0.5,
) -> None:
"""
Draw a list of text labels for some bounding box on the image in upward direction.
The next text label will be on top of the previous one.
Args:
text_ls (list of strings): a list of text labels.
box_coordinate (tensor): shape (4,). The (x_left, y_top, x_right, y_bottom)
coordinates of the box.
y_corner (int): Value of either 1 or 3. Indicate the index of the y-coordinate of
the box to draw labels around.
font_size (Optional[int]): font of the text. If not provided, a font size
proportional to the image width is calculated and used.
color (str): color of the text. Refer to `matplotlib.colors` for full list
of formats that are accepted.
box_facecolors (str or list of strs): colors of the box wrapped around the
text. Refer to `matplotlib.colors` for full list of formats that
are accepted.
alpha (float): transparency level of the box.
"""
if not isinstance(box_facecolors, list):
box_facecolors = [box_facecolors] * len(text_ls)
assert len(box_facecolors) == len(
text_ls
), "Number of colors provided is not equal to the number of text labels."
assert y_corner in [1, 3], "Y_corner must be either 1 or 3"
if not font_size:
font_size = self._default_font_size
x, horizontal_alignment = self._align_x_coordinate(box_coordinate)
y = box_coordinate[y_corner].item()
for i, text in enumerate(text_ls):
self.draw_text(
text,
(x, y),
font_size=font_size,
color=color,
horizontal_alignment=horizontal_alignment,
vertical_alignment="bottom",
box_facecolor=box_facecolors[i],
alpha=alpha,
)
y -= font_size + font_size // 2
def draw_multiple_text_downward(
self,
text_ls: List[str],
box_coordinate: torch.Tensor,
*,
y_corner: int = 1,
font_size: Optional[int] = None,
color: str = "w",
box_facecolors: str = "black",
alpha: float = 0.5,
) -> None:
"""
Draw a list of text labels for some bounding box on the image in downward direction.
The next text label will be below the previous one.
Args:
text_ls (list of strings): a list of text labels.
box_coordinate (tensor): shape (4,). The (x_left, y_top, x_right, y_bottom)
coordinates of the box.
y_corner (int): Value of either 1 or 3. Indicate the index of the y-coordinate of
the box to draw labels around.
font_size (Optional[int]): font of the text. If not provided, a font size
proportional to the image width is calculated and used.
color (str): color of the text. Refer to `matplotlib.colors` for full list
of formats that are accepted.
box_facecolors (str): colors of the box wrapped around the text. Refer to
`matplotlib.colors` for full list of formats that are accepted.
alpha (float): transparency level of the box.
"""
if not isinstance(box_facecolors, list):
box_facecolors = [box_facecolors] * len(text_ls)
assert len(box_facecolors) == len(
text_ls
), "Number of colors provided is not equal to the number of text labels."
assert y_corner in [1, 3], "Y_corner must be either 1 or 3"
if not font_size:
font_size = self._default_font_size
x, horizontal_alignment = self._align_x_coordinate(box_coordinate)
y = box_coordinate[y_corner].item()
for i, text in enumerate(text_ls):
self.draw_text(
text,
(x, y),
font_size=font_size,
color=color,
horizontal_alignment=horizontal_alignment,
vertical_alignment="top",
box_facecolor=box_facecolors[i],
alpha=alpha,
)
y += font_size + font_size // 2
def _align_x_coordinate(self, box_coordinate: torch.Tensor) -> Tuple[float, str]:
"""
Choose an x-coordinate from the box to make sure the text label
does not go out of frames. By default, the left x-coordinate is
chosen and text is aligned left. If the box is too close to the
right side of the image, then the right x-coordinate is chosen
instead and the text is aligned right.
Args:
box_coordinate (array-like): shape (4,). The (x_left, y_top, x_right, y_bottom)
coordinates of the box.
Returns:
x_coordinate (float): the chosen x-coordinate.
alignment (str): whether to align left or right.
"""
# If the x-coordinate is greater than 5/6 of the image width,
# then we align test to the right of the box. This is
# chosen by heuristics.
if box_coordinate[0] > (self.output.width * 5) // 6:
return box_coordinate[2], "right"
return box_coordinate[0], "left"
def _align_y_top(
self, box_coordinate: torch.Tensor, num_text: int, textbox_width: float
) -> int:
"""
Calculate the number of text labels to plot on top of the box
without going out of frames.
Args:
box_coordinate (array-like): shape (4,). The (x_left, y_top, x_right, y_bottom)
coordinates of the box.
num_text (int): the number of text labels to plot.
textbox_width (float): the width of the box wrapped around text label.
"""
dist_to_top = box_coordinate[1]
num_text_top = dist_to_top // textbox_width
if isinstance(num_text_top, torch.Tensor):
num_text_top = int(num_text_top.item())
return min(num_text, num_text_top)
def _align_y_bottom(
self, box_coordinate: torch.Tensor, num_text: int, textbox_width: float
) -> int:
"""
Calculate the number of text labels to plot at the bottom of the box
without going out of frames.
Args:
box_coordinate (array-like): shape (4,). The (x_left, y_top, x_right, y_bottom)
coordinates of the box.
num_text (int): the number of text labels to plot.
textbox_width (float): the width of the box wrapped around text label.
"""
dist_to_bottom = self.output.height - box_coordinate[3]
num_text_bottom = dist_to_bottom // textbox_width
if isinstance(num_text_bottom, torch.Tensor):
num_text_bottom = int(num_text_bottom.item())
return min(num_text, num_text_bottom)
class VideoVisualizer:
def __init__(
self,
num_classes: int,
class_names: Dict,
top_k: int = 1,
colormap: str = "rainbow",
thres: float = 0.7,
lower_thres: float = 0.3,
common_class_names: Optional[List[str]] = None,
mode: str = "top-k",
) -> None:
"""
Args:
num_classes (int): total number of classes.
class_names (dict): Dict mapping classID to name.
top_k (int): number of top predicted classes to plot.
colormap (str): the colormap to choose color for class labels from.
See https://matplotlib.org/tutorials/colors/colormaps.html
thres (float): threshold for picking predicted classes to visualize.
lower_thres (Optional[float]): If `common_class_names` if given,
this `lower_thres` will be applied to uncommon classes and
`thres` will be applied to classes in `common_class_names`.
common_class_names (Optional[list of str]): list of common class names
to apply `thres`. Class names not included in `common_class_names` will
have `lower_thres` as a threshold. If None, all classes will have
`thres` as a threshold. This is helpful for model trained on
highly imbalanced dataset.
mode (str): Supported modes are {"top-k", "thres"}.
This is used for choosing predictions for visualization.
"""
assert mode in ["top-k", "thres"], "Mode {} is not supported.".format(mode)
self.mode = mode
self.num_classes = num_classes
self.class_names = class_names
self.top_k = top_k
self.thres = thres
self.lower_thres = lower_thres
if mode == "thres":
self._get_thres_array(common_class_names=common_class_names)
self.color_map = plt.get_cmap(colormap)
def _get_color(self, class_id: int) -> List[float]:
"""
Get color for a class id.
Args:
class_id (int): class id.
"""
return self.color_map(class_id / self.num_classes)[:3]
def draw_one_frame(
self,
frame: Union[torch.Tensor, np.ndarray],
preds: Union[torch.Tensor, List[float]],
bboxes: Optional[torch.Tensor] = None,
alpha: float = 0.5,
text_alpha: float = 0.7,
ground_truth: bool = False,
) -> np.ndarray:
"""
Draw labels and bouding boxes for one image. By default, predicted
labels are drawn in the top left corner of the image or corresponding
bounding boxes. For ground truth labels (setting True for ground_truth flag),
labels will be drawn in the bottom left corner.
Args:
frame (array-like): a tensor or numpy array of shape (H, W, C),
where H and W correspond to
the height and width of the image respectively. C is the number of
color channels. The image is required to be in RGB format since that
is a requirement of the Matplotlib library. The image is also expected
to be in the range [0, 255].
preds (tensor or list): If ground_truth is False, provide a float tensor of
shape (num_boxes, num_classes) that contains all of the confidence
scores of the model. For recognition task, input shape can be (num_classes,).
To plot true label (ground_truth is True), preds is a list contains int32
of the shape (num_boxes, true_class_ids) or (true_class_ids,).
bboxes (Optional[tensor]): shape (num_boxes, 4) that contains the coordinates
of the bounding boxes.
alpha (Optional[float]): transparency level of the bounding boxes.
text_alpha (Optional[float]): transparency level of the box wrapped around
text labels.
ground_truth (bool): whether the prodived bounding boxes are ground-truth.
Returns:
An image with bounding box annotations and corresponding bbox
labels plotted on it.
"""
if isinstance(preds, torch.Tensor):
if preds.ndim == 1:
preds = preds.unsqueeze(0)
n_instances = preds.shape[0]
elif isinstance(preds, list):
n_instances = len(preds)
else:
# logger.error("Unsupported type of prediction input.")
return
if ground_truth:
top_scores, top_classes = [None] * n_instances, preds
elif self.mode == "top-k":
top_scores, top_classes = torch.topk(preds, k=self.top_k)
top_scores, top_classes = top_scores.tolist(), top_classes.tolist()
elif self.mode == "thres":
top_scores, top_classes = [], []
for pred in preds:
mask = pred >= self.thres
top_scores.append(pred[mask].tolist())
top_class = torch.squeeze(torch.nonzero(mask), dim=-1).tolist()
top_classes.append(top_class)
# Create labels top k predicted classes with their scores.
text_labels = []
for i in range(n_instances):
text_labels.append(
_create_text_labels(
top_classes[i],
top_scores[i],
self.class_names,
ground_truth=ground_truth,
)
)
frame_visualizer = ImgVisualizer(frame, meta=None)
font_size = min(max(np.sqrt(frame.shape[0] * frame.shape[1]) // 25, 5), 9)
top_corner = not ground_truth
if bboxes is not None:
assert len(preds) == len(
bboxes
), "Encounter {} predictions and {} bounding boxes".format(
len(preds), len(bboxes)
)
for i, box in enumerate(bboxes):
text = text_labels[i]
pred_class = top_classes[i]
colors = [self._get_color(pred) for pred in pred_class]
box_color = "r" if ground_truth else "g"
line_style = "--" if ground_truth else "-."
frame_visualizer.draw_box(
box,
alpha=alpha,
edge_color=box_color,
line_style=line_style,
)
frame_visualizer.draw_multiple_text(
text,
box,
top_corner=top_corner,
font_size=font_size,
box_facecolors=colors,
alpha=text_alpha,
)
else:
text = text_labels[0]
pred_class = top_classes[0]
colors = [self._get_color(pred) for pred in pred_class]
frame_visualizer.draw_multiple_text(
text,
torch.Tensor([0, 5, frame.shape[1], frame.shape[0] - 5]),
top_corner=top_corner,
font_size=font_size,
box_facecolors=colors,
alpha=text_alpha,
)
return frame_visualizer.output.get_image()
def draw_clip_range(
self,
frames: Union[torch.Tensor, np.ndarray],
preds: Union[torch.Tensor, List[float]],
bboxes: Optional[torch.Tensor] = None,
text_alpha: float = 0.5,
ground_truth: bool = False,
keyframe_idx: Optional[int] = None,
draw_range: Optional[List[int]] = None,
repeat_frame: int = 1,
) -> List[np.ndarray]:
"""
Draw predicted labels or ground truth classes to clip.
Draw bouding boxes to clip if bboxes is provided. Boxes will gradually
fade in and out the clip, centered around the clip's central frame,
within the provided `draw_range`.
Args:
frames (array-like): video data in the shape (T, H, W, C).
preds (tensor): a tensor of shape (num_boxes, num_classes) that
contains all of the confidence scores of the model. For recognition
task or for ground_truth labels, input shape can be (num_classes,).
bboxes (Optional[tensor]): shape (num_boxes, 4) that contains the coordinates
of the bounding boxes.
text_alpha (float): transparency label of the box wrapped around text labels.
ground_truth (bool): whether the prodived bounding boxes are ground-truth.
keyframe_idx (int): the index of keyframe in the clip.
draw_range (Optional[list[ints]): only draw frames in range
[start_idx, end_idx] inclusively in the clip. If None, draw on
the entire clip.
repeat_frame (int): repeat each frame in draw_range for `repeat_frame`
time for slow-motion effect.
Returns:
A list of frames with bounding box annotations and corresponding
bbox labels ploted on them.
"""
if draw_range is None:
draw_range = [0, len(frames) - 1]
if draw_range is not None:
draw_range[0] = max(0, draw_range[0])
left_frames = frames[: draw_range[0]]
right_frames = frames[draw_range[1] + 1 :]
draw_frames = frames[draw_range[0] : draw_range[1] + 1]
if keyframe_idx is None:
keyframe_idx = len(frames) // 2
img_ls = (
list(left_frames)
+ self.draw_clip(
draw_frames,
preds,
bboxes=bboxes,
text_alpha=text_alpha,
ground_truth=ground_truth,
keyframe_idx=keyframe_idx - draw_range[0],
repeat_frame=repeat_frame,
)
+ list(right_frames)
)
return img_ls
def draw_clip(
self,
frames: Union[torch.Tensor, np.ndarray],
preds: Union[torch.Tensor, List[float]],
bboxes: Optional[torch.Tensor] = None,
text_alpha: float = 0.5,
ground_truth: bool = False,
keyframe_idx: Optional[int] = None,
repeat_frame: int = 1,
) -> List[np.ndarray]:
"""
Draw predicted labels or ground truth classes to clip. Draw bouding boxes to clip
if bboxes is provided. Boxes will gradually fade in and out the clip, centered
around the clip's central frame.
Args:
frames (array-like): video data in the shape (T, H, W, C).
preds (tensor): a tensor of shape (num_boxes, num_classes) that contains
all of the confidence scores of the model. For recognition task or for
ground_truth labels, input shape can be (num_classes,).
bboxes (Optional[tensor]): shape (num_boxes, 4) that contains the coordinates
of the bounding boxes.
text_alpha (float): transparency label of the box wrapped around text labels.
ground_truth (bool): whether the prodived bounding boxes are ground-truth.
keyframe_idx (int): the index of keyframe in the clip.
repeat_frame (int): repeat each frame in draw_range for `repeat_frame`
time for slow-motion effect.
Returns:
A list of frames with bounding box annotations and corresponding
bbox labels plotted on them.
"""
assert repeat_frame >= 1, "`repeat_frame` must be a positive integer."
repeated_seq = range(0, len(frames))
repeated_seq = list(
itertools.chain.from_iterable(
itertools.repeat(x, repeat_frame) for x in repeated_seq
)
)
frames, adjusted = self._adjust_frames_type(frames)
if keyframe_idx is None:
half_left = len(repeated_seq) // 2
half_right = (len(repeated_seq) + 1) // 2
else:
mid = int((keyframe_idx / len(frames)) * len(repeated_seq))
half_left = mid
half_right = len(repeated_seq) - mid
alpha_ls = np.concatenate(
[
np.linspace(0, 1, num=half_left),
np.linspace(1, 0, num=half_right),
]
)
text_alpha = text_alpha
frames = frames[repeated_seq]
img_ls = []
for alpha, frame in zip(alpha_ls, frames):
draw_img = self.draw_one_frame(
frame,
preds,
bboxes,
alpha=alpha,
text_alpha=text_alpha,
ground_truth=ground_truth,
)
if adjusted:
draw_img = draw_img.astype("float32") / 255
img_ls.append(draw_img)
return img_ls
def _adjust_frames_type(
self, frames: torch.Tensor
) -> Tuple[List[np.ndarray], bool]:
"""
Modify video data to have dtype of uint8 and values range in [0, 255].
Args:
frames (array-like): 4D array of shape (T, H, W, C).
Returns:
frames (list of frames): list of frames in range [0, 1].
adjusted (bool): whether the original frames need adjusted.
"""
assert (
frames is not None and len(frames) != 0
), "Frames does not contain any values"
frames = np.array(frames)
assert np.array(frames).ndim == 4, "Frames must have 4 dimensions"
adjusted = False
if frames.dtype in [np.float32, np.float64]:
frames *= 255
frames = frames.astype(np.uint8)
adjusted = True
return frames, adjusted
def _get_thres_array(self, common_class_names: Optional[List[str]] = None) -> None:
"""
Compute a thresholds array for all classes based on `self.thes` and `self.lower_thres`.
Args:
common_class_names (Optional[list of str]): a list of common class names.
"""
common_class_ids = []
if common_class_names is not None:
common_classes = set(common_class_names)
for key, name in self.class_names.items():
if name in common_classes:
common_class_ids.append(key)
else:
common_class_ids = list(range(self.num_classes))
thres_array = np.full(shape=(self.num_classes,), fill_value=self.lower_thres)
thres_array[common_class_ids] = self.thres
self.thres = torch.from_numpy(thres_array) |