import albumentations as A
import base64
import cv2
import logging
import gradio as gr
import inspect
import io
import numpy as np

from dataclasses import dataclass
from copy import deepcopy
from functools import wraps
from PIL import Image, ImageDraw
from typing import get_type_hints, Optional
from mixpanel import Mixpanel

from utils import is_not_supported_transform

logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s")
logger = logging.getLogger(__name__)

MIXPANEL_TOKEN = os.getenv("MIXPANEL_TOKEN")
mp = Mixpanel(MIXPANEL_TOKEN)

HEADER = f"""
<div align="center">
    <p>
        <img src="https://avatars.githubusercontent.com/u/57894582?s=200&v=4" alt="A" width="50" height="50" style="display:inline;">
        <span style="font-size: 30px; vertical-align: bottom;"> lbumentations Demo ({A.__version__})</span>
    </p>
    <p style="margin-top: -15px;">
        <a href="https://albumentations.ai/docs/" target="_blank" style="color: grey;">Documentation</a>
        &nbsp;
        <a href="https://github.com/albumentations-team/albumentations" target="_blank" style="color: grey;">GitHub Repository</a>
    </p>
</div>
"""

DEFAULT_TRANSFORM = "Rotate"

DEFAULT_IMAGE_PATH = "images/doctor.webp"
DEFAULT_IMAGE = np.array(Image.open(DEFAULT_IMAGE_PATH))
DEFAULT_IMAGE_HEIGHT = DEFAULT_IMAGE.shape[0]
DEFAULT_IMAGE_WIDTH = DEFAULT_IMAGE.shape[1]
DEFAULT_BOXES = [
    [265, 121, 326, 177],  # Mask
    [192, 169, 401, 395],  # Coverall
]

mask_keypoints = [[270, 123], [320, 130], [270, 151], [321, 158]]
pocket_keypoints = [[226, 379], [272, 386], [307, 388], [364, 380]]
arm_keypoints = [[215, 194], [372, 192], [214, 322], [378, 330]]
DEFAULT_KEYPOINTS = mask_keypoints + pocket_keypoints + arm_keypoints

BASE64_DEFAULT_MASKS = [
    {
        "label": "Coverall",
        # light green color
        "color": (144, 238, 144),
        "mask": "iVBORw0KGgoAAAANSUhEUgAAAlgAAAGQCAAAAABXXkFEAAAF+ElEQVR4nO3dwXLjNhBFUSg1///LziLj1Iwt26KkFhuvz9kkWVhFAJcAqXGStQAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAACe5+3sCyCUsq745+wLSKCsz4T1DMr6RFiUENZT2LI+EhYlhPWnt7t3nvt/MtSvsy+gkcfaeFuXJ11HBJPx7r+s7piPP3o0m/9zFP729tdfjv/gnT8dy1G41npeEc7Dd3astR7q6uOP2rT+4wb7mMLBGfkckildyyw8HMa1HWr8pK7pz1hF55YnrdE315dVHZmTb9IcPLVr7Oi/36oOTMpPe97Q+R16FL7wzW3sqThw2DdkdfOs3JTowDmeN+gbN6tbp+XJHxdk1pAPnIE3TczNnzdrmteaNeJjj1Y3zMyRD5w00WuNGu/RR/afpubZn5dlymjvexH8Znbu+cApk73WlLE+8v3C1Rm69wNnTPdaM0ba6hcOJkz4WhPG2SqrtSZM+Vr5o2yX1Vr5k75W+hhbZrVW+rSvlT3Ctlmt7Hlfa0X/anLnrnpf3DPkhtV86Zpf3sNyw+ouvKzYsPqvW/8rfERsWJxLWOeJ3rKERQlhnSh5y0oNK3nNtpAa1h6C8w8NK3jFNpEZlq5OlxnWNnLvgMiwcpdrH5FhbST2HkgMK3axdpIY1lZS7wJhUSIwrM32gM0u91aBYdFBXlihO8Bu4sLSVQ9xYe0n81ZICytzlTaUFtaOIm8GYVEiLKzIm39LYWHRhbAoIawGEg9wYVEiK6zEW39TWWHtKvCGEBYlhEUJYbWQdxYKixLCokRUWHkHyr6iwqIPYVEiKSwnYSNJYdGIsCghLEoIixLCokRQWF4KOwkKa2txd4WwmkgrS1iUEFYXYVuWsCiRE1bYHb+7nLBoRViUEBYlhEUJYVFCWG1kvdYKixIxYWXd7/tLCUtXzaSEFeBy9gU8VUhYNqxuMsLSVTsZYdFORFgZG1bGKN5FhEU/wqJEQlhZZ0iIhLBoSFiUEBYlhEUJYVEiIaysP70NkRAWDQmLEsLqI+qLXmFRQliUEBYlhEWJX2dfwK4ua4U9bj+XsA66/P0P0vqCo/CQy8dv+X3r/wVhHXElI2VdJ6wDiiOKalRYlBDWo6L2mecRVhtZhQrrUb5wuEpYlBBWF1knYUZYYWsSISKsM8vyiHVdRlivYWM8ICQsa95NSFinleUk/EJKWDQjrCbSDvOYsM5ZGCfhV2LCohdhPcKG9SVh3e5TRk/sKu0RKyis1y+N/eobOWG9nK6+ExTWa7esN119y79XeBdV/URYd5DVz4R1wNtlqepGUa+5+6551DKstaIe3ulEWJQQFiWERQlhUSIqrLx3q31FhUUfwqJEVljOwjaywqINYXUQuNOGhRW4QpsKC4su0sKyZTWRFtaWZe14zT+JCytylTYUuQyb/cJf5Brk7Vhrt5Xa62pvFRnWVmu107UekBlW6mptJDQsZZ0tNaxtpN4BqeN6fzW8/PH3LaUuQOq4PuqaVuz8OwopEXvHfNRyywqe/eCh/aVjV9Fz7z8KcpborIR1jvCo1hLWGQZk5a3wBCO6EtbLzehKWNQQFiWERQlhUUJYlBAWJYRFCWG9Wsc/Di8gLEoMCWvINtHIkLB4NWFRQliUENbLzXjeExYlhEUJYVFiRlgzHmtamREWLycsSgjr9UYczMKihLAoISxKCIsSwqKEsCghLEoI6wQTvsgSFiWERQlhUUJYZxjwkCUsSgiLEiPCGnDytDMirH7yU58QVv4qNjQhLE4gLEoIixLCooSwzhH/QjEhrMuQ/31NKxPCktYJBs14s9MnfOZn7FhrLdvWaw0KS1qvNG+qu5yI4TMfPryrmqSVPfWjjsLfnIgvMHWOO+xa0XM/ccdaK3xRO5gaFsWERQlhUWJqWB0e3qNNDauD6LiFRYmhYUVvFi3MDEtX5UaGpat6I8Oi3r8KSpCuwVpGmQAAAABJRU5ErkJggg==",
    },
    {
        "label": "Mask",
        # light blue color
        "color": (173, 216, 230),
        "mask": "iVBORw0KGgoAAAANSUhEUgAAAlgAAAGQCAAAAABXXkFEAAAB4ElEQVR4nO3csQ6CMBSG0avv/864OFhoobW9UeM5i4ML+fOFkmCMAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAIbcPn0BX257ftppkMEqtuY35uplqYN2VhFhsU5m2rnIKsJmXYy00xFWRBjuin1KvV2F6c7dP30Bv2ugwT8krPcp64SwJmzSahJWYbQUZbUIqzD8QK6sBmEVdLKKsEghLFIIa5LDs05YpBAWKYQ1y1lYJSxSCIsUwiKFsF55XlpGWLP83q9KWKQQ1qt37j6OzyphkUJYpBBWwZP4KsIqKWsRYZFCWDtuWWsIixTC2nPLWkJYB8paQVhHY2XpsMosdV0vaozXZpsG/+s3x0BN9bQM1sdOJ45pmauXpU4VadlqgLEubRF2AgAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAGCZBwKLGEVAl/J/AAAAAElFTkSuQmCC",
    },
]

# Get all the transforms from the albumentations library
transforms_map = {
    name: cls
    for name, cls in vars(A).items()
    if (
        inspect.isclass(cls)
        and issubclass(cls, (A.DualTransform, A.ImageOnlyTransform))
        and not is_not_supported_transform(cls)
    )
}
transforms_map.pop("DualTransform", None)
transforms_map.pop("ImageOnlyTransform", None)
transforms_map.pop("ReferenceBasedTransform", None)
transforms_keys = list(sorted(transforms_map.keys()))


# Decode the masks
for mask in BASE64_DEFAULT_MASKS:
    mask["mask"] = np.array(
        Image.open(io.BytesIO(base64.b64decode(mask["mask"]))).convert("L")
    )


@dataclass
class RequestParams:
    user_ip: str
    transform_name: Optional[str]

def track_event(event_name, user_id="unknown", properties=None):
    if properties is None:
        properties = {}
    mp.track(user_id, event_name, properties)
    logger.info(f"Event tracked: {event_name} - {properties}")


def get_params(request: gr.Request) -> RequestParams:
    """Parse input request parameters."""
    ip = request.client.host
    transform_name = request.query_params.get("transform", None)
    params = RequestParams(user_ip=ip, transform_name=transform_name)
    track_event("app_opened", user_id=params.user_ip, properties={"transform_name": params.transform_name})
    return params


def run_with_retry(compose):
    @wraps(compose)
    def wrapper(*args, **kwargs):
        processors = deepcopy(compose.processors)
        for _ in range(4):
            try:
                result = compose(*args, **kwargs)
                break
            except NotImplementedError as e:
                print(f"Caught NotImplementedError: {e}")
                if "bbox" in str(e):
                    kwargs.pop("bboxes", None)
                    kwargs.pop("category_id", None)
                    compose.processors.pop("bboxes")
                if "keypoint" in str(e):
                    kwargs.pop("keypoints", None)
                    compose.processors.pop("keypoints")
                if "mask" in str(e):
                    kwargs.pop("mask", None)
            except Exception as e:
                compose.processors = processors
                raise e
        compose.processors = processors
        return result

    return wrapper


def draw_boxes(image, boxes, color=(255, 0, 0), thickness=1) -> np.ndarray:
    """Draw boxes with PIL."""
    pil_image = Image.fromarray(image)
    draw = ImageDraw.Draw(pil_image)
    for box in boxes:
        x_min, y_min, x_max, y_max = box
        draw.rectangle([x_min, y_min, x_max, y_max], outline=color, width=thickness)
    return np.array(pil_image)


def draw_keypoints(image, keypoints, color=(255, 0, 0), radius=2):
    """Draw keypoints with PIL."""
    pil_image = Image.fromarray(image)
    draw = ImageDraw.Draw(pil_image)
    for keypoint in keypoints:
        x, y = keypoint
        draw.ellipse([x - radius, y - radius, x + radius, y + radius], fill=color)
    return np.array(pil_image)


def get_rgb_mask(masks):
    """Get the RGB mask from the binary mask."""
    rgb_mask = np.zeros((DEFAULT_IMAGE_HEIGHT, DEFAULT_IMAGE_WIDTH, 3), dtype=np.uint8)
    for data in masks:
        mask = data["mask"]
        rgb_mask[mask > 0] = np.array(data["color"])
    return rgb_mask


def draw_mask(image, mask):
    """Draw the mask on the image."""
    image_with_mask = cv2.addWeighted(image, 0.5, mask, 0.5, 0)
    return image_with_mask


def draw_not_implemented_image(image: np.ndarray, annotation_type: str):
    """Draw the image with a text. In the middle."""
    pil_image = Image.fromarray(image)
    draw = ImageDraw.Draw(pil_image)
    # align in the centerm, and make bigger font
    text = f'Transform NOT working with "{annotation_type.upper()}" annotations.'
    length = draw.textlength(text)
    draw.text(
        (DEFAULT_IMAGE_WIDTH // 2 - length // 2, DEFAULT_IMAGE_HEIGHT // 2),
        text,
        fill=(255, 0, 0),
        align="center",
    )
    return np.array(pil_image)


def get_formatted_signature(function_or_class, indentation=4):

    signature = inspect.signature(function_or_class)
    type_hints = get_type_hints(function_or_class)

    args = []
    for param in signature.parameters.values():
        if param.name == "p":
            str_param = "p=1.0,"
        elif param.default == inspect.Parameter.empty:
            str_param = f"{param.name}=,"
        else:
            if isinstance(param.default, str):
                str_param = f'{param.name}="{param.default}",'
            else:
                str_param = f"{param.name}={param.default},"

        annotation = type_hints.get(param.name, param.annotation)
        if isinstance(param.annotation, type):
            str_param += f"  # {param.annotation.__name__}"
        else:
            str_annotation = str(annotation).replace("typing.", "")
            str_param += f"  # {str_annotation}"
        str_param = "\n" + " " * indentation + str_param
        args.append(str_param)

    result = "(" + "".join(args) + "\n" + " " * (indentation - 4) + ")"
    return result


def get_formatted_transform(transform_name):
    track_event("transform_selected", properties={"transform_name": transform_name})
    transform = transforms_map[transform_name]
    return f"A.{transform.__name__}{get_formatted_signature(transform)}"


def get_formatted_transform_docs(transform_name):
    transform = transforms_map[transform_name]
    return transform.__doc__.strip("\n")


def update_augmented_images(image, code):

    augmentation = eval(code)
    track_event("transform_applied", properties={"transform_name": augmentation.__class__.__name__, "code": code})

    compose = A.Compose(
        [augmentation],
        bbox_params=A.BboxParams(format="pascal_voc", label_fields=["category_id"]),
        keypoint_params=A.KeypointParams(format="xy"),
    )
    compose = run_with_retry(compose)  # to prevent NotImplementedError

    keypoints = DEFAULT_KEYPOINTS
    bboxes = DEFAULT_BOXES
    mask = get_rgb_mask(BASE64_DEFAULT_MASKS)
    augmented = compose(
        image=image,
        not_implemented_image=image.copy(),
        mask=mask,
        keypoints=keypoints,
        bboxes=bboxes,
        category_id=range(len(bboxes)),
    )
    image = augmented["image"]
    mask = augmented.get("mask", None)
    bboxes = augmented.get("bboxes", None)
    keypoints = augmented.get("keypoints", None)

    # Draw the augmented images (or replace by placeholder if not implemented)
    if mask is not None:
        image_with_mask = draw_mask(image.copy(), mask)
    else:
        image_with_mask = draw_not_implemented_image(image.copy(), "mask")

    if bboxes is not None:
        image_with_bboxes = draw_boxes(image.copy(), bboxes)
    else:
        image_with_bboxes = draw_not_implemented_image(image.copy(), "boxes")

    if keypoints is not None:
        image_with_keypoints = draw_keypoints(image.copy(), keypoints)
    else:
        image_with_keypoints = draw_not_implemented_image(image.copy(), "keypoints")

    return [
        (image_with_mask, "Mask"),
        (image_with_bboxes, "Boxes"),
        (image_with_keypoints, "Keypoints"),
    ]


def update_image_info(image):
    h, w = image.shape[:2]
    dtype = image.dtype
    max_, min_ = image.max(), image.min()
    return f"Image info:\n\t - shape: {h}x{w}\n\t - dtype: {dtype}\n\t - min/max: {min_}/{max_}"


def update_code_and_docs(select):
    code = get_formatted_transform(select)
    docs = get_formatted_transform_docs(select)
    return code, docs

def update_code_and_docs_on_start(url_params: gr.Request):
    params = get_params(url_params)
    transform_name = params.transform_name if params.transform_name is not None else DEFAULT_TRANSFORM
    return gr.update(value=transform_name)

with gr.Blocks() as demo:
    gr.Markdown(HEADER)
    with gr.Row():
        with gr.Column():
            with gr.Group():
                select = gr.Dropdown(
                    label="Select a transformation",
                    choices=transforms_keys,
                    value=DEFAULT_TRANSFORM,
                    type="value",
                    interactive=True,
                )
                with gr.Accordion("Documentation (click to expand)", open=False):
                    docs = gr.TextArea(
                        get_formatted_transform_docs(DEFAULT_TRANSFORM),
                        show_label=False,
                        interactive=False,
                    )
                code = gr.Code(
                    label="Code",
                    language="python",
                    value=get_formatted_transform(DEFAULT_TRANSFORM),
                    interactive=True,
                    lines=5,
                )
            info = gr.TextArea(
                value=f"Image size: {DEFAULT_IMAGE_HEIGHT} x {DEFAULT_IMAGE_WIDTH} (height x width)",
                show_label=False,
                lines=1,
                max_lines=1,
            )
            button = gr.Button("Apply!")
        image = gr.Image(
            value=DEFAULT_IMAGE_PATH,
            type="numpy",
            height=500,
            width=300,
            sources=[],
        )
    with gr.Row():
        augmented_image = gr.Gallery(
            value=update_augmented_images(DEFAULT_IMAGE, "A.NoOp()"),
            rows=1,
            columns=3,
            show_label=False,
        )
    select.change(fn=update_code_and_docs, inputs=[select], outputs=[code, docs])
    button.click(
        fn=update_augmented_images, inputs=[image, code], outputs=[augmented_image]
    )
    demo.load(
        update_code_and_docs_on_start, inputs=None, outputs=[select]
    )

if __name__ == "__main__":
    demo.launch()