import gradio as gr import numpy as np import random from huggingface_hub import hf_hub_download import spaces # [uncomment to use ZeroGPU] from diffusers import FluxPipeline import torch device = "cuda" if torch.cuda.is_available() else "cpu" model_repo_id = "black-forest-labs/FLUX.1-schnell" # Replace to the model you would like to use torch_dtype = torch.bfloat16 # pipe = FluxPipeline.from_pretrained(model_repo_id, torch_dtype=torch_dtype) # pipe = pipe.to(device) # load pruned model pruned_pipe = FluxPipeline.from_pretrained(model_repo_id, torch_dtype=torch_dtype) pruned_pipe.transformer = torch.load( hf_hub_download("zhangyang-0123/EcoDiffPrunedModels", "model/flux/flux.pkl"), map_location="cpu", ) pruned_pipe = pruned_pipe.to(device) MAX_SEED = np.iinfo(np.int32).max MAX_IMAGE_SIZE = 1024 @spaces.GPU def generate_images(prompt, seed, steps): # Run the model and return images directly # g_cpu = torch.Generator("cuda").manual_seed(seed) # original_image = pipe(prompt=prompt, generator=g_cpu, num_inference_steps=steps).images[0] g_cpu = torch.Generator("cuda").manual_seed(seed) ecodiff_image = pruned_pipe(prompt=prompt, generator=g_cpu, num_inference_steps=steps).images[0] return ecodiff_image examples = [ "A clock tower floating in a sea of clouds", "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k", "An astronaut riding a green horse", "A delicious ceviche cheesecake slice", "A sprawling cyberpunk metropolis at night, with towering skyscrapers emitting neon lights of every color, holographic billboards advertising alien languages", ] css = """ #col-container { margin: 0 auto; max-width: 640px; } """ header = """ # 🌱 EcoDiff Pruned FLUX-Schnell (20% Pruning Ratio) We are not able to host two FLUX models in the same space, so we only show the pruned model here. **👉 [Click here to compare with the Original FLUX Model](https://huggingface.co/spaces/black-forest-labs/FLUX.1-schnell)**. """ header_2 = """
""" with gr.Blocks() as demo: gr.Markdown(header) gr.HTML(header_2) with gr.Row(): prompt = gr.Textbox( label="Prompt", value="A clock tower floating in a sea of clouds", scale=3, ) seed = gr.Number(label="Seed", value=44, precision=0, scale=1) steps = gr.Slider( label="Number of Steps", minimum=1, maximum=100, value=5, step=1, scale=1, ) generate_btn = gr.Button("Generate Images") gr.Examples( examples=[ "A clock tower floating in a sea of clouds", "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k", "An astronaut riding a green horse", "A delicious ceviche cheesecake slice", "A sprawling cyberpunk metropolis at night, with towering skyscrapers emitting neon lights of every color, holographic billboards advertising alien languages", ], inputs=[prompt], ) with gr.Row(): # original_output = gr.Image(label="Original Output") ecodiff_output = gr.Image(label="EcoDiff Output") gr.on( triggers=[generate_btn.click, prompt.submit], fn=generate_images, inputs=[ prompt, seed, steps, ], outputs=[ecodiff_output], ) if __name__ == "__main__": demo.launch()