import time import streamlit as st import numpy as np from PIL import Image from io import BytesIO from models.HAT.hat import * from models.RCAN.rcan import * # Initialize session state for enhanced images if 'hat_enhanced_image' not in st.session_state: st.session_state['hat_enhanced_image'] = None if 'rcan_enhanced_image' not in st.session_state: st.session_state['rcan_enhanced_image'] = None if 'hat_clicked' not in st.session_state: st.session_state['hat_clicked'] = False if 'rcan_clicked' not in st.session_state: st.session_state['rcan_clicked'] = False st.markdown("

Image Super Resolution

", unsafe_allow_html=True) # Sidebar for navigation st.sidebar.title("Options") app_mode = st.sidebar.selectbox("Choose the input source", ["Upload image", "Take a photo"]) # Depending on the choice, show the uploader widget or webcam capture if app_mode == "Upload image": uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "png"], on_change=lambda: reset_states()) if uploaded_file is not None: image = Image.open(uploaded_file).convert("RGB") elif app_mode == "Take a photo": # Using JS code to access user's webcam camera_input = st.camera_input("Take a picture", on_change=lambda: reset_states()) if camera_input is not None: # Convert the camera image to an RGB image image = Image.open(camera_input).convert("RGB") def reset_states(): st.session_state['hat_enhanced_image'] = None st.session_state['rcan_enhanced_image'] = None st.session_state['hat_clicked'] = False st.session_state['rcan_clicked'] = False def get_image_download_link(img, filename): """Generates a link allowing the PIL image to be downloaded""" # Convert the PIL image to Bytes buffered = BytesIO() img.save(buffered, format="PNG") return st.download_button( label="Download Image", data=buffered.getvalue(), file_name=filename, mime="image/png" ) if 'image' in locals(): # st.image(image, caption='Uploaded Image', use_column_width=True) st.write("") # ------------------------ HAT ------------------------ # if st.button('Enhance with HAT'): with st.spinner('Processing using HAT...'): with st.spinner('Wait for it... the model is processing the image'): enhanced_image = HAT_for_deployment(image) st.session_state['hat_enhanced_image'] = enhanced_image st.session_state['hat_clicked'] = True st.success('Done!') if st.session_state['hat_enhanced_image'] is not None: col1, col2 = st.columns(2) col1.header("Original") col1.image(image, use_column_width=True) col2.header("Enhanced") col2.image(st.session_state['hat_enhanced_image'], use_column_width=True) with col2: get_image_download_link(st.session_state['hat_enhanced_image'], 'hat_enhanced.jpg') # ------------------------ RCAN ------------------------ # if st.button('Enhance with RCAN'): with st.spinner('Processing using RCAN...'): with st.spinner('Wait for it... the model is processing the image'): rcan_model = RCAN() device = torch.device('cpu') if not torch.cuda.is_available() else torch.device('cuda') rcan_model.load_state_dict(torch.load('models/RCAN/rcan_checkpoint.pth', map_location=device)) enhanced_image = rcan_model.inference(image) st.session_state['rcan_enhanced_image'] = enhanced_image st.session_state['rcan_clicked'] = True st.success('Done!') if st.session_state['rcan_enhanced_image'] is not None: col1, col2 = st.columns(2) col1.header("Original") col1.image(image, use_column_width=True) col2.header("Enhanced") col2.image(st.session_state['rcan_enhanced_image'], use_column_width=True) with col2: get_image_download_link(st.session_state['rcan_enhanced_image'], 'rcan_enhanced.jpg')