LMM / mogen /datasets /text_motion_dataset.py
mingyuan's picture
initial commit
373af33
raw
history blame
7.4 kB
import copy
import os
import os.path
from typing import Optional, Union, List, Dict
import numpy as np
import torch
import json
from .base_dataset import BaseMotionDataset
from .builder import DATASETS
@DATASETS.register_module()
class TextMotionDataset(BaseMotionDataset):
"""
TextMotion dataset for handling motion data paired with text descriptions.
Args:
data_prefix (str): Path to the base directory containing the dataset.
pipeline (list): List of data transformations to apply.
dataset_name (Optional[str]): Name of the dataset.
fixed_length (Optional[int]): Fixed length of data samples (if applicable).
ann_file (Optional[str]): Path to the annotation file.
motion_dir (Optional[str]): Path to the directory containing motion data.
text_dir (Optional[str]): Path to the directory containing text data.
token_dir (Optional[str]): Path to the directory containing token data.
clip_feat_dir (Optional[str]): Path to the directory containing clip feature data.
meta_dir (Optional[str]): Path to the directory containing metadata.
eval_cfg (Optional[dict]): Configuration for evaluation metrics.
test_mode (Optional[bool]): Whether the dataset is in test mode. Defaults to False.
siamese_mode (Optional[bool]): Whether to use Siamese mode (motion1 vs. motion2 comparison). Defaults to False.
tcomb_mode (Optional[bool]): Mode for specific processing (tcomb). Defaults to False.
fine_mode (Optional[bool]): Whether to use fine-grained text processing. Defaults to False.
balanced_sampling (Optional[int]): Number of categories for balanced sampling. If not None, enables balanced sampling.
"""
def __init__(self,
data_prefix: str,
pipeline: List[Dict],
dataset_name: Optional[Union[str, None]] = None,
fixed_length: Optional[Union[int, None]] = None,
ann_file: Optional[Union[str, None]] = None,
motion_dir: Optional[Union[str, None]] = None,
text_dir: Optional[Union[str, None]] = None,
token_dir: Optional[Union[str, None]] = None,
clip_feat_dir: Optional[Union[str, None]] = None,
meta_dir: Optional[Union[str, None]] = None,
eval_cfg: Optional[Union[dict, None]] = None,
test_mode: Optional[bool] = False,
siamese_mode: Optional[bool] = False,
tcomb_mode: Optional[bool] = False,
fine_mode: Optional[bool] = False,
balanced_sampling: Optional[Union[int, None]] = None):
self.text_dir = os.path.join(data_prefix, 'datasets', dataset_name, text_dir)
self.token_dir = os.path.join(data_prefix, 'datasets', dataset_name, token_dir) if token_dir else None
self.clip_feat_dir = os.path.join(data_prefix, 'datasets', dataset_name, clip_feat_dir) if clip_feat_dir else None
self.meta_dir = os.path.join(data_prefix, 'datasets', dataset_name, meta_dir) if meta_dir else None
self.siamese_mode = siamese_mode
self.tcomb_mode = tcomb_mode
self.fine_mode = fine_mode
self.balanced_sampling = balanced_sampling is not None
if self.balanced_sampling:
self.category_list = [[] for _ in range(balanced_sampling)]
super(TextMotionDataset, self).__init__(
data_prefix=data_prefix,
pipeline=pipeline,
dataset_name=dataset_name,
fixed_length=fixed_length,
ann_file=ann_file,
motion_dir=motion_dir,
eval_cfg=eval_cfg,
test_mode=test_mode
)
def load_anno(self, idx: int, name: str) -> Dict:
"""
Load a single annotation based on the given index and name.
Args:
idx (int): Index of the data sample.
name (str): Name of the data sample (typically used as a file identifier).
Returns:
dict: A dictionary containing the loaded data and relevant information.
"""
results = {}
if self.siamese_mode:
motion_path = os.path.join(self.motion_dir, name + '.npz')
motion_data = np.load(motion_path)
results['motion1'] = motion_data['motion1']
results['motion2'] = motion_data['motion2']
assert results['motion1'].shape == results['motion2'].shape
else:
motion_path = os.path.join(self.motion_dir, name + '.npy')
motion_data = np.load(motion_path)
results['motion'] = motion_data
if self.fine_mode:
text_path = os.path.join(self.text_dir, name + '.json')
text_data = json.load(open(text_path))
for entry in text_data:
entry.pop('start_frame', None)
entry.pop('end_frame', None)
entry.pop('num_frames', None)
results['text'] = text_data
else:
text_path = os.path.join(self.text_dir, name + '.txt')
results['text'] = [line.strip() for line in open(text_path, 'r')]
if self.token_dir:
token_path = os.path.join(self.token_dir, name + '.txt')
results['token'] = [line.strip() for line in open(token_path, 'r')]
if self.clip_feat_dir:
clip_feat_path = os.path.join(self.clip_feat_dir, name + '.npy')
results['clip_feat_path'] = clip_feat_path
# if self.fine_mode:
# results['clip_feat_path'] = clip_feat_path
# else:
# clip_feat = torch.from_numpy(np.load(clip_feat_path))
# if len(clip_feat.shape) == 2:
# clip_feat = clip_feat.unsqueeze(0)
# results['clip_feat'] = clip_feat
if self.meta_dir:
score_path = os.path.join(self.meta_dir, name + '_score.npy')
results['score'] = torch.from_numpy(np.load(score_path))
if self.balanced_sampling:
assert self.meta_dir is not None
category_path = os.path.join(self.meta_dir, name + '.json')
category = json.load(open(category_path))['category']
self.category_list[category].append(idx)
return results
def prepare_data(self, idx: int) -> Dict:
"""
Prepare raw data for the given index.
Args:
idx (int): Index of the data sample.
Returns:
dict: Processed data after applying the pipeline.
"""
results = copy.deepcopy(self.data_infos[idx])
text_list = results['text']
selected_idx = np.random.randint(0, len(text_list))
results['text'] = text_list[selected_idx]
if 'clip_feat' in results:
results['clip_feat'] = results['clip_feat'][selected_idx]
if 'clip_feat_path' in results:
clip_feat = torch.from_numpy(np.load(results['clip_feat_path']))
if len(clip_feat.shape) == 2:
clip_feat = clip_feat.unsqueeze(0)
results['clip_feat'] = clip_feat[selected_idx]
if 'token' in results:
results['token'] = results['token'][selected_idx]
results['dataset_name'] = self.dataset_name
results = self.pipeline(results)
return results