File size: 7,396 Bytes
373af33
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
import copy
import os
import os.path
from typing import Optional, Union, List, Dict

import numpy as np
import torch
import json

from .base_dataset import BaseMotionDataset
from .builder import DATASETS


@DATASETS.register_module()
class TextMotionDataset(BaseMotionDataset):
    """
    TextMotion dataset for handling motion data paired with text descriptions.

    Args:
        data_prefix (str): Path to the base directory containing the dataset.
        pipeline (list): List of data transformations to apply.
        dataset_name (Optional[str]): Name of the dataset.
        fixed_length (Optional[int]): Fixed length of data samples (if applicable).
        ann_file (Optional[str]): Path to the annotation file.
        motion_dir (Optional[str]): Path to the directory containing motion data.
        text_dir (Optional[str]): Path to the directory containing text data.
        token_dir (Optional[str]): Path to the directory containing token data.
        clip_feat_dir (Optional[str]): Path to the directory containing clip feature data.
        meta_dir (Optional[str]): Path to the directory containing metadata.
        eval_cfg (Optional[dict]): Configuration for evaluation metrics.
        test_mode (Optional[bool]): Whether the dataset is in test mode. Defaults to False.
        siamese_mode (Optional[bool]): Whether to use Siamese mode (motion1 vs. motion2 comparison). Defaults to False.
        tcomb_mode (Optional[bool]): Mode for specific processing (tcomb). Defaults to False.
        fine_mode (Optional[bool]): Whether to use fine-grained text processing. Defaults to False.
        balanced_sampling (Optional[int]): Number of categories for balanced sampling. If not None, enables balanced sampling.
    """

    def __init__(self,
                 data_prefix: str,
                 pipeline: List[Dict],
                 dataset_name: Optional[Union[str, None]] = None,
                 fixed_length: Optional[Union[int, None]] = None,
                 ann_file: Optional[Union[str, None]] = None,
                 motion_dir: Optional[Union[str, None]] = None,
                 text_dir: Optional[Union[str, None]] = None,
                 token_dir: Optional[Union[str, None]] = None,
                 clip_feat_dir: Optional[Union[str, None]] = None,
                 meta_dir: Optional[Union[str, None]] = None,
                 eval_cfg: Optional[Union[dict, None]] = None,
                 test_mode: Optional[bool] = False,
                 siamese_mode: Optional[bool] = False,
                 tcomb_mode: Optional[bool] = False,
                 fine_mode: Optional[bool] = False,
                 balanced_sampling: Optional[Union[int, None]] = None):
        self.text_dir = os.path.join(data_prefix, 'datasets', dataset_name, text_dir)
        self.token_dir = os.path.join(data_prefix, 'datasets', dataset_name, token_dir) if token_dir else None
        self.clip_feat_dir = os.path.join(data_prefix, 'datasets', dataset_name, clip_feat_dir) if clip_feat_dir else None
        self.meta_dir = os.path.join(data_prefix, 'datasets', dataset_name, meta_dir) if meta_dir else None
        self.siamese_mode = siamese_mode
        self.tcomb_mode = tcomb_mode
        self.fine_mode = fine_mode
        self.balanced_sampling = balanced_sampling is not None

        if self.balanced_sampling:
            self.category_list = [[] for _ in range(balanced_sampling)]

        super(TextMotionDataset, self).__init__(
            data_prefix=data_prefix,
            pipeline=pipeline,
            dataset_name=dataset_name,
            fixed_length=fixed_length,
            ann_file=ann_file,
            motion_dir=motion_dir,
            eval_cfg=eval_cfg,
            test_mode=test_mode
        )

    def load_anno(self, idx: int, name: str) -> Dict:
        """
        Load a single annotation based on the given index and name.

        Args:
            idx (int): Index of the data sample.
            name (str): Name of the data sample (typically used as a file identifier).

        Returns:
            dict: A dictionary containing the loaded data and relevant information.
        """
        results = {}
        if self.siamese_mode:
            motion_path = os.path.join(self.motion_dir, name + '.npz')
            motion_data = np.load(motion_path)
            results['motion1'] = motion_data['motion1']
            results['motion2'] = motion_data['motion2']
            assert results['motion1'].shape == results['motion2'].shape
        else:
            motion_path = os.path.join(self.motion_dir, name + '.npy')
            motion_data = np.load(motion_path)
            results['motion'] = motion_data

        if self.fine_mode:
            text_path = os.path.join(self.text_dir, name + '.json')
            text_data = json.load(open(text_path))
            for entry in text_data:
                entry.pop('start_frame', None)
                entry.pop('end_frame', None)
                entry.pop('num_frames', None)
            results['text'] = text_data
        else:
            text_path = os.path.join(self.text_dir, name + '.txt')
            results['text'] = [line.strip() for line in open(text_path, 'r')]

        if self.token_dir:
            token_path = os.path.join(self.token_dir, name + '.txt')
            results['token'] = [line.strip() for line in open(token_path, 'r')]

        if self.clip_feat_dir:
            clip_feat_path = os.path.join(self.clip_feat_dir, name + '.npy')
            results['clip_feat_path'] = clip_feat_path
            # if self.fine_mode:
            #     results['clip_feat_path'] = clip_feat_path
            # else:
            #     clip_feat = torch.from_numpy(np.load(clip_feat_path))
            #     if len(clip_feat.shape) == 2:
            #         clip_feat = clip_feat.unsqueeze(0)
            #     results['clip_feat'] = clip_feat

        if self.meta_dir:
            score_path = os.path.join(self.meta_dir, name + '_score.npy')
            results['score'] = torch.from_numpy(np.load(score_path))

        if self.balanced_sampling:
            assert self.meta_dir is not None
            category_path = os.path.join(self.meta_dir, name + '.json')
            category = json.load(open(category_path))['category']
            self.category_list[category].append(idx)

        return results

    def prepare_data(self, idx: int) -> Dict:
        """
        Prepare raw data for the given index.

        Args:
            idx (int): Index of the data sample.

        Returns:
            dict: Processed data after applying the pipeline.
        """
        results = copy.deepcopy(self.data_infos[idx])
        text_list = results['text']
        selected_idx = np.random.randint(0, len(text_list))
        results['text'] = text_list[selected_idx]

        if 'clip_feat' in results:
            results['clip_feat'] = results['clip_feat'][selected_idx]

        if 'clip_feat_path' in results:
            clip_feat = torch.from_numpy(np.load(results['clip_feat_path']))
            if len(clip_feat.shape) == 2:
                clip_feat = clip_feat.unsqueeze(0)
            results['clip_feat'] = clip_feat[selected_idx]

        if 'token' in results:
            results['token'] = results['token'][selected_idx]

        results['dataset_name'] = self.dataset_name
        results = self.pipeline(results)
        return results