import torch

from flux_pipeline import FluxPipeline
import gradio as gr  # type: ignore
from PIL import Image


def create_demo(
    config_path: str,
):
    generator = FluxPipeline.load_pipeline_from_config_path(config_path)

    def generate_image(
        prompt,
        width,
        height,
        num_steps,
        guidance,
        seed,
        init_image,
        image2image_strength,
        add_sampling_metadata,
    ):

        seed = int(seed)
        if seed == -1:
            seed = None
        out = generator.generate(
            prompt,
            width,
            height,
            num_steps=num_steps,
            guidance=guidance,
            seed=seed,
            init_image=init_image,
            strength=image2image_strength,
            silent=False,
            num_images=1,
            return_seed=True,
        )
        image_bytes = out[0]
        return Image.open(image_bytes), str(out[1]), None

    is_schnell = generator.config.version == "flux-schnell"

    with gr.Blocks() as demo:
        gr.Markdown(f"# Flux Image Generation Demo - Model: {generator.config.version}")

        with gr.Row():
            with gr.Column():
                prompt = gr.Textbox(
                    label="Prompt",
                    value='a photo of a forest with mist swirling around the tree trunks. The word "FLUX" is painted over it in big, red brush strokes with visible texture',
                )
                do_img2img = gr.Checkbox(
                    label="Image to Image", value=False, interactive=not is_schnell
                )
                init_image = gr.Image(label="Input Image", visible=False)
                image2image_strength = gr.Slider(
                    0.0, 1.0, 0.8, step=0.1, label="Noising strength", visible=False
                )

                with gr.Accordion("Advanced Options", open=False):
                    width = gr.Slider(128, 8192, 1152, step=16, label="Width")
                    height = gr.Slider(128, 8192, 640, step=16, label="Height")
                    num_steps = gr.Slider(
                        1, 50, 4 if is_schnell else 20, step=1, label="Number of steps"
                    )
                    guidance = gr.Slider(
                        1.0,
                        10.0,
                        3.5,
                        step=0.1,
                        label="Guidance",
                        interactive=not is_schnell,
                    )
                    seed = gr.Textbox(-1, label="Seed (-1 for random)")
                    add_sampling_metadata = gr.Checkbox(
                        label="Add sampling parameters to metadata?", value=True
                    )

                generate_btn = gr.Button("Generate")

            with gr.Column(min_width="960px"):
                output_image = gr.Image(label="Generated Image")
                seed_output = gr.Number(label="Used Seed")
                warning_text = gr.Textbox(label="Warning", visible=False)
                # download_btn = gr.File(label="Download full-resolution")

        def update_img2img(do_img2img):
            return {
                init_image: gr.update(visible=do_img2img),
                image2image_strength: gr.update(visible=do_img2img),
            }

        do_img2img.change(
            update_img2img, do_img2img, [init_image, image2image_strength]
        )

        generate_btn.click(
            fn=generate_image,
            inputs=[
                prompt,
                width,
                height,
                num_steps,
                guidance,
                seed,
                init_image,
                image2image_strength,
                add_sampling_metadata,
            ],
            outputs=[output_image, seed_output, warning_text],
        )

    return demo


if __name__ == "__main__":
    import argparse

    parser = argparse.ArgumentParser(description="Flux")
    parser.add_argument(
        "--config", type=str, default="configs/config-dev.json", help="Config file path"
    )
    parser.add_argument(
        "--share", action="store_true", help="Create a public link to your demo"
    )
    args = parser.parse_args()

    demo = create_demo(args.config)
    demo.launch(share=args.share)