import time
import streamlit as st
import numpy as np
from PIL import Image
from io import BytesIO
from models.HAT.hat import *
from models.RCAN.rcan import *
# Initialize session state for enhanced images
if 'hat_enhanced_image' not in st.session_state:
st.session_state['hat_enhanced_image'] = None
if 'rcan_enhanced_image' not in st.session_state:
st.session_state['rcan_enhanced_image'] = None
if 'hat_clicked' not in st.session_state:
st.session_state['hat_clicked'] = False
if 'rcan_clicked' not in st.session_state:
st.session_state['rcan_clicked'] = False
st.markdown("
Image Super Resolution
", unsafe_allow_html=True)
# Sidebar for navigation
st.sidebar.title("Options")
app_mode = st.sidebar.selectbox("Choose the input source",
["Upload image", "Take a photo"])
# Depending on the choice, show the uploader widget or webcam capture
if app_mode == "Upload image":
uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "png"], on_change=lambda: reset_states())
if uploaded_file is not None:
image = Image.open(uploaded_file).convert("RGB")
elif app_mode == "Take a photo":
# Using JS code to access user's webcam
camera_input = st.camera_input("Take a picture", on_change=lambda: reset_states())
if camera_input is not None:
# Convert the camera image to an RGB image
image = Image.open(camera_input).convert("RGB")
def reset_states():
st.session_state['hat_enhanced_image'] = None
st.session_state['rcan_enhanced_image'] = None
st.session_state['hat_clicked'] = False
st.session_state['rcan_clicked'] = False
def get_image_download_link(img, filename):
"""Generates a link allowing the PIL image to be downloaded"""
# Convert the PIL image to Bytes
buffered = BytesIO()
img.save(buffered, format="PNG")
return st.download_button(
label="Download Image",
data=buffered.getvalue(),
file_name=filename,
mime="image/png"
)
if 'image' in locals():
# st.image(image, caption='Uploaded Image', use_column_width=True)
st.write("")
# ------------------------ HAT ------------------------ #
if st.button('Enhance with HAT'):
with st.spinner('Processing using HAT...'):
with st.spinner('Wait for it... the model is processing the image'):
enhanced_image = HAT_for_deployment(image)
st.session_state['hat_enhanced_image'] = enhanced_image
st.session_state['hat_clicked'] = True
st.success('Done!')
if st.session_state['hat_enhanced_image'] is not None:
col1, col2 = st.columns(2)
col1.header("Original")
col1.image(image, use_column_width=True)
col2.header("Enhanced")
col2.image(st.session_state['hat_enhanced_image'], use_column_width=True)
with col2:
get_image_download_link(st.session_state['hat_enhanced_image'], 'hat_enhanced.jpg')
# ------------------------ RCAN ------------------------ #
if st.button('Enhance with RCAN'):
with st.spinner('Processing using RCAN...'):
with st.spinner('Wait for it... the model is processing the image'):
rcan_model = RCAN()
device = torch.device('cpu') if not torch.cuda.is_available() else torch.device('cuda')
rcan_model.load_state_dict(torch.load('models/RCAN/rcan_checkpoint.pth', map_location=device))
enhanced_image = rcan_model.inference(image)
st.session_state['rcan_enhanced_image'] = enhanced_image
st.session_state['rcan_clicked'] = True
st.success('Done!')
if st.session_state['rcan_enhanced_image'] is not None:
col1, col2 = st.columns(2)
col1.header("Original")
col1.image(image, use_column_width=True)
col2.header("Enhanced")
col2.image(st.session_state['rcan_enhanced_image'], use_column_width=True)
with col2:
get_image_download_link(st.session_state['rcan_enhanced_image'], 'rcan_enhanced.jpg')