Nguyễn Bá Thiêm commited on
Commit
239e299
·
1 Parent(s): 1ac6098

Add streamlit and gdown to requirements.txt

Browse files
images/img_003_SRF_4_LR.png ADDED
models/HAT/hat.py CHANGED
@@ -0,0 +1,1363 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gdown
2
+
3
+ # url = 'https://drive.google.com/file/d/1LHIUM7YoUDk8cXWzVZhroAcA1xXi-d87/view?usp=drive_link'
4
+ output = 'models/HAT/hat_model_checkpoint_best.pth'
5
+ # gdown.download(url, output, quiet=False)
6
+
7
+ import gc
8
+ import os
9
+ import random
10
+ import time
11
+ import wandb
12
+ from tqdm import tqdm
13
+
14
+ import matplotlib.pyplot as plt
15
+ from PIL import Image
16
+ from skimage.metrics import structural_similarity as ssim
17
+
18
+ import torch
19
+ from torch import nn, optim
20
+ import torch.nn.functional as F
21
+ from torch.utils.data import Dataset, DataLoader, ConcatDataset
22
+ from torchvision import transforms
23
+ from torchvision.transforms import Compose
24
+ from torchmetrics.functional.image import structural_similarity_index_measure as ssim
25
+
26
+ from basicsr.archs.arch_util import to_2tuple, trunc_normal_
27
+ from einops import rearrange
28
+ import math
29
+
30
+ class ChannelAttention(nn.Module):
31
+ """Channel attention used in RCAN.
32
+ Args:
33
+ num_feat (int): Channel number of intermediate features.
34
+ squeeze_factor (int): Channel squeeze factor. Default: 16.
35
+ """
36
+
37
+ def __init__(self, num_feat, squeeze_factor=16):
38
+ super(ChannelAttention, self).__init__()
39
+ self.attention = nn.Sequential(
40
+ nn.AdaptiveAvgPool2d(1),
41
+ nn.Conv2d(num_feat, num_feat // squeeze_factor, 1, padding=0),
42
+ nn.ReLU(inplace=True),
43
+ nn.Conv2d(num_feat // squeeze_factor, num_feat, 1, padding=0),
44
+ nn.Sigmoid())
45
+
46
+ def forward(self, x):
47
+ y = self.attention(x)
48
+ return x * y
49
+
50
+
51
+ class CAB(nn.Module):
52
+
53
+ def __init__(self, num_feat, compress_ratio=3, squeeze_factor=30):
54
+ super(CAB, self).__init__()
55
+
56
+ self.cab = nn.Sequential(
57
+ nn.Conv2d(num_feat, num_feat // compress_ratio, 3, 1, 1),
58
+ nn.GELU(),
59
+ nn.Conv2d(num_feat // compress_ratio, num_feat, 3, 1, 1),
60
+ ChannelAttention(num_feat, squeeze_factor)
61
+ )
62
+
63
+ def forward(self, x):
64
+ return self.cab(x)
65
+
66
+ def window_partition(x, window_size):
67
+ """
68
+ Args:
69
+ x: (b, h, w, c)
70
+ window_size (int): window size
71
+
72
+ Returns:
73
+ windows: (num_windows*b, window_size, window_size, c)
74
+ """
75
+ b, h, w, c = x.shape
76
+ x = x.view(b, h // window_size, window_size, w // window_size, window_size, c)
77
+ windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, c)
78
+ return windows
79
+
80
+ def window_reverse(windows, window_size, h, w):
81
+ """
82
+ Args:
83
+ windows: (num_windows*b, window_size, window_size, c)
84
+ window_size (int): Window size
85
+ h (int): Height of image
86
+ w (int): Width of image
87
+
88
+ Returns:
89
+ x: (b, h, w, c)
90
+ """
91
+
92
+ b = int(windows.shape[0] / (h * w / window_size / window_size))
93
+ x = windows.view(b, h // window_size, w // window_size, window_size, window_size, -1)
94
+ x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(b, h, w, -1)
95
+ return x
96
+
97
+
98
+
99
+ class WindowAttention(nn.Module):
100
+ r""" Window based multi-head self attention (W-MSA) module with relative position bias.
101
+ It supports both of shifted and non-shifted window.
102
+
103
+ Args:
104
+ dim (int): Number of input channels.
105
+ window_size (tuple[int]): The height and width of the window.
106
+ num_heads (int): Number of attention heads.
107
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
108
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
109
+ attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
110
+ proj_drop (float, optional): Dropout ratio of output. Default: 0.0
111
+ """
112
+
113
+ def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.):
114
+
115
+ super().__init__()
116
+ self.dim = dim
117
+ self.window_size = window_size # Wh, Ww
118
+ self.num_heads = num_heads
119
+ head_dim = dim // num_heads
120
+ self.scale = qk_scale or head_dim**-0.5
121
+
122
+ # define a parameter table of relative position bias
123
+ self.relative_position_bias_table = nn.Parameter(
124
+ torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH
125
+
126
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
127
+ self.attn_drop = nn.Dropout(attn_drop)
128
+ self.proj = nn.Linear(dim, dim)
129
+
130
+ self.proj_drop = nn.Dropout(proj_drop)
131
+
132
+ trunc_normal_(self.relative_position_bias_table, std=.02)
133
+ self.softmax = nn.Softmax(dim=-1)
134
+
135
+ def forward(self, x, rpi, mask=None):
136
+ """
137
+ Args:
138
+ x: input features with shape of (num_windows*b, n, c)
139
+ mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
140
+ """
141
+ b_, n, c = x.shape
142
+ qkv = self.qkv(x).reshape(b_, n, 3, self.num_heads, c // self.num_heads).permute(2, 0, 3, 1, 4)
143
+ q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
144
+
145
+ q = q * self.scale
146
+ attn = (q @ k.transpose(-2, -1))
147
+
148
+ relative_position_bias = self.relative_position_bias_table[rpi.view(-1)].view(
149
+ self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH
150
+ relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
151
+ attn = attn + relative_position_bias.unsqueeze(0)
152
+
153
+ if mask is not None:
154
+ nw = mask.shape[0]
155
+ attn = attn.view(b_ // nw, nw, self.num_heads, n, n) + mask.unsqueeze(1).unsqueeze(0)
156
+ attn = attn.view(-1, self.num_heads, n, n)
157
+ attn = self.softmax(attn)
158
+ else:
159
+ attn = self.softmax(attn)
160
+
161
+ attn = self.attn_drop(attn)
162
+
163
+ x = (attn @ v).transpose(1, 2).reshape(b_, n, c)
164
+ x = self.proj(x)
165
+ x = self.proj_drop(x)
166
+ return x
167
+
168
+ def drop_path(x, drop_prob: float = 0., training: bool = False):
169
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
170
+
171
+ From: https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/layers/drop.py
172
+ """
173
+ if drop_prob == 0. or not training:
174
+ return x
175
+ keep_prob = 1 - drop_prob
176
+ shape = (x.shape[0], ) + (1, ) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
177
+ random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
178
+ random_tensor.floor_() # binarize
179
+ output = x.div(keep_prob) * random_tensor
180
+ return output
181
+
182
+
183
+ class DropPath(nn.Module):
184
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
185
+
186
+ From: https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/layers/drop.py
187
+ """
188
+
189
+ def __init__(self, drop_prob=None):
190
+ super(DropPath, self).__init__()
191
+ self.drop_prob = drop_prob
192
+
193
+ def forward(self, x):
194
+ return drop_path(x, self.drop_prob, self.training)
195
+
196
+
197
+ class Mlp(nn.Module):
198
+
199
+ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
200
+ super().__init__()
201
+ out_features = out_features or in_features
202
+ hidden_features = hidden_features or in_features
203
+ self.fc1 = nn.Linear(in_features, hidden_features)
204
+ self.act = act_layer()
205
+ self.fc2 = nn.Linear(hidden_features, out_features)
206
+ self.drop = nn.Dropout(drop)
207
+
208
+ def forward(self, x):
209
+ x = self.fc1(x)
210
+ x = self.act(x)
211
+ x = self.drop(x)
212
+ x = self.fc2(x)
213
+ x = self.drop(x)
214
+ return x
215
+
216
+ class OCAB(nn.Module):
217
+ # overlapping cross-attention block
218
+
219
+ def __init__(self, dim,
220
+ input_resolution,
221
+ window_size,
222
+ overlap_ratio,
223
+ num_heads,
224
+ qkv_bias=True,
225
+ qk_scale=None,
226
+ mlp_ratio=2,
227
+ norm_layer=nn.LayerNorm
228
+ ):
229
+
230
+ super().__init__()
231
+ self.dim = dim
232
+ self.input_resolution = input_resolution
233
+ self.window_size = window_size
234
+ self.num_heads = num_heads
235
+ head_dim = dim // num_heads
236
+ self.scale = qk_scale or head_dim**-0.5
237
+ self.overlap_win_size = int(window_size * overlap_ratio) + window_size
238
+
239
+ self.norm1 = norm_layer(dim)
240
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
241
+ self.unfold = nn.Unfold(kernel_size=(self.overlap_win_size, self.overlap_win_size), stride=window_size, padding=(self.overlap_win_size-window_size)//2)
242
+
243
+ # define a parameter table of relative position bias
244
+ self.relative_position_bias_table = nn.Parameter(
245
+ torch.zeros((window_size + self.overlap_win_size - 1) * (window_size + self.overlap_win_size - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH
246
+
247
+ trunc_normal_(self.relative_position_bias_table, std=.02)
248
+ self.softmax = nn.Softmax(dim=-1)
249
+
250
+ self.proj = nn.Linear(dim,dim)
251
+
252
+ self.norm2 = norm_layer(dim)
253
+ mlp_hidden_dim = int(dim * mlp_ratio)
254
+ self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=nn.GELU)
255
+
256
+ def forward(self, x, x_size, rpi):
257
+ h, w = x_size
258
+ b, _, c = x.shape
259
+
260
+ shortcut = x
261
+ x = self.norm1(x)
262
+ x = x.view(b, h, w, c)
263
+
264
+ qkv = self.qkv(x).reshape(b, h, w, 3, c).permute(3, 0, 4, 1, 2) # 3, b, c, h, w
265
+ q = qkv[0].permute(0, 2, 3, 1) # b, h, w, c
266
+ kv = torch.cat((qkv[1], qkv[2]), dim=1) # b, 2*c, h, w
267
+
268
+ # partition windows
269
+ q_windows = window_partition(q, self.window_size) # nw*b, window_size, window_size, c
270
+ q_windows = q_windows.view(-1, self.window_size * self.window_size, c) # nw*b, window_size*window_size, c
271
+
272
+ kv_windows = self.unfold(kv) # b, c*w*w, nw
273
+ kv_windows = rearrange(kv_windows, 'b (nc ch owh oww) nw -> nc (b nw) (owh oww) ch', nc=2, ch=c, owh=self.overlap_win_size, oww=self.overlap_win_size).contiguous() # 2, nw*b, ow*ow, c
274
+ k_windows, v_windows = kv_windows[0], kv_windows[1] # nw*b, ow*ow, c
275
+
276
+ b_, nq, _ = q_windows.shape
277
+ _, n, _ = k_windows.shape
278
+ d = self.dim // self.num_heads
279
+ q = q_windows.reshape(b_, nq, self.num_heads, d).permute(0, 2, 1, 3) # nw*b, nH, nq, d
280
+ k = k_windows.reshape(b_, n, self.num_heads, d).permute(0, 2, 1, 3) # nw*b, nH, n, d
281
+ v = v_windows.reshape(b_, n, self.num_heads, d).permute(0, 2, 1, 3) # nw*b, nH, n, d
282
+
283
+ q = q * self.scale
284
+ attn = (q @ k.transpose(-2, -1))
285
+
286
+ relative_position_bias = self.relative_position_bias_table[rpi.view(-1)].view(
287
+ self.window_size * self.window_size, self.overlap_win_size * self.overlap_win_size, -1) # ws*ws, wse*wse, nH
288
+ relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, ws*ws, wse*wse
289
+ attn = attn + relative_position_bias.unsqueeze(0)
290
+
291
+ attn = self.softmax(attn)
292
+ attn_windows = (attn @ v).transpose(1, 2).reshape(b_, nq, self.dim)
293
+
294
+ # merge windows
295
+ attn_windows = attn_windows.view(-1, self.window_size, self.window_size, self.dim)
296
+ x = window_reverse(attn_windows, self.window_size, h, w) # b h w c
297
+ x = x.view(b, h * w, self.dim)
298
+
299
+ x = self.proj(x) + shortcut
300
+
301
+ x = x + self.mlp(self.norm2(x))
302
+ return x
303
+ class AttenBlocks(nn.Module):
304
+ """ A series of attention blocks for one RHAG.
305
+
306
+ Args:
307
+ dim (int): Number of input channels.
308
+ input_resolution (tuple[int]): Input resolution.
309
+ depth (int): Number of blocks.
310
+ num_heads (int): Number of attention heads.
311
+ window_size (int): Local window size.
312
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
313
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
314
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
315
+ drop (float, optional): Dropout rate. Default: 0.0
316
+ attn_drop (float, optional): Attention dropout rate. Default: 0.0
317
+ drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
318
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
319
+ downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
320
+ use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
321
+ """
322
+
323
+ def __init__(self,
324
+ dim,
325
+ input_resolution,
326
+ depth,
327
+ num_heads,
328
+ window_size,
329
+ compress_ratio,
330
+ squeeze_factor,
331
+ conv_scale,
332
+ overlap_ratio,
333
+ mlp_ratio=4.,
334
+ qkv_bias=True,
335
+ qk_scale=None,
336
+ drop=0.,
337
+ attn_drop=0.,
338
+ drop_path=0.,
339
+ norm_layer=nn.LayerNorm,
340
+ downsample=None,
341
+ use_checkpoint=False):
342
+
343
+ super().__init__()
344
+ self.dim = dim
345
+ self.input_resolution = input_resolution
346
+ self.depth = depth
347
+ self.use_checkpoint = use_checkpoint
348
+
349
+ # build blocks
350
+ self.blocks = nn.ModuleList([
351
+ HAB(
352
+ dim=dim,
353
+ input_resolution=input_resolution,
354
+ num_heads=num_heads,
355
+ window_size=window_size,
356
+ shift_size=0 if (i % 2 == 0) else window_size // 2,
357
+ compress_ratio=compress_ratio,
358
+ squeeze_factor=squeeze_factor,
359
+ conv_scale=conv_scale,
360
+ mlp_ratio=mlp_ratio,
361
+ qkv_bias=qkv_bias,
362
+ qk_scale=qk_scale,
363
+ drop=drop,
364
+ attn_drop=attn_drop,
365
+ drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
366
+ norm_layer=norm_layer) for i in range(depth)
367
+ ])
368
+
369
+ # OCAB
370
+ self.overlap_attn = OCAB(
371
+ dim=dim,
372
+ input_resolution=input_resolution,
373
+ window_size=window_size,
374
+ overlap_ratio=overlap_ratio,
375
+ num_heads=num_heads,
376
+ qkv_bias=qkv_bias,
377
+ qk_scale=qk_scale,
378
+ mlp_ratio=mlp_ratio,
379
+ norm_layer=norm_layer
380
+ )
381
+
382
+ # patch merging layer
383
+ if downsample is not None:
384
+ self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer)
385
+ else:
386
+ self.downsample = None
387
+
388
+ def forward(self, x, x_size, params):
389
+ for blk in self.blocks:
390
+ x = blk(x, x_size, params['rpi_sa'], params['attn_mask'])
391
+
392
+ x = self.overlap_attn(x, x_size, params['rpi_oca'])
393
+
394
+ if self.downsample is not None:
395
+ x = self.downsample(x)
396
+ return x
397
+
398
+
399
+ class RHAG(nn.Module):
400
+ """Residual Hybrid Attention Group (RHAG).
401
+
402
+ Args:
403
+ dim (int): Number of input channels.
404
+ input_resolution (tuple[int]): Input resolution.
405
+ depth (int): Number of blocks.
406
+ num_heads (int): Number of attention heads.
407
+ window_size (int): Local window size.
408
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
409
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
410
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
411
+ drop (float, optional): Dropout rate. Default: 0.0
412
+ attn_drop (float, optional): Attention dropout rate. Default: 0.0
413
+ drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
414
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
415
+ downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
416
+ use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
417
+ img_size: Input image size.
418
+ patch_size: Patch size.
419
+ resi_connection: The convolutional block before residual connection.
420
+ """
421
+
422
+ def __init__(self,
423
+ dim,
424
+ input_resolution,
425
+ depth,
426
+ num_heads,
427
+ window_size,
428
+ compress_ratio,
429
+ squeeze_factor,
430
+ conv_scale,
431
+ overlap_ratio,
432
+ mlp_ratio=4.,
433
+ qkv_bias=True,
434
+ qk_scale=None,
435
+ drop=0.,
436
+ attn_drop=0.,
437
+ drop_path=0.,
438
+ norm_layer=nn.LayerNorm,
439
+ downsample=None,
440
+ use_checkpoint=False,
441
+ img_size=224,
442
+ patch_size=4,
443
+ resi_connection='1conv'):
444
+ super(RHAG, self).__init__()
445
+
446
+ self.dim = dim
447
+ self.input_resolution = input_resolution
448
+
449
+ self.residual_group = AttenBlocks(
450
+ dim=dim,
451
+ input_resolution=input_resolution,
452
+ depth=depth,
453
+ num_heads=num_heads,
454
+ window_size=window_size,
455
+ compress_ratio=compress_ratio,
456
+ squeeze_factor=squeeze_factor,
457
+ conv_scale=conv_scale,
458
+ overlap_ratio=overlap_ratio,
459
+ mlp_ratio=mlp_ratio,
460
+ qkv_bias=qkv_bias,
461
+ qk_scale=qk_scale,
462
+ drop=drop,
463
+ attn_drop=attn_drop,
464
+ drop_path=drop_path,
465
+ norm_layer=norm_layer,
466
+ downsample=downsample,
467
+ use_checkpoint=use_checkpoint)
468
+
469
+ if resi_connection == '1conv':
470
+ self.conv = nn.Conv2d(dim, dim, 3, 1, 1)
471
+ elif resi_connection == 'identity':
472
+ self.conv = nn.Identity()
473
+
474
+ self.patch_embed = PatchEmbed(
475
+ img_size=img_size, patch_size=patch_size, in_chans=0, embed_dim=dim, norm_layer=None)
476
+
477
+ self.patch_unembed = PatchUnEmbed(
478
+ img_size=img_size, patch_size=patch_size, in_chans=0, embed_dim=dim, norm_layer=None)
479
+
480
+ def forward(self, x, x_size, params):
481
+ return self.patch_embed(self.conv(self.patch_unembed(self.residual_group(x, x_size, params), x_size))) + x
482
+
483
+
484
+ class PatchEmbed(nn.Module):
485
+ r""" Image to Patch Embedding
486
+
487
+ Args:
488
+ img_size (int): Image size. Default: 224.
489
+ patch_size (int): Patch token size. Default: 4.
490
+ in_chans (int): Number of input image channels. Default: 3.
491
+ embed_dim (int): Number of linear projection output channels. Default: 96.
492
+ norm_layer (nn.Module, optional): Normalization layer. Default: None
493
+ """
494
+
495
+ def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
496
+ super().__init__()
497
+ img_size = to_2tuple(img_size)
498
+ patch_size = to_2tuple(patch_size)
499
+ patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]]
500
+ self.img_size = img_size
501
+ self.patch_size = patch_size
502
+ self.patches_resolution = patches_resolution
503
+ self.num_patches = patches_resolution[0] * patches_resolution[1]
504
+
505
+ self.in_chans = in_chans
506
+ self.embed_dim = embed_dim
507
+
508
+ if norm_layer is not None:
509
+ self.norm = norm_layer(embed_dim)
510
+ else:
511
+ self.norm = None
512
+
513
+ def forward(self, x):
514
+ x = x.flatten(2).transpose(1, 2) # b Ph*Pw c
515
+ if self.norm is not None:
516
+ x = self.norm(x)
517
+ return x
518
+
519
+
520
+ class PatchUnEmbed(nn.Module):
521
+ r""" Image to Patch Unembedding
522
+
523
+ Args:
524
+ img_size (int): Image size. Default: 224.
525
+ patch_size (int): Patch token size. Default: 4.
526
+ in_chans (int): Number of input image channels. Default: 3.
527
+ embed_dim (int): Number of linear projection output channels. Default: 96.
528
+ norm_layer (nn.Module, optional): Normalization layer. Default: None
529
+ """
530
+
531
+ def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
532
+ super().__init__()
533
+ img_size = to_2tuple(img_size)
534
+ patch_size = to_2tuple(patch_size)
535
+ patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]]
536
+ self.img_size = img_size
537
+ self.patch_size = patch_size
538
+ self.patches_resolution = patches_resolution
539
+ self.num_patches = patches_resolution[0] * patches_resolution[1]
540
+
541
+ self.in_chans = in_chans
542
+ self.embed_dim = embed_dim
543
+
544
+ def forward(self, x, x_size):
545
+ x = x.transpose(1, 2).contiguous().view(x.shape[0], self.embed_dim, x_size[0], x_size[1]) # b Ph*Pw c
546
+ return x
547
+
548
+
549
+
550
+ class Upsample(nn.Sequential):
551
+ """Upsample module.
552
+
553
+ Args:
554
+ scale (int): Scale factor. Supported scales: 2^n and 3.
555
+ num_feat (int): Channel number of intermediate features.
556
+ """
557
+
558
+ def __init__(self, scale, num_feat):
559
+ m = []
560
+ if (scale & (scale - 1)) == 0: # scale = 2^n
561
+ for _ in range(int(math.log(scale, 2))):
562
+ m.append(nn.Conv2d(num_feat, 4 * num_feat, 3, 1, 1))
563
+ m.append(nn.PixelShuffle(2))
564
+ elif scale == 3:
565
+ m.append(nn.Conv2d(num_feat, 9 * num_feat, 3, 1, 1))
566
+ m.append(nn.PixelShuffle(3))
567
+ else:
568
+ raise ValueError(f'scale {scale} is not supported. ' 'Supported scales: 2^n and 3.')
569
+ super(Upsample, self).__init__(*m)
570
+
571
+ class HAT(nn.Module):
572
+ r""" Hybrid Attention Transformer
573
+ A PyTorch implementation of : `Activating More Pixels in Image Super-Resolution Transformer`.
574
+ Some codes are based on SwinIR.
575
+ Args:
576
+ img_size (int | tuple(int)): Input image size. Default 64
577
+ patch_size (int | tuple(int)): Patch size. Default: 1
578
+ in_chans (int): Number of input image channels. Default: 3
579
+ embed_dim (int): Patch embedding dimension. Default: 96
580
+ depths (tuple(int)): Depth of each Swin Transformer layer.
581
+ num_heads (tuple(int)): Number of attention heads in different layers.
582
+ window_size (int): Window size. Default: 7
583
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4
584
+ qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True
585
+ qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. Default: None
586
+ drop_rate (float): Dropout rate. Default: 0
587
+ attn_drop_rate (float): Attention dropout rate. Default: 0
588
+ drop_path_rate (float): Stochastic depth rate. Default: 0.1
589
+ norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
590
+ ape (bool): If True, add absolute position embedding to the patch embedding. Default: False
591
+ patch_norm (bool): If True, add normalization after patch embedding. Default: True
592
+ use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False
593
+ upscale: Upscale factor. 2/3/4/8 for image SR, 1 for denoising and compress artifact reduction
594
+ img_range: Image range. 1. or 255.
595
+ upsampler: The reconstruction reconstruction module. 'pixelshuffle'/'pixelshuffledirect'/'nearest+conv'/None
596
+ resi_connection: The convolutional block before residual connection. '1conv'/'3conv'
597
+ """
598
+
599
+ def __init__(self,
600
+ img_size=64,
601
+ patch_size=1,
602
+ in_chans=3,
603
+ embed_dim=96,
604
+ depths=(6, 6, 6, 6),
605
+ num_heads=(6, 6, 6, 6),
606
+ window_size=7,
607
+ compress_ratio=3,
608
+ squeeze_factor=30,
609
+ conv_scale=0.01,
610
+ overlap_ratio=0.5,
611
+ mlp_ratio=4.,
612
+ qkv_bias=True,
613
+ qk_scale=None,
614
+ drop_rate=0.,
615
+ attn_drop_rate=0.,
616
+ drop_path_rate=0.1,
617
+ norm_layer=nn.LayerNorm,
618
+ ape=False,
619
+ patch_norm=True,
620
+ use_checkpoint=False,
621
+ upscale=2,
622
+ img_range=1.,
623
+ upsampler='',
624
+ resi_connection='1conv',
625
+ **kwargs):
626
+ super(HAT, self).__init__()
627
+
628
+ self.window_size = window_size
629
+ self.shift_size = window_size // 2
630
+ self.overlap_ratio = overlap_ratio
631
+
632
+ num_in_ch = in_chans
633
+ num_out_ch = in_chans
634
+ num_feat = 64
635
+ self.img_range = img_range
636
+ if in_chans == 3:
637
+ rgb_mean = (0.4488, 0.4371, 0.4040)
638
+ self.mean = torch.Tensor(rgb_mean).view(1, 3, 1, 1)
639
+ else:
640
+ self.mean = torch.zeros(1, 1, 1, 1)
641
+ self.upscale = upscale
642
+ self.upsampler = upsampler
643
+
644
+ # relative position index
645
+ relative_position_index_SA = self.calculate_rpi_sa()
646
+ relative_position_index_OCA = self.calculate_rpi_oca()
647
+ self.register_buffer('relative_position_index_SA', relative_position_index_SA)
648
+ self.register_buffer('relative_position_index_OCA', relative_position_index_OCA)
649
+
650
+ # ------------------------- 1, shallow feature extraction ------------------------- #
651
+ self.conv_first = nn.Conv2d(num_in_ch, embed_dim, 3, 1, 1)
652
+
653
+ # ------------------------- 2, deep feature extraction ------------------------- #
654
+ self.num_layers = len(depths)
655
+ self.embed_dim = embed_dim
656
+ self.ape = ape
657
+ self.patch_norm = patch_norm
658
+ self.num_features = embed_dim
659
+ self.mlp_ratio = mlp_ratio
660
+
661
+ # split image into non-overlapping patches
662
+ self.patch_embed = PatchEmbed(
663
+ img_size=img_size,
664
+ patch_size=patch_size,
665
+ in_chans=embed_dim,
666
+ embed_dim=embed_dim,
667
+ norm_layer=norm_layer if self.patch_norm else None)
668
+ num_patches = self.patch_embed.num_patches
669
+ patches_resolution = self.patch_embed.patches_resolution
670
+ self.patches_resolution = patches_resolution
671
+
672
+ # merge non-overlapping patches into image
673
+ self.patch_unembed = PatchUnEmbed(
674
+ img_size=img_size,
675
+ patch_size=patch_size,
676
+ in_chans=embed_dim,
677
+ embed_dim=embed_dim,
678
+ norm_layer=norm_layer if self.patch_norm else None)
679
+
680
+ # absolute position embedding
681
+ if self.ape:
682
+ self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim))
683
+ trunc_normal_(self.absolute_pos_embed, std=.02)
684
+
685
+ self.pos_drop = nn.Dropout(p=drop_rate)
686
+
687
+ # stochastic depth
688
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule
689
+
690
+ # build Residual Hybrid Attention Groups (RHAG)
691
+ self.layers = nn.ModuleList()
692
+ for i_layer in range(self.num_layers):
693
+ layer = RHAG(
694
+ dim=embed_dim,
695
+ input_resolution=(patches_resolution[0], patches_resolution[1]),
696
+ depth=depths[i_layer],
697
+ num_heads=num_heads[i_layer],
698
+ window_size=window_size,
699
+ compress_ratio=compress_ratio,
700
+ squeeze_factor=squeeze_factor,
701
+ conv_scale=conv_scale,
702
+ overlap_ratio=overlap_ratio,
703
+ mlp_ratio=self.mlp_ratio,
704
+ qkv_bias=qkv_bias,
705
+ qk_scale=qk_scale,
706
+ drop=drop_rate,
707
+ attn_drop=attn_drop_rate,
708
+ drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], # no impact on SR results
709
+ norm_layer=norm_layer,
710
+ downsample=None,
711
+ use_checkpoint=use_checkpoint,
712
+ img_size=img_size,
713
+ patch_size=patch_size,
714
+ resi_connection=resi_connection)
715
+ self.layers.append(layer)
716
+ self.norm = norm_layer(self.num_features)
717
+
718
+ # build the last conv layer in deep feature extraction
719
+ if resi_connection == '1conv':
720
+ self.conv_after_body = nn.Conv2d(embed_dim, embed_dim, 3, 1, 1)
721
+ elif resi_connection == 'identity':
722
+ self.conv_after_body = nn.Identity()
723
+
724
+ # ------------------------- 3, high quality image reconstruction ------------------------- #
725
+ if self.upsampler == 'pixelshuffle':
726
+ # for classical SR
727
+ self.conv_before_upsample = nn.Sequential(
728
+ nn.Conv2d(embed_dim, num_feat, 3, 1, 1), nn.LeakyReLU(inplace=True))
729
+ self.upsample = Upsample(upscale, num_feat)
730
+ self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
731
+
732
+ self.apply(self._init_weights)
733
+
734
+ def _init_weights(self, m):
735
+ if isinstance(m, nn.Linear):
736
+ trunc_normal_(m.weight, std=.02)
737
+ if isinstance(m, nn.Linear) and m.bias is not None:
738
+ nn.init.constant_(m.bias, 0)
739
+ elif isinstance(m, nn.LayerNorm):
740
+ nn.init.constant_(m.bias, 0)
741
+ nn.init.constant_(m.weight, 1.0)
742
+
743
+ def calculate_rpi_sa(self):
744
+ # calculate relative position index for SA
745
+ coords_h = torch.arange(self.window_size)
746
+ coords_w = torch.arange(self.window_size)
747
+ coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
748
+ coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
749
+ relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
750
+ relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
751
+ relative_coords[:, :, 0] += self.window_size - 1 # shift to start from 0
752
+ relative_coords[:, :, 1] += self.window_size - 1
753
+ relative_coords[:, :, 0] *= 2 * self.window_size - 1
754
+ relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
755
+ return relative_position_index
756
+
757
+ def calculate_rpi_oca(self):
758
+ # calculate relative position index for OCA
759
+ window_size_ori = self.window_size
760
+ window_size_ext = self.window_size + int(self.overlap_ratio * self.window_size)
761
+
762
+ coords_h = torch.arange(window_size_ori)
763
+ coords_w = torch.arange(window_size_ori)
764
+ coords_ori = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, ws, ws
765
+ coords_ori_flatten = torch.flatten(coords_ori, 1) # 2, ws*ws
766
+
767
+ coords_h = torch.arange(window_size_ext)
768
+ coords_w = torch.arange(window_size_ext)
769
+ coords_ext = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, wse, wse
770
+ coords_ext_flatten = torch.flatten(coords_ext, 1) # 2, wse*wse
771
+
772
+ relative_coords = coords_ext_flatten[:, None, :] - coords_ori_flatten[:, :, None] # 2, ws*ws, wse*wse
773
+
774
+ relative_coords = relative_coords.permute(1, 2, 0).contiguous() # ws*ws, wse*wse, 2
775
+ relative_coords[:, :, 0] += window_size_ori - window_size_ext + 1 # shift to start from 0
776
+ relative_coords[:, :, 1] += window_size_ori - window_size_ext + 1
777
+
778
+ relative_coords[:, :, 0] *= window_size_ori + window_size_ext - 1
779
+ relative_position_index = relative_coords.sum(-1)
780
+ return relative_position_index
781
+
782
+ def calculate_mask(self, x_size):
783
+ # calculate attention mask for SW-MSA
784
+ h, w = x_size
785
+ img_mask = torch.zeros((1, h, w, 1)) # 1 h w 1
786
+ h_slices = (slice(0, -self.window_size), slice(-self.window_size,
787
+ -self.shift_size), slice(-self.shift_size, None))
788
+ w_slices = (slice(0, -self.window_size), slice(-self.window_size,
789
+ -self.shift_size), slice(-self.shift_size, None))
790
+ cnt = 0
791
+ for h in h_slices:
792
+ for w in w_slices:
793
+ img_mask[:, h, w, :] = cnt
794
+ cnt += 1
795
+
796
+ mask_windows = window_partition(img_mask, self.window_size) # nw, window_size, window_size, 1
797
+ mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
798
+ attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
799
+ attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
800
+
801
+ return attn_mask
802
+
803
+ @torch.jit.ignore
804
+ def no_weight_decay(self):
805
+ return {'absolute_pos_embed'}
806
+
807
+ @torch.jit.ignore
808
+ def no_weight_decay_keywords(self):
809
+ return {'relative_position_bias_table'}
810
+
811
+ def forward_features(self, x):
812
+ x_size = (x.shape[2], x.shape[3])
813
+
814
+ # Calculate attention mask and relative position index in advance to speed up inference.
815
+ # The original code is very time-consuming for large window size.
816
+ attn_mask = self.calculate_mask(x_size).to(x.device)
817
+ params = {'attn_mask': attn_mask, 'rpi_sa': self.relative_position_index_SA, 'rpi_oca': self.relative_position_index_OCA}
818
+
819
+ x = self.patch_embed(x)
820
+ if self.ape:
821
+ x = x + self.absolute_pos_embed
822
+ x = self.pos_drop(x)
823
+
824
+ for layer in self.layers:
825
+ x = layer(x, x_size, params)
826
+
827
+ x = self.norm(x) # b seq_len c
828
+ x = self.patch_unembed(x, x_size)
829
+
830
+ return x
831
+
832
+ def forward(self, x):
833
+ self.mean = self.mean.type_as(x)
834
+ x = (x - self.mean) * self.img_range
835
+
836
+ if self.upsampler == 'pixelshuffle':
837
+ # for classical SR
838
+ x = self.conv_first(x)
839
+ x = self.conv_after_body(self.forward_features(x)) + x
840
+ x = self.conv_before_upsample(x)
841
+ x = self.conv_last(self.upsample(x))
842
+
843
+ x = x / self.img_range + self.mean
844
+
845
+ return x
846
+ # ------------------------------ HYPERPARAMS ------------------------------ #
847
+ config = {
848
+ "network_g": {
849
+ "type": "HAT",
850
+ "upscale": 4,
851
+ "in_chans": 3,
852
+ "img_size": 64,
853
+ "window_size": 16,
854
+ "compress_ratio": 3,
855
+ "squeeze_factor": 30,
856
+ "conv_scale": 0.01,
857
+ "overlap_ratio": 0.5,
858
+ "img_range": 1.,
859
+ "depths": [6, 6, 6, 6, 6, 6],
860
+ "embed_dim": 180,
861
+ "num_heads": [6, 6, 6, 6, 6, 6],
862
+ "mlp_ratio": 2,
863
+ "upsampler": 'pixelshuffle',
864
+ "resi_connection": '1conv'
865
+ },
866
+ "train": {
867
+ "ema_decay": 0.999,
868
+ "optim_g": {
869
+ "type": "Adam",
870
+ "lr": 1e-4,
871
+ "weight_decay": 0,
872
+ "betas": [0.9, 0.99]
873
+ },
874
+ "scheduler": {
875
+ "type": "MultiStepLR",
876
+ "milestones": [12, 20, 25, 30],
877
+ "gamma": 0.5
878
+ },
879
+ "total_iter": 30,
880
+ "warmup_iter": -1,
881
+ "pixel_opt": {
882
+ "type": "L1Loss",
883
+ "loss_weight": 1.0,
884
+ "reduction": "mean"
885
+ }
886
+ },
887
+ 'tile':{
888
+ 'tile_size': 56,
889
+ 'tile_pad': 4
890
+ }
891
+
892
+ }
893
+
894
+ DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
895
+ DEVICE
896
+
897
+ class Network:
898
+ def __init__(self, train_dataloader=train_dataloader, valid_dataloader=valid_dataloader,
899
+ config = config, device=DEVICE, run_id=None, wandb_mode = False, STOP = float('inf'), save_temp_model = True, train_model_continue = False):
900
+ self.config = config
901
+ self.model = HAT(
902
+ upscale=self.config['network_g']['upscale'],
903
+ in_chans=self.config['network_g']['in_chans'],
904
+ img_size=self.config['network_g']['img_size'],
905
+ window_size=self.config['network_g']['window_size'],
906
+ compress_ratio=self.config['network_g']['compress_ratio'],
907
+ squeeze_factor=self.config['network_g']['squeeze_factor'],
908
+ conv_scale=self.config['network_g']['conv_scale'],
909
+ overlap_ratio=self.config['network_g']['overlap_ratio'],
910
+ img_range=self.config['network_g']['img_range'],
911
+ depths=self.config['network_g']['depths'],
912
+ embed_dim=self.config['network_g']['embed_dim'],
913
+ num_heads=self.config['network_g']['num_heads'],
914
+ mlp_ratio=self.config['network_g']['mlp_ratio'],
915
+ upsampler=self.config['network_g']['upsampler'],
916
+ resi_connection=self.config['network_g']['resi_connection']
917
+ ).to(device)
918
+ self.device = device
919
+ self.STOP = STOP
920
+ self.wandb_mode = wandb_mode
921
+ self.loss_fn = nn.L1Loss(reduction='mean').to(device)
922
+
923
+ self.optimizer = optim.Adam(self.model.parameters(), lr=self.config['train']['optim_g']['lr'], weight_decay=config['train']['optim_g']['weight_decay'],betas=tuple(config['train']['optim_g']['betas']))
924
+ self.scheduler = optim.lr_scheduler.MultiStepLR(self.optimizer, milestones = self.config['train']['scheduler']['milestones'], gamma=self.config['train']['scheduler']['gamma'])
925
+ self.train_dataloader = train_dataloader
926
+ self.valid_dataloader = valid_dataloader
927
+ self.num_epochs = self.config['train']['total_iter']
928
+ self.run_id = run_id
929
+ self.save_temp_model = save_temp_model
930
+ self.train_model_continue = train_model_continue
931
+ self.last_valid_loss = float('inf')
932
+ checkpoint_path = output
933
+ if self.save_temp_model:
934
+ if self.train_model_continue:
935
+ # Load the network and other states from the checkpoint
936
+ self.start_epoch, train_loss, valid_loss = self.load_network(checkpoint_path)
937
+
938
+ initial_lr = self.config['train']['optim_g']['lr'] * self.config['train']['scheduler']['gamma'] # Define your initial or desired learning rate
939
+ for param_group in self.optimizer.param_groups:
940
+ param_group['lr'] = initial_lr # Resetting learning rate
941
+
942
+ # Recreate the scheduler with the updated optimizer
943
+ self.scheduler = optim.lr_scheduler.MultiStepLR(
944
+ self.optimizer,
945
+ milestones=self.config['train']['scheduler']['milestones'],
946
+ gamma=self.config['train']['scheduler']['gamma'],
947
+ last_epoch = self.start_epoch - 1 # Ensure to set the last_epoch to continue correctly
948
+ )
949
+
950
+ # Print the updated learning rate and scheduler state
951
+ print("Updated Learning Rate is:", self.optimizer.param_groups[0]['lr'])
952
+ print(self.scheduler.state_dict())
953
+ self.last_valid_loss = valid_loss
954
+ # self.num_epochs-= self.start_epoch
955
+ print("Previous train loss: ", train_loss)
956
+ print("Previous valid loss: ", self.last_valid_loss)
957
+
958
+ # Resume training notice
959
+ print("------------------- Resuming training -------------------")
960
+
961
+ self.save_network(0, 0, 0, 'temp_model_checkpoint.pth')
962
+
963
+ def del_model(self):
964
+ del self.model
965
+ del self.optimizer
966
+ del self.scheduler
967
+ gc.collect()
968
+ torch.cuda.empty_cache()
969
+
970
+ def pre_process(self):
971
+ # pad to multiplication of window_size
972
+ window_size = self.config['network_g']['window_size'] * 4
973
+ self.scale = self.config['network_g']['upscale']
974
+
975
+ self.mod_pad_h, self.mod_pad_w = 0, 0
976
+ _, _, h, w = self.input_tile.size()
977
+
978
+ if h % window_size != 0:
979
+ self.mod_pad_h = window_size - h % window_size
980
+ # Loop to add padding to the height until it's a multiple of window_size
981
+ for i in range(self.mod_pad_h):
982
+ self.input_tile = F.pad(self.input_tile, (0, 0, 0, 1), 'reflect')
983
+
984
+ if w % window_size != 0:
985
+ # Loop to add padding to the width until it's a multiple of window_size
986
+ self.mod_pad_w = window_size - w % window_size
987
+ for i in range(self.mod_pad_w):
988
+ self.input_tile = F.pad(self.input_tile, (0, 1, 0, 0), 'reflect')
989
+
990
+
991
+ def post_process(self):
992
+ _, _, h, w = self.output_tile.size()
993
+ self.output_tile = self.output_tile[:, :, 0:h - self.mod_pad_h * self.scale, 0:w - self.mod_pad_w * self.scale]
994
+
995
+
996
+ def save_network(self, epoch, train_loss, valid_loss, checkpoint_path):
997
+ checkpoint = {
998
+ 'epoch': epoch,
999
+ 'train_loss': train_loss,
1000
+ 'valid_loss': valid_loss,
1001
+ 'model': self.model.state_dict(),
1002
+ 'optimizer': self.optimizer.state_dict(),
1003
+ 'learning_rate_scheduler': self.scheduler.state_dict(),
1004
+ 'network': self
1005
+ }
1006
+ torch.save(checkpoint, checkpoint_path)
1007
+
1008
+ def load_network(self, checkpoint_path):
1009
+
1010
+ checkpoint = torch.load(checkpoint_path, map_location=self.device)
1011
+ self.model = HAT(
1012
+ upscale=self.config['network_g']['upscale'],
1013
+ in_chans=self.config['network_g']['in_chans'],
1014
+ img_size=self.config['network_g']['img_size'],
1015
+ window_size=self.config['network_g']['window_size'],
1016
+ compress_ratio=self.config['network_g']['compress_ratio'],
1017
+ squeeze_factor=self.config['network_g']['squeeze_factor'],
1018
+ conv_scale=self.config['network_g']['conv_scale'],
1019
+ overlap_ratio=self.config['network_g']['overlap_ratio'],
1020
+ img_range=self.config['network_g']['img_range'],
1021
+ depths=self.config['network_g']['depths'],
1022
+ embed_dim=self.config['network_g']['embed_dim'],
1023
+ num_heads=self.config['network_g']['num_heads'],
1024
+ mlp_ratio=self.config['network_g']['mlp_ratio'],
1025
+ upsampler=self.config['network_g']['upsampler'],
1026
+ resi_connection=self.config['network_g']['resi_connection']
1027
+ ).to(self.device)
1028
+ self.optimizer = optim.Adam(self.model.parameters(), lr=self.config['train']['optim_g']['lr'], weight_decay=config['train']['optim_g']['weight_decay'],betas=tuple(config['train']['optim_g']['betas']))
1029
+ self.model.load_state_dict(checkpoint['model'])
1030
+ self.optimizer.load_state_dict(checkpoint['optimizer']) # before create and load scheduler
1031
+
1032
+ self.scheduler = optim.lr_scheduler.MultiStepLR(self.optimizer, milestones = self.config['train']['scheduler']['milestones'], gamma=self.config['train']['scheduler']['gamma'])
1033
+ self.scheduler.load_state_dict(checkpoint['learning_rate_scheduler'])
1034
+ return checkpoint['epoch'], checkpoint['train_loss'], checkpoint['valid_loss']
1035
+
1036
+ def train_step(self, lr_images, hr_images):
1037
+ lr_images, hr_images = lr_images.to(self.device), hr_images.to(self.device)
1038
+ sr_images = self.model(lr_images)
1039
+
1040
+ self.optimizer.zero_grad()
1041
+ loss = self.loss_fn(sr_images, hr_images)
1042
+ loss.backward()
1043
+ self.optimizer.step()
1044
+
1045
+ # Memory cleanup
1046
+ del sr_images, lr_images, hr_images
1047
+ gc.collect()
1048
+ torch.cuda.empty_cache()
1049
+
1050
+ return loss.item()
1051
+
1052
+ def valid_step(self, lr_images, hr_images):
1053
+ lr_images, hr_images = lr_images.to(self.device), hr_images.to(self.device)
1054
+
1055
+ sr_images = self.tile_valid(lr_images)
1056
+
1057
+ loss = self.loss_fn(sr_images, hr_images)
1058
+
1059
+ # Memory cleanup
1060
+ del sr_images, lr_images, hr_images
1061
+ gc.collect()
1062
+ torch.cuda.empty_cache()
1063
+
1064
+ return loss.item()
1065
+
1066
+
1067
+ def tile_valid(self, lr_images):
1068
+ """
1069
+ Process all tiles of an image in a batch and then merge them back into the output image.
1070
+ """
1071
+
1072
+ batch, channel, height, width = lr_images.shape
1073
+ output_height = height * self.config['network_g']['upscale']
1074
+ output_width = width * self.config['network_g']['upscale']
1075
+ output_shape = (batch, channel, output_height, output_width)
1076
+
1077
+ # Start with black image for output
1078
+ sr_images = lr_images.new_zeros(output_shape)
1079
+ tiles_x = math.ceil(width / self.config['tile']['tile_size'])
1080
+ tiles_y = math.ceil(height / self.config['tile']['tile_size'])
1081
+
1082
+ tile_list = []
1083
+
1084
+ # Extract all tiles
1085
+ for y in range(tiles_y):
1086
+ for x in range(tiles_x):
1087
+
1088
+ input_start_x = x * self.config['tile']['tile_size']
1089
+ input_end_x = min(input_start_x + self.config['tile']['tile_size'], width)
1090
+ input_start_y = y * self.config['tile']['tile_size']
1091
+ input_end_y = min(input_start_y + self.config['tile']['tile_size'], height)
1092
+
1093
+ input_start_x_pad = max(input_start_x - self.config['tile']['tile_pad'], 0)
1094
+ input_end_x_pad = min(input_end_x + self.config['tile']['tile_pad'], width)
1095
+ input_start_y_pad = max(input_start_y - self.config['tile']['tile_pad'], 0)
1096
+ input_end_y_pad = min(input_end_y + self.config['tile']['tile_pad'], height)
1097
+
1098
+ # Extract tile and add to list
1099
+ self.input_tile = lr_images[:, :, input_start_y_pad:input_end_y_pad, input_start_x_pad:input_end_x_pad]
1100
+ self.pre_process()
1101
+ tile_list.append(self.input_tile.clone())
1102
+
1103
+ output_tiles = []
1104
+
1105
+ # Determine the number of tiles to process per batch
1106
+ batch_size = 16 # Adjust based on your specific situation
1107
+
1108
+ for i in range(0, len(tile_list), batch_size):
1109
+ # Extract a batch of tiles
1110
+ batch = tile_list[i:i + batch_size]
1111
+ tile_batch = torch.cat(batch, dim=0) # This creates a batch of tiles
1112
+
1113
+ # Process the batch through the model
1114
+ self.model.eval()
1115
+ with torch.no_grad():
1116
+ # Ensure that each tile processed by the model returns a 3D tensor (C, H, W)
1117
+ output_batch = self.model(tile_batch)
1118
+
1119
+ # Extend the list of processed tiles
1120
+ output_tiles.append(output_batch) # Assuming output_batch is 4D
1121
+
1122
+ # Concatenate along the first dimension to combine all the processed tiles
1123
+ output_tile_batch = torch.cat(output_tiles, dim=0) # This should be 4D now
1124
+
1125
+
1126
+ for y in range(tiles_y):
1127
+ for x in range(tiles_x):
1128
+ # input tile area on total image
1129
+ input_start_x = x * self.config['tile']['tile_size']
1130
+ input_end_x = min(input_start_x + self.config['tile']['tile_size'], width)
1131
+ input_start_y = y * self.config['tile']['tile_size']
1132
+ input_end_y = min(input_start_y + self.config['tile']['tile_size'], height)
1133
+
1134
+ # input tile area on total image with padding
1135
+ input_start_x_pad = max(input_start_x - self.config['tile']['tile_pad'], 0)
1136
+ input_end_x_pad = min(input_end_x + self.config['tile']['tile_pad'], width)
1137
+ input_start_y_pad = max(input_start_y - self.config['tile']['tile_pad'], 0)
1138
+ input_end_y_pad = min(input_end_y + self.config['tile']['tile_pad'], height)
1139
+
1140
+ # input tile dimensions
1141
+ input_tile_width = input_end_x - input_start_x
1142
+ input_tile_height = input_end_y - input_start_y
1143
+ tile_idx = y * tiles_x + x
1144
+
1145
+ self.pre_process()
1146
+ self.output_tile = output_tile_batch[tile_idx, :, :, :].unsqueeze(0).clone()
1147
+ self.post_process()
1148
+
1149
+ # output tile area on total image
1150
+ output_start_x = input_start_x * self.config['network_g']['upscale']
1151
+ output_end_x = input_end_x * self.config['network_g']['upscale']
1152
+ output_start_y = input_start_y * self.config['network_g']['upscale']
1153
+ output_end_y = input_end_y * self.config['network_g']['upscale']
1154
+
1155
+ # output tile area without padding
1156
+ output_start_x_tile = (input_start_x - input_start_x_pad) * self.config['network_g']['upscale']
1157
+ output_end_x_tile = output_start_x_tile + input_tile_width * self.config['network_g']['upscale']
1158
+ output_start_y_tile = (input_start_y - input_start_y_pad) * self.config['network_g']['upscale']
1159
+ output_end_y_tile = output_start_y_tile + input_tile_height * self.config['network_g']['upscale']
1160
+
1161
+ # put tile into output image
1162
+ sr_images[:, :, output_start_y:output_end_y,
1163
+ output_start_x:output_end_x] = self.output_tile[:, :, output_start_y_tile:output_end_y_tile,
1164
+ output_start_x_tile:output_end_x_tile]
1165
+
1166
+ del self.input_tile, self.output_tile, tile_batch, tile_list, output_tile_batch, output_tiles
1167
+ gc.collect()
1168
+ torch.cuda.empty_cache()
1169
+ return sr_images
1170
+
1171
+ def train_model(self):
1172
+
1173
+ if self.wandb_mode:
1174
+ wandb.init(project='HAT-for-image-sr',
1175
+ resume='allow',
1176
+ config= self.config,
1177
+ id=self.run_id)
1178
+ wandb.watch(self.model)
1179
+ if self.train_model_continue:
1180
+ epoch_lst = range(self.start_epoch, self.num_epochs)
1181
+ else:
1182
+ epoch_lst = range(self.num_epochs)
1183
+ for epoch in epoch_lst:
1184
+
1185
+ start1 = time.time()
1186
+
1187
+ # ------------------- TRAIN -------------------
1188
+ if self.save_temp_model:
1189
+ self.load_network('temp_model_checkpoint.pth')
1190
+ self.model.train()
1191
+ train_epoch_loss = 0
1192
+
1193
+ stop = 0
1194
+ for hr_images, lr_images in tqdm(self.train_dataloader, desc=f'Epoch {epoch+1}/{self.num_epochs}'):
1195
+
1196
+ if stop == self.STOP:
1197
+ break
1198
+ stop+=1
1199
+
1200
+ loss = self.train_step(lr_images, hr_images)
1201
+ train_epoch_loss += loss
1202
+
1203
+ if self.wandb_mode:
1204
+ wandb.log({
1205
+ 'batch_loss': loss,
1206
+ })
1207
+
1208
+ if self.wandb_mode:
1209
+ wandb.log({
1210
+ 'learning_rate': self.optimizer.param_groups[0]['lr']
1211
+ })
1212
+ print("Learning Rate is:", self.optimizer.param_groups[0]['lr'])
1213
+
1214
+ self.scheduler.step()
1215
+
1216
+
1217
+ if self.save_temp_model:
1218
+ self.save_network(epoch, train_epoch_loss, 0, 'temp_model_checkpoint.pth')
1219
+ print(self.scheduler.state_dict())
1220
+ self.del_model()
1221
+
1222
+ del hr_images
1223
+ del lr_images
1224
+ gc.collect()
1225
+
1226
+ train_epoch_loss /= len(self.train_dataloader)
1227
+
1228
+ end1 = time.time()
1229
+
1230
+
1231
+ # ------------------- VALID -------------------
1232
+ start2 = time.time()
1233
+ if self.save_temp_model:
1234
+ self.load_network('temp_model_checkpoint.pth')
1235
+
1236
+ self.model.eval()
1237
+ with torch.no_grad():
1238
+ valid_epoch_loss = 0
1239
+
1240
+ stop = 0
1241
+ for hr_images, lr_images in tqdm(self.valid_dataloader, desc=f'Epoch {epoch+1}/{self.num_epochs}'):
1242
+ if stop == self.STOP:
1243
+ break
1244
+ stop+=1
1245
+ loss = self.valid_step(lr_images, hr_images)
1246
+ valid_epoch_loss += loss
1247
+
1248
+ valid_epoch_loss /= len(self.valid_dataloader)
1249
+
1250
+ end2 = time.time()
1251
+
1252
+ # ------------------- LOG -------------------
1253
+ if self.wandb_mode:
1254
+ wandb.log({
1255
+ 'train_loss': train_epoch_loss,
1256
+ 'valid_loss': valid_epoch_loss,
1257
+ })
1258
+ # ------------------- VERBOSE -------------------
1259
+ print(f'Epoch {epoch+1}/{self.num_epochs} | Train Loss: {train_epoch_loss:.4f} | Valid Loss: {valid_epoch_loss:.4f} | Time train: {end1-start1:.2f}s | Time valid: {end2-start2:.2f}s')
1260
+
1261
+ # ------------------- CHECKPOINT -------------------
1262
+ self.save_network(epoch, train_epoch_loss, valid_epoch_loss, 'model_checkpoint_latest.pth')
1263
+ if valid_epoch_loss < self.last_valid_loss:
1264
+ self.last_valid_loss = valid_epoch_loss
1265
+ self.save_network(epoch, train_epoch_loss, valid_epoch_loss, 'model_checkpoint_best.pth')
1266
+ print("New best checkpoint saved!")
1267
+
1268
+ if self.save_temp_model:
1269
+ self.del_model()
1270
+
1271
+ del hr_images
1272
+ del lr_images
1273
+ gc.collect()
1274
+
1275
+ if self.wandb_mode:
1276
+ wandb.finish()
1277
+
1278
+ def inference(self, lr_image, hr_image):
1279
+ """
1280
+ - lr_image: torch.Tensor
1281
+ 3D Tensor (C, H, W)
1282
+ - hr_image: torch.Tesnor
1283
+ 3D Tensor (C, H, W). This parameter is optional, for comparing the model output and the
1284
+ ground-truth high-res image. If used solely for inference, skip this. Default is None/
1285
+ """
1286
+ lr_image = lr_image.unsqueeze(0).to(self.device)
1287
+ self.for_inference = True
1288
+ with torch.no_grad():
1289
+ sr_image = self.tile_valid(lr_image)
1290
+
1291
+ lr_image = lr_image.squeeze(0)
1292
+ sr_image = sr_image.squeeze(0)
1293
+
1294
+ print(">> Size of low-res image:", lr_image.size())
1295
+ print(">> Size of super-res image:", sr_image.size())
1296
+ if hr_image != None:
1297
+ print(">> Size of high-res image:", hr_image.size())
1298
+
1299
+ if hr_image != None:
1300
+ fig, axes = plt.subplots(1, 3, figsize=(10, 6))
1301
+ axes[0].imshow(lr_image.cpu().detach().permute((1, 2, 0)))
1302
+ axes[0].set_title('Low Resolution')
1303
+ axes[1].imshow(sr_image.cpu().detach().permute((1, 2, 0)))
1304
+ axes[1].set_title('Super Resolution')
1305
+ axes[2].imshow(hr_image.cpu().detach().permute((1, 2, 0)))
1306
+ axes[2].set_title('High Resolution')
1307
+ for ax in axes.flat:
1308
+ ax.axis('off')
1309
+ else:
1310
+ fig, axes = plt.subplots(1, 2, figsize=(10, 6))
1311
+ axes[0].imshow(lr_image.cpu().detach().permute((1, 2, 0)))
1312
+ axes[0].set_title('Low Resolution')
1313
+ axes[1].imshow(sr_image.cpu().detach().permute((1, 2, 0)))
1314
+ axes[1].set_title('Super Resolution')
1315
+ for ax in axes.flat:
1316
+ ax.axis('off')
1317
+
1318
+ plt.tight_layout()
1319
+ plt.show()
1320
+
1321
+ return sr_image
1322
+
1323
+
1324
+ class TestDataset(Dataset):
1325
+ def __init__(self, lr_images_path):
1326
+ super(TestDataset, self).__init__()
1327
+ # hr_images_list = os.listdir(hr_images_path)
1328
+ self.lr_images_path = lr_images_path
1329
+
1330
+ def __getitem__(self, idx):
1331
+
1332
+ lr_image = Image.open(self.lr_image_path)
1333
+
1334
+ lr_image = transforms.functional.to_tensor(lr_image)
1335
+
1336
+ return lr_image
1337
+
1338
+
1339
+ if __name__ == "__main__":
1340
+ import os
1341
+ import sys
1342
+ # Getting to the Lambda directory
1343
+ sys.path.append(os.path.join(os.path.dirname(os.path.realpath(__file__)), "../../"))
1344
+ image_path = "images/img_003_SRF_4_LR.png"
1345
+
1346
+ infer_dataset = TestDataset(images_path=image_path)
1347
+
1348
+ # hat = Network(run_id="hat-for-image-sr-" + str(int(1704006834)),config = config, wandb_mode = False, save_temp_model = True, train_model_continue = False) # STOP = 2
1349
+ # num_params = sum(p.numel() for p in hat.model.parameters() if p.requires_grad)
1350
+ # print("Number of learnable parameters: ", num_params)
1351
+
1352
+ # ---------- LOAD FROM LATEST CHECKPOINT ---------- #
1353
+ gc.collect()
1354
+ torch.cuda.empty_cache()
1355
+ hat = Network()
1356
+ hat.load_network(output)
1357
+ num_params = sum(p.numel() for p in hat.model.parameters() if p.requires_grad)
1358
+ print("Number of learnable parameters: ", num_params)
1359
+ image = image.squeeze(0)
1360
+ hat.inference(lr_image)
1361
+
1362
+
1363
+
requirements.txt CHANGED
@@ -4,4 +4,5 @@ basicsr
4
  skimage
5
  torchvision
6
  torchmetrics
7
- streamlit
 
 
4
  skimage
5
  torchvision
6
  torchmetrics
7
+ streamlit
8
+ gdown