SunderAli17 commited on
Commit
b230f3c
·
verified ·
1 Parent(s): 73bdc04

Create utils/degradation_pipeline.py

Browse files
Files changed (1) hide show
  1. utils/degradation_pipeline.py +353 -0
utils/degradation_pipeline.py ADDED
@@ -0,0 +1,353 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import math
3
+ import numpy as np
4
+ import random
5
+ import torch
6
+ from torch.utils import data as data
7
+
8
+ from basicsr.data.degradations import circular_lowpass_kernel, random_mixed_kernels
9
+ from basicsr.data.transforms import augment
10
+ from basicsr.utils import img2tensor, DiffJPEG, USMSharp
11
+ from basicsr.utils.img_process_util import filter2D
12
+ from basicsr.data.degradations import random_add_gaussian_noise_pt, random_add_poisson_noise_pt
13
+ from basicsr.data.transforms import paired_random_crop
14
+
15
+ AUGMENT_OPT = {
16
+ 'use_hflip': False,
17
+ 'use_rot': False
18
+ }
19
+
20
+ KERNEL_OPT = {
21
+ 'blur_kernel_size': 21,
22
+ 'kernel_list': ['iso', 'aniso', 'generalized_iso', 'generalized_aniso', 'plateau_iso', 'plateau_aniso'],
23
+ 'kernel_prob': [0.45, 0.25, 0.12, 0.03, 0.12, 0.03],
24
+ 'sinc_prob': 0.1,
25
+ 'blur_sigma': [0.2, 3],
26
+ 'betag_range': [0.5, 4],
27
+ 'betap_range': [1, 2],
28
+
29
+ 'blur_kernel_size2': 21,
30
+ 'kernel_list2': ['iso', 'aniso', 'generalized_iso', 'generalized_aniso', 'plateau_iso', 'plateau_aniso'],
31
+ 'kernel_prob2': [0.45, 0.25, 0.12, 0.03, 0.12, 0.03],
32
+ 'sinc_prob2': 0.1,
33
+ 'blur_sigma2': [0.2, 1.5],
34
+ 'betag_range2': [0.5, 4],
35
+ 'betap_range2': [1, 2],
36
+ 'final_sinc_prob': 0.8,
37
+ }
38
+
39
+ DEGRADE_OPT = {
40
+ 'resize_prob': [0.2, 0.7, 0.1], # up, down, keep
41
+ 'resize_range': [0.15, 1.5],
42
+ 'gaussian_noise_prob': 0.5,
43
+ 'noise_range': [1, 30],
44
+ 'poisson_scale_range': [0.05, 3],
45
+ 'gray_noise_prob': 0.4,
46
+ 'jpeg_range': [30, 95],
47
+
48
+ # the second degradation process
49
+ 'second_blur_prob': 0.8,
50
+ 'resize_prob2': [0.3, 0.4, 0.3], # up, down, keep
51
+ 'resize_range2': [0.3, 1.2],
52
+ 'gaussian_noise_prob2': 0.5,
53
+ 'noise_range2': [1, 25],
54
+ 'poisson_scale_range2': [0.05, 2.5],
55
+ 'gray_noise_prob2': 0.4,
56
+ 'jpeg_range2': [30, 95],
57
+
58
+ 'gt_size': 512,
59
+ 'no_degradation_prob': 0.01,
60
+ 'use_usm': True,
61
+ 'sf': 4,
62
+ 'random_size': False,
63
+ 'resize_lq': True
64
+ }
65
+
66
+ class RealESRGANDegradation:
67
+
68
+ def __init__(self, augment_opt=None, kernel_opt=None, degrade_opt=None, device='cuda', resolution=None):
69
+ if augment_opt is None:
70
+ augment_opt = AUGMENT_OPT
71
+ self.augment_opt = augment_opt
72
+ if kernel_opt is None:
73
+ kernel_opt = KERNEL_OPT
74
+ self.kernel_opt = kernel_opt
75
+ if degrade_opt is None:
76
+ degrade_opt = DEGRADE_OPT
77
+ self.degrade_opt = degrade_opt
78
+ if resolution is not None:
79
+ self.degrade_opt['gt_size'] = resolution
80
+ self.device = device
81
+
82
+ self.jpeger = DiffJPEG(differentiable=False).to(self.device)
83
+ self.usm_sharpener = USMSharp().to(self.device)
84
+
85
+ # blur settings for the first degradation
86
+ self.blur_kernel_size = kernel_opt['blur_kernel_size']
87
+ self.kernel_list = kernel_opt['kernel_list']
88
+ self.kernel_prob = kernel_opt['kernel_prob'] # a list for each kernel probability
89
+ self.blur_sigma = kernel_opt['blur_sigma']
90
+ self.betag_range = kernel_opt['betag_range'] # betag used in generalized Gaussian blur kernels
91
+ self.betap_range = kernel_opt['betap_range'] # betap used in plateau blur kernels
92
+ self.sinc_prob = kernel_opt['sinc_prob'] # the probability for sinc filters
93
+
94
+ # blur settings for the second degradation
95
+ self.blur_kernel_size2 = kernel_opt['blur_kernel_size2']
96
+ self.kernel_list2 = kernel_opt['kernel_list2']
97
+ self.kernel_prob2 = kernel_opt['kernel_prob2']
98
+ self.blur_sigma2 = kernel_opt['blur_sigma2']
99
+ self.betag_range2 = kernel_opt['betag_range2']
100
+ self.betap_range2 = kernel_opt['betap_range2']
101
+ self.sinc_prob2 = kernel_opt['sinc_prob2']
102
+
103
+ # a final sinc filter
104
+ self.final_sinc_prob = kernel_opt['final_sinc_prob']
105
+
106
+ self.kernel_range = [2 * v + 1 for v in range(3, 11)] # kernel size ranges from 7 to 21
107
+ # TODO: kernel range is now hard-coded, should be in the configure file
108
+ self.pulse_tensor = torch.zeros(21, 21).float() # convolving with pulse tensor brings no blurry effect
109
+ self.pulse_tensor[10, 10] = 1
110
+
111
+ def get_kernel(self):
112
+
113
+ # ------------------------ Generate kernels (used in the first degradation) ------------------------ #
114
+ kernel_size = random.choice(self.kernel_range)
115
+ if np.random.uniform() < self.kernel_opt['sinc_prob']:
116
+ # this sinc filter setting is for kernels ranging from [7, 21]
117
+ if kernel_size < 13:
118
+ omega_c = np.random.uniform(np.pi / 3, np.pi)
119
+ else:
120
+ omega_c = np.random.uniform(np.pi / 5, np.pi)
121
+ kernel = circular_lowpass_kernel(omega_c, kernel_size, pad_to=False)
122
+ else:
123
+ kernel = random_mixed_kernels(
124
+ self.kernel_list,
125
+ self.kernel_prob,
126
+ kernel_size,
127
+ self.blur_sigma,
128
+ self.blur_sigma, [-math.pi, math.pi],
129
+ self.betag_range,
130
+ self.betap_range,
131
+ noise_range=None)
132
+ # pad kernel
133
+ pad_size = (21 - kernel_size) // 2
134
+ kernel = np.pad(kernel, ((pad_size, pad_size), (pad_size, pad_size)))
135
+
136
+ # ------------------------ Generate kernels (used in the second degradation) ------------------------ #
137
+ kernel_size = random.choice(self.kernel_range)
138
+ if np.random.uniform() < self.kernel_opt['sinc_prob2']:
139
+ if kernel_size < 13:
140
+ omega_c = np.random.uniform(np.pi / 3, np.pi)
141
+ else:
142
+ omega_c = np.random.uniform(np.pi / 5, np.pi)
143
+ kernel2 = circular_lowpass_kernel(omega_c, kernel_size, pad_to=False)
144
+ else:
145
+ kernel2 = random_mixed_kernels(
146
+ self.kernel_list2,
147
+ self.kernel_prob2,
148
+ kernel_size,
149
+ self.blur_sigma2,
150
+ self.blur_sigma2, [-math.pi, math.pi],
151
+ self.betag_range2,
152
+ self.betap_range2,
153
+ noise_range=None)
154
+
155
+ # pad kernel
156
+ pad_size = (21 - kernel_size) // 2
157
+ kernel2 = np.pad(kernel2, ((pad_size, pad_size), (pad_size, pad_size)))
158
+
159
+ # ------------------------------------- the final sinc kernel ------------------------------------- #
160
+ if np.random.uniform() < self.kernel_opt['final_sinc_prob']:
161
+ kernel_size = random.choice(self.kernel_range)
162
+ omega_c = np.random.uniform(np.pi / 3, np.pi)
163
+ sinc_kernel = circular_lowpass_kernel(omega_c, kernel_size, pad_to=21)
164
+ sinc_kernel = torch.FloatTensor(sinc_kernel)
165
+ else:
166
+ sinc_kernel = self.pulse_tensor
167
+
168
+ # BGR to RGB, HWC to CHW, numpy to tensor
169
+ kernel = torch.FloatTensor(kernel)
170
+ kernel2 = torch.FloatTensor(kernel2)
171
+
172
+ return (kernel, kernel2, sinc_kernel)
173
+
174
+ @torch.no_grad()
175
+ def __call__(self, img_gt, kernels=None):
176
+ '''
177
+ :param: img_gt: BCHW, RGB, [0, 1] float32 tensor
178
+ '''
179
+ if kernels is None:
180
+ kernel = []
181
+ kernel2 = []
182
+ sinc_kernel = []
183
+ for _ in range(img_gt.shape[0]):
184
+ k, k2, sk = self.get_kernel()
185
+ kernel.append(k)
186
+ kernel2.append(k2)
187
+ sinc_kernel.append(sk)
188
+ kernel = torch.stack(kernel)
189
+ kernel2 = torch.stack(kernel2)
190
+ sinc_kernel = torch.stack(sinc_kernel)
191
+ else:
192
+ # kernels created in dataset.
193
+ kernel, kernel2, sinc_kernel = kernels
194
+
195
+ # ----------------------- Pre-process ----------------------- #
196
+ im_gt = img_gt.to(self.device)
197
+ if self.degrade_opt['use_usm']:
198
+ im_gt = self.usm_sharpener(im_gt)
199
+ im_gt = im_gt.to(memory_format=torch.contiguous_format).float()
200
+ kernel = kernel.to(self.device)
201
+ kernel2 = kernel2.to(self.device)
202
+ sinc_kernel = sinc_kernel.to(self.device)
203
+ ori_h, ori_w = im_gt.size()[2:4]
204
+
205
+ # ----------------------- The first degradation process ----------------------- #
206
+ # blur
207
+ out = filter2D(im_gt, kernel)
208
+ # random resize
209
+ updown_type = random.choices(
210
+ ['up', 'down', 'keep'],
211
+ self.degrade_opt['resize_prob'],
212
+ )[0]
213
+ if updown_type == 'up':
214
+ scale = random.uniform(1, self.degrade_opt['resize_range'][1])
215
+ elif updown_type == 'down':
216
+ scale = random.uniform(self.degrade_opt['resize_range'][0], 1)
217
+ else:
218
+ scale = 1
219
+ mode = random.choice(['area', 'bilinear', 'bicubic'])
220
+ out = torch.nn.functional.interpolate(out, scale_factor=scale, mode=mode)
221
+ # add noise
222
+ gray_noise_prob = self.degrade_opt['gray_noise_prob']
223
+ if random.random() < self.degrade_opt['gaussian_noise_prob']:
224
+ out = random_add_gaussian_noise_pt(
225
+ out,
226
+ sigma_range=self.degrade_opt['noise_range'],
227
+ clip=True,
228
+ rounds=False,
229
+ gray_prob=gray_noise_prob,
230
+ )
231
+ else:
232
+ out = random_add_poisson_noise_pt(
233
+ out,
234
+ scale_range=self.degrade_opt['poisson_scale_range'],
235
+ gray_prob=gray_noise_prob,
236
+ clip=True,
237
+ rounds=False)
238
+ # JPEG compression
239
+ jpeg_p = out.new_zeros(out.size(0)).uniform_(*self.degrade_opt['jpeg_range'])
240
+ out = torch.clamp(out, 0, 1) # clamp to [0, 1], otherwise JPEGer will result in unpleasant artifacts
241
+ out = self.jpeger(out, quality=jpeg_p)
242
+
243
+ # ----------------------- The second degradation process ----------------------- #
244
+ # blur
245
+ if random.random() < self.degrade_opt['second_blur_prob']:
246
+ out = out.contiguous()
247
+ out = filter2D(out, kernel2)
248
+ # random resize
249
+ updown_type = random.choices(
250
+ ['up', 'down', 'keep'],
251
+ self.degrade_opt['resize_prob2'],
252
+ )[0]
253
+ if updown_type == 'up':
254
+ scale = random.uniform(1, self.degrade_opt['resize_range2'][1])
255
+ elif updown_type == 'down':
256
+ scale = random.uniform(self.degrade_opt['resize_range2'][0], 1)
257
+ else:
258
+ scale = 1
259
+ mode = random.choice(['area', 'bilinear', 'bicubic'])
260
+ out = torch.nn.functional.interpolate(
261
+ out,
262
+ size=(int(ori_h / self.degrade_opt['sf'] * scale),
263
+ int(ori_w / self.degrade_opt['sf'] * scale)),
264
+ mode=mode,
265
+ )
266
+ # add noise
267
+ gray_noise_prob = self.degrade_opt['gray_noise_prob2']
268
+ if random.random() < self.degrade_opt['gaussian_noise_prob2']:
269
+ out = random_add_gaussian_noise_pt(
270
+ out,
271
+ sigma_range=self.degrade_opt['noise_range2'],
272
+ clip=True,
273
+ rounds=False,
274
+ gray_prob=gray_noise_prob,
275
+ )
276
+ else:
277
+ out = random_add_poisson_noise_pt(
278
+ out,
279
+ scale_range=self.degrade_opt['poisson_scale_range2'],
280
+ gray_prob=gray_noise_prob,
281
+ clip=True,
282
+ rounds=False,
283
+ )
284
+
285
+ # JPEG compression + the final sinc filter
286
+ # We also need to resize images to desired sizes. We group [resize back + sinc filter] together
287
+ # as one operation.
288
+ # We consider two orders:
289
+ # 1. [resize back + sinc filter] + JPEG compression
290
+ # 2. JPEG compression + [resize back + sinc filter]
291
+ # Empirically, we find other combinations (sinc + JPEG + Resize) will introduce twisted lines.
292
+ if random.random() < 0.5:
293
+ # resize back + the final sinc filter
294
+ mode = random.choice(['area', 'bilinear', 'bicubic'])
295
+ out = torch.nn.functional.interpolate(
296
+ out,
297
+ size=(ori_h // self.degrade_opt['sf'],
298
+ ori_w // self.degrade_opt['sf']),
299
+ mode=mode,
300
+ )
301
+ out = out.contiguous()
302
+ out = filter2D(out, sinc_kernel)
303
+ # JPEG compression
304
+ jpeg_p = out.new_zeros(out.size(0)).uniform_(*self.degrade_opt['jpeg_range2'])
305
+ out = torch.clamp(out, 0, 1)
306
+ out = self.jpeger(out, quality=jpeg_p)
307
+ else:
308
+ # JPEG compression
309
+ jpeg_p = out.new_zeros(out.size(0)).uniform_(*self.degrade_opt['jpeg_range2'])
310
+ out = torch.clamp(out, 0, 1)
311
+ out = self.jpeger(out, quality=jpeg_p)
312
+ # resize back + the final sinc filter
313
+ mode = random.choice(['area', 'bilinear', 'bicubic'])
314
+ out = torch.nn.functional.interpolate(
315
+ out,
316
+ size=(ori_h // self.degrade_opt['sf'],
317
+ ori_w // self.degrade_opt['sf']),
318
+ mode=mode,
319
+ )
320
+ out = out.contiguous()
321
+ out = filter2D(out, sinc_kernel)
322
+
323
+ # clamp and round
324
+ im_lq = torch.clamp(out, 0, 1.0)
325
+
326
+ # random crop
327
+ gt_size = self.degrade_opt['gt_size']
328
+ im_gt, im_lq = paired_random_crop(im_gt, im_lq, gt_size, self.degrade_opt['sf'])
329
+
330
+ if self.degrade_opt['resize_lq']:
331
+ im_lq = torch.nn.functional.interpolate(
332
+ im_lq,
333
+ size=(im_gt.size(-2),
334
+ im_gt.size(-1)),
335
+ mode='bicubic',
336
+ )
337
+
338
+ if random.random() < self.degrade_opt['no_degradation_prob'] or torch.isnan(im_lq).any():
339
+ im_lq = im_gt
340
+
341
+ # sharpen self.gt again, as we have changed the self.gt with self._dequeue_and_enqueue
342
+ im_lq = im_lq.contiguous() # for the warning: grad and param do not obey the gradient layout contract
343
+ im_lq = im_lq*2 - 1.0
344
+ im_gt = im_gt*2 - 1.0
345
+
346
+ if self.degrade_opt['random_size']:
347
+ raise NotImplementedError
348
+ im_lq, im_gt = self.randn_cropinput(im_lq, im_gt)
349
+
350
+ im_lq = torch.clamp(im_lq, -1.0, 1.0)
351
+ im_gt = torch.clamp(im_gt, -1.0, 1.0)
352
+
353
+ return (im_lq, im_gt)