import gradio as gr
import torch
from diffusers import FluxPipeline
from huggingface_hub import HfApi
import spaces
import random

"""
This application uses the Flux.1 Lite model:
@article{flux1-lite,
  title={Flux.1 Lite: Distilling Flux1.dev for Efficient Text-to-Image Generation},
  author={Daniel Verdú, Javier Martín},
  email={dverdu@freepik.com, javier.martin@freepik.com},
  year={2024},
}
"""

@spaces.GPU(duration=70)
def initialize_model():
    model_id = "Freepik/flux.1-lite-8B-alpha"
    pipe = FluxPipeline.from_pretrained(
        model_id,
        torch_dtype=torch.bfloat16
    ).to("cuda")
    return pipe

@spaces.GPU(duration=70)
def generate_image(
    prompt,
    guidance_scale=3.5,
    width=1024,
    height=1024
):
    try:
        # Initialize model within the GPU context
        pipe = initialize_model()
        
        # Generate random seed
        seed = random.randint(1, 1000000)
        
        with torch.inference_mode():
            image = pipe(
                prompt=prompt,
                generator=torch.Generator(device="cuda").manual_seed(seed),
                num_inference_steps=25,  # Fixed steps
                guidance_scale=guidance_scale,
                height=height,
                width=width,
            ).images[0]
        
        return image
    except Exception as e:
        print(f"Error during image generation: {str(e)}")
        raise e

# Create the Gradio interface
demo = gr.Interface(
    fn=generate_image,
    inputs=[
        gr.Textbox(
            label="Prompt",
            placeholder="Enter your image description here...",
            value="a glass cup with beer, inside the beer a scuba diver, with a beautiful sunset background"
        ),
        gr.Slider(
            minimum=1,
            maximum=20,
            value=3.5,
            label="Guidance Scale",
            step=0.5
        ),
        gr.Slider(
            minimum=128,
            maximum=1024,
            value=1024,
            label="Width",
            step=64
        ),
        gr.Slider(
            minimum=128,
            maximum=1024,
            value=1024,
            label="Height",
            step=64
        )
    ],
    outputs=gr.Image(type="pil", label="Generated Image"),
    title="Freepix Flux.1-lite-8B-alpha Model (Zero-GPU)",
    description="Generate images using Freepik's Flux model with Zero-GPU allocation. Using 25 fixed steps and random seed for each generation.",
    examples=[
        ["A close-up image of a green alien with fluorescent skin in the middle of a dark purple forest", 3.5, 1024, 1024],
        ["a glass cup with beer, inside the beer a scuba diver, with a beautiful sunset backgroudn", 3.5, 1024, 1024]
    ]  # Properly closed the examples list
)  # Properly closed the Interface parenthesis

# Launch the app
if __name__ == "__main__":
    demo.launch()