import os
import gradio as gr
import torch
from PIL import Image
from transformers import MllamaForConditionalGeneration, AutoProcessor
from huggingface_hub import login
import spaces
import json
import matplotlib.pyplot as plt
import io
import base64


def check_environment():
    required_vars = ["HF_TOKEN"]
    missing_vars = [var for var in required_vars if var not in os.environ]

    if missing_vars:
        raise ValueError(
            f"Missing required environment variables: {', '.join(missing_vars)}\n"
            "Please set the HF_TOKEN environment variable with your Hugging Face token"
        )


# Login to Hugging Face
check_environment()
login(token=os.environ["HF_TOKEN"], add_to_git_credential=True)

# Load model and processor (do this outside the inference function to avoid reloading)
base_model_path = (
    "taesiri/BugsBunny-LLama-3.2-11B-Vision-BaseCaptioner-XLarge-FullModel"
)

processor = AutoProcessor.from_pretrained(base_model_path)
model = MllamaForConditionalGeneration.from_pretrained(
    base_model_path,
    torch_dtype=torch.bfloat16,
    device_map="cuda",
)
# model = PeftModel.from_pretrained(model, lora_weights_path)
model.tie_weights()


def describe_image_in_JSON(json_string):
    try:
        # First JSON decode
        first_decode = json.loads(json_string)

        # Second JSON decode - parse the actual data
        final_data = json.loads(first_decode)

        return final_data

    except json.JSONDecodeError as e:
        return f"Error parsing JSON: {str(e)}"


def create_color_palette_image(colors):
    if not colors or not isinstance(colors, list):
        return None

    try:
        # Validate color format
        for color in colors:
            if not isinstance(color, str) or not color.startswith("#"):
                return None

        # Create figure and axis
        fig, ax = plt.subplots(figsize=(10, 2))

        # Create rectangles for each color
        for i, color in enumerate(colors):
            ax.add_patch(plt.Rectangle((i, 0), 1, 1, facecolor=color))

        # Set the view limits and aspect ratio
        ax.set_xlim(0, len(colors))
        ax.set_ylim(0, 1)
        ax.set_xticks([])
        ax.set_yticks([])

        return fig  # Return the matplotlib figure directly
    except Exception as e:
        print(f"Error creating color palette: {e}")
        return None


@spaces.GPU
def inference(image):
    if image is None:
        return ["Please provide an image"] * 8

    if not isinstance(image, Image.Image):
        try:
            image = Image.fromarray(image)
        except Exception as e:
            print(f"Image conversion error: {e}")
            return ["Invalid image format"] * 8

    # Prepare input
    messages = [
        {
            "role": "user",
            "content": [
                {"type": "image"},
                {"type": "text", "text": "Describe the image in JSON"},
            ],
        }
    ]
    input_text = processor.apply_chat_template(messages, add_generation_prompt=True)
    try:
        # Move inputs to the correct device
        inputs = processor(
            image, input_text, add_special_tokens=False, return_tensors="pt"
        ).to(model.device)

        # Clear CUDA cache after inference
        with torch.no_grad():
            output = model.generate(**inputs, max_new_tokens=2048)
            if torch.cuda.is_available():
                torch.cuda.empty_cache()

    except Exception as e:
        print(f"Inference error: {e}")
        return ["Error during inference"] * 8

    # Decode output
    result = processor.decode(output[0], skip_special_tokens=True)
    print("DEBUG: Full decoded output:", result)

    try:
        json_str = result.strip().split("assistant\n")[1].strip()
        print("DEBUG: Extracted JSON string after split:", json_str)
    except Exception as e:
        print("DEBUG: Error splitting response:", e)
        return ["Error extracting JSON from response"] * 8 + [
            "Failed to extract JSON",
            "Error",
        ]

    parsed_json = describe_image_in_JSON(json_str)
    if parsed_json:
        # Create color palette visualization
        colors = parsed_json.get("color_palette", [])
        color_image = create_color_palette_image(colors)

        # Convert lists to proper format for Gradio JSON components
        character_list = json.dumps(parsed_json.get("character_list", []))
        object_list = json.dumps(parsed_json.get("object_list", []))
        texture_details = json.dumps(parsed_json.get("texture_details", []))

        return (
            parsed_json.get("description", "Not available"),
            parsed_json.get("scene_description", "Not available"),
            character_list,
            object_list,
            texture_details,
            parsed_json.get("lighting_details", "Not available"),
            color_image,
            json_str,
            "",  # Error box
            "Analysis complete",  # Status
        )
    return ["Error parsing response"] * 8 + ["Failed to parse JSON", "Error"]


# Update Gradio interface
with gr.Blocks() as demo:
    gr.Markdown("# BugsBunny-LLama-3.2-11B-Base-XLarge Demo")

    with gr.Row():
        with gr.Column(scale=1):
            image_input = gr.Image(
                type="pil",
                label="Upload Image",
                elem_id="large-image",
            )
            submit_btn = gr.Button("Analyze Image", variant="primary")

            # Updated examples
            gr.Examples(
                examples=[
                    "./examples/1.jpg",
                    "./examples/2.jpg",
                    "./examples/3.jpg",
                    "./examples/4.jpg",
                    "./examples/5.jpg",
                    "./examples/6.jpg",
                    "./examples/7.jpg",
                    "./examples/8.jpg",
                    "./examples/9.jpg",
                ],
                inputs=image_input,
                label="Example Images",
                examples_per_page=5,
            )

    with gr.Tabs():
        with gr.Tab("Structured Results"):
            with gr.Column(scale=1):
                description_output = gr.Textbox(
                    label="Description",
                    lines=4,
                )
                scene_output = gr.Textbox(
                    label="Scene Description",
                    lines=2,
                )
                characters_output = gr.JSON(
                    label="Characters",
                )
                objects_output = gr.JSON(
                    label="Objects",
                )
                textures_output = gr.JSON(
                    label="Texture Details",
                )
                lighting_output = gr.Textbox(
                    label="Lighting Details",
                    lines=2,
                )
                color_palette_output = gr.Plot(
                    label="Color Palette",
                )

        with gr.Tab("Raw Output"):
            raw_output = gr.Textbox(
                label="Raw JSON Response",
                lines=25,
                max_lines=30,
            )

    error_box = gr.Textbox(label="Error Messages", visible=False)

    with gr.Row():
        status_text = gr.Textbox(label="Status", value="Ready", interactive=False)

    submit_btn.click(
        fn=inference,
        inputs=[image_input],
        outputs=[
            description_output,
            scene_output,
            characters_output,
            objects_output,
            textures_output,
            lighting_output,
            color_palette_output,
            raw_output,
            error_box,
            status_text,
        ],
        api_name="analyze",
    )

demo.launch(share=True)