Spaces:
Running
Running
delete unnecessary files
Browse files- models/SRFlow/__pycache__/srflow.cpython-311.pyc +0 -0
- models/SRFlow/code/Measure.py +0 -134
- models/SRFlow/code/__init__.py +0 -1
- models/SRFlow/code/__pycache__/__init__.cpython-311.pyc +0 -0
- models/SRFlow/code/a.py +0 -27
- models/SRFlow/code/confs/RRDB_CelebA_8X.yml +0 -83
- models/SRFlow/code/confs/RRDB_DF2K_4X.yml +0 -85
- models/SRFlow/code/confs/RRDB_DF2K_8X.yml +0 -85
- models/SRFlow/code/confs/SRFlow_CelebA_8X.yml +0 -107
- models/SRFlow/code/confs/SRFlow_DF2K_8X.yml +0 -112
- models/SRFlow/code/data/LRHR_PKL_dataset.py +0 -179
- models/SRFlow/code/data/__init__.py +0 -51
- models/SRFlow/code/demo_on_pretrained.ipynb +0 -0
- models/SRFlow/code/imresize.py +0 -180
- models/SRFlow/code/prepare_data.py +0 -118
- models/SRFlow/code/test.py +0 -192
- models/SRFlow/code/train.py +0 -328
models/SRFlow/__pycache__/srflow.cpython-311.pyc
DELETED
Binary file (2.18 kB)
|
|
models/SRFlow/code/Measure.py
DELETED
@@ -1,134 +0,0 @@
|
|
1 |
-
# Copyright (c) 2020 Huawei Technologies Co., Ltd.
|
2 |
-
# Licensed under CC BY-NC-SA 4.0 (Attribution-NonCommercial-ShareAlike 4.0 International) (the "License");
|
3 |
-
# you may not use this file except in compliance with the License.
|
4 |
-
# You may obtain a copy of the License at
|
5 |
-
#
|
6 |
-
# https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode
|
7 |
-
#
|
8 |
-
# The code is released for academic research use only. For commercial use, please contact Huawei Technologies Co., Ltd.
|
9 |
-
# Unless required by applicable law or agreed to in writing, software
|
10 |
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
-
# See the License for the specific language governing permissions and
|
13 |
-
# limitations under the License.
|
14 |
-
|
15 |
-
import glob
|
16 |
-
import os
|
17 |
-
import time
|
18 |
-
from collections import OrderedDict
|
19 |
-
|
20 |
-
import numpy as np
|
21 |
-
import torch
|
22 |
-
import cv2
|
23 |
-
import argparse
|
24 |
-
|
25 |
-
from natsort import natsort
|
26 |
-
from skimage.metrics import structural_similarity as ssim
|
27 |
-
from skimage.metrics import peak_signal_noise_ratio as psnr
|
28 |
-
import lpips
|
29 |
-
|
30 |
-
|
31 |
-
class Measure():
|
32 |
-
def __init__(self, net='alex', use_gpu=False):
|
33 |
-
self.device = 'cuda' if use_gpu else 'cpu'
|
34 |
-
self.model = lpips.LPIPS(net=net)
|
35 |
-
self.model.to(self.device)
|
36 |
-
|
37 |
-
def measure(self, imgA, imgB):
|
38 |
-
return [float(f(imgA, imgB)) for f in [self.psnr, self.ssim, self.lpips]]
|
39 |
-
|
40 |
-
def lpips(self, imgA, imgB, model=None):
|
41 |
-
tA = t(imgA).to(self.device)
|
42 |
-
tB = t(imgB).to(self.device)
|
43 |
-
dist01 = self.model.forward(tA, tB).item()
|
44 |
-
return dist01
|
45 |
-
|
46 |
-
def ssim(self, imgA, imgB):
|
47 |
-
# multichannel: If True, treat the last dimension of the array as channels. Similarity calculations are done independently for each channel then averaged.
|
48 |
-
score, diff = ssim(imgA, imgB, full=True, multichannel=True, channel_axis=-1)
|
49 |
-
return score
|
50 |
-
|
51 |
-
def psnr(self, imgA, imgB):
|
52 |
-
psnr_val = psnr(imgA, imgB)
|
53 |
-
return psnr_val
|
54 |
-
|
55 |
-
|
56 |
-
def t(img):
|
57 |
-
def to_4d(img):
|
58 |
-
assert len(img.shape) == 3
|
59 |
-
assert img.dtype == np.uint8
|
60 |
-
img_new = np.expand_dims(img, axis=0)
|
61 |
-
assert len(img_new.shape) == 4
|
62 |
-
return img_new
|
63 |
-
|
64 |
-
def to_CHW(img):
|
65 |
-
return np.transpose(img, [2, 0, 1])
|
66 |
-
|
67 |
-
def to_tensor(img):
|
68 |
-
return torch.Tensor(img)
|
69 |
-
|
70 |
-
return to_tensor(to_4d(to_CHW(img))) / 127.5 - 1
|
71 |
-
|
72 |
-
|
73 |
-
def fiFindByWildcard(wildcard):
|
74 |
-
return natsort.natsorted(glob.glob(wildcard, recursive=True))
|
75 |
-
|
76 |
-
|
77 |
-
def imread(path):
|
78 |
-
return cv2.imread(path)[:, :, [2, 1, 0]]
|
79 |
-
|
80 |
-
|
81 |
-
def format_result(psnr, ssim, lpips):
|
82 |
-
return f'{psnr:0.2f}, {ssim:0.3f}, {lpips:0.3f}'
|
83 |
-
|
84 |
-
def measure_dirs(dirA, dirB, use_gpu, verbose=False):
|
85 |
-
if verbose:
|
86 |
-
vprint = lambda x: print(x)
|
87 |
-
else:
|
88 |
-
vprint = lambda x: None
|
89 |
-
|
90 |
-
|
91 |
-
t_init = time.time()
|
92 |
-
|
93 |
-
paths_A = fiFindByWildcard(os.path.join(dirA, f'*.{type}'))
|
94 |
-
paths_B = fiFindByWildcard(os.path.join(dirB, f'*.{type}'))
|
95 |
-
|
96 |
-
vprint("Comparing: ")
|
97 |
-
vprint(dirA)
|
98 |
-
vprint(dirB)
|
99 |
-
|
100 |
-
measure = Measure(use_gpu=use_gpu)
|
101 |
-
|
102 |
-
results = []
|
103 |
-
for pathA, pathB in zip(paths_A, paths_B):
|
104 |
-
result = OrderedDict()
|
105 |
-
|
106 |
-
t = time.time()
|
107 |
-
result['psnr'], result['ssim'], result['lpips'] = measure.measure(imread(pathA), imread(pathB))
|
108 |
-
d = time.time() - t
|
109 |
-
vprint(f"{pathA.split('/')[-1]}, {pathB.split('/')[-1]}, {format_result(**result)}, {d:0.1f}")
|
110 |
-
|
111 |
-
results.append(result)
|
112 |
-
|
113 |
-
psnr = np.mean([result['psnr'] for result in results])
|
114 |
-
ssim = np.mean([result['ssim'] for result in results])
|
115 |
-
lpips = np.mean([result['lpips'] for result in results])
|
116 |
-
|
117 |
-
vprint(f"Final Result: {format_result(psnr, ssim, lpips)}, {time.time() - t_init:0.1f}s")
|
118 |
-
|
119 |
-
|
120 |
-
if __name__ == "__main__":
|
121 |
-
parser = argparse.ArgumentParser()
|
122 |
-
parser.add_argument('-dirA', default='', type=str)
|
123 |
-
parser.add_argument('-dirB', default='', type=str)
|
124 |
-
parser.add_argument('-type', default='png')
|
125 |
-
parser.add_argument('--use_gpu', action='store_true', default=False)
|
126 |
-
args = parser.parse_args()
|
127 |
-
|
128 |
-
dirA = args.dirA
|
129 |
-
dirB = args.dirB
|
130 |
-
type = args.type
|
131 |
-
use_gpu = args.use_gpu
|
132 |
-
|
133 |
-
if len(dirA) > 0 and len(dirB) > 0:
|
134 |
-
measure_dirs(dirA, dirB, use_gpu=use_gpu, verbose=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
models/SRFlow/code/__init__.py
CHANGED
@@ -22,7 +22,6 @@ sys.path.append('../..')
|
|
22 |
from natsort import natsort
|
23 |
import SRFlow.code.options.options as option
|
24 |
|
25 |
-
from SRFlow.code.models import create_model
|
26 |
import torch
|
27 |
from SRFlow.code.utils.util import opt_get
|
28 |
from SRFlow.code.models.SRFlow_model import SRFlowModel
|
|
|
22 |
from natsort import natsort
|
23 |
import SRFlow.code.options.options as option
|
24 |
|
|
|
25 |
import torch
|
26 |
from SRFlow.code.utils.util import opt_get
|
27 |
from SRFlow.code.models.SRFlow_model import SRFlowModel
|
models/SRFlow/code/__pycache__/__init__.cpython-311.pyc
CHANGED
Binary files a/models/SRFlow/code/__pycache__/__init__.cpython-311.pyc and b/models/SRFlow/code/__pycache__/__init__.cpython-311.pyc differ
|
|
models/SRFlow/code/a.py
DELETED
@@ -1,27 +0,0 @@
|
|
1 |
-
import pickle
|
2 |
-
import numpy as np
|
3 |
-
import os
|
4 |
-
import matplotlib.pyplot as plt
|
5 |
-
|
6 |
-
def load_pkls(path):
|
7 |
-
assert os.path.isfile(path), path
|
8 |
-
images = []
|
9 |
-
with open(path, "rb") as f:
|
10 |
-
images += pickle.load(f)
|
11 |
-
assert len(images) > 0, path
|
12 |
-
images = [np.transpose(image, [2, 0, 1]) for image in images]
|
13 |
-
return images
|
14 |
-
|
15 |
-
path = 'datasets/DIV2K-va.pklv4'
|
16 |
-
loaded_images = load_pkls(path)
|
17 |
-
print(len(loaded_images))
|
18 |
-
# Display the first image
|
19 |
-
if loaded_images:
|
20 |
-
first_image = loaded_images[11]
|
21 |
-
plt.imshow(np.transpose(first_image, [1, 2, 0])) # Transpose image to original shape [height, width, channels]
|
22 |
-
plt.title('First Image')
|
23 |
-
plt.axis('off') # Hide axis
|
24 |
-
plt.show()
|
25 |
-
else:
|
26 |
-
print("No images loaded from the pickle file.")
|
27 |
-
print(loaded_images[11])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
models/SRFlow/code/confs/RRDB_CelebA_8X.yml
DELETED
@@ -1,83 +0,0 @@
|
|
1 |
-
# Copyright (c) 2020 Huawei Technologies Co., Ltd.
|
2 |
-
# Licensed under CC BY-NC-SA 4.0 (Attribution-NonCommercial-ShareAlike 4.0 International) (the "License");
|
3 |
-
# you may not use this file except in compliance with the License.
|
4 |
-
# You may obtain a copy of the License at
|
5 |
-
#
|
6 |
-
# https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode
|
7 |
-
#
|
8 |
-
# The code is released for academic research use only. For commercial use, please contact Huawei Technologies Co., Ltd.
|
9 |
-
# Unless required by applicable law or agreed to in writing, software
|
10 |
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
-
# See the License for the specific language governing permissions and
|
13 |
-
# limitations under the License.
|
14 |
-
#
|
15 |
-
# This file contains content licensed by https://github.com/xinntao/BasicSR/blob/master/LICENSE/LICENSE
|
16 |
-
|
17 |
-
#### general settings
|
18 |
-
name: train
|
19 |
-
use_tb_logger: true
|
20 |
-
model: SR
|
21 |
-
distortion: sr
|
22 |
-
scale: 8
|
23 |
-
#gpu_ids: [ 0 ]
|
24 |
-
|
25 |
-
#### datasets
|
26 |
-
datasets:
|
27 |
-
train:
|
28 |
-
name: CelebA_160_tr
|
29 |
-
mode: LRHR_PKL
|
30 |
-
dataroot_GT: ../datasets/celebA-train-gt_1pct.pklv4
|
31 |
-
dataroot_LQ: ../datasets/celebA-train-x8_1pct.pklv4
|
32 |
-
|
33 |
-
use_shuffle: true
|
34 |
-
n_workers: 0 # per GPU
|
35 |
-
batch_size: 16
|
36 |
-
GT_size: 160
|
37 |
-
use_flip: true
|
38 |
-
use_rot: true
|
39 |
-
color: RGB
|
40 |
-
val:
|
41 |
-
name: CelebA_160_va
|
42 |
-
mode: LRHR_PKL
|
43 |
-
dataroot_GT: ../datasets/celebA-valid-gt_1pct.pklv4
|
44 |
-
dataroot_LQ: ../datasets/celebA-valid-x8_1pct.pklv4
|
45 |
-
n_max: 10
|
46 |
-
|
47 |
-
#### network structures
|
48 |
-
network_G:
|
49 |
-
which_model_G: RRDBNet
|
50 |
-
in_nc: 3
|
51 |
-
out_nc: 3
|
52 |
-
nf: 64
|
53 |
-
nb: 23
|
54 |
-
|
55 |
-
#### path
|
56 |
-
path:
|
57 |
-
pretrain_model_G: ~
|
58 |
-
strict_load: true
|
59 |
-
resume_state: auto
|
60 |
-
|
61 |
-
#### training settings: learning rate scheme, loss
|
62 |
-
train:
|
63 |
-
lr_G: !!float 2e-4
|
64 |
-
lr_scheme: CosineAnnealingLR_Restart
|
65 |
-
beta1: 0.9
|
66 |
-
beta2: 0.99
|
67 |
-
niter: 200000
|
68 |
-
warmup_iter: -1 # no warm up
|
69 |
-
T_period: [ 50000, 50000, 50000, 50000 ]
|
70 |
-
restarts: [ 50000, 100000, 150000 ]
|
71 |
-
restart_weights: [ 1, 1, 1 ]
|
72 |
-
eta_min: !!float 1e-7
|
73 |
-
|
74 |
-
pixel_criterion: l1
|
75 |
-
pixel_weight: 1.0
|
76 |
-
|
77 |
-
manual_seed: 10
|
78 |
-
val_freq: !!float 5e3
|
79 |
-
|
80 |
-
#### logger
|
81 |
-
logger:
|
82 |
-
print_freq: 100
|
83 |
-
save_checkpoint_freq: !!float 1e3
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
models/SRFlow/code/confs/RRDB_DF2K_4X.yml
DELETED
@@ -1,85 +0,0 @@
|
|
1 |
-
# Copyright (c) 2020 Huawei Technologies Co., Ltd.
|
2 |
-
# Licensed under CC BY-NC-SA 4.0 (Attribution-NonCommercial-ShareAlike 4.0 International) (the "License");
|
3 |
-
# you may not use this file except in compliance with the License.
|
4 |
-
# You may obtain a copy of the License at
|
5 |
-
#
|
6 |
-
# https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode
|
7 |
-
#
|
8 |
-
# The code is released for academic research use only. For commercial use, please contact Huawei Technologies Co., Ltd.
|
9 |
-
# Unless required by applicable law or agreed to in writing, software
|
10 |
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
-
# See the License for the specific language governing permissions and
|
13 |
-
# limitations under the License.
|
14 |
-
#
|
15 |
-
# This file contains content licensed by https://github.com/xinntao/BasicSR/blob/master/LICENSE/LICENSE
|
16 |
-
|
17 |
-
#### general settings
|
18 |
-
name: train
|
19 |
-
use_tb_logger: true
|
20 |
-
model: SR
|
21 |
-
distortion: sr
|
22 |
-
scale: 4
|
23 |
-
gpu_ids: [ 0 ]
|
24 |
-
|
25 |
-
#### datasets
|
26 |
-
datasets:
|
27 |
-
train:
|
28 |
-
name: CelebA_160_tr
|
29 |
-
mode: LRHR_PKL
|
30 |
-
dataroot_GT: ../datasets/DF2K-train-gt_1pct.pklv4
|
31 |
-
dataroot_LQ: ../datasets/DF2K-train-x4_1pct.pklv4
|
32 |
-
quant: 32
|
33 |
-
|
34 |
-
use_shuffle: true
|
35 |
-
n_workers: 3 # per GPU
|
36 |
-
batch_size: 16
|
37 |
-
GT_size: 160
|
38 |
-
use_flip: true
|
39 |
-
color: RGB
|
40 |
-
val:
|
41 |
-
name: CelebA_160_va
|
42 |
-
mode: LRHR_PKL
|
43 |
-
dataroot_GT: ../datasets/DF2K-valid-gt_1pct.pklv4
|
44 |
-
dataroot_LQ: ../datasets/DF2K-valid-x4_1pct.pklv4
|
45 |
-
quant: 32
|
46 |
-
n_max: 20
|
47 |
-
|
48 |
-
#### network structures
|
49 |
-
network_G:
|
50 |
-
which_model_G: RRDBNet
|
51 |
-
use_orig: True
|
52 |
-
in_nc: 3
|
53 |
-
out_nc: 3
|
54 |
-
nf: 64
|
55 |
-
nb: 23
|
56 |
-
|
57 |
-
#### path
|
58 |
-
path:
|
59 |
-
pretrain_model_G: ~
|
60 |
-
strict_load: true
|
61 |
-
resume_state: auto
|
62 |
-
|
63 |
-
#### training settings: learning rate scheme, loss
|
64 |
-
train:
|
65 |
-
lr_G: !!float 2e-4
|
66 |
-
lr_scheme: CosineAnnealingLR_Restart
|
67 |
-
beta1: 0.9
|
68 |
-
beta2: 0.99
|
69 |
-
niter: 1000000
|
70 |
-
warmup_iter: -1 # no warm up
|
71 |
-
T_period: [ 50000, 50000, 50000, 50000 ]
|
72 |
-
restarts: [ 50000, 100000, 150000 ]
|
73 |
-
restart_weights: [ 1, 1, 1 ]
|
74 |
-
eta_min: !!float 1e-7
|
75 |
-
|
76 |
-
pixel_criterion: l1
|
77 |
-
pixel_weight: 1.0
|
78 |
-
|
79 |
-
manual_seed: 10
|
80 |
-
val_freq: !!float 5e3
|
81 |
-
|
82 |
-
#### logger
|
83 |
-
logger:
|
84 |
-
print_freq: 100
|
85 |
-
save_checkpoint_freq: !!float 1e3
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
models/SRFlow/code/confs/RRDB_DF2K_8X.yml
DELETED
@@ -1,85 +0,0 @@
|
|
1 |
-
# Copyright (c) 2020 Huawei Technologies Co., Ltd.
|
2 |
-
# Licensed under CC BY-NC-SA 4.0 (Attribution-NonCommercial-ShareAlike 4.0 International) (the "License");
|
3 |
-
# you may not use this file except in compliance with the License.
|
4 |
-
# You may obtain a copy of the License at
|
5 |
-
#
|
6 |
-
# https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode
|
7 |
-
#
|
8 |
-
# The code is released for academic research use only. For commercial use, please contact Huawei Technologies Co., Ltd.
|
9 |
-
# Unless required by applicable law or agreed to in writing, software
|
10 |
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
-
# See the License for the specific language governing permissions and
|
13 |
-
# limitations under the License.
|
14 |
-
#
|
15 |
-
# This file contains content licensed by https://github.com/xinntao/BasicSR/blob/master/LICENSE/LICENSE
|
16 |
-
|
17 |
-
#### general settings
|
18 |
-
name: train
|
19 |
-
use_tb_logger: true
|
20 |
-
model: SR
|
21 |
-
distortion: sr
|
22 |
-
scale: 8
|
23 |
-
gpu_ids: [ 0 ]
|
24 |
-
|
25 |
-
#### datasets
|
26 |
-
datasets:
|
27 |
-
train:
|
28 |
-
name: CelebA_160_tr
|
29 |
-
mode: LRHR_PKL
|
30 |
-
dataroot_GT: ../datasets/DF2K-train-gt_1pct.pklv4
|
31 |
-
dataroot_LQ: ../datasets/DF2K-train-x8_1pct.pklv4
|
32 |
-
quant: 32
|
33 |
-
|
34 |
-
use_shuffle: true
|
35 |
-
n_workers: 3 # per GPU
|
36 |
-
batch_size: 16
|
37 |
-
GT_size: 160
|
38 |
-
use_flip: true
|
39 |
-
color: RGB
|
40 |
-
|
41 |
-
val:
|
42 |
-
name: CelebA_160_va
|
43 |
-
mode: LRHR_PKL
|
44 |
-
dataroot_GT: ../datasets/DF2K-valid-gt_1pct.pklv4
|
45 |
-
dataroot_LQ: ../datasets/DF2K-valid-x8_1pct.pklv4
|
46 |
-
quant: 32
|
47 |
-
n_max: 20
|
48 |
-
|
49 |
-
#### network structures
|
50 |
-
network_G:
|
51 |
-
which_model_G: RRDBNet
|
52 |
-
in_nc: 3
|
53 |
-
out_nc: 3
|
54 |
-
nf: 64
|
55 |
-
nb: 23
|
56 |
-
|
57 |
-
#### path
|
58 |
-
path:
|
59 |
-
pretrain_model_G: ~
|
60 |
-
strict_load: true
|
61 |
-
resume_state: auto
|
62 |
-
|
63 |
-
#### training settings: learning rate scheme, loss
|
64 |
-
train:
|
65 |
-
lr_G: !!float 2e-4
|
66 |
-
lr_scheme: CosineAnnealingLR_Restart
|
67 |
-
beta1: 0.9
|
68 |
-
beta2: 0.99
|
69 |
-
niter: 200000
|
70 |
-
warmup_iter: -1 # no warm up
|
71 |
-
T_period: [ 50000, 50000, 50000, 50000 ]
|
72 |
-
restarts: [ 50000, 100000, 150000 ]
|
73 |
-
restart_weights: [ 1, 1, 1 ]
|
74 |
-
eta_min: !!float 1e-7
|
75 |
-
|
76 |
-
pixel_criterion: l1
|
77 |
-
pixel_weight: 1.0
|
78 |
-
|
79 |
-
manual_seed: 10
|
80 |
-
val_freq: !!float 5e3
|
81 |
-
|
82 |
-
#### logger
|
83 |
-
logger:
|
84 |
-
print_freq: 100
|
85 |
-
save_checkpoint_freq: !!float 1e3
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
models/SRFlow/code/confs/SRFlow_CelebA_8X.yml
DELETED
@@ -1,107 +0,0 @@
|
|
1 |
-
# Copyright (c) 2020 Huawei Technologies Co., Ltd.
|
2 |
-
# Licensed under CC BY-NC-SA 4.0 (Attribution-NonCommercial-ShareAlike 4.0 International) (the "License");
|
3 |
-
# you may not use this file except in compliance with the License.
|
4 |
-
# You may obtain a copy of the License at
|
5 |
-
#
|
6 |
-
# https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode
|
7 |
-
#
|
8 |
-
# The code is released for academic research use only. For commercial use, please contact Huawei Technologies Co., Ltd.
|
9 |
-
# Unless required by applicable law or agreed to in writing, software
|
10 |
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
-
# See the License for the specific language governing permissions and
|
13 |
-
# limitations under the License.
|
14 |
-
#
|
15 |
-
# This file contains content licensed by https://github.com/xinntao/BasicSR/blob/master/LICENSE/LICENSE
|
16 |
-
|
17 |
-
#### general settings
|
18 |
-
name: train
|
19 |
-
use_tb_logger: true
|
20 |
-
model: SRFlow
|
21 |
-
distortion: sr
|
22 |
-
scale: 8
|
23 |
-
gpu_ids: [ 0 ]
|
24 |
-
|
25 |
-
#### datasets
|
26 |
-
datasets:
|
27 |
-
train:
|
28 |
-
name: CelebA_160_tr
|
29 |
-
mode: LRHR_PKL
|
30 |
-
dataroot_GT: ../datasets/celebA-train-gt.pklv4
|
31 |
-
dataroot_LQ: ../datasets/celebA-train-x8.pklv4
|
32 |
-
quant: 32
|
33 |
-
|
34 |
-
use_shuffle: true
|
35 |
-
n_workers: 3 # per GPU
|
36 |
-
batch_size: 16
|
37 |
-
GT_size: 160
|
38 |
-
use_flip: true
|
39 |
-
color: RGB
|
40 |
-
val:
|
41 |
-
name: CelebA_160_va
|
42 |
-
mode: LRHR_PKL
|
43 |
-
dataroot_GT: ../datasets/celebA-train-gt.pklv4
|
44 |
-
dataroot_LQ: ../datasets/celebA-train-x8.pklv4
|
45 |
-
quant: 32
|
46 |
-
n_max: 20
|
47 |
-
|
48 |
-
#### Test Settings
|
49 |
-
dataroot_GT: ../datasets/celebA-validation-gt
|
50 |
-
dataroot_LR: ../datasets/celebA-validation-x8
|
51 |
-
model_path: ../pretrained_models/SRFlow_CelebA_8X.pth
|
52 |
-
heat: 0.9 # This is the standard deviation of the latent vectors
|
53 |
-
|
54 |
-
#### network structures
|
55 |
-
network_G:
|
56 |
-
which_model_G: SRFlowNet
|
57 |
-
in_nc: 3
|
58 |
-
out_nc: 3
|
59 |
-
nf: 64
|
60 |
-
nb: 8
|
61 |
-
upscale: 8
|
62 |
-
train_RRDB: false
|
63 |
-
train_RRDB_delay: 0.5
|
64 |
-
|
65 |
-
flow:
|
66 |
-
K: 16
|
67 |
-
L: 4
|
68 |
-
noInitialInj: true
|
69 |
-
coupling: CondAffineSeparatedAndCond
|
70 |
-
additionalFlowNoAffine: 2
|
71 |
-
split:
|
72 |
-
enable: true
|
73 |
-
fea_up0: true
|
74 |
-
stackRRDB:
|
75 |
-
blocks: [ 1, 3, 5, 7 ]
|
76 |
-
concat: true
|
77 |
-
|
78 |
-
#### path
|
79 |
-
path:
|
80 |
-
pretrain_model_G: ../pretrained_models/RRDB_CelebA_8X.pth
|
81 |
-
strict_load: true
|
82 |
-
resume_state: auto
|
83 |
-
|
84 |
-
#### training settings: learning rate scheme, loss
|
85 |
-
train:
|
86 |
-
manual_seed: 10
|
87 |
-
lr_G: !!float 5e-4
|
88 |
-
weight_decay_G: 0
|
89 |
-
beta1: 0.9
|
90 |
-
beta2: 0.99
|
91 |
-
lr_scheme: MultiStepLR
|
92 |
-
warmup_iter: -1 # no warm up
|
93 |
-
lr_steps_rel: [ 0.5, 0.75, 0.9, 0.95 ]
|
94 |
-
lr_gamma: 0.5
|
95 |
-
|
96 |
-
niter: 200000
|
97 |
-
val_freq: 40000
|
98 |
-
|
99 |
-
#### validation settings
|
100 |
-
val:
|
101 |
-
heats: [ 0.0, 0.5, 0.75, 1.0 ]
|
102 |
-
n_sample: 3
|
103 |
-
|
104 |
-
#### logger
|
105 |
-
logger:
|
106 |
-
print_freq: 100
|
107 |
-
save_checkpoint_freq: !!float 1e3
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
models/SRFlow/code/confs/SRFlow_DF2K_8X.yml
DELETED
@@ -1,112 +0,0 @@
|
|
1 |
-
# Copyright (c) 2020 Huawei Technologies Co., Ltd.
|
2 |
-
# Licensed under CC BY-NC-SA 4.0 (Attribution-NonCommercial-ShareAlike 4.0 International) (the "License");
|
3 |
-
# you may not use this file except in compliance with the License.
|
4 |
-
# You may obtain a copy of the License at
|
5 |
-
#
|
6 |
-
# https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode
|
7 |
-
#
|
8 |
-
# The code is released for academic research use only. For commercial use, please contact Huawei Technologies Co., Ltd.
|
9 |
-
# Unless required by applicable law or agreed to in writing, software
|
10 |
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
-
# See the License for the specific language governing permissions and
|
13 |
-
# limitations under the License.
|
14 |
-
#
|
15 |
-
# This file contains content licensed by https://github.com/xinntao/BasicSR/blob/master/LICENSE/LICENSE
|
16 |
-
|
17 |
-
#### general settings
|
18 |
-
name: train
|
19 |
-
use_tb_logger: true
|
20 |
-
model: SRFlow
|
21 |
-
distortion: sr
|
22 |
-
scale: 8
|
23 |
-
gpu_ids: [ 0 ]
|
24 |
-
|
25 |
-
#### datasets
|
26 |
-
datasets:
|
27 |
-
train:
|
28 |
-
name: CelebA_160_tr
|
29 |
-
mode: LRHR_PKL
|
30 |
-
dataroot_GT: ../datasets/DF2K-tr.pklv4
|
31 |
-
dataroot_LQ: ../datasets/DF2K-tr_X8.pklv4
|
32 |
-
quant: 32
|
33 |
-
|
34 |
-
use_shuffle: true
|
35 |
-
n_workers: 3 # per GPU
|
36 |
-
batch_size: 16
|
37 |
-
GT_size: 160
|
38 |
-
use_flip: true
|
39 |
-
color: RGB
|
40 |
-
|
41 |
-
val:
|
42 |
-
name: CelebA_160_va
|
43 |
-
mode: LRHR_PKL
|
44 |
-
dataroot_GT: ../datasets/DIV2K-va.pklv4
|
45 |
-
dataroot_LQ: ../datasets/DIV2K-va_X8.pklv4
|
46 |
-
quant: 32
|
47 |
-
n_max: 20
|
48 |
-
|
49 |
-
#### Test Settings
|
50 |
-
dataroot_GT: ../datasets/div2k-validation-modcrop8-gt
|
51 |
-
dataroot_LR: ../datasets/div2k-validation-modcrop8-x8
|
52 |
-
model_path: ../pretrained_models/SRFlow_DF2K_8X.pth
|
53 |
-
heat: 0.9 # This is the standard deviation of the latent vectors
|
54 |
-
|
55 |
-
#### network structures
|
56 |
-
network_G:
|
57 |
-
which_model_G: SRFlowNet
|
58 |
-
in_nc: 3
|
59 |
-
out_nc: 3
|
60 |
-
nf: 64
|
61 |
-
nb: 23
|
62 |
-
upscale: 8
|
63 |
-
train_RRDB: false
|
64 |
-
train_RRDB_delay: 0.5
|
65 |
-
|
66 |
-
flow:
|
67 |
-
K: 16
|
68 |
-
L: 4
|
69 |
-
noInitialInj: true
|
70 |
-
coupling: CondAffineSeparatedAndCond
|
71 |
-
additionalFlowNoAffine: 2
|
72 |
-
split:
|
73 |
-
enable: true
|
74 |
-
fea_up0: true
|
75 |
-
stackRRDB:
|
76 |
-
blocks: [ 1, 3, 5, 7 ]
|
77 |
-
concat: true
|
78 |
-
|
79 |
-
#### path
|
80 |
-
path:
|
81 |
-
pretrain_model_G: ../pretrained_models/RRDB_DF2K_8X.pth
|
82 |
-
strict_load: true
|
83 |
-
resume_state: auto
|
84 |
-
|
85 |
-
#### training settings: learning rate scheme, loss
|
86 |
-
train:
|
87 |
-
manual_seed: 10
|
88 |
-
lr_G: !!float 5e-4
|
89 |
-
weight_decay_G: 0
|
90 |
-
beta1: 0.9
|
91 |
-
beta2: 0.99
|
92 |
-
lr_scheme: MultiStepLR
|
93 |
-
warmup_iter: -1 # no warm up
|
94 |
-
lr_steps_rel: [ 0.5, 0.75, 0.9, 0.95 ]
|
95 |
-
lr_gamma: 0.5
|
96 |
-
|
97 |
-
niter: 200000
|
98 |
-
val_freq: 40000
|
99 |
-
|
100 |
-
#### validation settings
|
101 |
-
val:
|
102 |
-
heats: [ 0.0, 0.5, 0.75, 1.0 ]
|
103 |
-
n_sample: 3
|
104 |
-
|
105 |
-
test:
|
106 |
-
heats: [ 0.0, 0.7, 0.8, 0.9 ]
|
107 |
-
|
108 |
-
#### logger
|
109 |
-
logger:
|
110 |
-
# Debug print_freq: 100
|
111 |
-
print_freq: 100
|
112 |
-
save_checkpoint_freq: !!float 1e3
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
models/SRFlow/code/data/LRHR_PKL_dataset.py
DELETED
@@ -1,179 +0,0 @@
|
|
1 |
-
# Copyright (c) 2020 Huawei Technologies Co., Ltd.
|
2 |
-
# Licensed under CC BY-NC-SA 4.0 (Attribution-NonCommercial-ShareAlike 4.0 International) (the "License");
|
3 |
-
# you may not use this file except in compliance with the License.
|
4 |
-
# You may obtain a copy of the License at
|
5 |
-
#
|
6 |
-
# https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode
|
7 |
-
#
|
8 |
-
# The code is released for academic research use only. For commercial use, please contact Huawei Technologies Co., Ltd.
|
9 |
-
# Unless required by applicable law or agreed to in writing, software
|
10 |
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
-
# See the License for the specific language governing permissions and
|
13 |
-
# limitations under the License.
|
14 |
-
#
|
15 |
-
# This file contains content licensed by https://github.com/xinntao/BasicSR/blob/master/LICENSE/LICENSE
|
16 |
-
|
17 |
-
import os
|
18 |
-
import subprocess
|
19 |
-
import torch.utils.data as data
|
20 |
-
import numpy as np
|
21 |
-
import time
|
22 |
-
import torch
|
23 |
-
|
24 |
-
import pickle
|
25 |
-
|
26 |
-
|
27 |
-
class LRHR_PKLDataset(data.Dataset):
|
28 |
-
def __init__(self, opt):
|
29 |
-
super(LRHR_PKLDataset, self).__init__()
|
30 |
-
self.opt = opt
|
31 |
-
self.crop_size = opt.get("GT_size", None)
|
32 |
-
self.scale = None
|
33 |
-
self.random_scale_list = [1]
|
34 |
-
|
35 |
-
hr_file_path = opt["dataroot_GT"]
|
36 |
-
lr_file_path = opt["dataroot_LQ"]
|
37 |
-
y_labels_file_path = opt['dataroot_y_labels']
|
38 |
-
|
39 |
-
gpu = True
|
40 |
-
augment = True
|
41 |
-
|
42 |
-
self.use_flip = opt["use_flip"] if "use_flip" in opt.keys() else False
|
43 |
-
self.use_rot = opt["use_rot"] if "use_rot" in opt.keys() else False
|
44 |
-
self.use_crop = opt["use_crop"] if "use_crop" in opt.keys() else False
|
45 |
-
self.center_crop_hr_size = opt.get("center_crop_hr_size", None)
|
46 |
-
|
47 |
-
n_max = opt["n_max"] if "n_max" in opt.keys() else int(1e8)
|
48 |
-
|
49 |
-
t = time.time()
|
50 |
-
self.lr_images = self.load_pkls(lr_file_path, n_max)
|
51 |
-
self.hr_images = self.load_pkls(hr_file_path, n_max)
|
52 |
-
|
53 |
-
min_val_hr = np.min([i.min() for i in self.hr_images[:20]])
|
54 |
-
max_val_hr = np.max([i.max() for i in self.hr_images[:20]])
|
55 |
-
|
56 |
-
min_val_lr = np.min([i.min() for i in self.lr_images[:20]])
|
57 |
-
max_val_lr = np.max([i.max() for i in self.lr_images[:20]])
|
58 |
-
|
59 |
-
t = time.time() - t
|
60 |
-
print("Loaded {} HR images with [{:.2f}, {:.2f}] in {:.2f}s from {}".
|
61 |
-
format(len(self.hr_images), min_val_hr, max_val_hr, t, hr_file_path))
|
62 |
-
print("Loaded {} LR images with [{:.2f}, {:.2f}] in {:.2f}s from {}".
|
63 |
-
format(len(self.lr_images), min_val_lr, max_val_lr, t, lr_file_path))
|
64 |
-
|
65 |
-
self.gpu = gpu
|
66 |
-
self.augment = augment
|
67 |
-
|
68 |
-
self.measures = None
|
69 |
-
|
70 |
-
def load_pkls(self, path, n_max):
|
71 |
-
assert os.path.isfile(path), path
|
72 |
-
images = []
|
73 |
-
with open(path, "rb") as f:
|
74 |
-
images += pickle.load(f)
|
75 |
-
assert len(images) > 0, path
|
76 |
-
images = images[:n_max]
|
77 |
-
images = [np.transpose(image, [2, 0, 1]) for image in images]
|
78 |
-
return images
|
79 |
-
|
80 |
-
def __len__(self):
|
81 |
-
return len(self.hr_images)
|
82 |
-
|
83 |
-
def __getitem__(self, item):
|
84 |
-
hr = self.hr_images[item]
|
85 |
-
lr = self.lr_images[item]
|
86 |
-
|
87 |
-
if self.scale == None:
|
88 |
-
self.scale = hr.shape[1] // lr.shape[1]
|
89 |
-
assert hr.shape[1] == self.scale * lr.shape[1], ('non-fractional ratio', lr.shape, hr.shape)
|
90 |
-
|
91 |
-
if self.use_crop:
|
92 |
-
hr, lr = random_crop(hr, lr, self.crop_size, self.scale, self.use_crop)
|
93 |
-
|
94 |
-
if self.center_crop_hr_size:
|
95 |
-
hr, lr = center_crop(hr, self.center_crop_hr_size), center_crop(lr, self.center_crop_hr_size // self.scale)
|
96 |
-
|
97 |
-
if self.use_flip:
|
98 |
-
hr, lr = random_flip(hr, lr)
|
99 |
-
|
100 |
-
if self.use_rot:
|
101 |
-
hr, lr = random_rotation(hr, lr)
|
102 |
-
|
103 |
-
hr = hr / 255.0
|
104 |
-
lr = lr / 255.0
|
105 |
-
|
106 |
-
if self.measures is None or np.random.random() < 0.05:
|
107 |
-
if self.measures is None:
|
108 |
-
self.measures = {}
|
109 |
-
self.measures['hr_means'] = np.mean(hr)
|
110 |
-
self.measures['hr_stds'] = np.std(hr)
|
111 |
-
self.measures['lr_means'] = np.mean(lr)
|
112 |
-
self.measures['lr_stds'] = np.std(lr)
|
113 |
-
|
114 |
-
hr = torch.Tensor(hr)
|
115 |
-
lr = torch.Tensor(lr)
|
116 |
-
|
117 |
-
# if self.gpu:
|
118 |
-
# hr = hr.cuda()
|
119 |
-
# lr = lr.cuda()
|
120 |
-
|
121 |
-
return {'LQ': lr, 'GT': hr, 'LQ_path': str(item), 'GT_path': str(item)}
|
122 |
-
|
123 |
-
def print_and_reset(self, tag):
|
124 |
-
m = self.measures
|
125 |
-
kvs = []
|
126 |
-
for k in sorted(m.keys()):
|
127 |
-
kvs.append("{}={:.2f}".format(k, m[k]))
|
128 |
-
print("[KPI] " + tag + ": " + ", ".join(kvs))
|
129 |
-
self.measures = None
|
130 |
-
|
131 |
-
|
132 |
-
def random_flip(img, seg):
|
133 |
-
random_choice = np.random.choice([True, False])
|
134 |
-
img = img if random_choice else np.flip(img, 2).copy()
|
135 |
-
seg = seg if random_choice else np.flip(seg, 2).copy()
|
136 |
-
return img, seg
|
137 |
-
|
138 |
-
|
139 |
-
def random_rotation(img, seg):
|
140 |
-
random_choice = np.random.choice([0, 1, 3])
|
141 |
-
img = np.rot90(img, random_choice, axes=(1, 2)).copy()
|
142 |
-
seg = np.rot90(seg, random_choice, axes=(1, 2)).copy()
|
143 |
-
return img, seg
|
144 |
-
|
145 |
-
|
146 |
-
def random_crop(hr, lr, size_hr, scale, random):
|
147 |
-
size_lr = size_hr // scale
|
148 |
-
|
149 |
-
size_lr_x = lr.shape[1]
|
150 |
-
size_lr_y = lr.shape[2]
|
151 |
-
|
152 |
-
start_x_lr = np.random.randint(low=0, high=(size_lr_x - size_lr) + 1) if size_lr_x > size_lr else 0
|
153 |
-
start_y_lr = np.random.randint(low=0, high=(size_lr_y - size_lr) + 1) if size_lr_y > size_lr else 0
|
154 |
-
|
155 |
-
# LR Patch
|
156 |
-
lr_patch = lr[:, start_x_lr:start_x_lr + size_lr, start_y_lr:start_y_lr + size_lr]
|
157 |
-
|
158 |
-
# HR Patch
|
159 |
-
start_x_hr = start_x_lr * scale
|
160 |
-
start_y_hr = start_y_lr * scale
|
161 |
-
hr_patch = hr[:, start_x_hr:start_x_hr + size_hr, start_y_hr:start_y_hr + size_hr]
|
162 |
-
|
163 |
-
return hr_patch, lr_patch
|
164 |
-
|
165 |
-
|
166 |
-
def center_crop(img, size):
|
167 |
-
assert img.shape[1] == img.shape[2], img.shape
|
168 |
-
border_double = img.shape[1] - size
|
169 |
-
assert border_double % 2 == 0, (img.shape, size)
|
170 |
-
border = border_double // 2
|
171 |
-
return img[:, border:-border, border:-border]
|
172 |
-
|
173 |
-
|
174 |
-
def center_crop_tensor(img, size):
|
175 |
-
assert img.shape[2] == img.shape[3], img.shape
|
176 |
-
border_double = img.shape[2] - size
|
177 |
-
assert border_double % 2 == 0, (img.shape, size)
|
178 |
-
border = border_double // 2
|
179 |
-
return img[:, :, border:-border, border:-border]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
models/SRFlow/code/data/__init__.py
DELETED
@@ -1,51 +0,0 @@
|
|
1 |
-
# Copyright (c) 2020 Huawei Technologies Co., Ltd.
|
2 |
-
# Licensed under CC BY-NC-SA 4.0 (Attribution-NonCommercial-ShareAlike 4.0 International) (the "License");
|
3 |
-
# you may not use this file except in compliance with the License.
|
4 |
-
# You may obtain a copy of the License at
|
5 |
-
#
|
6 |
-
# https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode
|
7 |
-
#
|
8 |
-
# The code is released for academic research use only. For commercial use, please contact Huawei Technologies Co., Ltd.
|
9 |
-
# Unless required by applicable law or agreed to in writing, software
|
10 |
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
-
# See the License for the specific language governing permissions and
|
13 |
-
# limitations under the License.
|
14 |
-
#
|
15 |
-
# This file contains content licensed by https://github.com/xinntao/BasicSR/blob/master/LICENSE/LICENSE
|
16 |
-
|
17 |
-
'''create dataset and dataloader'''
|
18 |
-
import logging
|
19 |
-
import torch
|
20 |
-
import torch.utils.data
|
21 |
-
|
22 |
-
|
23 |
-
def create_dataloader(dataset, dataset_opt, opt=None, sampler=None):
|
24 |
-
phase = dataset_opt.get('phase', 'test')
|
25 |
-
if phase == 'train':
|
26 |
-
gpu_ids = opt.get('gpu_ids', None)
|
27 |
-
gpu_ids = gpu_ids if gpu_ids else []
|
28 |
-
num_workers = dataset_opt['n_workers'] * len(gpu_ids)
|
29 |
-
batch_size = dataset_opt['batch_size']
|
30 |
-
shuffle = True
|
31 |
-
return torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=shuffle,
|
32 |
-
num_workers=num_workers, sampler=sampler, drop_last=True,
|
33 |
-
pin_memory=False)
|
34 |
-
else:
|
35 |
-
return torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, num_workers=1,
|
36 |
-
pin_memory=True)
|
37 |
-
|
38 |
-
|
39 |
-
def create_dataset(dataset_opt):
|
40 |
-
print(dataset_opt)
|
41 |
-
mode = dataset_opt['mode']
|
42 |
-
if mode == 'LRHR_PKL':
|
43 |
-
from data.LRHR_PKL_dataset import LRHR_PKLDataset as D
|
44 |
-
else:
|
45 |
-
raise NotImplementedError('Dataset [{:s}] is not recognized.'.format(mode))
|
46 |
-
dataset = D(dataset_opt)
|
47 |
-
|
48 |
-
logger = logging.getLogger('base')
|
49 |
-
logger.info('Dataset [{:s} - {:s}] is created.'.format(dataset.__class__.__name__,
|
50 |
-
dataset_opt['name']))
|
51 |
-
return dataset
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
models/SRFlow/code/demo_on_pretrained.ipynb
DELETED
The diff for this file is too large to render.
See raw diff
|
|
models/SRFlow/code/imresize.py
DELETED
@@ -1,180 +0,0 @@
|
|
1 |
-
# https://github.com/fatheral/matlab_imresize
|
2 |
-
#
|
3 |
-
# MIT License
|
4 |
-
#
|
5 |
-
# Copyright (c) 2020 Alex
|
6 |
-
#
|
7 |
-
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
8 |
-
# of this software and associated documentation files (the "Software"), to deal
|
9 |
-
# in the Software without restriction, including without limitation the rights
|
10 |
-
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
11 |
-
# copies of the Software, and to permit persons to whom the Software is
|
12 |
-
# furnished to do so, subject to the following conditions:
|
13 |
-
#
|
14 |
-
# The above copyright notice and this permission notice shall be included in all
|
15 |
-
# copies or substantial portions of the Software.
|
16 |
-
#
|
17 |
-
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
18 |
-
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
19 |
-
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
20 |
-
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
21 |
-
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
22 |
-
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
23 |
-
# SOFTWARE.
|
24 |
-
|
25 |
-
|
26 |
-
from __future__ import print_function
|
27 |
-
import numpy as np
|
28 |
-
from math import ceil, floor
|
29 |
-
|
30 |
-
|
31 |
-
def deriveSizeFromScale(img_shape, scale):
|
32 |
-
output_shape = []
|
33 |
-
for k in range(2):
|
34 |
-
output_shape.append(int(ceil(scale[k] * img_shape[k])))
|
35 |
-
return output_shape
|
36 |
-
|
37 |
-
|
38 |
-
def deriveScaleFromSize(img_shape_in, img_shape_out):
|
39 |
-
scale = []
|
40 |
-
for k in range(2):
|
41 |
-
scale.append(1.0 * img_shape_out[k] / img_shape_in[k])
|
42 |
-
return scale
|
43 |
-
|
44 |
-
|
45 |
-
def triangle(x):
|
46 |
-
x = np.array(x).astype(np.float64)
|
47 |
-
lessthanzero = np.logical_and((x >= -1), x < 0)
|
48 |
-
greaterthanzero = np.logical_and((x <= 1), x >= 0)
|
49 |
-
f = np.multiply((x + 1), lessthanzero) + np.multiply((1 - x), greaterthanzero)
|
50 |
-
return f
|
51 |
-
|
52 |
-
|
53 |
-
def cubic(x):
|
54 |
-
x = np.array(x).astype(np.float64)
|
55 |
-
absx = np.absolute(x)
|
56 |
-
absx2 = np.multiply(absx, absx)
|
57 |
-
absx3 = np.multiply(absx2, absx)
|
58 |
-
f = np.multiply(1.5 * absx3 - 2.5 * absx2 + 1, absx <= 1) + np.multiply(-0.5 * absx3 + 2.5 * absx2 - 4 * absx + 2,
|
59 |
-
(1 < absx) & (absx <= 2))
|
60 |
-
return f
|
61 |
-
|
62 |
-
|
63 |
-
def contributions(in_length, out_length, scale, kernel, k_width):
|
64 |
-
if scale < 1:
|
65 |
-
h = lambda x: scale * kernel(scale * x)
|
66 |
-
kernel_width = 1.0 * k_width / scale
|
67 |
-
else:
|
68 |
-
h = kernel
|
69 |
-
kernel_width = k_width
|
70 |
-
x = np.arange(1, out_length + 1).astype(np.float64)
|
71 |
-
u = x / scale + 0.5 * (1 - 1 / scale)
|
72 |
-
left = np.floor(u - kernel_width / 2)
|
73 |
-
P = int(ceil(kernel_width)) + 2
|
74 |
-
ind = np.expand_dims(left, axis=1) + np.arange(P) - 1 # -1 because indexing from 0
|
75 |
-
indices = ind.astype(np.int32)
|
76 |
-
weights = h(np.expand_dims(u, axis=1) - indices - 1) # -1 because indexing from 0
|
77 |
-
weights = np.divide(weights, np.expand_dims(np.sum(weights, axis=1), axis=1))
|
78 |
-
aux = np.concatenate((np.arange(in_length), np.arange(in_length - 1, -1, step=-1))).astype(np.int32)
|
79 |
-
indices = aux[np.mod(indices, aux.size)]
|
80 |
-
ind2store = np.nonzero(np.any(weights, axis=0))
|
81 |
-
weights = weights[:, ind2store]
|
82 |
-
indices = indices[:, ind2store]
|
83 |
-
return weights, indices
|
84 |
-
|
85 |
-
|
86 |
-
def imresizemex(inimg, weights, indices, dim):
|
87 |
-
in_shape = inimg.shape
|
88 |
-
w_shape = weights.shape
|
89 |
-
out_shape = list(in_shape)
|
90 |
-
out_shape[dim] = w_shape[0]
|
91 |
-
outimg = np.zeros(out_shape)
|
92 |
-
if dim == 0:
|
93 |
-
for i_img in range(in_shape[1]):
|
94 |
-
for i_w in range(w_shape[0]):
|
95 |
-
w = weights[i_w, :]
|
96 |
-
ind = indices[i_w, :]
|
97 |
-
im_slice = inimg[ind, i_img].astype(np.float64)
|
98 |
-
outimg[i_w, i_img] = np.sum(np.multiply(np.squeeze(im_slice, axis=0), w.T), axis=0)
|
99 |
-
elif dim == 1:
|
100 |
-
for i_img in range(in_shape[0]):
|
101 |
-
for i_w in range(w_shape[0]):
|
102 |
-
w = weights[i_w, :]
|
103 |
-
ind = indices[i_w, :]
|
104 |
-
im_slice = inimg[i_img, ind].astype(np.float64)
|
105 |
-
outimg[i_img, i_w] = np.sum(np.multiply(np.squeeze(im_slice, axis=0), w.T), axis=0)
|
106 |
-
if inimg.dtype == np.uint8:
|
107 |
-
outimg = np.clip(outimg, 0, 255)
|
108 |
-
return np.around(outimg).astype(np.uint8)
|
109 |
-
else:
|
110 |
-
return outimg
|
111 |
-
|
112 |
-
|
113 |
-
def imresizevec(inimg, weights, indices, dim):
|
114 |
-
wshape = weights.shape
|
115 |
-
if dim == 0:
|
116 |
-
weights = weights.reshape((wshape[0], wshape[2], 1, 1))
|
117 |
-
outimg = np.sum(weights * ((inimg[indices].squeeze(axis=1)).astype(np.float64)), axis=1)
|
118 |
-
elif dim == 1:
|
119 |
-
weights = weights.reshape((1, wshape[0], wshape[2], 1))
|
120 |
-
outimg = np.sum(weights * ((inimg[:, indices].squeeze(axis=2)).astype(np.float64)), axis=2)
|
121 |
-
if inimg.dtype == np.uint8:
|
122 |
-
outimg = np.clip(outimg, 0, 255)
|
123 |
-
return np.around(outimg).astype(np.uint8)
|
124 |
-
else:
|
125 |
-
return outimg
|
126 |
-
|
127 |
-
|
128 |
-
def resizeAlongDim(A, dim, weights, indices, mode="vec"):
|
129 |
-
if mode == "org":
|
130 |
-
out = imresizemex(A, weights, indices, dim)
|
131 |
-
else:
|
132 |
-
out = imresizevec(A, weights, indices, dim)
|
133 |
-
return out
|
134 |
-
|
135 |
-
|
136 |
-
def imresize(I, scalar_scale=None, method='bicubic', output_shape=None, mode="vec"):
|
137 |
-
if method is 'bicubic':
|
138 |
-
kernel = cubic
|
139 |
-
elif method is 'bilinear':
|
140 |
-
kernel = triangle
|
141 |
-
else:
|
142 |
-
print('Error: Unidentified method supplied')
|
143 |
-
|
144 |
-
kernel_width = 4.0
|
145 |
-
# Fill scale and output_size
|
146 |
-
if scalar_scale is not None:
|
147 |
-
scalar_scale = float(scalar_scale)
|
148 |
-
scale = [scalar_scale, scalar_scale]
|
149 |
-
output_size = deriveSizeFromScale(I.shape, scale)
|
150 |
-
elif output_shape is not None:
|
151 |
-
scale = deriveScaleFromSize(I.shape, output_shape)
|
152 |
-
output_size = list(output_shape)
|
153 |
-
else:
|
154 |
-
print('Error: scalar_scale OR output_shape should be defined!')
|
155 |
-
return
|
156 |
-
scale_np = np.array(scale)
|
157 |
-
order = np.argsort(scale_np)
|
158 |
-
weights = []
|
159 |
-
indices = []
|
160 |
-
for k in range(2):
|
161 |
-
w, ind = contributions(I.shape[k], output_size[k], scale[k], kernel, kernel_width)
|
162 |
-
weights.append(w)
|
163 |
-
indices.append(ind)
|
164 |
-
B = np.copy(I)
|
165 |
-
flag2D = False
|
166 |
-
if B.ndim == 2:
|
167 |
-
B = np.expand_dims(B, axis=2)
|
168 |
-
flag2D = True
|
169 |
-
for k in range(2):
|
170 |
-
dim = order[k]
|
171 |
-
B = resizeAlongDim(B, dim, weights[dim], indices[dim], mode)
|
172 |
-
if flag2D:
|
173 |
-
B = np.squeeze(B, axis=2)
|
174 |
-
return B
|
175 |
-
|
176 |
-
|
177 |
-
def convertDouble2Byte(I):
|
178 |
-
B = np.clip(I, 0.0, 1.0)
|
179 |
-
B = 255 * B
|
180 |
-
return np.around(B).astype(np.uint8)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
models/SRFlow/code/prepare_data.py
DELETED
@@ -1,118 +0,0 @@
|
|
1 |
-
# Copyright (c) 2020 Huawei Technologies Co., Ltd.
|
2 |
-
# Licensed under CC BY-NC-SA 4.0 (Attribution-NonCommercial-ShareAlike 4.0 International) (the "License");
|
3 |
-
# you may not use this file except in compliance with the License.
|
4 |
-
# You may obtain a copy of the License at
|
5 |
-
#
|
6 |
-
# https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode
|
7 |
-
#
|
8 |
-
# The code is released for academic research use only. For commercial use, please contact Huawei Technologies Co., Ltd.
|
9 |
-
# Unless required by applicable law or agreed to in writing, software
|
10 |
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
-
# See the License for the specific language governing permissions and
|
13 |
-
# limitations under the License.
|
14 |
-
|
15 |
-
import glob
|
16 |
-
import os
|
17 |
-
import sys
|
18 |
-
|
19 |
-
import numpy as np
|
20 |
-
import random
|
21 |
-
import imageio
|
22 |
-
import pickle
|
23 |
-
|
24 |
-
from natsort import natsort
|
25 |
-
from tqdm import tqdm
|
26 |
-
|
27 |
-
def get_img_paths(dir_path, wildcard='*.png'):
|
28 |
-
return natsort.natsorted(glob.glob(dir_path + '/' + wildcard))
|
29 |
-
|
30 |
-
def create_all_dirs(path):
|
31 |
-
if "." in path.split("/")[-1]:
|
32 |
-
dirs = os.path.dirname(path)
|
33 |
-
else:
|
34 |
-
dirs = path
|
35 |
-
os.makedirs(dirs, exist_ok=True)
|
36 |
-
|
37 |
-
def to_pklv4(obj, path, vebose=False):
|
38 |
-
create_all_dirs(path)
|
39 |
-
with open(path, 'wb') as f:
|
40 |
-
pickle.dump(obj, f, protocol=4)
|
41 |
-
if vebose:
|
42 |
-
print("Wrote {}".format(path))
|
43 |
-
|
44 |
-
|
45 |
-
from imresize import imresize
|
46 |
-
|
47 |
-
def random_crop(img, size):
|
48 |
-
h, w, c = img.shape
|
49 |
-
|
50 |
-
h_start = np.random.randint(0, h - size)
|
51 |
-
h_end = h_start + size
|
52 |
-
|
53 |
-
w_start = np.random.randint(0, w - size)
|
54 |
-
w_end = w_start + size
|
55 |
-
|
56 |
-
return img[h_start:h_end, w_start:w_end]
|
57 |
-
|
58 |
-
|
59 |
-
def imread(img_path):
|
60 |
-
img = imageio.imread(img_path)
|
61 |
-
if len(img.shape) == 2:
|
62 |
-
img = np.stack([img, ] * 3, axis=2)
|
63 |
-
return img
|
64 |
-
|
65 |
-
|
66 |
-
def to_pklv4_1pct(obj, path, vebose):
|
67 |
-
n = int(round(len(obj) * 0.01))
|
68 |
-
path = path.replace(".", "_1pct.")
|
69 |
-
to_pklv4(obj[:n], path, vebose=True)
|
70 |
-
|
71 |
-
|
72 |
-
def main(dir_path):
|
73 |
-
hrs = []
|
74 |
-
lqs = []
|
75 |
-
|
76 |
-
img_paths = get_img_paths(dir_path)
|
77 |
-
for img_path in tqdm(img_paths):
|
78 |
-
img = imread(img_path)
|
79 |
-
|
80 |
-
for i in range(47):
|
81 |
-
crop = random_crop(img, 256)
|
82 |
-
cropX4 = imresize(crop, scalar_scale=0.25)
|
83 |
-
hrs.append(crop)
|
84 |
-
lqs.append(cropX4)
|
85 |
-
|
86 |
-
shuffle_combined(hrs, lqs)
|
87 |
-
|
88 |
-
hrs_path = get_hrs_path(dir_path)
|
89 |
-
to_pklv4(hrs, hrs_path, vebose=True)
|
90 |
-
|
91 |
-
lqs_path = get_lqs_path(dir_path)
|
92 |
-
to_pklv4(lqs, lqs_path, vebose=True)
|
93 |
-
|
94 |
-
|
95 |
-
def get_hrs_path(dir_path):
|
96 |
-
base_dir = '/kaggle/working/'
|
97 |
-
name = os.path.basename(dir_path)
|
98 |
-
hrs_path = os.path.join(base_dir, 'pkls', name + '.pklv4')
|
99 |
-
return hrs_path
|
100 |
-
|
101 |
-
|
102 |
-
def get_lqs_path(dir_path):
|
103 |
-
base_dir = '/kaggle/working/'
|
104 |
-
name = os.path.basename(dir_path)
|
105 |
-
hrs_path = os.path.join(base_dir, 'pkls', name + '_X4.pklv4')
|
106 |
-
return hrs_path
|
107 |
-
|
108 |
-
|
109 |
-
def shuffle_combined(hrs, lqs):
|
110 |
-
combined = list(zip(hrs, lqs))
|
111 |
-
random.shuffle(combined)
|
112 |
-
hrs[:], lqs[:] = zip(*combined)
|
113 |
-
|
114 |
-
|
115 |
-
if __name__ == "__main__":
|
116 |
-
dir_path = sys.argv[1]
|
117 |
-
assert os.path.isdir(dir_path)
|
118 |
-
main(dir_path)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
models/SRFlow/code/test.py
DELETED
@@ -1,192 +0,0 @@
|
|
1 |
-
# Copyright (c) 2020 Huawei Technologies Co., Ltd.
|
2 |
-
# Licensed under CC BY-NC-SA 4.0 (Attribution-NonCommercial-ShareAlike 4.0 International) (the "License");
|
3 |
-
# you may not use this file except in compliance with the License.
|
4 |
-
# You may obtain a copy of the License at
|
5 |
-
#
|
6 |
-
# https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode
|
7 |
-
#
|
8 |
-
# The code is released for academic research use only. For commercial use, please contact Huawei Technologies Co., Ltd.
|
9 |
-
# Unless required by applicable law or agreed to in writing, software
|
10 |
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
-
# See the License for the specific language governing permissions and
|
13 |
-
# limitations under the License.
|
14 |
-
#
|
15 |
-
# This file contains content licensed by https://github.com/xinntao/BasicSR/blob/master/LICENSE/LICENSE
|
16 |
-
|
17 |
-
|
18 |
-
import glob
|
19 |
-
import sys
|
20 |
-
from collections import OrderedDict
|
21 |
-
|
22 |
-
from natsort import natsort
|
23 |
-
|
24 |
-
import options.options as option
|
25 |
-
from Measure import Measure, psnr
|
26 |
-
from imresize import imresize
|
27 |
-
from models import create_model
|
28 |
-
import torch
|
29 |
-
from utils.util import opt_get
|
30 |
-
import numpy as np
|
31 |
-
import pandas as pd
|
32 |
-
import os
|
33 |
-
import cv2
|
34 |
-
|
35 |
-
|
36 |
-
def fiFindByWildcard(wildcard):
|
37 |
-
return natsort.natsorted(glob.glob(wildcard, recursive=True))
|
38 |
-
|
39 |
-
|
40 |
-
def load_model(conf_path):
|
41 |
-
opt = option.parse(conf_path, is_train=False)
|
42 |
-
opt['gpu_ids'] = None
|
43 |
-
opt = option.dict_to_nonedict(opt)
|
44 |
-
model = create_model(opt)
|
45 |
-
|
46 |
-
model_path = opt_get(opt, ['model_path'], None)
|
47 |
-
model.load_network(load_path=model_path, network=model.netG)
|
48 |
-
return model, opt
|
49 |
-
|
50 |
-
|
51 |
-
def predict(model, lr):
|
52 |
-
model.feed_data({"LQ": t(lr)}, need_GT=False)
|
53 |
-
model.test()
|
54 |
-
visuals = model.get_current_visuals(need_GT=False)
|
55 |
-
return visuals.get('rlt', visuals.get("SR"))
|
56 |
-
|
57 |
-
|
58 |
-
def t(array): return torch.Tensor(np.expand_dims(array.transpose([2, 0, 1]), axis=0).astype(np.float32)) / 255
|
59 |
-
|
60 |
-
|
61 |
-
def rgb(t): return (
|
62 |
-
np.clip((t[0] if len(t.shape) == 4 else t).detach().cpu().numpy().transpose([1, 2, 0]), 0, 1) * 255).astype(
|
63 |
-
np.uint8)
|
64 |
-
|
65 |
-
|
66 |
-
def imread(path):
|
67 |
-
return cv2.imread(path)[:, :, [2, 1, 0]]
|
68 |
-
|
69 |
-
|
70 |
-
def imwrite(path, img):
|
71 |
-
os.makedirs(os.path.dirname(path), exist_ok=True)
|
72 |
-
cv2.imwrite(path, img[:, :, [2, 1, 0]])
|
73 |
-
|
74 |
-
|
75 |
-
def imCropCenter(img, size):
|
76 |
-
h, w, c = img.shape
|
77 |
-
|
78 |
-
h_start = max(h // 2 - size // 2, 0)
|
79 |
-
h_end = min(h_start + size, h)
|
80 |
-
|
81 |
-
w_start = max(w // 2 - size // 2, 0)
|
82 |
-
w_end = min(w_start + size, w)
|
83 |
-
|
84 |
-
return img[h_start:h_end, w_start:w_end]
|
85 |
-
|
86 |
-
|
87 |
-
def impad(img, top=0, bottom=0, left=0, right=0, color=255):
|
88 |
-
return np.pad(img, [(top, bottom), (left, right), (0, 0)], 'reflect')
|
89 |
-
|
90 |
-
|
91 |
-
def main():
|
92 |
-
conf_path = sys.argv[1]
|
93 |
-
conf = conf_path.split('/')[-1].replace('.yml', '')
|
94 |
-
model, opt = load_model(conf_path)
|
95 |
-
|
96 |
-
data_dir = opt['dataroot']
|
97 |
-
|
98 |
-
# this_dir = os.path.dirname(os.path.realpath(__file__))
|
99 |
-
test_dir = os.path.join('/kaggle/working/', 'results', conf)
|
100 |
-
print(f"Out dir: {test_dir}")
|
101 |
-
|
102 |
-
measure = Measure(use_gpu=False)
|
103 |
-
|
104 |
-
fname = f'measure_full.csv'
|
105 |
-
fname_tmp = fname + "_"
|
106 |
-
path_out_measures = os.path.join(test_dir, fname_tmp)
|
107 |
-
path_out_measures_final = os.path.join(test_dir, fname)
|
108 |
-
|
109 |
-
if os.path.isfile(path_out_measures_final):
|
110 |
-
df = pd.read_csv(path_out_measures_final)
|
111 |
-
elif os.path.isfile(path_out_measures):
|
112 |
-
df = pd.read_csv(path_out_measures)
|
113 |
-
else:
|
114 |
-
df = None
|
115 |
-
|
116 |
-
scale = opt['scale']
|
117 |
-
|
118 |
-
pad_factor = 2
|
119 |
-
|
120 |
-
data_sets = [
|
121 |
-
'Set5',
|
122 |
-
'Set14',
|
123 |
-
'Urban100',
|
124 |
-
'BSD100'
|
125 |
-
]
|
126 |
-
|
127 |
-
final_df = pd.DataFrame()
|
128 |
-
|
129 |
-
for data_set in data_sets:
|
130 |
-
lr_paths = fiFindByWildcard(os.path.join(data_dir, data_set, '*LR.png'))
|
131 |
-
hr_paths = fiFindByWildcard(os.path.join(data_dir, data_set, '*HR.png'))
|
132 |
-
|
133 |
-
df = pd.DataFrame(columns=['conf', 'heat', 'data_set', 'name', 'PSNR', 'SSIM', 'LPIPS', 'LRC PSNR'])
|
134 |
-
|
135 |
-
for lr_path, hr_path, idx_test in zip(lr_paths, hr_paths, range(len(lr_paths))):
|
136 |
-
with torch.no_grad(), torch.cuda.amp.autocast():
|
137 |
-
lr = imread(lr_path)
|
138 |
-
hr = imread(hr_path)
|
139 |
-
|
140 |
-
# Pad image to be % 2
|
141 |
-
h, w, c = lr.shape
|
142 |
-
lq_orig = lr.copy()
|
143 |
-
lr = impad(lr, bottom=int(np.ceil(h / pad_factor) * pad_factor - h),
|
144 |
-
right=int(np.ceil(w / pad_factor) * pad_factor - w))
|
145 |
-
|
146 |
-
lr_t = t(lr)
|
147 |
-
|
148 |
-
heat = opt['heat']
|
149 |
-
|
150 |
-
if df is not None and len(df[(df['heat'] == heat) & (df['name'] == idx_test)]) == 1:
|
151 |
-
continue
|
152 |
-
|
153 |
-
sr_t = model.get_sr(lq=lr_t, heat=heat)
|
154 |
-
|
155 |
-
sr = rgb(torch.clamp(sr_t, 0, 1))
|
156 |
-
sr = sr[:h * scale, :w * scale]
|
157 |
-
|
158 |
-
path_out_sr = os.path.join(test_dir, data_set, "{:0.2f}".format(heat).replace('.', ''), "{:06d}.png".format(idx_test))
|
159 |
-
imwrite(path_out_sr, sr)
|
160 |
-
|
161 |
-
meas = OrderedDict(conf=conf, heat=heat, data_set=data_set, name=idx_test)
|
162 |
-
meas['PSNR'], meas['SSIM'], meas['LPIPS'] = measure.measure(sr, hr)
|
163 |
-
|
164 |
-
lr_reconstruct_rgb = imresize(sr, 1 / opt['scale'])
|
165 |
-
meas['LRC PSNR'] = psnr(lq_orig, lr_reconstruct_rgb)
|
166 |
-
|
167 |
-
str_out = format_measurements(meas)
|
168 |
-
print(str_out)
|
169 |
-
|
170 |
-
df = df._append(pd.DataFrame([meas]), ignore_index=True)
|
171 |
-
|
172 |
-
final_df = pd.concat([final_df, df])
|
173 |
-
|
174 |
-
final_df.to_csv(path_out_measures, index=False)
|
175 |
-
os.rename(path_out_measures, path_out_measures_final)
|
176 |
-
|
177 |
-
# str_out = format_measurements(df.mean())
|
178 |
-
# print(f"Results in: {path_out_measures_final}")
|
179 |
-
# print('Mean: ' + str_out)
|
180 |
-
|
181 |
-
|
182 |
-
def format_measurements(meas):
|
183 |
-
s_out = []
|
184 |
-
for k, v in meas.items():
|
185 |
-
v = f"{v:0.2f}" if isinstance(v, float) else v
|
186 |
-
s_out.append(f"{k}: {v}")
|
187 |
-
str_out = ", ".join(s_out)
|
188 |
-
return str_out
|
189 |
-
|
190 |
-
|
191 |
-
if __name__ == "__main__":
|
192 |
-
main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
models/SRFlow/code/train.py
DELETED
@@ -1,328 +0,0 @@
|
|
1 |
-
# Copyright (c) 2020 Huawei Technologies Co., Ltd.
|
2 |
-
# Licensed under CC BY-NC-SA 4.0 (Attribution-NonCommercial-ShareAlike 4.0 International) (the "License");
|
3 |
-
# you may not use this file except in compliance with the License.
|
4 |
-
# You may obtain a copy of the License at
|
5 |
-
#
|
6 |
-
# https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode
|
7 |
-
#
|
8 |
-
# The code is released for academic research use only. For commercial use, please contact Huawei Technologies Co., Ltd.
|
9 |
-
# Unless required by applicable law or agreed to in writing, software
|
10 |
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
-
# See the License for the specific language governing permissions and
|
13 |
-
# limitations under the License.
|
14 |
-
#
|
15 |
-
# This file contains content licensed by https://github.com/xinntao/BasicSR/blob/master/LICENSE/LICENSE
|
16 |
-
|
17 |
-
import os
|
18 |
-
from os.path import basename
|
19 |
-
import math
|
20 |
-
import argparse
|
21 |
-
import random
|
22 |
-
import logging
|
23 |
-
import cv2
|
24 |
-
|
25 |
-
import torch
|
26 |
-
import torch.distributed as dist
|
27 |
-
import torch.multiprocessing as mp
|
28 |
-
|
29 |
-
import options.options as option
|
30 |
-
from utils import util
|
31 |
-
from data import create_dataloader, create_dataset
|
32 |
-
from models import create_model
|
33 |
-
from utils.timer import Timer, TickTock
|
34 |
-
from utils.util import get_resume_paths
|
35 |
-
|
36 |
-
import wandb
|
37 |
-
|
38 |
-
def getEnv(name): import os; return True if name in os.environ.keys() else False
|
39 |
-
|
40 |
-
|
41 |
-
def init_dist(backend='nccl', **kwargs):
|
42 |
-
''' initialization for distributed training'''
|
43 |
-
# if mp.get_start_method(allow_none=True) is None:
|
44 |
-
if mp.get_start_method(allow_none=True) != 'spawn':
|
45 |
-
mp.set_start_method('spawn')
|
46 |
-
rank = int(os.environ['RANK'])
|
47 |
-
num_gpus = torch.cuda.device_count()
|
48 |
-
torch.cuda.set_deviceDistIterSampler(rank % num_gpus)
|
49 |
-
dist.init_process_group(backend=backend, **kwargs)
|
50 |
-
|
51 |
-
|
52 |
-
def main():
|
53 |
-
wandb.init(project='srflow')
|
54 |
-
#### options
|
55 |
-
parser = argparse.ArgumentParser()
|
56 |
-
parser.add_argument('-opt', type=str, help='Path to option YMAL file.')
|
57 |
-
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none',
|
58 |
-
help='job launcher')
|
59 |
-
parser.add_argument('--local_rank', type=int, default=0)
|
60 |
-
args = parser.parse_args()
|
61 |
-
opt = option.parse(args.opt, is_train=True)
|
62 |
-
|
63 |
-
#### distributed training settings
|
64 |
-
opt['dist'] = False
|
65 |
-
rank = -1
|
66 |
-
print('Disabled distributed training.')
|
67 |
-
|
68 |
-
#### loading resume state if exists
|
69 |
-
if opt['path'].get('resume_state', None):
|
70 |
-
resume_state_path, _ = get_resume_paths(opt)
|
71 |
-
|
72 |
-
# distributed resuming: all load into default GPU
|
73 |
-
if resume_state_path is None:
|
74 |
-
resume_state = None
|
75 |
-
else:
|
76 |
-
device_id = torch.cuda.current_device()
|
77 |
-
resume_state = torch.load(resume_state_path,
|
78 |
-
map_location=lambda storage, loc: storage.cuda(device_id))
|
79 |
-
option.check_resume(opt, resume_state['iter']) # check resume options
|
80 |
-
else:
|
81 |
-
resume_state = None
|
82 |
-
|
83 |
-
#### mkdir and loggers
|
84 |
-
if rank <= 0: # normal training (rank -1) OR distributed training (rank 0)
|
85 |
-
if resume_state is None:
|
86 |
-
util.mkdir_and_rename(
|
87 |
-
opt['path']['experiments_root']) # rename experiment folder if exists
|
88 |
-
util.mkdirs((path for key, path in opt['path'].items() if not key == 'experiments_root'
|
89 |
-
and 'pretrain_model' not in key and 'resume' not in key))
|
90 |
-
|
91 |
-
# config loggers. Before it, the log will not work
|
92 |
-
util.setup_logger('base', opt['path']['log'], 'train_' + opt['name'], level=logging.INFO,
|
93 |
-
screen=True, tofile=True)
|
94 |
-
util.setup_logger('val', opt['path']['log'], 'val_' + opt['name'], level=logging.INFO,
|
95 |
-
screen=True, tofile=True)
|
96 |
-
logger = logging.getLogger('base')
|
97 |
-
logger.info(option.dict2str(opt))
|
98 |
-
|
99 |
-
# tensorboard logger
|
100 |
-
if opt.get('use_tb_logger', False) and 'debug' not in opt['name']:
|
101 |
-
version = float(torch.__version__[0:3])
|
102 |
-
if version >= 1.1: # PyTorch 1.1
|
103 |
-
from torch.utils.tensorboard import SummaryWriter
|
104 |
-
else:
|
105 |
-
logger.info(
|
106 |
-
'You are using PyTorch {}. Tensorboard will use [tensorboardX]'.format(version))
|
107 |
-
from tensorboardX import SummaryWriter
|
108 |
-
conf_name = basename(args.opt).replace(".yml", "")
|
109 |
-
exp_dir = opt['path']['experiments_root']
|
110 |
-
log_dir_train = os.path.join(exp_dir, 'tb', conf_name, 'train')
|
111 |
-
log_dir_valid = os.path.join(exp_dir, 'tb', conf_name, 'valid')
|
112 |
-
tb_logger_train = SummaryWriter(log_dir=log_dir_train)
|
113 |
-
tb_logger_valid = SummaryWriter(log_dir=log_dir_valid)
|
114 |
-
else:
|
115 |
-
util.setup_logger('base', opt['path']['log'], 'train', level=logging.INFO, screen=True)
|
116 |
-
logger = logging.getLogger('base')
|
117 |
-
|
118 |
-
# convert to NoneDict, which returns None for missing keys
|
119 |
-
opt = option.dict_to_nonedict(opt)
|
120 |
-
|
121 |
-
#### random seed
|
122 |
-
seed = opt['train']['manual_seed']
|
123 |
-
if seed is None:
|
124 |
-
seed = random.randint(1, 10000)
|
125 |
-
if rank <= 0:
|
126 |
-
logger.info('Random seed: {}'.format(seed))
|
127 |
-
util.set_random_seed(seed)
|
128 |
-
|
129 |
-
torch.backends.cudnn.benchmark = True
|
130 |
-
# torch.backends.cudnn.deterministic = True
|
131 |
-
|
132 |
-
#### create train and val dataloader
|
133 |
-
dataset_ratio = 200 # enlarge the size of each epoch
|
134 |
-
for phase, dataset_opt in opt['datasets'].items():
|
135 |
-
if phase == 'train':
|
136 |
-
full_dataset = create_dataset(dataset_opt)
|
137 |
-
print('Dataset created')
|
138 |
-
train_len = int(len(full_dataset) * 0.95)
|
139 |
-
val_len = len(full_dataset) - train_len
|
140 |
-
train_set, val_set = torch.utils.data.random_split(full_dataset, [train_len, val_len])
|
141 |
-
train_size = int(math.ceil(train_len / dataset_opt['batch_size']))
|
142 |
-
total_iters = int(opt['train']['niter'])
|
143 |
-
total_epochs = int(math.ceil(total_iters / train_size))
|
144 |
-
train_sampler = None
|
145 |
-
train_loader = create_dataloader(train_set, dataset_opt, opt, train_sampler)
|
146 |
-
if rank <= 0:
|
147 |
-
logger.info('Number of train images: {:,d}, iters: {:,d}'.format(
|
148 |
-
len(train_set), train_size))
|
149 |
-
logger.info('Total epochs needed: {:d} for iters {:,d}'.format(
|
150 |
-
total_epochs, total_iters))
|
151 |
-
val_loader = torch.utils.data.DataLoader(val_set, batch_size=1, shuffle=False, num_workers=1,
|
152 |
-
pin_memory=True)
|
153 |
-
elif phase == 'val':
|
154 |
-
continue
|
155 |
-
else:
|
156 |
-
raise NotImplementedError('Phase [{:s}] is not recognized.'.format(phase))
|
157 |
-
assert train_loader is not None
|
158 |
-
|
159 |
-
#### create model
|
160 |
-
current_step = 0 if resume_state is None else resume_state['iter']
|
161 |
-
model = create_model(opt, current_step)
|
162 |
-
|
163 |
-
#### resume training
|
164 |
-
if resume_state:
|
165 |
-
logger.info('Resuming training from epoch: {}, iter: {}.'.format(
|
166 |
-
resume_state['epoch'], resume_state['iter']))
|
167 |
-
|
168 |
-
start_epoch = resume_state['epoch']
|
169 |
-
current_step = resume_state['iter']
|
170 |
-
model.resume_training(resume_state) # handle optimizers and schedulers
|
171 |
-
else:
|
172 |
-
current_step = 0
|
173 |
-
start_epoch = 0
|
174 |
-
|
175 |
-
#### training
|
176 |
-
timer = Timer()
|
177 |
-
logger.info('Start training from epoch: {:d}, iter: {:d}'.format(start_epoch, current_step))
|
178 |
-
timerData = TickTock()
|
179 |
-
|
180 |
-
for epoch in range(start_epoch, total_epochs + 1):
|
181 |
-
if opt['dist']:
|
182 |
-
train_sampler.set_epoch(epoch)
|
183 |
-
|
184 |
-
timerData.tick()
|
185 |
-
for _, train_data in enumerate(train_loader):
|
186 |
-
timerData.tock()
|
187 |
-
current_step += 1
|
188 |
-
if current_step > total_iters:
|
189 |
-
break
|
190 |
-
|
191 |
-
#### training
|
192 |
-
model.feed_data(train_data)
|
193 |
-
|
194 |
-
#### update learning rate
|
195 |
-
model.update_learning_rate(current_step, warmup_iter=opt['train']['warmup_iter'])
|
196 |
-
|
197 |
-
try:
|
198 |
-
nll = model.optimize_parameters(current_step)
|
199 |
-
except RuntimeError as e:
|
200 |
-
print("Skipping ERROR caught in nll = model.optimize_parameters(current_step): ")
|
201 |
-
print(e)
|
202 |
-
|
203 |
-
if nll is None:
|
204 |
-
nll = 0
|
205 |
-
|
206 |
-
wandb.log({"loss": nll})
|
207 |
-
#### log
|
208 |
-
def eta(t_iter):
|
209 |
-
return (t_iter * (opt['train']['niter'] - current_step)) / 3600
|
210 |
-
|
211 |
-
if current_step % opt['logger']['print_freq'] == 0 \
|
212 |
-
or current_step - (resume_state['iter'] if resume_state else 0) < 25:
|
213 |
-
avg_time = timer.get_average_and_reset()
|
214 |
-
avg_data_time = timerData.get_average_and_reset()
|
215 |
-
message = '<epoch:{:3d}, iter:{:8,d}, lr:{:.3e}, t:{:.2e}, td:{:.2e}, eta:{:.2e}, nll:{:.3e}> '.format(
|
216 |
-
epoch, current_step, model.get_current_learning_rate(), avg_time, avg_data_time,
|
217 |
-
eta(avg_time), nll)
|
218 |
-
print(message)
|
219 |
-
timer.tick()
|
220 |
-
# Reduce number of logs
|
221 |
-
if current_step % 5 == 0:
|
222 |
-
tb_logger_train.add_scalar('loss/nll', nll, current_step)
|
223 |
-
tb_logger_train.add_scalar('lr/base', model.get_current_learning_rate(), current_step)
|
224 |
-
tb_logger_train.add_scalar('time/iteration', timer.get_last_iteration(), current_step)
|
225 |
-
tb_logger_train.add_scalar('time/data', timerData.get_last_iteration(), current_step)
|
226 |
-
tb_logger_train.add_scalar('time/eta', eta(timer.get_last_iteration()), current_step)
|
227 |
-
for k, v in model.get_current_log().items():
|
228 |
-
tb_logger_train.add_scalar(k, v, current_step)
|
229 |
-
|
230 |
-
# validation
|
231 |
-
if current_step % opt['train']['val_freq'] == 0 and rank <= 0:
|
232 |
-
avg_psnr = 0.0
|
233 |
-
idx = 0
|
234 |
-
nlls = []
|
235 |
-
for val_data in val_loader:
|
236 |
-
idx += 1
|
237 |
-
img_name = os.path.splitext(os.path.basename(val_data['LQ_path'][0]))[0]
|
238 |
-
img_dir = os.path.join(opt['path']['val_images'], img_name)
|
239 |
-
util.mkdir(img_dir)
|
240 |
-
|
241 |
-
model.feed_data(val_data)
|
242 |
-
|
243 |
-
nll = model.test()
|
244 |
-
if nll is None:
|
245 |
-
nll = 0
|
246 |
-
nlls.append(nll)
|
247 |
-
|
248 |
-
visuals = model.get_current_visuals()
|
249 |
-
|
250 |
-
sr_img = None
|
251 |
-
# Save SR images for reference
|
252 |
-
if hasattr(model, 'heats'):
|
253 |
-
for heat in model.heats:
|
254 |
-
for i in range(model.n_sample):
|
255 |
-
sr_img = util.tensor2img(visuals['SR', heat, i]) # uint8
|
256 |
-
save_img_path = os.path.join(img_dir,
|
257 |
-
'{:s}_{:09d}_h{:03d}_s{:d}.png'.format(img_name,
|
258 |
-
current_step,
|
259 |
-
int(heat * 100), i))
|
260 |
-
util.save_img(sr_img, save_img_path)
|
261 |
-
else:
|
262 |
-
sr_img = util.tensor2img(visuals['SR']) # uint8
|
263 |
-
save_img_path = os.path.join(img_dir,
|
264 |
-
'{:s}_{:d}.png'.format(img_name, current_step))
|
265 |
-
util.save_img(sr_img, save_img_path)
|
266 |
-
assert sr_img is not None
|
267 |
-
|
268 |
-
# Save LQ images for reference
|
269 |
-
save_img_path_lq = os.path.join(img_dir,
|
270 |
-
'{:s}_LQ.png'.format(img_name))
|
271 |
-
if not os.path.isfile(save_img_path_lq):
|
272 |
-
lq_img = util.tensor2img(visuals['LQ']) # uint8
|
273 |
-
util.save_img(
|
274 |
-
cv2.resize(lq_img, dsize=None, fx=opt['scale'], fy=opt['scale'],
|
275 |
-
interpolation=cv2.INTER_NEAREST),
|
276 |
-
save_img_path_lq)
|
277 |
-
|
278 |
-
# Save GT images for reference
|
279 |
-
gt_img = util.tensor2img(visuals['GT']) # uint8
|
280 |
-
save_img_path_gt = os.path.join(img_dir,
|
281 |
-
'{:s}_GT.png'.format(img_name))
|
282 |
-
if not os.path.isfile(save_img_path_gt):
|
283 |
-
util.save_img(gt_img, save_img_path_gt)
|
284 |
-
|
285 |
-
# calculate PSNR
|
286 |
-
crop_size = opt['scale']
|
287 |
-
gt_img = gt_img / 255.
|
288 |
-
sr_img = sr_img / 255.
|
289 |
-
cropped_sr_img = sr_img[crop_size:-crop_size, crop_size:-crop_size, :]
|
290 |
-
cropped_gt_img = gt_img[crop_size:-crop_size, crop_size:-crop_size, :]
|
291 |
-
avg_psnr += util.calculate_psnr(cropped_sr_img * 255, cropped_gt_img * 255)
|
292 |
-
|
293 |
-
avg_psnr = avg_psnr / idx
|
294 |
-
avg_nll = sum(nlls) / len(nlls)
|
295 |
-
|
296 |
-
# log
|
297 |
-
logger.info('# Validation # PSNR: {:.4e}'.format(avg_psnr))
|
298 |
-
logger_val = logging.getLogger('val') # validation logger
|
299 |
-
logger_val.info('<epoch:{:3d}, iter:{:8,d}> psnr: {:.4e}'.format(
|
300 |
-
epoch, current_step, avg_psnr))
|
301 |
-
|
302 |
-
# tensorboard logger
|
303 |
-
tb_logger_valid.add_scalar('loss/psnr', avg_psnr, current_step)
|
304 |
-
tb_logger_valid.add_scalar('loss/nll', avg_nll, current_step)
|
305 |
-
|
306 |
-
tb_logger_train.flush()
|
307 |
-
tb_logger_valid.flush()
|
308 |
-
|
309 |
-
#### save models and training states
|
310 |
-
if current_step % opt['logger']['save_checkpoint_freq'] == 0:
|
311 |
-
if rank <= 0:
|
312 |
-
logger.info('Saving models and training states.')
|
313 |
-
model.save(current_step)
|
314 |
-
model.save_training_state(epoch, current_step)
|
315 |
-
|
316 |
-
timerData.tick()
|
317 |
-
|
318 |
-
with open(os.path.join(opt['path']['root'], "TRAIN_DONE"), 'w') as f:
|
319 |
-
f.write("TRAIN_DONE")
|
320 |
-
|
321 |
-
if rank <= 0:
|
322 |
-
logger.info('Saving the final model.')
|
323 |
-
model.save('latest')
|
324 |
-
logger.info('End of training.')
|
325 |
-
|
326 |
-
|
327 |
-
if __name__ == '__main__':
|
328 |
-
main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|