import gradio as gr
import numpy as np
import os
import json
import subprocess
from PIL import Image
from functools import partial
from datetime import datetime
from sam_inference import get_sam_predictor, sam_seg
from utils import blend_seg, blend_seg_pure
import cv2
import uuid
import torch
import trimesh
from huggingface_hub import snapshot_download

from gradio_model3dcolor import Model3DColor
# from gradio_model3dnormal import Model3DNormal

code_dir = snapshot_download("sudo-ai/MeshFormer-API", token=os.environ['HF_TOKEN'])

with open(f'{code_dir}/api.json', 'r') as file:
    api_dict = json.load(file)
    SEG_CMD = api_dict["SEG_CMD"]
    MESH_CMD = api_dict["MESH_CMD"]

STYLE = """
    <link href="https://cdn.jsdelivr.net/npm/bootstrap@5.3.2/dist/css/bootstrap.min.css" rel="stylesheet" integrity="sha384-T3c6CoIi6uLrA9TneNEoa7RxnatzjcDSCmG1MXxSR1GAsXEV/Dwwykc2MPK8M2HN" crossorigin="anonymous">
    <style>
        .alert, .alert div, .alert b {
            color: black !important;
        }
    </style>
"""
# info (info-circle-fill), cursor (hand-index-thumb), wait (hourglass-split), done (check-circle)
ICONS = {
    "info": """<svg xmlns="http://www.w3.org/2000/svg" width="16" height="16" fill="#0d6efd" class="bi bi-info-circle-fill flex-shrink-0 me-2" viewBox="0 0 16 16">
    <path d="M8 16A8 8 0 1 0 8 0a8 8 0 0 0 0 16zm.93-9.412-1 4.705c-.07.34.029.533.304.533.194 0 .487-.07.686-.246l-.088.416c-.287.346-.92.598-1.465.598-.703 0-1.002-.422-.808-1.319l.738-3.468c.064-.293.006-.399-.287-.47l-.451-.081.082-.381 2.29-.287zM8 5.5a1 1 0 1 1 0-2 1 1 0 0 1 0 2z"/>
    </svg>""",
    "cursor": """<svg xmlns="http://www.w3.org/2000/svg" width="16" height="16" fill="#0dcaf0" class="bi bi-hand-index-thumb-fill flex-shrink-0 me-2" viewBox="0 0 16 16">
    <path d="M8.5 1.75v2.716l.047-.002c.312-.012.742-.016 1.051.046.28.056.543.18.738.288.273.152.456.385.56.642l.132-.012c.312-.024.794-.038 1.158.108.37.148.689.487.88.716.075.09.141.175.195.248h.582a2 2 0 0 1 1.99 2.199l-.272 2.715a3.5 3.5 0 0 1-.444 1.389l-1.395 2.441A1.5 1.5 0 0 1 12.42 16H6.118a1.5 1.5 0 0 1-1.342-.83l-1.215-2.43L1.07 8.589a1.517 1.517 0 0 1 2.373-1.852L5 8.293V1.75a1.75 1.75 0 0 1 3.5 0z"/>
    </svg>""",
    "wait": """<svg xmlns="http://www.w3.org/2000/svg" width="16" height="16" fill="#6c757d" class="bi bi-hourglass-split flex-shrink-0 me-2" viewBox="0 0 16 16">
    <path d="M2.5 15a.5.5 0 1 1 0-1h1v-1a4.5 4.5 0 0 1 2.557-4.06c.29-.139.443-.377.443-.59v-.7c0-.213-.154-.451-.443-.59A4.5 4.5 0 0 1 3.5 3V2h-1a.5.5 0 0 1 0-1h11a.5.5 0 0 1 0 1h-1v1a4.5 4.5 0 0 1-2.557 4.06c-.29.139-.443.377-.443.59v.7c0 .213.154.451.443.59A4.5 4.5 0 0 1 12.5 13v1h1a.5.5 0 0 1 0 1h-11zm2-13v1c0 .537.12 1.045.337 1.5h6.326c.216-.455.337-.963.337-1.5V2h-7zm3 6.35c0 .701-.478 1.236-1.011 1.492A3.5 3.5 0 0 0 4.5 13s.866-1.299 3-1.48V8.35zm1 0v3.17c2.134.181 3 1.48 3 1.48a3.5 3.5 0 0 0-1.989-3.158C8.978 9.586 8.5 9.052 8.5 8.351z"/>
    </svg>""",
    "done": """<svg xmlns="http://www.w3.org/2000/svg" width="16" height="16" fill="#198754" class="bi bi-check-circle-fill flex-shrink-0 me-2" viewBox="0 0 16 16">
    <path d="M16 8A8 8 0 1 1 0 8a8 8 0 0 1 16 0zm-3.97-3.03a.75.75 0 0 0-1.08.022L7.477 9.417 5.384 7.323a.75.75 0 0 0-1.06 1.06L6.97 11.03a.75.75 0 0 0 1.079-.02l3.992-4.99a.75.75 0 0 0-.01-1.05z"/>
    </svg>""",
}

icons2alert = {
    "info": "primary",  # blue
    "cursor": "info",  # light blue
    "wait": "secondary",  # gray
    "done": "success",  # green
}


def message(text, icon_type="info"):
    return f"""{STYLE}  <div class="alert alert-{icons2alert[icon_type]} d-flex align-items-center" role="alert"> {ICONS[icon_type]}
                            <div> 
                                {text} 
                            </div>
                        </div>"""


def preprocess(tmp_dir, input_img, idx=None):
    if idx is not None:
        print("image idx:", int(idx))
        input_img = Image.open(input_img[int(idx)]["name"])
    input_img.save(f"{tmp_dir}/input.png")
    # print(SEG_CMD.format(tmp_dir=tmp_dir))
    os.system(SEG_CMD.format(tmp_dir=tmp_dir))
    processed_img = Image.open(f"{tmp_dir}/seg.png")
    return processed_img.resize((320, 320), Image.Resampling.LANCZOS)


def ply_to_glb(ply_path):
    result = subprocess.run(
        ["python", "ply2glb.py", "--", ply_path],
        capture_output=True,
        text=True,
    )

    print("Output of blender script:")
    print(result.stdout)

    glb_path = ply_path.replace(".ply", ".glb")
    return glb_path


def mesh_gen(tmp_dir, simplify, num_inference_steps):
    # print(MESH_CMD.format(tmp_dir=tmp_dir, num_inference_steps=num_inference_steps))
    os.system(MESH_CMD.format(tmp_dir=tmp_dir, num_inference_steps=num_inference_steps))

    mesh = trimesh.load_mesh(f"{tmp_dir}/mesh.ply")
    vertex_normals = mesh.vertex_normals
    theta = np.radians(-90)  # Rotation angle in radians
    # Create rotation matrix
    cos_theta = np.cos(theta)
    sin_theta = np.sin(theta)
    rotation_matrix = np.array([
        [cos_theta, -sin_theta, 0],
        [sin_theta, cos_theta, 0],
        [0, 0, 1]
    ])
    rotated_normal = np.dot(vertex_normals, rotation_matrix.T)
    # rotated_normal = rotated_normal / np.linalg.norm(rotated_normal)
    colors = (-rotated_normal + 1) / 2.0
    # colors = (-vertex_normals + 1) / 2.0
    colors = (colors * 255).clip(0, 255).astype(np.uint8)  # Convert to 8-bit color
    # print(colors.shape)
    mesh.visual.vertex_colors = colors[..., [2, 1, 0]]  # RGB -> BGR
    mesh.export(f"{tmp_dir}/mesh_normal.ply", file_type="ply")

    color_path = ply_to_glb(f"{tmp_dir}/mesh.ply")
    normal_path = ply_to_glb(f"{tmp_dir}/mesh_normal.ply")

    return color_path, normal_path


def create_tmp_dir():
    tmp_dir = (
        "demo_exp/"
        + datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
        + "_"
        + str(uuid.uuid4())[:4]
    )
    os.makedirs(tmp_dir, exist_ok=True)
    print("create tmp_exp_dir", tmp_dir)
    return tmp_dir


def vis_seg(checkbox):
    if checkbox:
        print("Show manual seg windows")
        return (
            [gr.Image(value=None, visible=True)] * 2
            + [gr.Radio(visible=True)]
            + [[], gr.Checkbox(visible=True)]
        )
    else:
        print("Clear manual seg")
        return (
            [gr.Image(visible=False)] * 2
            + [gr.Radio(visible=False)]
            + [[], gr.Checkbox(visible=False)]
        )


def calc_feat(checkbox, predictor, input_image, idx=None):
    if checkbox:
        if idx is not None:
            print("image idx:", int(idx))
            input_image = Image.open(input_image[int(idx)]["name"])
        input_image.thumbnail([512, 512], Image.Resampling.LANCZOS)
        w, h = input_image.size
        print("image size:", w, h)
        side_len = np.max((w, h))
        seg_in = Image.new(input_image.mode, (side_len, side_len), (255, 255, 255))
        seg_in.paste(
            input_image, (np.max((0, (h - w) // 2)), np.max((0, (w - h) // 2)))
        )
        print("Calculating image SAM feature...")
        predictor.set_image(np.array(seg_in.convert("RGB")))
        torch.cuda.empty_cache()
        return gr.Image(value=seg_in, visible=True)
    else:
        print("Quit manual seg")
        raise ValueError("Quit manual seg")


def manual_seg(
    predictor,
    seg_in,
    selected_points,
    fg_bg_radio,
    tmp_dir,
    seg_mask_opt,
    evt: gr.SelectData,
):
    print("Start segmentation")
    selected_points.append(
        {"coord": evt.index, "add_del": fg_bg_radio == "+ (add mask)"}
    )
    input_points = np.array([point["coord"] for point in selected_points])
    input_labels = np.array([point["add_del"] for point in selected_points])
    out_image = sam_seg(
        predictor, np.array(seg_in.convert("RGB")), input_points, input_labels
    )

    # seg_in.save(f"{tmp_dir}/in.png")
    # out_image.save(f"{tmp_dir}/out.png")
    if seg_mask_opt:
        segmentation = blend_seg_pure(
            seg_in.convert("RGB"), out_image, input_points, input_labels
        )
    else:
        segmentation = blend_seg(
            seg_in.convert("RGB"), out_image, input_points, input_labels
        )

    # recenter and rescale
    image_arr = np.array(out_image)
    ret, mask = cv2.threshold(
        np.array(out_image.split()[-1]), 0, 255, cv2.THRESH_BINARY
    )
    x, y, w, h = cv2.boundingRect(mask)
    max_size = max(w, h)
    ratio = 0.75
    side_len = int(max_size / ratio)
    padded_image = np.zeros((side_len, side_len, 4), dtype=np.uint8)
    center = side_len // 2
    padded_image[
        center - h // 2 : center - h // 2 + h, center - w // 2 : center - w // 2 + w
    ] = image_arr[y : y + h, x : x + w]
    rgba = Image.fromarray(padded_image)
    rgba.save(f"{tmp_dir}/seg.png")
    torch.cuda.empty_cache()

    return segmentation.resize((380, 380), Image.Resampling.LANCZOS), rgba.resize(
        (320, 320), Image.Resampling.LANCZOS
    )


custom_theme = gr.themes.Soft(primary_hue="blue").set(
    button_secondary_background_fill="*neutral_100",
    button_secondary_background_fill_hover="*neutral_200",
)

with gr.Blocks(title="MeshFormer Demo", css="style.css", theme=custom_theme) as demo:
    with gr.Row():
        gr.Markdown(
            "# MeshFormer: High-Quality Mesh Generation with 3D-Guided Reconstruction Model"
        )
    with gr.Row():
        gr.Markdown(
            "[Project Page](https://meshformer3d.github.io/) | [arXiv](https://arxiv.org/pdf/2408.10198)"
        )
    with gr.Row():
        gr.Markdown(
            """
<div>
<b><em>Check out <a href="https://www.sudo.ai/3dgen">Hillbot (sudoAI)</a> for more details and advanced features.</em></b>
</div>
"""
        )
    with gr.Row():
        guide_text_i2m = gr.HTML(message("Please input an image!"), visible=True)

    tmp_dir_img = gr.State("./demo_exp/placeholder")
    tmp_dir_txt = gr.State("./demo_exp/placeholder")
    tmp_dir_3t3 = gr.State("./demo_exp/placeholder")

    example_folder = os.path.join(os.path.dirname(__file__), "demo_examples")
    example_fns = os.listdir(example_folder)
    example_fns.sort()
    img_examples = [
        os.path.join(example_folder, x) for x in example_fns
    ]  # if x.endswith('.png') or x.endswith('.')

    with gr.Row(variant="panel"):
        with gr.Row():
            with gr.Column(scale=8):
                input_image = gr.Image(
                    type="pil",
                    image_mode="RGBA",
                    height=320,
                    label="Input Image",
                    interactive=True,
                )
                gr.Examples(
                    examples=img_examples,
                    inputs=[input_image],
                    outputs=[input_image],
                    cache_examples=False,
                    label="Image Examples (Click one of the images below to start)",
                    examples_per_page=27,
                )
                with gr.Accordion("Options", open=False):
                    img_simplify = gr.Checkbox(
                        False, label="simplify the generated mesh", visible=False
                    )
                    n_steps_img = gr.Slider(
                        value=28,
                        minimum=15,
                        maximum=100,
                        step=1,
                        label="number of inference steps",
                    )
                # manual segmentation
                checkbox_manual_seg = gr.Checkbox(False, label="manual segmentation")
                with gr.Row():
                    with gr.Column(scale=1):
                        seg_in = gr.Image(
                            type="pil",
                            image_mode="RGBA",
                            label="Click to segment",
                            visible=False,
                            show_download_button=False,
                            height=380,
                        )
                    with gr.Column(scale=1):
                        seg_out = gr.Image(
                            type="pil",
                            image_mode="RGBA",
                            label="Segmentation",
                            interactive=False,
                            visible=False,
                            show_download_button=False,
                            height=380,
                            elem_id="disp_image",
                        )
                fg_bg_radio = gr.Radio(
                    ["+ (add mask)", "- (remove area)"],
                    value="+ (add mask)",
                    info="Select foreground (+) or background (-) point",
                    label="Point label",
                    visible=False,
                    interactive=True,
                )
                seg_mask_opt = gr.Checkbox(
                    True,
                    label="show foreground mask in manual segmentation",
                    visible=False,
                )
                # run
                img_run_btn = gr.Button(
                    "Generate", variant="primary", interactive=False
                )
            with gr.Column(scale=6):
                processed_image = gr.Image(
                    type="pil",
                    label="Processed Image",
                    interactive=False,
                    height=320,
                    image_mode="RGBA",
                    elem_id="disp_image",
                )
                # with gr.Row():
                # mesh_output = gr.Model3D(label="Generated Mesh", elem_id="model-3d-out")
                mesh_output_normal = Model3DColor(
                    label="Generated Mesh (normal)",
                    elem_id="mesh-normal-out",
                    height=400,
                )
                mesh_output = Model3DColor(
                    label="Generated Mesh (color)",
                    elem_id="mesh-out",
                    height=400,
                )

    predictor = gr.State(value=get_sam_predictor())
    selected_points = gr.State(value=[])
    selected_points_t2i = gr.State(value=[])

    disable_checkbox = lambda: gr.Checkbox(value=False)
    disable_button = lambda: gr.Button(interactive=False)
    enable_button = lambda: gr.Button(interactive=True)
    update_guide = lambda GUIDE_TEXT, icon_type="info": gr.HTML(
        value=message(GUIDE_TEXT, icon_type)
    )
    update_md = lambda GUIDE_TEXT: gr.Markdown(value=GUIDE_TEXT)

    def is_img_clear(input_image):
        if not input_image:
            raise ValueError("Input image cleared.")

    checkbox_manual_seg.change(
        vis_seg,
        inputs=[checkbox_manual_seg],
        outputs=[seg_in, seg_out, fg_bg_radio, selected_points, seg_mask_opt],
        queue=False,
    ).success(
        calc_feat,
        inputs=[checkbox_manual_seg, predictor, input_image],
        outputs=[seg_in],
        queue=True,
    ).success(
        fn=create_tmp_dir, outputs=[tmp_dir_img], queue=False
    )

    seg_in.select(
        manual_seg,
        [predictor, seg_in, selected_points, fg_bg_radio, tmp_dir_img, seg_mask_opt],
        [seg_out, processed_image],
        queue=True,
    )

    input_image.change(disable_button, outputs=img_run_btn, queue=False).success(
        disable_checkbox, outputs=checkbox_manual_seg, queue=False
    ).success(fn=is_img_clear, inputs=input_image, queue=False).success(
        fn=create_tmp_dir, outputs=tmp_dir_img, queue=False
    ).success(
        fn=partial(update_guide, "Preprocessing the image!", "wait"),
        outputs=[guide_text_i2m],
        queue=False,
    ).success(
        fn=preprocess,
        inputs=[tmp_dir_img, input_image],
        outputs=[processed_image],
        queue=True,
    ).success(
        fn=partial(
            update_guide,
            "Click <b>Generate</b> to generate mesh! If the input image was not segmented accurately, please adjust it using <b>manual segmentation</b>.",
            "cursor",
        ),
        outputs=[guide_text_i2m],
        queue=False,
    ).success(
        enable_button, outputs=img_run_btn, queue=False
    )

    img_run_btn.click(
        fn=partial(update_guide, "Generating the mesh!", "wait"),
        outputs=[guide_text_i2m],
        queue=False,
    ).success(
        fn=mesh_gen,
        inputs=[tmp_dir_img, img_simplify, n_steps_img],
        outputs=[mesh_output, mesh_output_normal],
        queue=True,
    ).success(
        fn=partial(
            update_guide,
            "Successfully generated the mesh. (It might take a few seconds to load the mesh)",
            "done",
        ),
        outputs=[guide_text_i2m],
        queue=False,
    )

demo.queue().launch(
    debug=False, share=False, inline=False, show_api=False, server_name="0.0.0.0"
)