Spaces:
Running
Running
# Copyright (c) 2020 Huawei Technologies Co., Ltd. | |
# Licensed under CC BY-NC-SA 4.0 (Attribution-NonCommercial-ShareAlike 4.0 International) (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode | |
# | |
# The code is released for academic research use only. For commercial use, please contact Huawei Technologies Co., Ltd. | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
# | |
# This file contains content licensed by https://github.com/xinntao/BasicSR/blob/master/LICENSE/LICENSE | |
'''create dataset and dataloader''' | |
import logging | |
import torch | |
import torch.utils.data | |
def create_dataloader(dataset, dataset_opt, opt=None, sampler=None): | |
phase = dataset_opt.get('phase', 'test') | |
if phase == 'train': | |
gpu_ids = opt.get('gpu_ids', None) | |
gpu_ids = gpu_ids if gpu_ids else [] | |
num_workers = dataset_opt['n_workers'] * len(gpu_ids) | |
batch_size = dataset_opt['batch_size'] | |
shuffle = True | |
return torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, | |
num_workers=num_workers, sampler=sampler, drop_last=True, | |
pin_memory=False) | |
else: | |
return torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, num_workers=1, | |
pin_memory=True) | |
def create_dataset(dataset_opt): | |
print(dataset_opt) | |
mode = dataset_opt['mode'] | |
if mode == 'LRHR_PKL': | |
from data.LRHR_PKL_dataset import LRHR_PKLDataset as D | |
else: | |
raise NotImplementedError('Dataset [{:s}] is not recognized.'.format(mode)) | |
dataset = D(dataset_opt) | |
logger = logging.getLogger('base') | |
logger.info('Dataset [{:s} - {:s}] is created.'.format(dataset.__class__.__name__, | |
dataset_opt['name'])) | |
return dataset | |