File size: 6,921 Bytes
373af33
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
import copy
import os
import json
from abc import abstractmethod
from typing import Optional, Union, List, Dict

import numpy as np
import torch
from torch.utils.data import Dataset

# from mogen.core.evaluation import build_evaluator
from mogen.models.builder import build_submodule
from .builder import DATASETS
from .pipelines import Compose


@DATASETS.register_module()
class BaseMotionDataset(Dataset):
    """
    Base class for motion datasets.

    Args:
        data_prefix (str): The prefix of the data path.
        pipeline (list): A list of dicts, where each element represents an operation 
                         defined in `mogen.datasets.pipelines`.
        dataset_name (Optional[Union[str, None]]): The name of the dataset. Used to 
                         identify the type of evaluation metric.
        fixed_length (Optional[Union[int, None]]): The fixed length of the dataset for 
                         iteration. If None, the dataset length is based on the number 
                         of annotations.
        ann_file (Optional[Union[str, None]]): The annotation file. If it is a string, 
                         it is expected to be read from the file. If None, it will be 
                         read from `data_prefix`.
        motion_dir (Optional[Union[str, None]]): The directory containing motion data.
        eval_cfg (Optional[Union[dict, None]]): Configuration for evaluation metrics.
        test_mode (Optional[bool]): Whether the dataset is in test mode. Default is False.

    Attributes:
        data_infos (list): Loaded dataset annotations.
        evaluators (list): List of evaluation objects.
        eval_indexes (np.ndarray): Array of indices used for evaluation.
        evaluator_model (torch.nn.Module): Model used for evaluation.
        pipeline (Compose): Data processing pipeline.
    """
    
    def __init__(self,
                 data_prefix: str,
                 pipeline: List[Dict],
                 dataset_name: Optional[Union[str, None]] = None,
                 fixed_length: Optional[Union[int, None]] = None,
                 ann_file: Optional[Union[str, None]] = None,
                 motion_dir: Optional[Union[str, None]] = None,
                 eval_cfg: Optional[Union[dict, None]] = None,
                 test_mode: Optional[bool] = False):
        super(BaseMotionDataset, self).__init__()

        self.data_prefix = data_prefix
        self.pipeline = Compose(pipeline)
        self.dataset_name = dataset_name
        self.fixed_length = fixed_length
        self.ann_file = os.path.join(data_prefix, 'datasets', dataset_name, ann_file)
        self.motion_dir = os.path.join(data_prefix, 'datasets', dataset_name, motion_dir)
        self.eval_cfg = copy.deepcopy(eval_cfg)
        self.test_mode = test_mode

        self.load_annotations()
        if self.test_mode:
            self.prepare_evaluation()

    @abstractmethod
    def load_anno(self, name: str) -> dict:
        """
        Abstract method to load a single annotation.

        Args:
            name (str): Name or identifier of the annotation to load.

        Returns:
            dict: Loaded annotation as a dictionary.
        """
        pass

    def load_annotations(self):
        """Load annotations from `ann_file` to `data_infos`."""
        self.data_infos = []
        idx = 0
        for line in open(self.ann_file, 'r').readlines():
            line = line.strip()
            self.data_infos.append(self.load_anno(idx, line))
            idx += 1

    def prepare_data(self, idx: int) -> dict:
        """
        Prepare raw data for the given index.

        Args:
            idx (int): Index of the data to prepare.

        Returns:
            dict: Processed data for the given index.
        """
        results = copy.deepcopy(self.data_infos[idx])
        results['dataset_name'] = self.dataset_name
        results['sample_idx'] = idx
        return self.pipeline(results)

    def __len__(self) -> int:
        """Return the length of the current dataset.

        Returns:
            int: Length of the dataset.
        """
        if self.test_mode:
            return len(self.eval_indexes)
        elif self.fixed_length is not None:
            return self.fixed_length
        return len(self.data_infos)

    def __getitem__(self, idx: int) -> dict:
        """
        Prepare data for the given index.

        Args:
            idx (int): Index of the data.

        Returns:
            dict: Data for the specified index.
        """
        if self.test_mode:
            idx = self.eval_indexes[idx]
        elif self.fixed_length is not None:
            idx = idx % len(self.data_infos)
        elif self.balanced_sampling:
            cid = np.random.randint(0, len(self.category_list))
            idx = np.random.randint(0, len(self.category_list[cid]))
            idx = self.category_list[cid][idx]
        return self.prepare_data(idx)

    def prepare_evaluation(self):
        """Prepare evaluation settings, including evaluators and evaluation indices."""
        self.evaluators = []
        self.eval_indexes = []
        self.evaluator_model = build_submodule(self.eval_cfg.get('evaluator_model', None))
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.evaluator_model = self.evaluator_model.to(device)
        self.evaluator_model.eval()
        self.eval_cfg['evaluator_model'] = self.evaluator_model

        for _ in range(self.eval_cfg['replication_times']):
            eval_indexes = np.arange(len(self.data_infos))
            if self.eval_cfg.get('shuffle_indexes', False):
                np.random.shuffle(eval_indexes)
            self.eval_indexes.append(eval_indexes)

        for metric in self.eval_cfg['metrics']:
            evaluator, self.eval_indexes = build_evaluator(
                metric, self.eval_cfg, len(self.data_infos), self.eval_indexes)
            self.evaluators.append(evaluator)

        self.eval_indexes = np.concatenate(self.eval_indexes)

    def evaluate(self, results: List[dict], work_dir: str, logger=None) -> dict:
        """
        Evaluate the model performance based on the results.

        Args:
            results (list): A list of result dictionaries.
            work_dir (str): Directory where evaluation logs will be stored.
            logger: Logger object to record evaluation results (optional).

        Returns:
            dict: Dictionary containing evaluation metrics.
        """
        metrics = {}
        for evaluator in self.evaluators:
            metrics.update(evaluator.evaluate(results))
        if logger is not None:
            logger.info(metrics)
        eval_output = os.path.join(work_dir, 'eval_results.log')
        with open(eval_output, 'w') as f:
            for k, v in metrics.items():
                f.write(k + ': ' + str(v) + '\n')
        return metrics