File size: 2,540 Bytes
d4b77ac
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
import logging
import torch
import torch.utils.data
from importlib import import_module


def create_dataloader(phase, dataset, dataset_opt, opt=None, sampler=None):
    logger = logging.getLogger('base')
    if phase == 'train':
        num_workers = dataset_opt['n_workers'] * opt['world_size']
        batch_size = dataset_opt['batch_size']
        if sampler is not None:
            logger.info('N_workers: {}, batch_size: {} DDP train dataloader has been established'.format(num_workers,
                                                                                                         batch_size))
            return torch.utils.data.DataLoader(dataset, batch_size=batch_size,
                                               num_workers=num_workers, sampler=sampler,
                                               pin_memory=True)
        else:
            logger.info('N_workers: {}, batch_size: {} train dataloader has been established'.format(num_workers,
                                                                                                     batch_size))
            return torch.utils.data.DataLoader(dataset, batch_size=batch_size,
                                               num_workers=num_workers, shuffle=True,
                                               pin_memory=True)

    else:
        logger.info(
            'N_workers: {}, batch_size: {} validate/test dataloader has been established'.format(
                dataset_opt['n_workers'],
                dataset_opt['batch_size']))
        return torch.utils.data.DataLoader(dataset, batch_size=dataset_opt['batch_size'], shuffle=False,
                                           num_workers=dataset_opt['n_workers'],
                                           pin_memory=False)


def create_dataset(dataset_opt, dataInfo, phase, dataset_name):
    if phase == 'train':
        dataset_package = import_module('data.{}'.format(dataset_name))
        dataset = dataset_package.VideoBasedDataset(dataset_opt, dataInfo)

        mode = dataset_opt['mode']
        logger = logging.getLogger('base')
        logger.info(
            '{} train dataset [{:s} - {:s} - {:s}] is created.'.format(dataset_opt['type'].upper(),
                                                                       dataset.__class__.__name__,
                                                                       dataset_opt['name'], mode))
    else:  # validate and test dataset
        return ValueError('No dataset initialized for valdataset')

    return dataset