|
import copy |
|
import os |
|
import pickle as pkl |
|
from typing import Optional, Union, List |
|
|
|
import numpy as np |
|
import torch |
|
import torch.nn as nn |
|
import json |
|
from torch.utils.data import ConcatDataset, Dataset, WeightedRandomSampler |
|
from .builder import DATASETS |
|
from .pipelines import Compose, RetargetSkeleton |
|
import random |
|
import pytorch3d.transforms as geometry |
|
from scipy.ndimage import gaussian_filter |
|
|
|
|
|
from mogen.models.builder import build_submodule |
|
from .utils import copy_repr_data, extract_repr_data, move_repr_data, recover_from_ric |
|
|
|
class SingleMotionVerseDataset(Dataset): |
|
""" |
|
A dataset class for handling single MotionVerse datasets. |
|
|
|
Args: |
|
dataset_name (str): Name of the dataset and task to load. |
|
data_prefix (str): Path to the directory containing the dataset. |
|
ann_file (str): Path to the annotation file. |
|
pipeline (list): A list of transformations to apply on the data. |
|
mode (str): the mode of current work. Choices: ['pretrain', 'train', 'test']. |
|
eval_cfg (dict): Configuration for evaluation metrics. |
|
""" |
|
|
|
def __init__(self, |
|
dataset_path: Optional[str] = None, |
|
task_name: Optional[str] = None, |
|
data_prefix: Optional[str] = None, |
|
ann_file: Optional[str] = None, |
|
pipeline: Optional[List[dict]] = None, |
|
|
|
|
|
tgt_min_motion_length: int = 20, |
|
tgt_max_motion_length: int = 200, |
|
|
|
|
|
v2m_window_size: int = 20, |
|
|
|
|
|
mp_input_length: int = 50, |
|
mp_output_length: int = 25, |
|
mp_stride_step: int = 5, |
|
|
|
|
|
test_rotation_type: str = 'h3d_rot', |
|
target_framerate: float = 20, |
|
eval_cfg: Optional[dict] = None, |
|
test_mode: Optional[bool] = False): |
|
data_prefix = os.path.join(data_prefix, 'datasets', dataset_path) |
|
self.dataset_path = dataset_path |
|
assert task_name in ['mocap', 't2m', 'v2m', 's2g', 'm2d'] |
|
self.task_name = task_name |
|
self.dataset_name = dataset_path + '_' + task_name |
|
|
|
|
|
self.meta_dir = os.path.join(data_prefix, 'metas') |
|
self.motion_dir = os.path.join(data_prefix, 'motions') |
|
self.eval_motion_dir = os.path.join(data_prefix, 'eval_motions') |
|
self.text_dir = os.path.join(data_prefix, 'texts') |
|
self.text_feat_dir = os.path.join(data_prefix, 'text_feats') |
|
self.speech_dir = os.path.join(data_prefix, 'speeches') |
|
self.speech_feat_dir = os.path.join(data_prefix, 'speech_feats') |
|
self.music_dir = os.path.join(data_prefix, 'musics') |
|
self.music_feat_dir = os.path.join(data_prefix, 'music_feats') |
|
self.video_feat_dir = os.path.join(data_prefix, 'video_feats') |
|
self.anno_file = os.path.join(data_prefix, 'splits', ann_file) |
|
|
|
self.pipeline = Compose(pipeline) |
|
|
|
self.tgt_min_motion_length = tgt_min_motion_length |
|
self.tgt_max_motion_length = tgt_max_motion_length |
|
|
|
self.v2m_window_size = v2m_window_size |
|
|
|
self.mp_input_length = mp_input_length |
|
self.mp_output_length = mp_output_length |
|
self.mp_stride_step = mp_stride_step |
|
|
|
self.target_framerate = target_framerate |
|
self.test_rotation_type = test_rotation_type |
|
self.test_mode = test_mode |
|
self.load_annotations() |
|
self.eval_cfg = copy.deepcopy(eval_cfg) |
|
if self.test_mode: |
|
self.prepare_evaluation() |
|
|
|
def __len__(self) -> int: |
|
"""Return the length of the current dataset.""" |
|
if self.test_mode: |
|
return len(self.eval_indexes) |
|
return len(self.name_list) |
|
|
|
def __getitem__(self, idx: int) -> dict: |
|
"""Prepare data for the given index.""" |
|
if self.test_mode: |
|
idx = self.eval_indexes[idx] |
|
return self.prepare_data(idx) |
|
|
|
def load_annotations(self): |
|
if self.task_name == 'mocap': |
|
self.load_annotations_mocap() |
|
elif self.task_name == 't2m': |
|
self.load_annotations_t2m() |
|
elif self.task_name == 'v2m': |
|
self.load_annotations_v2m() |
|
elif self.task_name == 's2g': |
|
self.load_annotations_s2g() |
|
elif self.task_name == 'm2d': |
|
self.load_annotations_m2d() |
|
else: |
|
raise NotImplementedError() |
|
|
|
def load_annotations_mocap(self): |
|
if self.test_mode: |
|
self.name_list = [] |
|
self.src_start_frame = [] |
|
self.src_end_frame = [] |
|
self.tgt_start_frame = [] |
|
self.tgt_end_frame = [] |
|
tgt_motion_length = self.mp_input_length + self.mp_output_length |
|
for name in open(self.anno_file): |
|
name = name.strip() |
|
meta_path = os.path.join(self.meta_dir, name + ".json") |
|
meta_data = json.load(open(meta_path)) |
|
num_frames = meta_data['num_frames'] |
|
downrate = int(meta_data['framerate'] / self.target_framerate + 0.1) |
|
if num_frames < (self.mp_input_length + self.mp_output_length) * downrate: |
|
continue |
|
lim = num_frames // downrate - tgt_motion_length |
|
for start_frame in range(0, lim, self.mp_stride_step): |
|
self.name_list.append(name) |
|
self.src_start_frame.append((start_frame + 1) * downrate) |
|
self.src_end_frame.append((start_frame + tgt_motion_length + 1) * downrate) |
|
self.tgt_start_frame.append(start_frame + self.mp_input_length) |
|
self.tgt_end_frame.append(start_frame + tgt_motion_length) |
|
else: |
|
self.name_list = [] |
|
for name in open(self.anno_file): |
|
name = name.strip() |
|
self.name_list.append(name) |
|
|
|
def load_annotations_t2m(self): |
|
self.name_list = [] |
|
self.text_idx = [] |
|
for name in open(self.anno_file): |
|
name = name.strip() |
|
meta_path = os.path.join(self.meta_dir, name + ".json") |
|
meta_data = json.load(open(meta_path)) |
|
downrate = int(meta_data['framerate'] / self.target_framerate + 0.1) |
|
text_path = os.path.join(self.text_dir, name + ".json") |
|
text_data = json.load(open(text_path)) |
|
for i, anno in enumerate(text_data): |
|
start_frame = anno['start_frame'] // downrate |
|
end_frame = min(anno['end_frame'], meta_data['num_frames']) // downrate |
|
num_frame = end_frame - start_frame |
|
if num_frame < self.tgt_min_motion_length or num_frame > self.tgt_max_motion_length: |
|
continue |
|
if len(anno['body_text']) > 0: |
|
self.name_list.append(name) |
|
self.text_idx.append(i) |
|
|
|
def load_annotations_v2m(self): |
|
if not self.test_mode: |
|
self.name_list = [] |
|
for name in open(self.anno_file): |
|
name = name.strip() |
|
self.name_list.append(name) |
|
else: |
|
self.name_list = [] |
|
self.start_frame = [] |
|
self.end_frame = [] |
|
self.valid_start_frame = [] |
|
self.valid_end_frame = [] |
|
for name in open(self.anno_file): |
|
name = name.strip() |
|
meta_path = os.path.join(self.meta_dir, name + ".json") |
|
meta_data = json.load(open(meta_path)) |
|
num_frames = meta_data['num_frames'] |
|
assert num_frames >= self.v2m_window_size |
|
cur_idx = 0 |
|
while cur_idx < num_frames: |
|
if cur_idx + self.v2m_window_size < num_frames: |
|
self.name_list.append(name) |
|
self.start_frame.append(cur_idx) |
|
self.end_frame.append(cur_idx + self.v2m_window_size) |
|
self.valid_start_frame.append(cur_idx) |
|
self.valid_end_frame.append(cur_idx + self.v2m_window_size) |
|
cur_idx += self.v2m_window_size |
|
else: |
|
self.name_list.append(name) |
|
self.start_frame.append(num_frames - self.v2m_window_size) |
|
self.end_frame.append(num_frames) |
|
self.valid_start_frame.append(cur_idx) |
|
self.valid_end_frame.append(num_frames) |
|
break |
|
|
|
def load_annotations_s2g(self): |
|
self.name_list = [] |
|
self.speech_idx = [] |
|
for name in open(self.anno_file): |
|
name = name.strip() |
|
meta_path = os.path.join(self.meta_dir, name + ".json") |
|
meta_data = json.load(open(meta_path)) |
|
downrate = int(meta_data['framerate'] / self.target_framerate + 0.1) |
|
speech_path = os.path.join(self.speech_dir, name + ".json") |
|
speech_data = json.load(open(speech_path)) |
|
for i, anno in enumerate(speech_data): |
|
start_frame = anno['start_frame'] // downrate |
|
end_frame = min(anno['end_frame'], meta_data['num_frames']) // downrate |
|
num_frame = end_frame - start_frame |
|
if num_frame < self.tgt_min_motion_length or num_frame > self.tgt_max_motion_length: |
|
continue |
|
self.name_list.append(name) |
|
self.speech_idx.append(i) |
|
|
|
def load_annotations_m2d(self): |
|
self.name_list = [] |
|
self.music_idx = [] |
|
for name in open(self.anno_file): |
|
name = name.strip() |
|
meta_path = os.path.join(self.meta_dir, name + ".json") |
|
meta_data = json.load(open(meta_path)) |
|
downrate = int(meta_data['framerate'] / self.target_framerate + 0.1) |
|
music_path = os.path.join(self.music_dir, name + ".json") |
|
music_data = json.load(open(music_path)) |
|
for i, anno in enumerate(music_data): |
|
start_frame = anno['start_frame'] // downrate |
|
end_frame = min(anno['end_frame'], meta_data['num_frames']) // downrate |
|
num_frame = end_frame - start_frame |
|
if num_frame < self.tgt_min_motion_length or num_frame > self.tgt_max_motion_length: |
|
continue |
|
self.name_list.append(name) |
|
self.music_idx.append(i) |
|
|
|
def prepare_data_base(self, idx: int) -> dict: |
|
results = {} |
|
name = self.name_list[idx] |
|
results['motion_path'] = os.path.join(self.motion_dir, name + ".npz") |
|
meta_path = os.path.join(self.meta_dir, name + ".json") |
|
meta_data = json.load(open(meta_path)) |
|
meta_data['dataset_name'] = self.dataset_name |
|
results['meta_data'] = meta_data |
|
results['meta_data']['sample_idx'] = idx |
|
results.update({ |
|
'text_word_feat': np.zeros((77, 1024)).astype(np.float32), |
|
'text_seq_feat': np.zeros((1024)).astype(np.float32), |
|
'text_cond': 0, |
|
'music_word_feat': np.zeros((229, 768)).astype(np.float32), |
|
'music_seq_feat': np.zeros((1024)).astype(np.float32), |
|
'music_cond': 0, |
|
'speech_word_feat': np.zeros((229, 768)).astype(np.float32), |
|
'speech_seq_feat': np.zeros((1024)).astype(np.float32), |
|
'speech_cond': 0, |
|
'video_seq_feat': np.zeros((1024)).astype(np.float32), |
|
'video_cond': 0, |
|
}) |
|
return results |
|
|
|
def prepare_data(self, idx: int) -> dict: |
|
if self.task_name == 'mocap': |
|
results = self.prepare_data_mocap(idx) |
|
elif self.task_name == 't2m': |
|
results = self.prepare_data_t2m(idx) |
|
elif self.task_name == 'v2m': |
|
results = self.prepare_data_v2m(idx) |
|
elif self.task_name == 's2g': |
|
results = self.prepare_data_s2g(idx) |
|
elif self.task_name == 'm2d': |
|
results = self.prepare_data_m2d(idx) |
|
else: |
|
raise NotImplementedError() |
|
results = self.pipeline(results) |
|
return results |
|
|
|
def prepare_data_mocap(self, idx: int) -> dict: |
|
results = self.prepare_data_base(idx) |
|
if self.test_mode: |
|
results['meta_data']['start_frame'] = self.src_start_frame[idx] |
|
results['meta_data']['end_frame'] = self.src_end_frame[idx] |
|
results['context_mask'] = np.concatenate( |
|
(np.ones((self.mp_input_length - 1)), np.zeros((self.mp_output_length))), |
|
axis=-1 |
|
) |
|
return results |
|
|
|
def prepare_data_t2m(self, idx: int) -> dict: |
|
results = self.prepare_data_base(idx) |
|
name = self.name_list[idx] |
|
text_idx = self.text_idx[idx] |
|
text_path = os.path.join(self.text_dir, name + ".json") |
|
text_data = json.load(open(text_path))[text_idx] |
|
text_feat_path = os.path.join(self.text_feat_dir, name + ".pkl") |
|
text_feat_data = pkl.load(open(text_feat_path, "rb")) |
|
text_list = text_data['body_text'] |
|
tid = np.random.randint(len(text_list)) |
|
text = text_list[tid] |
|
text_word_feat = text_feat_data['text_word_feats'][text_idx][tid] |
|
text_seq_feat = text_feat_data['text_seq_feats'][text_idx][tid] |
|
assert text_word_feat.shape[0] == 77 |
|
assert text_word_feat.shape[1] == 1024 |
|
assert text_seq_feat.shape[0] == 1024 |
|
|
|
if self.test_mode: |
|
motion_path = os.path.join(self.eval_motion_dir, name + ".npy") |
|
motion_data = np.load(motion_path) |
|
assert not np.isnan(motion_data).any() |
|
downrate = int(results['meta_data']['framerate'] / self.target_framerate + 0.1) |
|
start_frame = text_data['start_frame'] // downrate |
|
end_frame = text_data['end_frame'] // downrate |
|
motion_data = motion_data[start_frame: end_frame] |
|
results['meta_data']['framerate'] = self.target_framerate |
|
results['meta_data']['rotation_type'] = self.test_rotation_type |
|
assert motion_data.shape[0] > 0 |
|
if 'body_tokens' in text_data: |
|
token = text_data['body_tokens'][tid] |
|
else: |
|
token = "" |
|
text_cond = 1 |
|
results.update({ |
|
'motion': motion_data, |
|
'text_word_feat': text_word_feat, |
|
'text_seq_feat': text_seq_feat, |
|
'text_cond': text_cond, |
|
'text': text, |
|
'token': token |
|
}) |
|
else: |
|
results['meta_data']['start_frame'] = text_data['start_frame'] |
|
results['meta_data']['end_frame'] = text_data['end_frame'] |
|
text_cond = 1 |
|
results.update({ |
|
'text_word_feat': text_word_feat, |
|
'text_seq_feat': text_seq_feat, |
|
'text_cond': text_cond |
|
}) |
|
return results |
|
|
|
def prepare_data_v2m(self, idx: int) -> dict: |
|
results = self.prepare_data_base(idx) |
|
name = self.name_list[idx] |
|
video_feat_path = os.path.join(self.video_feat_dir, name + ".pkl") |
|
video_feat_data = pkl.load(open(video_feat_path, "rb")) |
|
video_word_feat = video_feat_data['video_word_feats'] |
|
video_seq_feat = video_feat_data['video_seq_feats'] |
|
assert video_word_feat.shape[0] == results['meta_data']['num_frames'] |
|
assert video_word_feat.shape[1] == 1024 |
|
assert video_seq_feat.shape[0] == 1024 |
|
video_cond = 1 |
|
if self.test_mode: |
|
results['meta_data']['start_frame'] = self.start_frame[idx] |
|
results['meta_data']['end_frame'] = self.end_frame[idx] |
|
motion_path = os.path.join(self.eval_motion_dir, name + ".npy") |
|
motion_data = np.load(motion_path) |
|
assert not np.isnan(motion_data).any() |
|
|
|
start_frame = self.start_frame[idx] |
|
end_frame = self.end_frame[idx] |
|
motion_data = motion_data[start_frame: end_frame] |
|
video_word_feat = video_word_feat[start_frame: end_frame] |
|
results['meta_data']['framerate'] = self.target_framerate |
|
results['meta_data']['rotation_type'] = self.test_rotation_type |
|
assert motion_data.shape[0] > 0 |
|
results.update({ |
|
'motion': motion_data, |
|
'video_word_feat': video_word_feat, |
|
'video_seq_feat': video_seq_feat, |
|
'video_cond': video_cond |
|
}) |
|
else: |
|
results.update({ |
|
'video_word_feat': video_word_feat, |
|
'video_seq_feat': video_seq_feat, |
|
'video_cond': video_cond |
|
}) |
|
return results |
|
|
|
def prepare_data_s2g(self, idx: int) -> dict: |
|
results = self.prepare_data_base(idx) |
|
name = self.name_list[idx] |
|
speech_idx = self.speech_idx[idx] |
|
speech_path = os.path.join(self.speech_dir, name + ".json") |
|
speech_data = json.load(open(speech_path))[speech_idx] |
|
speech_feat_path = os.path.join(self.speech_feat_dir, name + ".pkl") |
|
speech_feat_data = pkl.load(open(speech_feat_path, "rb")) |
|
try: |
|
speech_word_feat = speech_feat_data['audio_word_feats'][speech_idx] |
|
speech_seq_feat = speech_feat_data['audio_seq_feats'][speech_idx] |
|
except: |
|
speech_word_feat = speech_feat_data['speech_word_feats'][speech_idx] |
|
speech_seq_feat = speech_feat_data['speech_seq_feats'][speech_idx] |
|
del speech_feat_data |
|
assert speech_word_feat.shape[0] == 229 |
|
assert speech_word_feat.shape[1] == 768 |
|
assert speech_seq_feat.shape[0] == 1024 |
|
|
|
results['meta_data']['start_frame'] = speech_data['start_frame'] |
|
results['meta_data']['end_frame'] = speech_data['end_frame'] |
|
speech_cond = 1 |
|
results.update({ |
|
'speech_word_feat': speech_word_feat, |
|
'speech_seq_feat': speech_seq_feat, |
|
'speech_cond': speech_cond |
|
}) |
|
if self.test_mode: |
|
results['meta_data']['framerate'] = self.target_framerate |
|
results['meta_data']['rotation_type'] = self.test_rotation_type |
|
eval_data_path = os.path.join(self.eval_motion_dir, name + ".npz") |
|
eval_data = np.load(eval_data_path) |
|
motion_data = eval_data["bvh_rot_beat141"] |
|
sem_data = eval_data["sem"] |
|
wav_data = eval_data["wave16k"] |
|
assert not np.isnan(motion_data).any() |
|
|
|
start_frame = results['meta_data']['start_frame'] |
|
end_frame = results['meta_data']['end_frame'] |
|
wav_start_frame = start_frame / results['meta_data']['framerate'] * 16000 |
|
wav_end_frame = end_frame / results['meta_data']['framerate'] * 16000 |
|
motion_data = motion_data[start_frame: end_frame] |
|
sem_data = sem_data[start_frame: end_frame] |
|
wav_data = wav_data[wav_start_frame: wav_end_frame] |
|
assert motion_data.shape[0] > 0 |
|
results.update({ |
|
'motion': motion_data, |
|
'sem_score': sem_data, |
|
'wav_feat': wav_data |
|
}) |
|
return results |
|
|
|
def prepare_data_m2d(self, idx: int) -> dict: |
|
results = self.prepare_data_base(idx) |
|
name = self.name_list[idx] |
|
music_idx = self.music_idx[idx] |
|
music_path = os.path.join(self.music_dir, name + ".json") |
|
music_data = json.load(open(music_path))[music_idx] |
|
music_feat_path = os.path.join(self.music_feat_dir, name + ".pkl") |
|
music_feat_data = pkl.load(open(music_feat_path, "rb")) |
|
music_word_feat = music_feat_data['audio_word_feats'][music_idx] |
|
music_seq_feat = music_feat_data['audio_seq_feats'][music_idx] |
|
assert music_word_feat.shape[0] == 229 |
|
assert music_word_feat.shape[1] == 768 |
|
assert music_seq_feat.shape[0] == 1024 |
|
|
|
results['meta_data']['start_frame'] = music_data['start_frame'] |
|
results['meta_data']['end_frame'] = music_data['end_frame'] |
|
music_cond = 1 |
|
results.update({ |
|
'music_word_feat': music_word_feat, |
|
'music_seq_feat': music_seq_feat, |
|
'music_cond': music_cond |
|
}) |
|
return results |
|
|
|
def prepare_evaluation(self): |
|
""" |
|
Prepare the dataset for evaluation by initializing evaluators and creating evaluation indexes. |
|
""" |
|
self.evaluators = [] |
|
self.eval_indexes = [] |
|
self.evaluator_model = build_submodule(self.eval_cfg.get('evaluator_model', None)) |
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
|
if self.evaluator_model is not None: |
|
self.evaluator_model = self.evaluator_model.to(device) |
|
self.evaluator_model.eval() |
|
self.eval_cfg['evaluator_model'] = self.evaluator_model |
|
|
|
for _ in range(self.eval_cfg['replication_times']): |
|
eval_indexes = np.arange(len(self.name_list)) |
|
if self.eval_cfg.get('shuffle_indexes', False): |
|
np.random.shuffle(eval_indexes) |
|
self.eval_indexes.append(eval_indexes) |
|
|
|
for metric in self.eval_cfg['metrics']: |
|
evaluator, self.eval_indexes = build_evaluator( |
|
metric, self.eval_cfg, len(self.name_list), self.eval_indexes) |
|
self.evaluators.append(evaluator) |
|
|
|
self.eval_indexes = np.concatenate(self.eval_indexes) |
|
|
|
def process_outputs(self, results): |
|
return results |
|
|
|
def evaluate(self, results: List[dict], work_dir: str, logger=None) -> dict: |
|
""" |
|
Evaluate the model performance based on the results. |
|
|
|
Args: |
|
results (list): A list of result dictionaries. |
|
work_dir (str): Directory where evaluation logs will be stored. |
|
logger: Logger object to record evaluation results (optional). |
|
|
|
Returns: |
|
dict: Dictionary containing evaluation metrics. |
|
""" |
|
metrics = {} |
|
results = self.process_outputs(results) |
|
for evaluator in self.evaluators: |
|
metrics.update(evaluator.evaluate(results)) |
|
if logger is not None: |
|
logger.info(metrics) |
|
eval_output = os.path.join(work_dir, 'eval_results.log') |
|
with open(eval_output, 'w') as f: |
|
for k, v in metrics.items(): |
|
f.write(k + ': ' + str(v) + '\n') |
|
return metrics |
|
|
|
|
|
def create_single_dataset(cfg: dict): |
|
dataset_path = cfg['dataset_path'] |
|
if dataset_path == 'amass': |
|
return MotionVerseAMASS(**cfg) |
|
elif dataset_path == 'humanml3d': |
|
return MotionVerseH3D(**cfg) |
|
elif dataset_path == 'kitml': |
|
return MotionVerseKIT(**cfg) |
|
elif dataset_path == 'babel': |
|
return MotionVerseBABEL(**cfg) |
|
elif dataset_path == 'motionx': |
|
return MotionVerseMotionX(**cfg) |
|
elif dataset_path == 'humanact12': |
|
return MotionVerseACT12(**cfg) |
|
elif dataset_path == 'uestc': |
|
return MotionVerseUESTC(**cfg) |
|
elif dataset_path == 'ntu': |
|
return MotionVerseNTU(**cfg) |
|
elif dataset_path == 'h36m': |
|
return MotionVerseH36M(**cfg) |
|
elif dataset_path == 'mpi': |
|
return MotionVerseMPI(**cfg) |
|
elif dataset_path == 'pw3d': |
|
return MotionVersePW3D(**cfg) |
|
elif dataset_path == 'aist': |
|
return MotionVerseAIST(**cfg) |
|
elif dataset_path == 'beat': |
|
return MotionVerseBEAT(**cfg) |
|
elif dataset_path == 'tedg': |
|
return MotionVerseTEDG(**cfg) |
|
elif dataset_path == 'tedex': |
|
return MotionVerseTEDEx(**cfg) |
|
elif dataset_path == 's2g3d': |
|
return MotionVerseS2G3D(**cfg) |
|
else: |
|
raise NotImplementedError() |
|
|
|
|
|
@DATASETS.register_module() |
|
class MotionVerse(Dataset): |
|
""" |
|
A dataset class that handles multiple MotionBench datasets. |
|
|
|
Args: |
|
dataset_cfgs (list[str]): List of dataset configurations. |
|
partitions (list[float]): List of partition weights corresponding to the datasets. |
|
num_data (Optional[int]): Number of data samples to load. Defaults to None. |
|
data_prefix (str): Path to the directory containing the dataset. |
|
""" |
|
|
|
def __init__(self, |
|
dataset_cfgs: List[dict], |
|
partitions: List[int], |
|
num_data: Optional[int] = None, |
|
data_prefix: Optional[str] = None): |
|
"""Load data from multiple datasets.""" |
|
assert min(partitions) >= 0 |
|
assert len(dataset_cfgs) == len(partitions) |
|
datasets = [] |
|
new_partitions = [] |
|
for idx, cfg in enumerate(dataset_cfgs): |
|
if partitions[idx] == 0: |
|
continue |
|
new_partitions.append(partitions[idx]) |
|
cfg.update({ |
|
'data_prefix': data_prefix |
|
}) |
|
datasets.append(create_single_dataset(cfg)) |
|
self.dataset = ConcatDataset(datasets) |
|
if num_data is not None: |
|
self.length = num_data |
|
else: |
|
self.length = max(len(ds) for ds in datasets) |
|
partitions = new_partitions |
|
weights = [np.ones(len(ds)) * p / len(ds) for (p, ds) in zip(partitions, datasets)] |
|
weights = np.concatenate(weights, axis=0) |
|
self.weights = weights |
|
self.task_proj = { |
|
'mocap': 0, |
|
't2m': 1, |
|
'v2m': 2, |
|
's2g': 3, |
|
'm2d': 4 |
|
} |
|
self.task_idx_list = [] |
|
for ds in datasets: |
|
self.task_idx_list += [self.task_proj[ds.task_name]] * len(ds) |
|
|
|
def __len__(self) -> int: |
|
"""Get the size of the dataset.""" |
|
return self.length |
|
|
|
def __getitem__(self, idx: int) -> dict: |
|
"""Given an index, sample data from multiple datasets with the specified proportion.""" |
|
return self.dataset[idx] |
|
|
|
def get_task_idx(self, idx: int) -> int: |
|
return self.task_idx_list[idx] |
|
|
|
|
|
@DATASETS.register_module() |
|
class MotionVerseEval(Dataset): |
|
|
|
def __init__(self, |
|
eval_cfgs: dict, |
|
testset: str, |
|
test_mode: bool = True): |
|
"""Load data from multiple datasets.""" |
|
assert testset in eval_cfgs |
|
dataset_path, task_name = testset.split('_') |
|
dataset_cfg = eval_cfgs[testset] |
|
dataset_cfg['dataset_path'] = dataset_path |
|
dataset_cfg['task_name'] = task_name |
|
dataset_cfg['test_mode'] = test_mode |
|
self.dataset = create_single_dataset(dataset_cfg) |
|
|
|
def __len__(self) -> int: |
|
return len(self.dataset) |
|
|
|
def __getitem__(self, idx: int) -> dict: |
|
return self.dataset[idx] |
|
|
|
def load_annotation(self): |
|
self.dataset.load_annotation() |
|
|
|
def prepare_data(self, idx: int) -> dict: |
|
return self.dataset.prepare_data(idx) |
|
|
|
def prepare_evaluation(self): |
|
self.dataset.prepare_evaluation() |
|
|
|
def process_outputs(self, results): |
|
return self.dataset.process_outputs(results) |
|
|
|
def evaluate(self, results: List[dict], work_dir: str, logger=None) -> dict: |
|
return self.dataset.evaluate(results=results, work_dir=work_dir, logger=logger) |
|
|
|
|
|
@DATASETS.register_module() |
|
class MotionVerseAMASS(SingleMotionVerseDataset): |
|
|
|
def __init__(self, **kwargs): |
|
if 'dataset_path' not in kwargs: |
|
kwargs['dataset_path'] = 'amass' |
|
task_name = kwargs['task_name'] |
|
assert task_name in ['mocap'] |
|
super().__init__(**kwargs) |
|
|
|
|
|
@DATASETS.register_module() |
|
class MotionVerseH3D(SingleMotionVerseDataset): |
|
|
|
def __init__(self, **kwargs): |
|
if 'dataset_path' not in kwargs: |
|
kwargs['dataset_path'] = 'humanml3d' |
|
task_name = kwargs['task_name'] |
|
assert task_name in ['mocap', 't2m'] |
|
super().__init__(**kwargs) |
|
|
|
|
|
@DATASETS.register_module() |
|
class MotionVerseKIT(SingleMotionVerseDataset): |
|
|
|
def __init__(self, **kwargs): |
|
if 'dataset_path' not in kwargs: |
|
kwargs['dataset_path'] = 'kitml' |
|
task_name = kwargs['task_name'] |
|
assert task_name in ['mocap', 't2m'] |
|
super().__init__(**kwargs) |
|
|
|
|
|
@DATASETS.register_module() |
|
class MotionVerseBABEL(SingleMotionVerseDataset): |
|
|
|
def __init__(self, **kwargs): |
|
if 'dataset_path' not in kwargs: |
|
kwargs['dataset_path'] = 'babel' |
|
task_name = kwargs['task_name'] |
|
assert task_name in ['mocap', 't2m'] |
|
super().__init__(**kwargs) |
|
|
|
|
|
@DATASETS.register_module() |
|
class MotionVerseMotionX(SingleMotionVerseDataset): |
|
|
|
def __init__(self, **kwargs): |
|
if 'dataset_path' not in kwargs: |
|
kwargs['dataset_path'] = 'motionx' |
|
task_name = kwargs['task_name'] |
|
assert task_name in ['mocap', 't2m'] |
|
super().__init__(**kwargs) |
|
|
|
|
|
@DATASETS.register_module() |
|
class MotionVerseACT12(SingleMotionVerseDataset): |
|
|
|
def __init__(self, **kwargs): |
|
if 'dataset_path' not in kwargs: |
|
kwargs['dataset_path'] = 'humanact12' |
|
task_name = kwargs['task_name'] |
|
assert task_name in ['mocap', 't2m'] |
|
super().__init__(**kwargs) |
|
|
|
|
|
@DATASETS.register_module() |
|
class MotionVerseUESTC(SingleMotionVerseDataset): |
|
|
|
def __init__(self, **kwargs): |
|
if 'dataset_path' not in kwargs: |
|
kwargs['dataset_path'] = 'uestc' |
|
task_name = kwargs['task_name'] |
|
assert task_name in ['mocap', 't2m'] |
|
super().__init__(**kwargs) |
|
|
|
|
|
@DATASETS.register_module() |
|
class MotionVerseNTU(SingleMotionVerseDataset): |
|
|
|
def __init__(self, **kwargs): |
|
if 'dataset_path' not in kwargs: |
|
kwargs['dataset_path'] = 'ntu' |
|
task_name = kwargs['task_name'] |
|
assert task_name in ['mocap', 't2m'] |
|
super().__init__(**kwargs) |
|
|
|
|
|
@DATASETS.register_module() |
|
class MotionVerseH36M(SingleMotionVerseDataset): |
|
|
|
def __init__(self, **kwargs): |
|
if 'dataset_path' not in kwargs: |
|
kwargs['dataset_path'] = 'h36m' |
|
task_name = kwargs['task_name'] |
|
assert task_name in ['mocap', 'v2m'] |
|
super().__init__(**kwargs) |
|
|
|
|
|
@DATASETS.register_module() |
|
class MotionVerseMPI(SingleMotionVerseDataset): |
|
|
|
def __init__(self, **kwargs): |
|
if 'dataset_path' not in kwargs: |
|
kwargs['dataset_path'] = 'mpi' |
|
task_name = kwargs['task_name'] |
|
assert task_name in ['mocap', 'v2m'] |
|
super().__init__(**kwargs) |
|
|
|
|
|
@DATASETS.register_module() |
|
class MotionVersePW3D(SingleMotionVerseDataset): |
|
|
|
def __init__(self, **kwargs): |
|
if 'dataset_path' not in kwargs: |
|
kwargs['dataset_path'] = '3dpw' |
|
task_name = kwargs['task_name'] |
|
assert task_name in ['mocap', 'v2m'] |
|
super().__init__(**kwargs) |
|
|
|
|
|
@DATASETS.register_module() |
|
class MotionVerseAIST(SingleMotionVerseDataset): |
|
|
|
def __init__(self, **kwargs): |
|
if 'dataset_path' not in kwargs: |
|
kwargs['dataset_path'] = 'aist' |
|
task_name = kwargs['task_name'] |
|
assert task_name in ['mocap', 'm2d'] |
|
super().__init__(**kwargs) |
|
|
|
|
|
@DATASETS.register_module() |
|
class MotionVerseBEAT(SingleMotionVerseDataset): |
|
|
|
def __init__(self, **kwargs): |
|
if 'dataset_path' not in kwargs: |
|
kwargs['dataset_path'] = 'beat' |
|
task_name = kwargs['task_name'] |
|
assert task_name in ['mocap', 's2g'] |
|
super().__init__(**kwargs) |
|
|
|
|
|
@DATASETS.register_module() |
|
class MotionVerseTEDG(SingleMotionVerseDataset): |
|
|
|
def __init__(self, **kwargs): |
|
if 'dataset_path' not in kwargs: |
|
kwargs['dataset_path'] = 'tedg' |
|
task_name = kwargs['task_name'] |
|
assert task_name in ['mocap', 's2g'] |
|
super().__init__(**kwargs) |
|
|
|
|
|
@DATASETS.register_module() |
|
class MotionVerseTEDEx(SingleMotionVerseDataset): |
|
|
|
def __init__(self, **kwargs): |
|
if 'dataset_path' not in kwargs: |
|
kwargs['dataset_path'] = 'tedex' |
|
task_name = kwargs['task_name'] |
|
assert task_name in ['mocap', 's2g'] |
|
super().__init__(**kwargs) |
|
|
|
|
|
@DATASETS.register_module() |
|
class MotionVerseS2G3D(SingleMotionVerseDataset): |
|
|
|
def __init__(self, **kwargs): |
|
if 'dataset_path' not in kwargs: |
|
kwargs['dataset_path'] = 's2g3d' |
|
task_name = kwargs['task_name'] |
|
assert task_name in ['mocap', 's2g'] |
|
super().__init__(**kwargs) |
|
|