# Copyright (c) 2020 Huawei Technologies Co., Ltd. # Licensed under CC BY-NC-SA 4.0 (Attribution-NonCommercial-ShareAlike 4.0 International) (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode # # The code is released for academic research use only. For commercial use, please contact Huawei Technologies Co., Ltd. # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # # This file contains content licensed by https://github.com/xinntao/BasicSR/blob/master/LICENSE/LICENSE '''create dataset and dataloader''' import logging import torch import torch.utils.data def create_dataloader(dataset, dataset_opt, opt=None, sampler=None): phase = dataset_opt.get('phase', 'test') if phase == 'train': gpu_ids = opt.get('gpu_ids', None) gpu_ids = gpu_ids if gpu_ids else [] num_workers = dataset_opt['n_workers'] * len(gpu_ids) batch_size = dataset_opt['batch_size'] shuffle = True return torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers, sampler=sampler, drop_last=True, pin_memory=False) else: return torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, num_workers=1, pin_memory=True) def create_dataset(dataset_opt): print(dataset_opt) mode = dataset_opt['mode'] if mode == 'LRHR_PKL': from data.LRHR_PKL_dataset import LRHR_PKLDataset as D else: raise NotImplementedError('Dataset [{:s}] is not recognized.'.format(mode)) dataset = D(dataset_opt) logger = logging.getLogger('base') logger.info('Dataset [{:s} - {:s}] is created.'.format(dataset.__class__.__name__, dataset_opt['name'])) return dataset