import spaces
import gradio as gr
import torch
import random
from diffusers import DiffusionPipeline
import os

# Initialize models
device = "cuda" if torch.cuda.is_available() else "cpu"
dtype = torch.bfloat16

huggingface_token = os.getenv("HUGGINGFACE_TOKEN")

# Initialize the base model and move it to GPU
base_model = "black-forest-labs/FLUX.1-dev"
pipe = DiffusionPipeline.from_pretrained(base_model, torch_dtype=torch.bfloat16, token=huggingface_token).to("cuda")

# Load LoRA weights
pipe.load_lora_weights("gokaygokay/Flux-Detailer-LoRA")
pipe.fuse_lora()

MAX_SEED = 2**32-1

@spaces.GPU(duration=75)
def generate_image(prompt, steps, seed, cfg_scale, width, height, lora_scale):
    generator = torch.Generator(device="cuda").manual_seed(seed)
    
    image = pipe(
        prompt=prompt,
        num_inference_steps=steps,
        guidance_scale=cfg_scale,
        width=width,
        height=height,
        generator=generator,
        joint_attention_kwargs={"scale": lora_scale},
    ).images[0]
    return image

def run_lora(prompt, cfg_scale, steps, randomize_seed, seed, width, height, lora_scale):
    if randomize_seed:
        seed = random.randint(0, MAX_SEED)
    
    image = generate_image(prompt, steps, seed, cfg_scale, width, height, lora_scale)
    return image, seed

custom_css = """
.input-group, .output-group {
    border: 1px solid #e0e0e0;
    border-radius: 10px;
    padding: 20px;
    margin-bottom: 20px;
    background-color: #f9f9f9;
}
.submit-btn {
    background-color: #2980b9 !important;
    color: white !important;
}
.submit-btn:hover {
    background-color: #3498db !important;
}
"""

title = """<h1 align="center">FLUX Creativity LoRA</h1>
"""

with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue", secondary_hue="gray"), css=custom_css) as app:
    gr.HTML(title)
    
    with gr.Row():
        prompt = gr.Textbox(label="Prompt", lines=3, placeholder="Type your prompt here")
    
    with gr.Row():
        generate_button = gr.Button("Generate", variant="primary")
    
    with gr.Row():
        result = gr.Image(label="Generated Image")
    
    with gr.Accordion("Advanced Settings", open=False):
        with gr.Row():
            cfg_scale = gr.Slider(label="CFG Scale", minimum=1, maximum=20, step=0.5, value=3.5)
            steps = gr.Slider(label="Steps", minimum=1, maximum=50, step=1, value=28)
        
        with gr.Row():
            width = gr.Slider(label="Width", minimum=256, maximum=1536, step=64, value=1024)
            height = gr.Slider(label="Height", minimum=256, maximum=1536, step=64, value=1024)
        
        with gr.Row():
            randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
            seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=0)
            lora_scale = gr.Slider(label="LoRA Scale", minimum=0, maximum=1, step=0.01, value=0.95)

    inputs = [prompt, cfg_scale, steps, randomize_seed, seed, width, height, lora_scale]
    outputs = [result, seed]

    generate_button.click(fn=run_lora, inputs=inputs, outputs=outputs)
    prompt.submit(fn=run_lora, inputs=inputs, outputs=outputs)

app.launch(debug=True)