super-resolution / models /SRFlow /code /data /LRHR_PKL_dataset.py
hail75's picture
add train.py
d2821a4
# Copyright (c) 2020 Huawei Technologies Co., Ltd.
# Licensed under CC BY-NC-SA 4.0 (Attribution-NonCommercial-ShareAlike 4.0 International) (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode
#
# The code is released for academic research use only. For commercial use, please contact Huawei Technologies Co., Ltd.
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# This file contains content licensed by https://github.com/xinntao/BasicSR/blob/master/LICENSE/LICENSE
import os
import subprocess
import torch.utils.data as data
import numpy as np
import time
import torch
import pickle
class LRHR_PKLDataset(data.Dataset):
def __init__(self, opt):
super(LRHR_PKLDataset, self).__init__()
self.opt = opt
self.crop_size = opt.get("GT_size", None)
self.scale = None
self.random_scale_list = [1]
hr_file_path = opt["dataroot_GT"]
lr_file_path = opt["dataroot_LQ"]
y_labels_file_path = opt['dataroot_y_labels']
gpu = True
augment = True
self.use_flip = opt["use_flip"] if "use_flip" in opt.keys() else False
self.use_rot = opt["use_rot"] if "use_rot" in opt.keys() else False
self.use_crop = opt["use_crop"] if "use_crop" in opt.keys() else False
self.center_crop_hr_size = opt.get("center_crop_hr_size", None)
n_max = opt["n_max"] if "n_max" in opt.keys() else int(1e8)
t = time.time()
self.lr_images = self.load_pkls(lr_file_path, n_max)
self.hr_images = self.load_pkls(hr_file_path, n_max)
min_val_hr = np.min([i.min() for i in self.hr_images[:20]])
max_val_hr = np.max([i.max() for i in self.hr_images[:20]])
min_val_lr = np.min([i.min() for i in self.lr_images[:20]])
max_val_lr = np.max([i.max() for i in self.lr_images[:20]])
t = time.time() - t
print("Loaded {} HR images with [{:.2f}, {:.2f}] in {:.2f}s from {}".
format(len(self.hr_images), min_val_hr, max_val_hr, t, hr_file_path))
print("Loaded {} LR images with [{:.2f}, {:.2f}] in {:.2f}s from {}".
format(len(self.lr_images), min_val_lr, max_val_lr, t, lr_file_path))
self.gpu = gpu
self.augment = augment
self.measures = None
def load_pkls(self, path, n_max):
assert os.path.isfile(path), path
images = []
with open(path, "rb") as f:
images += pickle.load(f)
assert len(images) > 0, path
images = images[:n_max]
images = [np.transpose(image, [2, 0, 1]) for image in images]
return images
def __len__(self):
return len(self.hr_images)
def __getitem__(self, item):
hr = self.hr_images[item]
lr = self.lr_images[item]
if self.scale == None:
self.scale = hr.shape[1] // lr.shape[1]
assert hr.shape[1] == self.scale * lr.shape[1], ('non-fractional ratio', lr.shape, hr.shape)
if self.use_crop:
hr, lr = random_crop(hr, lr, self.crop_size, self.scale, self.use_crop)
if self.center_crop_hr_size:
hr, lr = center_crop(hr, self.center_crop_hr_size), center_crop(lr, self.center_crop_hr_size // self.scale)
if self.use_flip:
hr, lr = random_flip(hr, lr)
if self.use_rot:
hr, lr = random_rotation(hr, lr)
hr = hr / 255.0
lr = lr / 255.0
if self.measures is None or np.random.random() < 0.05:
if self.measures is None:
self.measures = {}
self.measures['hr_means'] = np.mean(hr)
self.measures['hr_stds'] = np.std(hr)
self.measures['lr_means'] = np.mean(lr)
self.measures['lr_stds'] = np.std(lr)
hr = torch.Tensor(hr)
lr = torch.Tensor(lr)
# if self.gpu:
# hr = hr.cuda()
# lr = lr.cuda()
return {'LQ': lr, 'GT': hr, 'LQ_path': str(item), 'GT_path': str(item)}
def print_and_reset(self, tag):
m = self.measures
kvs = []
for k in sorted(m.keys()):
kvs.append("{}={:.2f}".format(k, m[k]))
print("[KPI] " + tag + ": " + ", ".join(kvs))
self.measures = None
def random_flip(img, seg):
random_choice = np.random.choice([True, False])
img = img if random_choice else np.flip(img, 2).copy()
seg = seg if random_choice else np.flip(seg, 2).copy()
return img, seg
def random_rotation(img, seg):
random_choice = np.random.choice([0, 1, 3])
img = np.rot90(img, random_choice, axes=(1, 2)).copy()
seg = np.rot90(seg, random_choice, axes=(1, 2)).copy()
return img, seg
def random_crop(hr, lr, size_hr, scale, random):
size_lr = size_hr // scale
size_lr_x = lr.shape[1]
size_lr_y = lr.shape[2]
start_x_lr = np.random.randint(low=0, high=(size_lr_x - size_lr) + 1) if size_lr_x > size_lr else 0
start_y_lr = np.random.randint(low=0, high=(size_lr_y - size_lr) + 1) if size_lr_y > size_lr else 0
# LR Patch
lr_patch = lr[:, start_x_lr:start_x_lr + size_lr, start_y_lr:start_y_lr + size_lr]
# HR Patch
start_x_hr = start_x_lr * scale
start_y_hr = start_y_lr * scale
hr_patch = hr[:, start_x_hr:start_x_hr + size_hr, start_y_hr:start_y_hr + size_hr]
return hr_patch, lr_patch
def center_crop(img, size):
assert img.shape[1] == img.shape[2], img.shape
border_double = img.shape[1] - size
assert border_double % 2 == 0, (img.shape, size)
border = border_double // 2
return img[:, border:-border, border:-border]
def center_crop_tensor(img, size):
assert img.shape[2] == img.shape[3], img.shape
border_double = img.shape[2] - size
assert border_double % 2 == 0, (img.shape, size)
border = border_double // 2
return img[:, :, border:-border, border:-border]