import spaces
import gradio as gr
import torch
import random
from diffusers import DiffusionPipeline
import os
# Initialize models
device = "cuda" if torch.cuda.is_available() else "cpu"
dtype = torch.bfloat16
huggingface_token = os.getenv("HUGGINGFACE_TOKEN")
# Initialize the base model and move it to GPU
base_model = "black-forest-labs/FLUX.1-dev"
pipe = DiffusionPipeline.from_pretrained(base_model, torch_dtype=torch.bfloat16, token=huggingface_token).to("cuda")
# Load LoRA weights
pipe.load_lora_weights("gokaygokay/Flux-Detailer-LoRA")
pipe.fuse_lora()
MAX_SEED = 2**32-1
@spaces.GPU(duration=75)
def generate_image(prompt, steps=28, seed=None, cfg_scale=3.5, width=1024, height=1024, lora_scale=1.0):
if seed is None:
seed = random.randint(0, MAX_SEED)
generator = torch.Generator(device="cuda").manual_seed(seed)
image = pipe(
prompt=prompt,
num_inference_steps=int(steps),
guidance_scale=cfg_scale,
width=int(width),
height=int(height),
generator=generator,
joint_attention_kwargs={"scale": lora_scale},
).images[0]
return image
def run_lora(prompt, cfg_scale=3.5, steps=28, randomize_seed=True, seed=None, width=1024, height=1024, lora_scale=1.0):
# Handle the case when only prompt is provided (for Examples)
if isinstance(prompt, str) and all(param is None for param in [cfg_scale, steps, randomize_seed, seed, width, height, lora_scale]):
cfg_scale = 3.5
steps = 28
randomize_seed = True
seed = None
width = 1024
height = 1024
lora_scale = 1.0
if randomize_seed or seed is None:
seed = random.randint(0, MAX_SEED)
image = generate_image(prompt, steps, seed, cfg_scale, width, height, lora_scale)
return image, seed
custom_css = """
.input-group, .output-group {
border: 1px solid #e0e0e0;
border-radius: 10px;
padding: 20px;
margin-bottom: 20px;
background-color: #f9f9f9;
}
.submit-btn {
background-color: #2980b9 !important;
color: white !important;
}
.submit-btn:hover {
background-color: #3498db !important;
}
"""
title = """
FLUX Creativity LoRA
"""
examples = [
["anime, cartoon, Hyper-detailed, endearing anime girl, bathed in a vibrant, colorful psychedelic glow, wearing dazzling, holographic Liquid Metal outfit, in a cozy tatami studio", 0.5],
["extraterrestrial visage, close-up, highly intricate, ultra-detailed, full high definition", 0.5],
["a full body photo shot of a beautiful and breathtaking image of a ((Man) ) wearing a fully clothed casual witchy witch clothes with intricate details in the style of a reapers cloak, he is holding a long curved double edged ((scythe) ). This full body image is a one of a kind unique highly detailed with 8k sharp focus quality masterpiece, hyper detailed, extremely detailed", 0.5],
["schizophrenia attacks,go haywire, go crazy, hyper detailed, extremely detailed", 0.5],
]
with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue", secondary_hue="gray"), css=custom_css) as app:
gr.HTML(title)
with gr.Row():
with gr.Column(scale=1):
prompt = gr.Textbox(label="Prompt", lines=3, placeholder="Type your prompt here")
with gr.Accordion("Advanced Settings", open=False):
cfg_scale = gr.Slider(label="CFG Scale", minimum=1, maximum=20, step=0.5, value=3.5)
steps = gr.Slider(label="Steps", minimum=1, maximum=50, step=1, value=28)
width = gr.Slider(label="Width", minimum=256, maximum=1536, step=64, value=1024)
height = gr.Slider(label="Height", minimum=256, maximum=1536, step=64, value=1024)
randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=0)
lora_scale = gr.Slider(label="LoRA Scale", minimum=0.0, maximum=1.0, step=0.01, value=1.0)
generate_button = gr.Button("Generate", variant="primary", elem_classes="submit-btn")
with gr.Column(scale=1):
result = gr.Image(label="Generated Image")
inputs = [prompt, cfg_scale, steps, randomize_seed, seed, width, height, lora_scale]
outputs = [result, seed]
generate_button.click(fn=run_lora, inputs=inputs, outputs=outputs)
prompt.submit(fn=run_lora, inputs=inputs, outputs=outputs)
gr.Examples(
examples=examples,
inputs=[prompt, lora_scale],
outputs=[result, seed],
fn=run_lora,
cache_examples=True
)
app.launch(debug=True)