pqt commited on
Commit
eee6b71
·
1 Parent(s): d6257d4
.gitignore CHANGED
@@ -1,4 +1,5 @@
1
  models/HAT/__pycache__/hat.cpython-39.pyc
2
  /.venv
3
  /models/HAT/__pycache__
4
- /models/RCAN/__pycache__
 
 
1
  models/HAT/__pycache__/hat.cpython-39.pyc
2
  /.venv
3
  /models/HAT/__pycache__
4
+ /models/RCAN/__pycache__
5
+ /models/SRGAN/__pycache__
app.py CHANGED
@@ -5,6 +5,7 @@ from PIL import Image
5
  from io import BytesIO
6
  from models.HAT.hat import *
7
  from models.RCAN.rcan import *
 
8
 
9
  # Initialize session state for enhanced images
10
  if 'hat_enhanced_image' not in st.session_state:
@@ -13,11 +14,17 @@ if 'hat_enhanced_image' not in st.session_state:
13
  if 'rcan_enhanced_image' not in st.session_state:
14
  st.session_state['rcan_enhanced_image'] = None
15
 
 
 
 
16
  if 'hat_clicked' not in st.session_state:
17
  st.session_state['hat_clicked'] = False
18
  if 'rcan_clicked' not in st.session_state:
19
  st.session_state['rcan_clicked'] = False
20
-
 
 
 
21
  st.markdown("<h1 style='text-align: center'>Image Super Resolution</h1>", unsafe_allow_html=True)
22
  # Sidebar for navigation
23
  st.sidebar.title("Options")
@@ -38,8 +45,10 @@ elif app_mode == "Take a photo":
38
  def reset_states():
39
  st.session_state['hat_enhanced_image'] = None
40
  st.session_state['rcan_enhanced_image'] = None
 
41
  st.session_state['hat_clicked'] = False
42
  st.session_state['rcan_clicked'] = False
 
43
 
44
  def get_image_download_link(img, filename):
45
  """Generates a link allowing the PIL image to be downloaded"""
@@ -93,3 +102,22 @@ if 'image' in locals():
93
  col2.image(st.session_state['rcan_enhanced_image'], use_column_width=True)
94
  with col2:
95
  get_image_download_link(st.session_state['rcan_enhanced_image'], 'rcan_enhanced.jpg')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
  from io import BytesIO
6
  from models.HAT.hat import *
7
  from models.RCAN.rcan import *
8
+ from models.SRGAN.srgan import *
9
 
10
  # Initialize session state for enhanced images
11
  if 'hat_enhanced_image' not in st.session_state:
 
14
  if 'rcan_enhanced_image' not in st.session_state:
15
  st.session_state['rcan_enhanced_image'] = None
16
 
17
+ if 'srgan_enhanced_image' not in st.session_state:
18
+ st.session_state['srgan_enhanced_image'] = None
19
+
20
  if 'hat_clicked' not in st.session_state:
21
  st.session_state['hat_clicked'] = False
22
  if 'rcan_clicked' not in st.session_state:
23
  st.session_state['rcan_clicked'] = False
24
+
25
+ if 'srgan_clicked' not in st.session_state:
26
+ st.session_state['srgan_clicked'] = False
27
+
28
  st.markdown("<h1 style='text-align: center'>Image Super Resolution</h1>", unsafe_allow_html=True)
29
  # Sidebar for navigation
30
  st.sidebar.title("Options")
 
45
  def reset_states():
46
  st.session_state['hat_enhanced_image'] = None
47
  st.session_state['rcan_enhanced_image'] = None
48
+ st.session_state['srgan_enhanced_image'] = None
49
  st.session_state['hat_clicked'] = False
50
  st.session_state['rcan_clicked'] = False
51
+ st.session_state['srgan_clicked'] = False
52
 
53
  def get_image_download_link(img, filename):
54
  """Generates a link allowing the PIL image to be downloaded"""
 
102
  col2.image(st.session_state['rcan_enhanced_image'], use_column_width=True)
103
  with col2:
104
  get_image_download_link(st.session_state['rcan_enhanced_image'], 'rcan_enhanced.jpg')
105
+ #--------------------------SRGAN--------------------------#
106
+ if st.button('Enhance with SRGAN'):
107
+ with st.spinner('Processing using SRGAN...'):
108
+ with st.spinner('Wait for it... the model is processing the image'):
109
+ srgan_model = GeneratorResnet()
110
+ device = torch.device('cpu') if not torch.cuda.is_available() else torch.device('cuda')
111
+ srgan_model = torch.load('models/SRGAN/srgan_checkpoint.pth', map_location=device)
112
+ enhanced_image = srgan_model.inference(image)
113
+ st.session_state['srgan_enhanced_image'] = enhanced_image
114
+ st.session_state['srgan_clicked'] = True
115
+ st.success('Done!')
116
+ if st.session_state['srgan_enhanced_image'] is not None:
117
+ col1, col2 = st.columns(2)
118
+ col1.header("Original")
119
+ col1.image(image, use_column_width=True)
120
+ col2.header("Enhanced")
121
+ col2.image(st.session_state['srgan_enhanced_image'], use_column_width=True)
122
+ with col2:
123
+ get_image_download_link(st.session_state['srgan_enhanced_image'], 'srgan_enhanced.jpg')
models/SRGAN/srgan.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ from torch import nn
4
+ from PIL import Image
5
+ from torchvision.transforms import ToTensor
6
+
7
+ class ResidualBlock(nn.Module):
8
+ def __init__(self, in_features):
9
+ super(ResidualBlock, self).__init__()
10
+ self.conv_block = nn.Sequential(
11
+ nn.Conv2d(in_features, in_features, kernel_size=3, stride=1, padding=1),
12
+ nn.BatchNorm2d(in_features, 0.8),
13
+ nn.PReLU(),
14
+ nn.Conv2d(in_features, in_features, kernel_size=3, stride=1, padding=1),
15
+ nn.BatchNorm2d(in_features, 0.8),
16
+ )
17
+
18
+ def forward(self, x):
19
+ return x + self.conv_block(x)
20
+
21
+ class GeneratorResnet(nn.Module):
22
+ def __init__(self, in_channels=3, out_channels=3, n_residual_blocks=16):
23
+ super(GeneratorResnet, self).__init__()
24
+ #first layer
25
+ self.conv1 = nn.Sequential(nn.Conv2d(in_channels, 64, kernel_size=9, stride=1, padding=4), nn.PReLU())
26
+
27
+ #Residual blocks
28
+ res_blocks=[]
29
+ for _ in range(n_residual_blocks):
30
+ res_blocks.append(ResidualBlock(64))
31
+ self.res_blocks = nn.Sequential(*res_blocks)
32
+
33
+ #second conv layer after res blocks
34
+ self.conv2 = nn.Sequential(nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1), nn.BatchNorm2d(64, 0.8))
35
+ upsampling=[]
36
+ for _ in range(2):
37
+ upsampling+=[
38
+ nn.Conv2d(64, 256, kernel_size=3, stride=1, padding=1),
39
+ nn.BatchNorm2d(256),
40
+ nn.PixelShuffle(upscale_factor=2),
41
+ nn.PReLU(),
42
+ ]
43
+ self.upsampling = nn.Sequential(*upsampling)
44
+
45
+ self.conv3 = nn.Sequential(nn.Conv2d(64, out_channels, kernel_size=9, stride=1, padding=4), nn.Tanh())
46
+
47
+ def forward(self, x):
48
+ out1 = self.conv1(x)
49
+ out = self.res_blocks(out1)
50
+ out2 = self.conv2(out)
51
+ out = torch.add(out1, out2)
52
+ out = self.upsampling(out)
53
+ out = self.conv3(out)
54
+ return out
55
+
56
+ def inference(self, x):
57
+ """
58
+ x is a PIL image
59
+ """
60
+ x = ToTensor()(x).unsqueeze(0)
61
+ x = self.forward(x)
62
+ x = Image.fromarray((x.squeeze(0).permute(1, 2, 0).detach().numpy() * 255).astype('uint8'))
63
+ return x
64
+
65
+ if __name__ == '__main__':
66
+ current_dir = os.path.dirname(os.path.realpath(__file__))
67
+
68
+ model = GeneratorResnet()
69
+ model = torch.load(current_dir + '/srgan_checkpoint.pth', map_location=torch.device('cpu'))
70
+ model.eval()
71
+ with torch.no_grad():
72
+ input_image = Image.open('images/demo.png')
73
+ output_image = model.inference(input_image)
models/SRGAN/srgan_checkpoint.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9b1423b711b26b2250612e02fddf95a4a7214e883c601eb153bfa82caceb5511
3
+ size 6336189