import spaces import gradio as gr import torch import random from diffusers import DiffusionPipeline import os # Initialize models device = "cuda" if torch.cuda.is_available() else "cpu" dtype = torch.bfloat16 huggingface_token = os.getenv("HUGGINGFACE_TOKEN") # Initialize the base model and move it to GPU base_model = "black-forest-labs/FLUX.1-dev" pipe = DiffusionPipeline.from_pretrained(base_model, torch_dtype=torch.bfloat16, token=huggingface_token).to("cuda") # Load LoRA weights pipe.load_lora_weights("gokaygokay/Flux-Detailer-LoRA") pipe.fuse_lora() MAX_SEED = 2**32-1 @spaces.GPU(duration=75) def generate_image(prompt, steps=28, seed=None, cfg_scale=3.5, width=1024, height=1024, lora_scale=1.0): if seed is None: seed = random.randint(0, MAX_SEED) generator = torch.Generator(device="cuda").manual_seed(seed) image = pipe( prompt=prompt, num_inference_steps=int(steps), guidance_scale=cfg_scale, width=int(width), height=int(height), generator=generator, joint_attention_kwargs={"scale": lora_scale}, ).images[0] return image def run_lora(prompt, cfg_scale=3.5, steps=28, randomize_seed=True, seed=None, width=1024, height=1024, lora_scale=1.0): # Handle the case when only prompt is provided (for Examples) if isinstance(prompt, str) and all(param is None for param in [cfg_scale, steps, randomize_seed, seed, width, height, lora_scale]): cfg_scale = 3.5 steps = 28 randomize_seed = True seed = None width = 1024 height = 1024 lora_scale = 1.0 if randomize_seed or seed is None: seed = random.randint(0, MAX_SEED) image = generate_image(prompt, steps, seed, cfg_scale, width, height, lora_scale) return image, seed custom_css = """ .input-group, .output-group { border: 1px solid #e0e0e0; border-radius: 10px; padding: 20px; margin-bottom: 20px; background-color: #f9f9f9; } .submit-btn { background-color: #2980b9 !important; color: white !important; } .submit-btn:hover { background-color: #3498db !important; } """ title = """

FLUX Creativity LoRA

""" examples = [ ["anime, cartoon, Hyper-detailed, endearing anime girl, bathed in a vibrant, colorful psychedelic glow, wearing dazzling, holographic Liquid Metal outfit, in a cozy tatami studio", 0.5], ["extraterrestrial visage, close-up, highly intricate, ultra-detailed, full high definition", 0.5], ["a full body photo shot of a beautiful and breathtaking image of a ((Man) ) wearing a fully clothed casual witchy witch clothes with intricate details in the style of a reapers cloak, he is holding a long curved double edged ((scythe) ). This full body image is a one of a kind unique highly detailed with 8k sharp focus quality masterpiece, hyper detailed, extremely detailed", 0.5], ["schizophrenia attacks,go haywire, go crazy, hyper detailed, extremely detailed", 0.5], ] with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue", secondary_hue="gray"), css=custom_css) as app: gr.HTML(title) with gr.Row(): with gr.Column(scale=1): prompt = gr.Textbox(label="Prompt", lines=3, placeholder="Type your prompt here") with gr.Accordion("Advanced Settings", open=False): cfg_scale = gr.Slider(label="CFG Scale", minimum=1, maximum=20, step=0.5, value=3.5) steps = gr.Slider(label="Steps", minimum=1, maximum=50, step=1, value=28) width = gr.Slider(label="Width", minimum=256, maximum=1536, step=64, value=1024) height = gr.Slider(label="Height", minimum=256, maximum=1536, step=64, value=1024) randomize_seed = gr.Checkbox(label="Randomize seed", value=True) seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=0) lora_scale = gr.Slider(label="LoRA Scale", minimum=0.0, maximum=1.0, step=0.01, value=1.0) generate_button = gr.Button("Generate", variant="primary", elem_classes="submit-btn") with gr.Column(scale=1): result = gr.Image(label="Generated Image") inputs = [prompt, cfg_scale, steps, randomize_seed, seed, width, height, lora_scale] outputs = [result, seed] generate_button.click(fn=run_lora, inputs=inputs, outputs=outputs) prompt.submit(fn=run_lora, inputs=inputs, outputs=outputs) gr.Examples( examples=examples, inputs=[prompt, lora_scale], outputs=[result, seed], fn=run_lora, cache_examples=True ) app.launch(debug=True)