# Copyright (c) 2020 Huawei Technologies Co., Ltd. # Licensed under CC BY-NC-SA 4.0 (Attribution-NonCommercial-ShareAlike 4.0 International) (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode # # The code is released for academic research use only. For commercial use, please contact Huawei Technologies Co., Ltd. # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # # This file contains content licensed by https://github.com/xinntao/BasicSR/blob/master/LICENSE/LICENSE import os import subprocess import torch.utils.data as data import numpy as np import time import torch import pickle class LRHR_PKLDataset(data.Dataset): def __init__(self, opt): super(LRHR_PKLDataset, self).__init__() self.opt = opt self.crop_size = opt.get("GT_size", None) self.scale = None self.random_scale_list = [1] hr_file_path = opt["dataroot_GT"] lr_file_path = opt["dataroot_LQ"] y_labels_file_path = opt['dataroot_y_labels'] gpu = True augment = True self.use_flip = opt["use_flip"] if "use_flip" in opt.keys() else False self.use_rot = opt["use_rot"] if "use_rot" in opt.keys() else False self.use_crop = opt["use_crop"] if "use_crop" in opt.keys() else False self.center_crop_hr_size = opt.get("center_crop_hr_size", None) n_max = opt["n_max"] if "n_max" in opt.keys() else int(1e8) t = time.time() self.lr_images = self.load_pkls(lr_file_path, n_max) self.hr_images = self.load_pkls(hr_file_path, n_max) min_val_hr = np.min([i.min() for i in self.hr_images[:20]]) max_val_hr = np.max([i.max() for i in self.hr_images[:20]]) min_val_lr = np.min([i.min() for i in self.lr_images[:20]]) max_val_lr = np.max([i.max() for i in self.lr_images[:20]]) t = time.time() - t print("Loaded {} HR images with [{:.2f}, {:.2f}] in {:.2f}s from {}". format(len(self.hr_images), min_val_hr, max_val_hr, t, hr_file_path)) print("Loaded {} LR images with [{:.2f}, {:.2f}] in {:.2f}s from {}". format(len(self.lr_images), min_val_lr, max_val_lr, t, lr_file_path)) self.gpu = gpu self.augment = augment self.measures = None def load_pkls(self, path, n_max): assert os.path.isfile(path), path images = [] with open(path, "rb") as f: images += pickle.load(f) assert len(images) > 0, path images = images[:n_max] images = [np.transpose(image, [2, 0, 1]) for image in images] return images def __len__(self): return len(self.hr_images) def __getitem__(self, item): hr = self.hr_images[item] lr = self.lr_images[item] if self.scale == None: self.scale = hr.shape[1] // lr.shape[1] assert hr.shape[1] == self.scale * lr.shape[1], ('non-fractional ratio', lr.shape, hr.shape) if self.use_crop: hr, lr = random_crop(hr, lr, self.crop_size, self.scale, self.use_crop) if self.center_crop_hr_size: hr, lr = center_crop(hr, self.center_crop_hr_size), center_crop(lr, self.center_crop_hr_size // self.scale) if self.use_flip: hr, lr = random_flip(hr, lr) if self.use_rot: hr, lr = random_rotation(hr, lr) hr = hr / 255.0 lr = lr / 255.0 if self.measures is None or np.random.random() < 0.05: if self.measures is None: self.measures = {} self.measures['hr_means'] = np.mean(hr) self.measures['hr_stds'] = np.std(hr) self.measures['lr_means'] = np.mean(lr) self.measures['lr_stds'] = np.std(lr) hr = torch.Tensor(hr) lr = torch.Tensor(lr) # if self.gpu: # hr = hr.cuda() # lr = lr.cuda() return {'LQ': lr, 'GT': hr, 'LQ_path': str(item), 'GT_path': str(item)} def print_and_reset(self, tag): m = self.measures kvs = [] for k in sorted(m.keys()): kvs.append("{}={:.2f}".format(k, m[k])) print("[KPI] " + tag + ": " + ", ".join(kvs)) self.measures = None def random_flip(img, seg): random_choice = np.random.choice([True, False]) img = img if random_choice else np.flip(img, 2).copy() seg = seg if random_choice else np.flip(seg, 2).copy() return img, seg def random_rotation(img, seg): random_choice = np.random.choice([0, 1, 3]) img = np.rot90(img, random_choice, axes=(1, 2)).copy() seg = np.rot90(seg, random_choice, axes=(1, 2)).copy() return img, seg def random_crop(hr, lr, size_hr, scale, random): size_lr = size_hr // scale size_lr_x = lr.shape[1] size_lr_y = lr.shape[2] start_x_lr = np.random.randint(low=0, high=(size_lr_x - size_lr) + 1) if size_lr_x > size_lr else 0 start_y_lr = np.random.randint(low=0, high=(size_lr_y - size_lr) + 1) if size_lr_y > size_lr else 0 # LR Patch lr_patch = lr[:, start_x_lr:start_x_lr + size_lr, start_y_lr:start_y_lr + size_lr] # HR Patch start_x_hr = start_x_lr * scale start_y_hr = start_y_lr * scale hr_patch = hr[:, start_x_hr:start_x_hr + size_hr, start_y_hr:start_y_hr + size_hr] return hr_patch, lr_patch def center_crop(img, size): assert img.shape[1] == img.shape[2], img.shape border_double = img.shape[1] - size assert border_double % 2 == 0, (img.shape, size) border = border_double // 2 return img[:, border:-border, border:-border] def center_crop_tensor(img, size): assert img.shape[2] == img.shape[3], img.shape border_double = img.shape[2] - size assert border_double % 2 == 0, (img.shape, size) border = border_double // 2 return img[:, :, border:-border, border:-border]