360PanoImage / app.py
gokaygokay's picture
pano
ece05f2
raw
history blame
4.59 kB
import spaces
import streamlit as st
import torch
from huggingface_hub import snapshot_download
from txt2panoimg import Text2360PanoramaImagePipeline
from img2panoimg import Image2360PanoramaImagePipeline
from PIL import Image
from streamlit_pannellum import streamlit_pannellum
# Custom CSS to make the UI more attractive
st.markdown("""
<style>
.stApp {
max-width: 1200px;
margin: 0 auto;
}
.main {
background-color: #f0f2f6;
}
h1 {
color: #1E3A8A;
text-align: center;
padding: 20px 0;
font-size: 2.5rem;
}
.stTabs {
background-color: white;
padding: 20px;
border-radius: 10px;
box-shadow: 0 4px 6px rgba(0, 0, 0, 0.1);
}
.stButton>button {
background-color: #1E3A8A;
color: white;
font-weight: bold;
}
.viewer-column {
background-color: white;
padding: 20px;
border-radius: 10px;
box-shadow: 0 4px 6px rgba(0, 0, 0, 0.1);
}
</style>
""", unsafe_allow_html=True)
# Download the model
model_path = snapshot_download("archerfmy0831/sd-t2i-360panoimage")
# Initialize pipelines
txt2panoimg = Text2360PanoramaImagePipeline(model_path, torch_dtype=torch.float16)
img2panoimg = Image2360PanoramaImagePipeline(model_path, torch_dtype=torch.float16)
# Load the default mask image
default_mask = Image.open("i2p-mask.jpg").convert("RGB")
@spaces.GPU(duration=200)
def text_to_pano(prompt, upscale):
input_data = {'prompt': prompt, 'upscale': upscale}
output = txt2panoimg(input_data)
return output
@spaces.GPU(duration=200)
def image_to_pano(image, mask, prompt, upscale):
image = image.resize((512, 512))
if mask is None:
mask = default_mask.resize((512, 512))
else:
mask = mask.resize((512, 512))
input_data = {
'prompt': prompt,
'image': image,
'mask': mask,
'upscale': upscale
}
output = img2panoimg(input_data)
return output
st.title("360° Panorama Image Generation")
tab1, tab2 = st.tabs(["Text to 360° Panorama", "Image to 360° Panorama"])
# Function to display the panorama viewer
def display_panorama(image):
streamlit_pannellum(
config={
"default": {
"firstScene": "generated",
},
"scenes": {
"generated": {
"title": "Generated Panorama",
"type": "equirectangular",
"panorama": image,
"autoLoad": True,
}
}
}
)
with tab1:
col1, col2 = st.columns([1, 1])
with col1:
st.subheader("Input")
t2p_input = st.text_area("Enter your prompt", height=100)
t2p_upscale = st.checkbox("Upscale (requires >16GB GPU)")
generate_button = st.button("Generate Panorama")
with col2:
st.subheader("Output")
output_placeholder = st.empty()
viewer_placeholder = st.empty()
if generate_button:
with st.spinner("Generating your 360° panorama..."):
output = text_to_pano(t2p_input, t2p_upscale)
output_placeholder.image(output, caption="Generated 360° Panorama", use_column_width=True)
with viewer_placeholder.container():
display_panorama(output)
with tab2:
col1, col2 = st.columns([1, 1])
with col1:
st.subheader("Input")
i2p_image = st.file_uploader("Upload Input Image", type=["png", "jpg", "jpeg"])
i2p_mask = st.file_uploader("Upload Mask Image (Optional)", type=["png", "jpg", "jpeg"])
i2p_prompt = st.text_area("Enter your prompt", height=100)
i2p_upscale = st.checkbox("Upscale (requires >16GB GPU)", key="i2p_upscale")
generate_button = st.button("Generate Panorama", key="i2p_generate")
with col2:
st.subheader("Output")
output_placeholder = st.empty()
viewer_placeholder = st.empty()
if generate_button and i2p_image is not None:
with st.spinner("Generating your 360° panorama..."):
image = Image.open(i2p_image)
mask = Image.open(i2p_mask) if i2p_mask is not None else None
output = image_to_pano(image, mask, i2p_prompt, i2p_upscale)
output_placeholder.image(output, caption="Generated 360° Panorama", use_column_width=True)
with viewer_placeholder.container():
display_panorama(output)
elif generate_button and i2p_image is None:
st.error("Please upload an input image.")