LMM / mogen /datasets /pipelines /transforms.py
mingyuan's picture
initial commit
373af33
raw
history blame
6.09 kB
import random
from typing import Optional, Union
import numpy as np
import torch
from ..builder import PIPELINES
@PIPELINES.register_module()
class Crop(object):
r"""Crop motion sequences.
Args:
crop_size (int): The size of the cropped motion sequence.
"""
def __init__(self, crop_size: Optional[Union[int, None]] = None):
self.crop_size = crop_size
assert self.crop_size is not None
def __call__(self, results):
motion = results['motion']
length = len(motion)
if length >= self.crop_size:
idx = random.randint(0, length - self.crop_size)
motion = motion[idx:idx + self.crop_size]
results['motion_length'] = self.crop_size
else:
padding_length = self.crop_size - length
D = motion.shape[1:]
padding_zeros = np.zeros((padding_length, *D), dtype=np.float32)
motion = np.concatenate([motion, padding_zeros], axis=0)
results['motion_length'] = length
assert len(motion) == self.crop_size
results['motion'] = motion
results['motion_shape'] = motion.shape
if length >= self.crop_size:
results['motion_mask'] = torch.ones(self.crop_size).numpy()
else:
results['motion_mask'] = torch.cat(
(torch.ones(length),
torch.zeros(self.crop_size - length))).numpy()
return results
def __repr__(self):
repr_str = self.__class__.__name__ + f'(crop_size={self.crop_size})'
return repr_str
@PIPELINES.register_module()
class PairCrop(object):
r"""Crop motion sequences.
Args:
crop_size (int): The size of the cropped motion sequence.
"""
def __init__(self, crop_size: Optional[Union[int, None]] = None):
self.crop_size = crop_size
assert self.crop_size is not None
def __call__(self, results):
motion = results['motion']
raw_motion = results['raw_motion']
length = len(motion)
if length >= self.crop_size:
idx = random.randint(0, length - self.crop_size)
motion = motion[idx:idx + self.crop_size]
raw_motion = raw_motion[idx:idx + self.crop_size]
results['motion_length'] = self.crop_size
else:
padding_length = self.crop_size - length
D = motion.shape[1:]
padding_zeros = np.zeros((padding_length, *D), dtype=np.float32)
motion = np.concatenate([motion, padding_zeros], axis=0)
D = raw_motion.shape[1:]
padding_zeros = np.zeros((padding_length, *D), dtype=np.float32)
raw_motion = np.concatenate([raw_motion, padding_zeros], axis=0)
results['motion_length'] = length
assert len(motion) == self.crop_size
assert len(raw_motion) == self.crop_size
results['motion'] = motion
results['raw_motion'] = raw_motion
results['motion_shape'] = motion.shape
if length >= self.crop_size:
results['motion_mask'] = torch.ones(self.crop_size).numpy()
else:
results['motion_mask'] = torch.cat(
(torch.ones(length),
torch.zeros(self.crop_size - length))).numpy()
return results
def __repr__(self):
repr_str = self.__class__.__name__ + f'(crop_size={self.crop_size})'
return repr_str
@PIPELINES.register_module()
class RandomCrop(object):
r"""Random crop motion sequences. Each sequence will be padded with zeros
to the maximum length.
Args:
min_size (int or None): The minimum size of the cropped motion
sequence (inclusive).
max_size (int or None): The maximum size of the cropped motion
sequence (inclusive).
"""
def __init__(self,
min_size: Optional[Union[int, None]] = None,
max_size: Optional[Union[int, None]] = None):
self.min_size = min_size
self.max_size = max_size
assert self.min_size is not None
assert self.max_size is not None
def __call__(self, results):
motion = results['motion']
length = len(motion)
crop_size = random.randint(self.min_size, self.max_size)
if length > crop_size:
idx = random.randint(0, length - crop_size)
motion = motion[idx:idx + crop_size]
results['motion_length'] = crop_size
else:
results['motion_length'] = length
padding_length = self.max_size - min(crop_size, length)
if padding_length > 0:
D = motion.shape[1:]
padding_zeros = np.zeros((padding_length, *D), dtype=np.float32)
motion = np.concatenate([motion, padding_zeros], axis=0)
results['motion'] = motion
results['motion_shape'] = motion.shape
if length >= self.max_size and crop_size == self.max_size:
results['motion_mask'] = torch.ones(self.max_size).numpy()
else:
results['motion_mask'] = torch.cat(
(torch.ones(min(length, crop_size)),
torch.zeros(self.max_size - min(length, crop_size))),
dim=0).numpy()
assert len(motion) == self.max_size
return results
def __repr__(self):
repr_str = self.__class__.__name__ + f'(min_size={self.min_size}'
repr_str += f', max_size={self.max_size})'
return repr_str
@PIPELINES.register_module()
class Normalize(object):
"""Normalize motion sequences.
Args:
mean_path (str): Path of mean file.
std_path (str): Path of std file.
"""
def __init__(self, mean_path, std_path, eps=1e-9, keys=['motion']):
self.mean = np.load(mean_path)
self.std = np.load(std_path)
self.eps = eps
self.keys = keys
def __call__(self, results):
for k in self.keys:
motion = results[k]
motion = (motion - self.mean) / (self.std + self.eps)
results[k] = motion
return results