Spaces:
Running
on
Zero
Running
on
Zero
import spaces | |
import streamlit as st | |
import torch | |
from huggingface_hub import snapshot_download | |
from txt2panoimg import Text2360PanoramaImagePipeline | |
from img2panoimg import Image2360PanoramaImagePipeline | |
from PIL import Image | |
from streamlit_pannellum import streamlit_pannellum | |
# Custom CSS to make the UI more attractive | |
st.markdown(""" | |
<style> | |
.stApp { | |
max-width: 1200px; | |
margin: 0 auto; | |
} | |
.main { | |
background-color: #f0f2f6; | |
} | |
h1 { | |
color: #1E3A8A; | |
text-align: center; | |
padding: 20px 0; | |
font-size: 2.5rem; | |
} | |
.stTabs { | |
background-color: white; | |
padding: 20px; | |
border-radius: 10px; | |
box-shadow: 0 4px 6px rgba(0, 0, 0, 0.1); | |
} | |
.stButton>button { | |
background-color: #1E3A8A; | |
color: white; | |
font-weight: bold; | |
} | |
.viewer-column { | |
background-color: white; | |
padding: 20px; | |
border-radius: 10px; | |
box-shadow: 0 4px 6px rgba(0, 0, 0, 0.1); | |
} | |
</style> | |
""", unsafe_allow_html=True) | |
# Download the model | |
model_path = snapshot_download("archerfmy0831/sd-t2i-360panoimage") | |
# Initialize pipelines | |
txt2panoimg = Text2360PanoramaImagePipeline(model_path, torch_dtype=torch.float16) | |
img2panoimg = Image2360PanoramaImagePipeline(model_path, torch_dtype=torch.float16) | |
# Load the default mask image | |
default_mask = Image.open("i2p-mask.jpg").convert("RGB") | |
def text_to_pano(prompt, upscale): | |
input_data = {'prompt': prompt, 'upscale': upscale} | |
output = txt2panoimg(input_data) | |
return output | |
def image_to_pano(image, mask, prompt, upscale): | |
image = image.resize((512, 512)) | |
if mask is None: | |
mask = default_mask.resize((512, 512)) | |
else: | |
mask = mask.resize((512, 512)) | |
input_data = { | |
'prompt': prompt, | |
'image': image, | |
'mask': mask, | |
'upscale': upscale | |
} | |
output = img2panoimg(input_data) | |
return output | |
st.title("360° Panorama Image Generation") | |
tab1, tab2 = st.tabs(["Text to 360° Panorama", "Image to 360° Panorama"]) | |
# Function to display the panorama viewer | |
def display_panorama(image): | |
streamlit_pannellum( | |
config={ | |
"default": { | |
"firstScene": "generated", | |
}, | |
"scenes": { | |
"generated": { | |
"title": "Generated Panorama", | |
"type": "equirectangular", | |
"panorama": image, | |
"autoLoad": True, | |
} | |
} | |
} | |
) | |
with tab1: | |
col1, col2 = st.columns([1, 1]) | |
with col1: | |
st.subheader("Input") | |
t2p_input = st.text_area("Enter your prompt", height=100) | |
t2p_upscale = st.checkbox("Upscale (requires >16GB GPU)") | |
generate_button = st.button("Generate Panorama") | |
with col2: | |
st.subheader("Output") | |
output_placeholder = st.empty() | |
viewer_placeholder = st.empty() | |
if generate_button: | |
with st.spinner("Generating your 360° panorama..."): | |
output = text_to_pano(t2p_input, t2p_upscale) | |
output_placeholder.image(output, caption="Generated 360° Panorama", use_column_width=True) | |
with viewer_placeholder.container(): | |
display_panorama(output) | |
with tab2: | |
col1, col2 = st.columns([1, 1]) | |
with col1: | |
st.subheader("Input") | |
i2p_image = st.file_uploader("Upload Input Image", type=["png", "jpg", "jpeg"]) | |
i2p_mask = st.file_uploader("Upload Mask Image (Optional)", type=["png", "jpg", "jpeg"]) | |
i2p_prompt = st.text_area("Enter your prompt", height=100) | |
i2p_upscale = st.checkbox("Upscale (requires >16GB GPU)", key="i2p_upscale") | |
generate_button = st.button("Generate Panorama", key="i2p_generate") | |
with col2: | |
st.subheader("Output") | |
output_placeholder = st.empty() | |
viewer_placeholder = st.empty() | |
if generate_button and i2p_image is not None: | |
with st.spinner("Generating your 360° panorama..."): | |
image = Image.open(i2p_image) | |
mask = Image.open(i2p_mask) if i2p_mask is not None else None | |
output = image_to_pano(image, mask, i2p_prompt, i2p_upscale) | |
output_placeholder.image(output, caption="Generated 360° Panorama", use_column_width=True) | |
with viewer_placeholder.container(): | |
display_panorama(output) | |
elif generate_button and i2p_image is None: | |
st.error("Please upload an input image.") |