import random from typing import Optional, Union import numpy as np import torch from ..builder import PIPELINES @PIPELINES.register_module() class Crop(object): r"""Crop motion sequences. Args: crop_size (int): The size of the cropped motion sequence. """ def __init__(self, crop_size: Optional[Union[int, None]] = None): self.crop_size = crop_size assert self.crop_size is not None def __call__(self, results): motion = results['motion'] length = len(motion) if length >= self.crop_size: idx = random.randint(0, length - self.crop_size) motion = motion[idx:idx + self.crop_size] results['motion_length'] = self.crop_size else: padding_length = self.crop_size - length D = motion.shape[1:] padding_zeros = np.zeros((padding_length, *D), dtype=np.float32) motion = np.concatenate([motion, padding_zeros], axis=0) results['motion_length'] = length assert len(motion) == self.crop_size results['motion'] = motion results['motion_shape'] = motion.shape if length >= self.crop_size: results['motion_mask'] = torch.ones(self.crop_size).numpy() else: results['motion_mask'] = torch.cat( (torch.ones(length), torch.zeros(self.crop_size - length))).numpy() return results def __repr__(self): repr_str = self.__class__.__name__ + f'(crop_size={self.crop_size})' return repr_str @PIPELINES.register_module() class PairCrop(object): r"""Crop motion sequences. Args: crop_size (int): The size of the cropped motion sequence. """ def __init__(self, crop_size: Optional[Union[int, None]] = None): self.crop_size = crop_size assert self.crop_size is not None def __call__(self, results): motion = results['motion'] raw_motion = results['raw_motion'] length = len(motion) if length >= self.crop_size: idx = random.randint(0, length - self.crop_size) motion = motion[idx:idx + self.crop_size] raw_motion = raw_motion[idx:idx + self.crop_size] results['motion_length'] = self.crop_size else: padding_length = self.crop_size - length D = motion.shape[1:] padding_zeros = np.zeros((padding_length, *D), dtype=np.float32) motion = np.concatenate([motion, padding_zeros], axis=0) D = raw_motion.shape[1:] padding_zeros = np.zeros((padding_length, *D), dtype=np.float32) raw_motion = np.concatenate([raw_motion, padding_zeros], axis=0) results['motion_length'] = length assert len(motion) == self.crop_size assert len(raw_motion) == self.crop_size results['motion'] = motion results['raw_motion'] = raw_motion results['motion_shape'] = motion.shape if length >= self.crop_size: results['motion_mask'] = torch.ones(self.crop_size).numpy() else: results['motion_mask'] = torch.cat( (torch.ones(length), torch.zeros(self.crop_size - length))).numpy() return results def __repr__(self): repr_str = self.__class__.__name__ + f'(crop_size={self.crop_size})' return repr_str @PIPELINES.register_module() class RandomCrop(object): r"""Random crop motion sequences. Each sequence will be padded with zeros to the maximum length. Args: min_size (int or None): The minimum size of the cropped motion sequence (inclusive). max_size (int or None): The maximum size of the cropped motion sequence (inclusive). """ def __init__(self, min_size: Optional[Union[int, None]] = None, max_size: Optional[Union[int, None]] = None): self.min_size = min_size self.max_size = max_size assert self.min_size is not None assert self.max_size is not None def __call__(self, results): motion = results['motion'] length = len(motion) crop_size = random.randint(self.min_size, self.max_size) if length > crop_size: idx = random.randint(0, length - crop_size) motion = motion[idx:idx + crop_size] results['motion_length'] = crop_size else: results['motion_length'] = length padding_length = self.max_size - min(crop_size, length) if padding_length > 0: D = motion.shape[1:] padding_zeros = np.zeros((padding_length, *D), dtype=np.float32) motion = np.concatenate([motion, padding_zeros], axis=0) results['motion'] = motion results['motion_shape'] = motion.shape if length >= self.max_size and crop_size == self.max_size: results['motion_mask'] = torch.ones(self.max_size).numpy() else: results['motion_mask'] = torch.cat( (torch.ones(min(length, crop_size)), torch.zeros(self.max_size - min(length, crop_size))), dim=0).numpy() assert len(motion) == self.max_size return results def __repr__(self): repr_str = self.__class__.__name__ + f'(min_size={self.min_size}' repr_str += f', max_size={self.max_size})' return repr_str @PIPELINES.register_module() class Normalize(object): """Normalize motion sequences. Args: mean_path (str): Path of mean file. std_path (str): Path of std file. """ def __init__(self, mean_path, std_path, eps=1e-9, keys=['motion']): self.mean = np.load(mean_path) self.std = np.load(std_path) self.eps = eps self.keys = keys def __call__(self, results): for k in self.keys: motion = results[k] motion = (motion - self.mean) / (self.std + self.eps) results[k] = motion return results