Spaces:
Running
on
Zero
Running
on
Zero
import gradio as gr | |
import numpy as np | |
import random | |
import spaces | |
import torch | |
import time | |
import os | |
from diffusers import DiffusionPipeline | |
from custom_pipeline import FLUXPipelineWithIntermediateOutputs | |
from transformers import pipeline | |
# Translation model loading with device specification | |
translator = pipeline("translation", model="Helsinki-NLP/opus-mt-ko-en", device="cpu") | |
# Constants | |
MAX_SEED = np.iinfo(np.int32).max | |
MAX_IMAGE_SIZE = 2048 | |
DEFAULT_WIDTH = 1024 | |
DEFAULT_HEIGHT = 1024 | |
DEFAULT_INFERENCE_STEPS = 1 | |
GPU_DURATION = 15 # Reduced from 25 to stay within quota | |
# Device and model setup with memory optimization | |
def setup_model(): | |
dtype = torch.float16 | |
pipe = FLUXPipelineWithIntermediateOutputs.from_pretrained( | |
"black-forest-labs/FLUX.1-schnell", | |
torch_dtype=dtype, | |
device_map="auto" # Enable model parallelism | |
) | |
return pipe | |
pipe = setup_model() | |
# Menu labels dictionary | |
english_labels = { | |
"Generated Image": "Generated Image", | |
"Prompt": "Prompt", | |
"Enhance Image": "Enhance Image", | |
"Advanced Options": "Advanced Options", | |
"Seed": "Seed", | |
"Randomize Seed": "Randomize Seed", | |
"Width": "Width", | |
"Height": "Height", | |
"Inference Steps": "Inference Steps", | |
"Inspiration Gallery": "Inspiration Gallery" | |
} | |
def translate_if_korean(text): | |
"""Safely translate Korean text to English.""" | |
try: | |
if any('\u3131' <= char <= '\u3163' or '\uac00' <= char <= '\ud7a3' for char in text): | |
return translator(text)[0]['translation_text'] | |
return text | |
except Exception as e: | |
print(f"Translation error: {e}") | |
return text | |
# Modified inference function with error handling and memory management | |
def generate_image(prompt, seed=None, width=DEFAULT_WIDTH, height=DEFAULT_HEIGHT, | |
randomize_seed=True, num_inference_steps=DEFAULT_INFERENCE_STEPS): | |
try: | |
# Input validation | |
if not isinstance(seed, (int, type(None))): | |
seed = None | |
randomize_seed = True | |
prompt = translate_if_korean(prompt) | |
if seed is None or randomize_seed: | |
seed = random.randint(0, MAX_SEED) | |
# Ensure valid dimensions | |
width = min(max(256, width), MAX_IMAGE_SIZE) | |
height = min(max(256, height), MAX_IMAGE_SIZE) | |
generator = torch.Generator().manual_seed(seed) | |
start_time = time.time() | |
with torch.cuda.amp.autocast(): # Enable automatic mixed precision | |
for img in pipe.generate_images( | |
prompt=prompt, | |
guidance_scale=0, | |
num_inference_steps=num_inference_steps, | |
width=width, | |
height=height, | |
generator=generator | |
): | |
latency = f"Processing Time: {(time.time()-start_time):.2f} seconds" | |
# Clear CUDA cache after generation | |
if torch.cuda.is_available(): | |
torch.cuda.empty_cache() | |
yield img, seed, latency | |
except Exception as e: | |
print(f"Error in generate_image: {e}") | |
# Return a blank image or error message | |
yield None, seed, f"Error: {str(e)}" | |
# Example generator with error handling | |
def generate_example_image(prompt): | |
try: | |
return next(generate_image(prompt, randomize_seed=True)) | |
except Exception as e: | |
print(f"Error in example generation: {e}") | |
return None, None, f"Error: {str(e)}" | |
# Example prompts | |
examples = [ | |
"๋น๋ ์๋์ฒผ์ ์ ๋๋ฉ์ด์ ์ผ๋ฌ์คํธ๋ ์ด์ ", | |
"A steampunk owl wearing Victorian-era clothing and reading a mechanical book", | |
"A floating island made of books with waterfalls of knowledge cascading down", | |
"A bioluminescent forest where mushrooms glow like neon signs in a cyberpunk city", | |
"An ancient temple being reclaimed by nature, with robots performing archaeology", | |
"A cosmic coffee shop where baristas are constellations serving drinks made of stardust" | |
] | |
css = """ | |
footer { | |
visibility: hidden; | |
} | |
""" | |
# --- Gradio UI with improved error handling --- | |
with gr.Blocks(theme="Yntec/HaleyCH_Theme_Orange", css=css) as demo: | |
with gr.Column(elem_id="app-container"): | |
with gr.Row(): | |
with gr.Column(scale=3): | |
result = gr.Image(label=english_labels["Generated Image"], | |
show_label=False, | |
interactive=False) | |
with gr.Column(scale=1): | |
prompt = gr.Text( | |
label=english_labels["Prompt"], | |
placeholder="Describe the image you want to generate...", | |
lines=3, | |
show_label=False, | |
container=False, | |
) | |
enhanceBtn = gr.Button(f"๐ {english_labels['Enhance Image']}") | |
with gr.Column(english_labels["Advanced Options"]): | |
with gr.Row(): | |
latency = gr.Text(show_label=False) | |
with gr.Row(): | |
# Modified Number component with proper validation | |
seed = gr.Number( | |
label=english_labels["Seed"], | |
value=42, | |
precision=0, | |
minimum=0, | |
maximum=MAX_SEED | |
) | |
randomize_seed = gr.Checkbox( | |
label=english_labels["Randomize Seed"], | |
value=True | |
) | |
with gr.Row(): | |
width = gr.Slider( | |
label=english_labels["Width"], | |
minimum=256, | |
maximum=MAX_IMAGE_SIZE, | |
step=32, | |
value=DEFAULT_WIDTH | |
) | |
height = gr.Slider( | |
label=english_labels["Height"], | |
minimum=256, | |
maximum=MAX_IMAGE_SIZE, | |
step=32, | |
value=DEFAULT_HEIGHT | |
) | |
num_inference_steps = gr.Slider( | |
label=english_labels["Inference Steps"], | |
minimum=1, | |
maximum=4, | |
step=1, | |
value=DEFAULT_INFERENCE_STEPS | |
) | |
with gr.Row(): | |
gr.Markdown(f"### ๐ {english_labels['Inspiration Gallery']}") | |
with gr.Row(): | |
gr.Examples( | |
examples=examples, | |
fn=generate_example_image, | |
inputs=[prompt], | |
outputs=[result, seed], | |
cache_examples=False | |
) | |
# Event handling with improved error handling | |
enhanceBtn.click( | |
fn=generate_image, | |
inputs=[prompt, seed, width, height], | |
outputs=[result, seed, latency], | |
show_progress="hidden", | |
show_api=False, | |
queue=False | |
) | |
# Modified event handler with proper input validation | |
def validated_generate(*args): | |
try: | |
return next(generate_image(*args)) | |
except Exception as e: | |
print(f"Error in validated_generate: {e}") | |
return None, args[1], f"Error: {str(e)}" | |
gr.on( | |
triggers=[prompt.input, width.input, height.input, num_inference_steps.input], | |
fn=validated_generate, | |
inputs=[prompt, seed, width, height, randomize_seed, num_inference_steps], | |
outputs=[result, seed, latency], | |
show_progress="hidden", | |
show_api=False, | |
trigger_mode="always_last", | |
queue=False | |
) | |
if __name__ == "__main__": | |
demo.launch() |