Spaces:
Runtime error
Runtime error
import gradio as gr | |
import torch | |
from diffusers import StableDiffusionXLPipeline, EulerDiscreteScheduler | |
from huggingface_hub import hf_hub_download | |
import spaces | |
# Constants | |
base = "stabilityai/stable-diffusion-xl-base-1.0" | |
repo = "ByteDance/SDXL-Lightning" | |
ckpt = "sdxl_lightning_4step_unet.pth" | |
# Function | |
def generate_image(prompt): | |
# Ensure model and scheduler are initialized in GPU-enabled function | |
pipe = StableDiffusionXLPipeline.from_pretrained(base, torch_dtype=torch.float16, variant="fp16").to("cuda") | |
pipe.unet.load_state_dict(torch.load(hf_hub_download(repo, ckpt), map_location="cuda")) | |
pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing") | |
image = pipe(prompt, num_inference_steps=4, guidance_scale=0).images[0] | |
return image | |
# Gradio Interface | |
description = """ | |
This demo utilizes the SDXL-Lightning model by ByteDance, which is a fast text-to-image generative model capable of producing high-quality images in 4 steps. | |
As a community effort, this demo was put together by AngryPenguin. Link to model: https://huggingface.co/ByteDance/SDXL-Lightning | |
""" | |
demo = gr.Interface( | |
fn=generate_image, | |
inputs="text", | |
outputs="image", | |
title="Text-to-Image with SDXL Lightning ⚡", | |
description=description | |
) | |
demo.launch() |