Spaces:
Running
Running
Interpolation added (#5)
Browse files- Interpolation added (ff4d653f5d182fc8825c2bf92bb9e4a65905c97f)
Co-authored-by: Nguyễn Nam <[email protected]>
- app.py +68 -0
- models/Interpolation/bicubic.py +7 -0
- models/Interpolation/bilinear.py +7 -0
- models/Interpolation/nearest_neighbor.py +7 -0
app.py
CHANGED
@@ -7,11 +7,20 @@ from io import BytesIO
|
|
7 |
from models.HAT.hat import *
|
8 |
from models.RCAN.rcan import *
|
9 |
from models.SRGAN.srgan import *
|
|
|
|
|
|
|
10 |
|
11 |
subprocess.call('pip install natsort', shell=True)
|
12 |
from models.SRFlow.srflow import *
|
13 |
|
14 |
# Initialize session state for enhanced images
|
|
|
|
|
|
|
|
|
|
|
|
|
15 |
if 'hat_enhanced_image' not in st.session_state:
|
16 |
st.session_state['hat_enhanced_image'] = None
|
17 |
if 'rcan_enhanced_image' not in st.session_state:
|
@@ -22,6 +31,12 @@ if 'srflow_enhanced_image' not in st.session_state:
|
|
22 |
st.session_state['srflow_enhanced_image'] = None
|
23 |
|
24 |
# Initialize session state for button clicks
|
|
|
|
|
|
|
|
|
|
|
|
|
25 |
if 'hat_clicked' not in st.session_state:
|
26 |
st.session_state['hat_clicked'] = False
|
27 |
if 'rcan_clicked' not in st.session_state:
|
@@ -52,10 +67,16 @@ def reset_states():
|
|
52 |
st.session_state['rcan_enhanced_image'] = None
|
53 |
st.session_state['srgan_enhanced_image'] = None
|
54 |
st.session_state['srflow_enhanced_image'] = None
|
|
|
|
|
|
|
55 |
st.session_state['hat_clicked'] = False
|
56 |
st.session_state['rcan_clicked'] = False
|
57 |
st.session_state['srgan_clicked'] = False
|
58 |
st.session_state['srflow_clicked'] = False
|
|
|
|
|
|
|
59 |
|
60 |
def get_image_download_link(img, filename):
|
61 |
"""Generates a link allowing the PIL image to be downloaded"""
|
@@ -72,6 +93,53 @@ def get_image_download_link(img, filename):
|
|
72 |
if 'image' in locals():
|
73 |
# st.image(image, caption='Uploaded Image', use_column_width=True)
|
74 |
st.write("")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
75 |
|
76 |
# ------------------------ HAT ------------------------ #
|
77 |
if st.button('Enhance with HAT'):
|
|
|
7 |
from models.HAT.hat import *
|
8 |
from models.RCAN.rcan import *
|
9 |
from models.SRGAN.srgan import *
|
10 |
+
from models.Interpolation.nearest_neighbor import NearestNeighbor_for_deployment
|
11 |
+
from models.Interpolation.bilinear import Bilinear_for_deployment
|
12 |
+
from models.Interpolation.bicubic import Bicubic_for_deployment
|
13 |
|
14 |
subprocess.call('pip install natsort', shell=True)
|
15 |
from models.SRFlow.srflow import *
|
16 |
|
17 |
# Initialize session state for enhanced images
|
18 |
+
if 'nearest_enhanced_image' not in st.session_state:
|
19 |
+
st.session_state['nearest_enhanced_image'] = None
|
20 |
+
if 'bilinear_enhanced_image' not in st.session_state:
|
21 |
+
st.session_state['bilinear_enhanced_image'] = None
|
22 |
+
if 'bicubic_enhanced_image' not in st.session_state:
|
23 |
+
st.session_state['bicubic_enhanced_image'] = None
|
24 |
if 'hat_enhanced_image' not in st.session_state:
|
25 |
st.session_state['hat_enhanced_image'] = None
|
26 |
if 'rcan_enhanced_image' not in st.session_state:
|
|
|
31 |
st.session_state['srflow_enhanced_image'] = None
|
32 |
|
33 |
# Initialize session state for button clicks
|
34 |
+
if 'nearest_clicked' not in st.session_state:
|
35 |
+
st.session_state['nearest_clicked'] = False
|
36 |
+
if 'bilinear_clicked' not in st.session_state:
|
37 |
+
st.session_state['bilinear_clicked'] = False
|
38 |
+
if 'bicubic_clicked' not in st.session_state:
|
39 |
+
st.session_state['bicubic_clicked'] = False
|
40 |
if 'hat_clicked' not in st.session_state:
|
41 |
st.session_state['hat_clicked'] = False
|
42 |
if 'rcan_clicked' not in st.session_state:
|
|
|
67 |
st.session_state['rcan_enhanced_image'] = None
|
68 |
st.session_state['srgan_enhanced_image'] = None
|
69 |
st.session_state['srflow_enhanced_image'] = None
|
70 |
+
st.session_state['bicubic_enhanced_image'] = None
|
71 |
+
st.session_state['bilinear_enhanced_image'] = None
|
72 |
+
st.session_state['nearest_enhanced_image'] = None
|
73 |
st.session_state['hat_clicked'] = False
|
74 |
st.session_state['rcan_clicked'] = False
|
75 |
st.session_state['srgan_clicked'] = False
|
76 |
st.session_state['srflow_clicked'] = False
|
77 |
+
st.session_state['bicubic_clicked'] = False
|
78 |
+
st.session_state['bilinear_clicked'] = False
|
79 |
+
st.session_state['nearest_clicked'] = False
|
80 |
|
81 |
def get_image_download_link(img, filename):
|
82 |
"""Generates a link allowing the PIL image to be downloaded"""
|
|
|
93 |
if 'image' in locals():
|
94 |
# st.image(image, caption='Uploaded Image', use_column_width=True)
|
95 |
st.write("")
|
96 |
+
# ------------------------ Nearest Neighbor ------------------------ #
|
97 |
+
if st.button('Enhance with Nearest Neighbor'):
|
98 |
+
with st.spinner('Processing using Nearest Neighbor...'):
|
99 |
+
enhanced_image = NearestNeighbor_for_deployment(image)
|
100 |
+
st.session_state['nearest_enhanced_image'] = enhanced_image
|
101 |
+
st.session_state['nearest_clicked'] = True
|
102 |
+
st.success('Done!')
|
103 |
+
if st.session_state['nearest_enhanced_image'] is not None:
|
104 |
+
col1, col2 = st.columns(2)
|
105 |
+
col1.header("Original")
|
106 |
+
col1.image(image, use_column_width=True)
|
107 |
+
col2.header("Enhanced")
|
108 |
+
col2.image(st.session_state['nearest_enhanced_image'], use_column_width=True)
|
109 |
+
with col2:
|
110 |
+
get_image_download_link(st.session_state['nearest_enhanced_image'], 'nearest_enhanced.jpg')
|
111 |
+
|
112 |
+
# ------------------------ Bilinear ------------------------ #
|
113 |
+
if st.button('Enhance with Bilinear'):
|
114 |
+
with st.spinner('Processing using Bilinear...'):
|
115 |
+
enhanced_image = Bilinear_for_deployment(image)
|
116 |
+
st.session_state['bilinear_enhanced_image'] = enhanced_image
|
117 |
+
st.session_state['bilinear_clicked'] = True
|
118 |
+
st.success('Done!')
|
119 |
+
if st.session_state['bilinear_enhanced_image'] is not None:
|
120 |
+
col1, col2 = st.columns(2)
|
121 |
+
col1.header("Original")
|
122 |
+
col1.image(image, use_column_width=True)
|
123 |
+
col2.header("Enhanced")
|
124 |
+
col2.image(st.session_state['bilinear_enhanced_image'], use_column_width=True)
|
125 |
+
with col2:
|
126 |
+
get_image_download_link(st.session_state['bilinear_enhanced_image'], 'bilinear_enhanced.jpg')
|
127 |
+
|
128 |
+
# ------------------------ Bicubic ------------------------ #
|
129 |
+
if st.button('Enhance with Bicubic'):
|
130 |
+
with st.spinner('Processing using Bicubic...'):
|
131 |
+
enhanced_image = Bicubic_for_deployment(image)
|
132 |
+
st.session_state['bicubic_enhanced_image'] = enhanced_image
|
133 |
+
st.session_state['bicubic_clicked'] = True
|
134 |
+
st.success('Done!')
|
135 |
+
if st.session_state['bicubic_enhanced_image'] is not None:
|
136 |
+
col1, col2 = st.columns(2)
|
137 |
+
col1.header("Original")
|
138 |
+
col1.image(image, use_column_width=True)
|
139 |
+
col2.header("Enhanced")
|
140 |
+
col2.image(st.session_state['bicubic_enhanced_image'], use_column_width=True)
|
141 |
+
with col2:
|
142 |
+
get_image_download_link(st.session_state['bicubic_enhanced_image'], 'bicubic_enhanced.jpg')
|
143 |
|
144 |
# ------------------------ HAT ------------------------ #
|
145 |
if st.button('Enhance with HAT'):
|
models/Interpolation/bicubic.py
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from PIL import Image
|
2 |
+
from torchvision import transforms
|
3 |
+
|
4 |
+
def Bicubic_for_deployment(lr_image):
|
5 |
+
w, h = lr_image.size
|
6 |
+
sr_image = transforms.functional.resize(lr_image, size=(h*4, w*4), interpolation=transforms.InterpolationMode.BICUBIC, antialias=True)
|
7 |
+
return sr_image
|
models/Interpolation/bilinear.py
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from PIL import Image
|
2 |
+
from torchvision import transforms
|
3 |
+
|
4 |
+
def Bilinear_for_deployment(lr_image):
|
5 |
+
w, h = lr_image.size
|
6 |
+
sr_image = transforms.functional.resize(lr_image, size=(h*4, w*4), interpolation=transforms.InterpolationMode.BILINEAR, antialias=True)
|
7 |
+
return sr_image
|
models/Interpolation/nearest_neighbor.py
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from PIL import Image
|
2 |
+
from torchvision import transforms
|
3 |
+
|
4 |
+
def NearestNeighbor_for_deployment(lr_image):
|
5 |
+
w, h = lr_image.size
|
6 |
+
sr_image = transforms.functional.resize(lr_image, size=(h*4, w*4),interpolation=transforms.InterpolationMode.NEAREST,antialias=False)
|
7 |
+
return sr_image
|