0x90e commited on
Commit
93c8609
·
1 Parent(s): 17cfe57

Add support for ESRGAN+ and new 4x Valar v1 model.

Browse files
architecture.py → ESRGAN/architecture.py RENAMED
@@ -1,8 +1,7 @@
1
  import math
2
  import torch
3
  import torch.nn as nn
4
- import block as B
5
-
6
 
7
  class RRDB_Net(nn.Module):
8
  def __init__(self, in_nc, out_nc, nf, nb, gc=32, upscale=4, norm_type=None, act_type='leakyrelu', \
 
1
  import math
2
  import torch
3
  import torch.nn as nn
4
+ import ESRGAN.block as B
 
5
 
6
  class RRDB_Net(nn.Module):
7
  def __init__(self, in_nc, out_nc, nf, nb, gc=32, upscale=4, norm_type=None, act_type='leakyrelu', \
block.py → ESRGAN/block.py RENAMED
File without changes
ESRGAN_plus/architecture.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ import torch.nn as nn
4
+ import ESRGAN_plus.block as B
5
+
6
+
7
+ class RRDB_Net(nn.Module):
8
+ def __init__(self, in_nc, out_nc, nf, nb, gc=32, upscale=4, norm_type=None, act_type='leakyrelu', \
9
+ mode='CNA', res_scale=1, upsample_mode='upconv'):
10
+ super(RRDB_Net, self).__init__()
11
+ n_upscale = int(math.log(upscale, 2))
12
+ if upscale == 3:
13
+ n_upscale = 1
14
+
15
+ fea_conv = B.conv_block(in_nc, nf, kernel_size=3, norm_type=None, act_type=None)
16
+ rb_blocks = [B.RRDB(nf, kernel_size=3, gc=32, stride=1, bias=True, pad_type='zero', \
17
+ norm_type=norm_type, act_type=act_type, mode='CNA') for _ in range(nb)]
18
+ LR_conv = B.conv_block(nf, nf, kernel_size=3, norm_type=norm_type, act_type=None, mode=mode)
19
+
20
+ if upsample_mode == 'upconv':
21
+ upsample_block = B.upconv_blcok
22
+ elif upsample_mode == 'pixelshuffle':
23
+ upsample_block = B.pixelshuffle_block
24
+ else:
25
+ raise NotImplementedError('upsample mode [%s] is not found' % upsample_mode)
26
+ if upscale == 3:
27
+ upsampler = upsample_block(nf, nf, 3, act_type=act_type)
28
+ else:
29
+ upsampler = [upsample_block(nf, nf, act_type=act_type) for _ in range(n_upscale)]
30
+ HR_conv0 = B.conv_block(nf, nf, kernel_size=3, norm_type=None, act_type=act_type)
31
+ HR_conv1 = B.conv_block(nf, out_nc, kernel_size=3, norm_type=None, act_type=None)
32
+
33
+ self.model = B.sequential(fea_conv, B.ShortcutBlock(B.sequential(*rb_blocks, LR_conv)),\
34
+ *upsampler, HR_conv0, HR_conv1)
35
+
36
+ def forward(self, x):
37
+ x = self.model(x)
38
+ return x
ESRGAN_plus/block.py ADDED
@@ -0,0 +1,287 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import OrderedDict
2
+ import torch
3
+ import torch.nn as nn
4
+
5
+ ####################
6
+ # Basic blocks
7
+ ####################
8
+
9
+
10
+ def act(act_type, inplace=True, neg_slope=0.2, n_prelu=1):
11
+ # helper selecting activation
12
+ # neg_slope: for leakyrelu and init of prelu
13
+ # n_prelu: for p_relu num_parameters
14
+ act_type = act_type.lower()
15
+ if act_type == 'relu':
16
+ layer = nn.ReLU(inplace)
17
+ elif act_type == 'leakyrelu':
18
+ layer = nn.LeakyReLU(neg_slope, inplace)
19
+ elif act_type == 'prelu':
20
+ layer = nn.PReLU(num_parameters=n_prelu, init=neg_slope)
21
+ else:
22
+ raise NotImplementedError('activation layer [{:s}] is not found'.format(act_type))
23
+ return layer
24
+
25
+
26
+ def norm(norm_type, nc):
27
+ # helper selecting normalization layer
28
+ norm_type = norm_type.lower()
29
+ if norm_type == 'batch':
30
+ layer = nn.BatchNorm2d(nc, affine=True)
31
+ elif norm_type == 'instance':
32
+ layer = nn.InstanceNorm2d(nc, affine=False)
33
+ else:
34
+ raise NotImplementedError('normalization layer [{:s}] is not found'.format(norm_type))
35
+ return layer
36
+
37
+
38
+ def pad(pad_type, padding):
39
+ # helper selecting padding layer
40
+ # if padding is 'zero', do by conv layers
41
+ pad_type = pad_type.lower()
42
+ if padding == 0:
43
+ return None
44
+ if pad_type == 'reflect':
45
+ layer = nn.ReflectionPad2d(padding)
46
+ elif pad_type == 'replicate':
47
+ layer = nn.ReplicationPad2d(padding)
48
+ else:
49
+ raise NotImplementedError('padding layer [{:s}] is not implemented'.format(pad_type))
50
+ return layer
51
+
52
+
53
+ def get_valid_padding(kernel_size, dilation):
54
+ kernel_size = kernel_size + (kernel_size - 1) * (dilation - 1)
55
+ padding = (kernel_size - 1) // 2
56
+ return padding
57
+
58
+
59
+ class ConcatBlock(nn.Module):
60
+ # Concat the output of a submodule to its input
61
+ def __init__(self, submodule):
62
+ super(ConcatBlock, self).__init__()
63
+ self.sub = submodule
64
+
65
+ def forward(self, x):
66
+ output = torch.cat((x, self.sub(x)), dim=1)
67
+ return output
68
+
69
+ def __repr__(self):
70
+ tmpstr = 'Identity .. \n|'
71
+ modstr = self.sub.__repr__().replace('\n', '\n|')
72
+ tmpstr = tmpstr + modstr
73
+ return tmpstr
74
+
75
+
76
+ class ShortcutBlock(nn.Module):
77
+ #Elementwise sum the output of a submodule to its input
78
+ def __init__(self, submodule):
79
+ super(ShortcutBlock, self).__init__()
80
+ self.sub = submodule
81
+
82
+ def forward(self, x):
83
+ output = x + self.sub(x)
84
+ return output
85
+
86
+ def __repr__(self):
87
+ tmpstr = 'Identity + \n|'
88
+ modstr = self.sub.__repr__().replace('\n', '\n|')
89
+ tmpstr = tmpstr + modstr
90
+ return tmpstr
91
+
92
+
93
+ def sequential(*args):
94
+ # Flatten Sequential. It unwraps nn.Sequential.
95
+ if len(args) == 1:
96
+ if isinstance(args[0], OrderedDict):
97
+ raise NotImplementedError('sequential does not support OrderedDict input.')
98
+ return args[0] # No sequential is needed.
99
+ modules = []
100
+ for module in args:
101
+ if isinstance(module, nn.Sequential):
102
+ for submodule in module.children():
103
+ modules.append(submodule)
104
+ elif isinstance(module, nn.Module):
105
+ modules.append(module)
106
+ return nn.Sequential(*modules)
107
+
108
+
109
+ def conv_block(in_nc, out_nc, kernel_size, stride=1, dilation=1, groups=1, bias=True, \
110
+ pad_type='zero', norm_type=None, act_type='relu', mode='CNA'):
111
+ '''
112
+ Conv layer with padding, normalization, activation
113
+ mode: CNA --> Conv -> Norm -> Act
114
+ NAC --> Norm -> Act --> Conv (Identity Mappings in Deep Residual Networks, ECCV16)
115
+ '''
116
+ assert mode in ['CNA', 'NAC', 'CNAC'], 'Wong conv mode [{:s}]'.format(mode)
117
+ padding = get_valid_padding(kernel_size, dilation)
118
+ p = pad(pad_type, padding) if pad_type and pad_type != 'zero' else None
119
+ padding = padding if pad_type == 'zero' else 0
120
+
121
+ c = nn.Conv2d(in_nc, out_nc, kernel_size=kernel_size, stride=stride, padding=padding, \
122
+ dilation=dilation, bias=bias, groups=groups)
123
+ a = act(act_type) if act_type else None
124
+ if 'CNA' in mode:
125
+ n = norm(norm_type, out_nc) if norm_type else None
126
+ return sequential(p, c, n, a)
127
+ elif mode == 'NAC':
128
+ if norm_type is None and act_type is not None:
129
+ a = act(act_type, inplace=False)
130
+ # Important!
131
+ # input----ReLU(inplace)----Conv--+----output
132
+ # |________________________|
133
+ # inplace ReLU will modify the input, therefore wrong output
134
+ n = norm(norm_type, in_nc) if norm_type else None
135
+ return sequential(n, a, p, c)
136
+
137
+
138
+ def conv1x1(in_planes, out_planes, stride=1):
139
+ """1x1 convolution"""
140
+ return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
141
+
142
+
143
+ class GaussianNoise(nn.Module):
144
+ def __init__(self, sigma=0.1, is_relative_detach=False):
145
+ super().__init__()
146
+ self.sigma = sigma
147
+ self.is_relative_detach = is_relative_detach
148
+ self.noise = torch.tensor(0, dtype=torch.float).to(torch.device('cuda'))
149
+
150
+ def forward(self, x):
151
+ if self.training and self.sigma != 0:
152
+ scale = self.sigma * x.detach() if self.is_relative_detach else self.sigma * x
153
+ sampled_noise = self.noise.repeat(*x.size()).normal_() * scale
154
+ x = x + sampled_noise
155
+ return x
156
+
157
+
158
+ ####################
159
+ # Useful blocks
160
+ ####################
161
+
162
+
163
+ class ResNetBlock(nn.Module):
164
+ '''
165
+ ResNet Block, 3-3 style
166
+ with extra residual scaling used in EDSR
167
+ (Enhanced Deep Residual Networks for Single Image Super-Resolution, CVPRW 17)
168
+ '''
169
+
170
+ def __init__(self, in_nc, mid_nc, out_nc, kernel_size=3, stride=1, dilation=1, groups=1, \
171
+ bias=True, pad_type='zero', norm_type=None, act_type='relu', mode='CNA', res_scale=1):
172
+ super(ResNetBlock, self).__init__()
173
+ conv0 = conv_block(in_nc, mid_nc, kernel_size, stride, dilation, groups, bias, pad_type, \
174
+ norm_type, act_type, mode)
175
+ if mode == 'CNA':
176
+ act_type = None
177
+ if mode == 'CNAC': # Residual path: |-CNAC-|
178
+ act_type = None
179
+ norm_type = None
180
+ conv1 = conv_block(mid_nc, out_nc, kernel_size, stride, dilation, groups, bias, pad_type, \
181
+ norm_type, act_type, mode)
182
+ # if in_nc != out_nc:
183
+ # self.project = conv_block(in_nc, out_nc, 1, stride, dilation, 1, bias, pad_type, \
184
+ # None, None)
185
+ # print('Need a projecter in ResNetBlock.')
186
+ # else:
187
+ # self.project = lambda x:x
188
+ self.res = sequential(conv0, conv1)
189
+ self.res_scale = res_scale
190
+
191
+ def forward(self, x):
192
+ res = self.res(x).mul(self.res_scale)
193
+ return x + res
194
+
195
+
196
+ class ResidualDenseBlock_5C(nn.Module):
197
+ '''
198
+ Residual Dense Block
199
+ style: 5 convs
200
+ The core module of paper: (Residual Dense Network for Image Super-Resolution, CVPR 18)
201
+ '''
202
+
203
+ def __init__(self, nc, kernel_size=3, gc=32, stride=1, bias=True, pad_type='zero', \
204
+ norm_type=None, act_type='leakyrelu', mode='CNA', noise_input=True):
205
+ super(ResidualDenseBlock_5C, self).__init__()
206
+ # gc: growth channel, i.e. intermediate channels
207
+ self.noise = GaussianNoise() if noise_input else None
208
+ self.conv1x1 = conv1x1(nc, gc)
209
+ self.conv1 = conv_block(nc, gc, kernel_size, stride, bias=bias, pad_type=pad_type, \
210
+ norm_type=norm_type, act_type=act_type, mode=mode)
211
+ self.conv2 = conv_block(nc+gc, gc, kernel_size, stride, bias=bias, pad_type=pad_type, \
212
+ norm_type=norm_type, act_type=act_type, mode=mode)
213
+ self.conv3 = conv_block(nc+2*gc, gc, kernel_size, stride, bias=bias, pad_type=pad_type, \
214
+ norm_type=norm_type, act_type=act_type, mode=mode)
215
+ self.conv4 = conv_block(nc+3*gc, gc, kernel_size, stride, bias=bias, pad_type=pad_type, \
216
+ norm_type=norm_type, act_type=act_type, mode=mode)
217
+ if mode == 'CNA':
218
+ last_act = None
219
+ else:
220
+ last_act = act_type
221
+ self.conv5 = conv_block(nc+4*gc, nc, 3, stride, bias=bias, pad_type=pad_type, \
222
+ norm_type=norm_type, act_type=last_act, mode=mode)
223
+
224
+ def forward(self, x):
225
+ x1 = self.conv1(x)
226
+ x2 = self.conv2(torch.cat((x, x1), 1))
227
+ x2 = x2 + self.conv1x1(x)
228
+ x3 = self.conv3(torch.cat((x, x1, x2), 1))
229
+ x4 = self.conv4(torch.cat((x, x1, x2, x3), 1))
230
+ x4 = x4 + x2
231
+ x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1))
232
+ return self.noise(x5.mul(0.2) + x)
233
+
234
+
235
+ class RRDB(nn.Module):
236
+ '''
237
+ Residual in Residual Dense Block
238
+ (ESRGAN: Enhanced Super-Resolution Generative Adversarial Networks)
239
+ '''
240
+
241
+ def __init__(self, nc, kernel_size=3, gc=32, stride=1, bias=True, pad_type='zero', \
242
+ norm_type=None, act_type='leakyrelu', mode='CNA'):
243
+ super(RRDB, self).__init__()
244
+ self.RDB1 = ResidualDenseBlock_5C(nc, kernel_size, gc, stride, bias, pad_type, \
245
+ norm_type, act_type, mode)
246
+ self.RDB2 = ResidualDenseBlock_5C(nc, kernel_size, gc, stride, bias, pad_type, \
247
+ norm_type, act_type, mode)
248
+ self.RDB3 = ResidualDenseBlock_5C(nc, kernel_size, gc, stride, bias, pad_type, \
249
+ norm_type, act_type, mode)
250
+ self.noise = GaussianNoise()
251
+
252
+ def forward(self, x):
253
+ out = self.RDB1(x)
254
+ out = self.RDB2(out)
255
+ out = self.RDB3(out)
256
+ return self.noise(out.mul(0.2) + x)
257
+
258
+
259
+ ####################
260
+ # Upsampler
261
+ ####################
262
+
263
+
264
+ def pixelshuffle_block(in_nc, out_nc, upscale_factor=2, kernel_size=3, stride=1, bias=True, \
265
+ pad_type='zero', norm_type=None, act_type='relu'):
266
+ '''
267
+ Pixel shuffle layer
268
+ (Real-Time Single Image and Video Super-Resolution Using an Efficient Sub-Pixel Convolutional
269
+ Neural Network, CVPR17)
270
+ '''
271
+ conv = conv_block(in_nc, out_nc * (upscale_factor ** 2), kernel_size, stride, bias=bias, \
272
+ pad_type=pad_type, norm_type=None, act_type=None)
273
+ pixel_shuffle = nn.PixelShuffle(upscale_factor)
274
+
275
+ n = norm(norm_type, out_nc) if norm_type else None
276
+ a = act(act_type) if act_type else None
277
+ return sequential(conv, pixel_shuffle, n, a)
278
+
279
+
280
+ def upconv_blcok(in_nc, out_nc, upscale_factor=2, kernel_size=3, stride=1, bias=True, \
281
+ pad_type='zero', norm_type=None, act_type='relu', mode='nearest'):
282
+ # Up conv
283
+ # described in https://distill.pub/2016/deconv-checkerboard/
284
+ upsample = nn.Upsample(scale_factor=upscale_factor, mode=mode)
285
+ conv = conv_block(in_nc, out_nc, kernel_size, stride, bias=bias, \
286
+ pad_type=pad_type, norm_type=norm_type, act_type=act_type)
287
+ return sequential(upsample, conv)
app.py CHANGED
@@ -35,9 +35,7 @@ with gr.Blocks(title=title, css=css) as demo:
35
  # {title}
36
  This space uses old ESRGAN architecture to upscale images, using models made by the community.
37
 
38
- Once the photo upscaled, click or tap the **download button** under the image to download it. **The preview image is not the upscaled one**
39
-
40
- I'll add more models after optimizing to size of the output image, right now it could be quite big.
41
 
42
  **Colab coming soon™**
43
  """)
@@ -47,7 +45,7 @@ with gr.Blocks(title=title, css=css) as demo:
47
  with gr.Column():
48
  input_image = gr.Image(type="pil", label="Input")
49
  upscale_size = gr.Radio(["x4", "x2"], label="Upscale by:", value="x4")
50
- upscale_type = gr.Radio(["Manga", "Anime", "General"], label="Select the type of picture you want to upscale:", value="Manga")
51
 
52
  with gr.Row():
53
  upscale_btn = gr.Button(value="Upscale", variant="primary")
 
35
  # {title}
36
  This space uses old ESRGAN architecture to upscale images, using models made by the community.
37
 
38
+ Once the photo upscaled (it can take a long time, this space only uses CPU), click or tap the **download button** under the image to download it. **The preview image is not the upscaled one**
 
 
39
 
40
  **Colab coming soon™**
41
  """)
 
45
  with gr.Column():
46
  input_image = gr.Image(type="pil", label="Input")
47
  upscale_size = gr.Radio(["x4", "x2"], label="Upscale by:", value="x4")
48
+ upscale_type = gr.Radio(["Manga", "Anime", "Photo", "General"], label="Select the type of picture you want to upscale:", value="Manga")
49
 
50
  with gr.Row():
51
  upscale_btn = gr.Button(value="Upscale", variant="primary")
inference.py CHANGED
@@ -1,9 +1,9 @@
1
  import sys
2
- import os.path
3
  import cv2
4
  import numpy as np
5
  import torch
6
- import architecture as arch
 
7
  from run_cmd import run_cmd
8
  from ESRGANer import ESRGANer
9
 
@@ -17,6 +17,8 @@ model_type = sys.argv[3]
17
 
18
  if model_type == "Anime":
19
  model_path = "models/4x-AnimeSharp.pth"
 
 
20
  else:
21
  model_path = "models/4x-UniScaleV2_Sharp.pth"
22
 
@@ -24,7 +26,10 @@ img_path = sys.argv[1]
24
  output_dir = sys.argv[2]
25
  device = torch.device('cuda' if is_cuda() else 'cpu')
26
 
27
- model = arch.RRDB_Net(3, 3, 64, 23, gc=32, upscale=4, norm_type=None, act_type='leakyrelu', mode='CNA', res_scale=1, upsample_mode='upconv')
 
 
 
28
 
29
  if is_cuda():
30
  print("Using GPU 🥶")
 
1
  import sys
 
2
  import cv2
3
  import numpy as np
4
  import torch
5
+ import ESRGAN.architecture as esrgan
6
+ import ESRGAN_plus.architecture as esrgan_plus
7
  from run_cmd import run_cmd
8
  from ESRGANer import ESRGANer
9
 
 
17
 
18
  if model_type == "Anime":
19
  model_path = "models/4x-AnimeSharp.pth"
20
+ if model_type == "Photo":
21
+ model_path = "models/4x_Valar_v1.pth"
22
  else:
23
  model_path = "models/4x-UniScaleV2_Sharp.pth"
24
 
 
26
  output_dir = sys.argv[2]
27
  device = torch.device('cuda' if is_cuda() else 'cpu')
28
 
29
+ if model_type != "Photo":
30
+ model = esrgan.RRDB_Net(3, 3, 64, 23, gc=32, upscale=4, norm_type=None, act_type='leakyrelu', mode='CNA', res_scale=1, upsample_mode='upconv')
31
+ else:
32
+ model = esrgan_plus.RRDB_Net(3, 3, 64, 23, gc=32, upscale=4, norm_type=None, act_type='leakyrelu', mode='CNA', res_scale=1, upsample_mode='upconv')
33
 
34
  if is_cuda():
35
  print("Using GPU 🥶")
inference_manga_v2.py CHANGED
@@ -1,9 +1,8 @@
1
  import sys
2
- import os.path
3
  import cv2
4
  import numpy as np
5
  import torch
6
- import architecture as arch
7
  from ESRGANer import ESRGANer
8
 
9
  def is_cuda():
 
1
  import sys
 
2
  import cv2
3
  import numpy as np
4
  import torch
5
+ import ESRGAN.architecture as arch
6
  from ESRGANer import ESRGANer
7
 
8
  def is_cuda():
models/4x_Valar_v1.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:90c3192bef43e4baaa095c04751868065f23c52d98c1b42e6d0916bfeda75646
3
+ size 67544144
net_interp.py DELETED
@@ -1,21 +0,0 @@
1
- import sys
2
- import torch
3
- from collections import OrderedDict
4
-
5
- alpha = float(sys.argv[1])
6
-
7
- net_PSNR_path = './models/RRDB_PSNR_x4.pth'
8
- net_ESRGAN_path = './models/RRDB_ESRGAN_x4.pth'
9
- net_interp_path = './models/interp_{:02d}.pth'.format(int(alpha*10))
10
-
11
- net_PSNR = torch.load(net_PSNR_path)
12
- net_ESRGAN = torch.load(net_ESRGAN_path)
13
- net_interp = OrderedDict()
14
-
15
- print('Interpolating with alpha = ', alpha)
16
-
17
- for k, v_PSNR in net_PSNR.items():
18
- v_ESRGAN = net_ESRGAN[k]
19
- net_interp[k] = (1 - alpha) * v_PSNR + alpha * v_ESRGAN
20
-
21
- torch.save(net_interp, net_interp_path)