|
import torch |
|
import random |
|
import numbers |
|
from torchvision.transforms import RandomCrop, RandomResizedCrop |
|
|
|
def _is_tensor_video_clip(clip): |
|
if not torch.is_tensor(clip): |
|
raise TypeError("clip should be Tensor. Got %s" % type(clip)) |
|
|
|
if not clip.ndimension() == 4: |
|
raise ValueError("clip should be 4D. Got %dD" % clip.dim()) |
|
|
|
return True |
|
|
|
|
|
def to_tensor(clip): |
|
""" |
|
Convert tensor data type from uint8 to float, divide value by 255.0 and |
|
permute the dimensions of clip tensor |
|
Args: |
|
clip (torch.tensor, dtype=torch.uint8): Size is (T, C, H, W) |
|
Return: |
|
clip (torch.tensor, dtype=torch.float): Size is (T, C, H, W) |
|
""" |
|
_is_tensor_video_clip(clip) |
|
if not clip.dtype == torch.uint8: |
|
raise TypeError("clip tensor should have data type uint8. Got %s" % str(clip.dtype)) |
|
|
|
return clip.float() / 255.0 |
|
|
|
|
|
def resize(clip, target_size, interpolation_mode): |
|
if len(target_size) != 2: |
|
raise ValueError(f"target size should be tuple (height, width), instead got {target_size}") |
|
return torch.nn.functional.interpolate(clip, size=target_size, mode=interpolation_mode, align_corners=False) |
|
|
|
|
|
class ToTensorVideo: |
|
""" |
|
Convert tensor data type from uint8 to float, divide value by 255.0 and |
|
permute the dimensions of clip tensor |
|
""" |
|
|
|
def __init__(self): |
|
pass |
|
|
|
def __call__(self, clip): |
|
""" |
|
Args: |
|
clip (torch.tensor, dtype=torch.uint8): Size is (T, C, H, W) |
|
Return: |
|
clip (torch.tensor, dtype=torch.float): Size is (T, C, H, W) |
|
""" |
|
return to_tensor(clip) |
|
|
|
def __repr__(self) -> str: |
|
return self.__class__.__name__ |
|
|
|
|
|
class ResizeVideo: |
|
''' |
|
Resize to the specified size |
|
''' |
|
def __init__( |
|
self, |
|
size, |
|
interpolation_mode="bilinear", |
|
): |
|
if isinstance(size, tuple): |
|
if len(size) != 2: |
|
raise ValueError(f"size should be tuple (height, width), instead got {size}") |
|
self.size = size |
|
else: |
|
self.size = (size, size) |
|
|
|
self.interpolation_mode = interpolation_mode |
|
|
|
|
|
def __call__(self, clip): |
|
""" |
|
Args: |
|
clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W) |
|
Returns: |
|
torch.tensor: scale resized video clip. |
|
size is (T, C, h, w) |
|
""" |
|
clip_resize = resize(clip, target_size=self.size, interpolation_mode=self.interpolation_mode) |
|
return clip_resize |
|
|
|
def __repr__(self) -> str: |
|
return f"{self.__class__.__name__}(size={self.size}, interpolation_mode={self.interpolation_mode}" |
|
|
|
|
|
class TemporalRandomCrop(object): |
|
"""Temporally crop the given frame indices at a random location. |
|
|
|
Args: |
|
size (int): Desired length of frames will be seen in the model. |
|
""" |
|
|
|
def __init__(self, size): |
|
self.size = size |
|
|
|
def __call__(self, total_frames): |
|
rand_end = max(0, total_frames - self.size - 1) |
|
begin_index = random.randint(0, rand_end) |
|
end_index = min(begin_index + self.size, total_frames) |
|
return begin_index, end_index |
|
|
|
|