import pickle as pkl
import numpy as np
import numpy.typing as npt

from PIL import Image
from PIL.Image import Image as ImageType
from pathlib import Path


def build_data(data_path: Path) -> dict:
    data = {}
    image_paths = (
        list(data_path.glob("*.png"))
        + list(data_path.glob("*.jpg"))
        + list(data_path.glob("*.jpeg"))
    )
    for image_path in image_paths:
        image_name = image_path.stem
        data[image_name] = {
            "image": image_path,
            "labels": [],
            "emb": None,
            "meta_data": None,
        }
    return data


class Data:
    def __init__(self, data_path: Path):
        self.data_path = data_path
        if Path(data_path).exists():
            with open(data_path, "rb") as f:
                self.data = pkl.load(f)
        else:
            data_path.parent.mkdir(parents=True, exist_ok=True)
            with open(data_path, "wb") as f:
                pkl.dump({}, f)
            self.data = {}

    def _save_data(self) -> None:
        with open(self.data_path, "wb") as f:
            pkl.dump(self.data, f)

    def __contains__(self, image: str) -> bool:
        return image in self.data

    def emb_exists(self, image: str) -> bool:
        return "emb" in self.data[image] and self.data[image]["emb"] is not None

    def save_labels(
        self, image: str, masks: list[ImageType], bboxes: list[tuple[int, ...]], labels: list[str]
    ) -> None:
        self.clear_labels(image)
        label_paths = []
        for i, (mask, label) in enumerate(zip(masks, labels)):
            label_path = self.data_path.parent / f"{image}.{label}.{i}.png"
            mask.save(label_path)
            label_paths.append(str(label_path))
        self.data[image]["masks"] = label_paths
        self.data[image]["labels"] = labels
        self.data[image]["bboxes"] = bboxes
        self._save_data()

    def save_meta_data(self, image: str, meta_data: dict) -> None:
        self.data[image]["meta_data"] = meta_data
        self._save_data()

    def save_emb(self, image: str, emb: npt.NDArray) -> None:
        emb_path = self.data_path.parent / f"{image}.emb.npy"
        np.save(emb_path, emb)
        self.data[image]["emb"] = emb_path
        self._save_data()

    def save_hq_emb(self, image: str, embs: list[npt.NDArray]) -> None:
        for i, emb in enumerate(embs):
            emb_path = self.data_path.parent / f"{image}.emb.{i}.npy"
            np.save(emb_path, emb)
            self.data[image][f"emb.{i}"] = emb_path
        self._save_data()

    def save_image(self, image: str, image_pil: ImageType) -> None:
        image_path = self.data_path.parent / f"{image}.png"
        image_pil.save(image_path)
        self.data[image] = {}
        self.data[image]["image"] = image_path
        self._save_data()

    def clear_labels(self, image: str) -> None:
        if "masks" in self.data[image]:
            for label_path in self.data[image]["masks"]:
                Path(label_path).unlink(missing_ok=True)
        if "labels" in self.data[image]:
            self.data[image]["labels"] = []
        self._save_data()

    def delete_image(self, image: str) -> None:
        if image in self.data:
            if "image" in self.data[image]:
                Path(self.data[image]["image"]).unlink(missing_ok=True)
            if "emb" in self.data[image]:
                Path(self.data[image]["emb"]).unlink(missing_ok=True)
            if "masks" in self.data[image]:
                for label_path in self.data[image]["masks"]:
                    Path(label_path).unlink(missing_ok=True)
            del self.data[image]
            self._save_data()

    def get_all_images(self) -> list:
        return list(self.data.keys())

    def get_image(self, image: str) -> ImageType:
        return Image.open(self.data[image]["image"])

    def get_emb(self, image: str) -> npt.NDArray:
        return np.load(self.data[image]["emb"])

    def get_hq_emb(self, image: str) -> list[npt.NDArray]:
        embs = []
        i = 0
        while True:
            if f"emb.{i}" in self.data[image]:
                embs.append(np.load(self.data[image][f"emb.{i}"]))
                i += 1
            else:
                break
        return embs

    def get_labels(
        self, image: str
    ) -> tuple[list[ImageType], list[tuple[int, ...]], list[str]]:
        if (
            "masks" not in self.data[image]
            or "labels" not in self.data[image]
            or "bboxes" not in self.data[image]
        ):
            return [], [], []
        return (
            [Image.open(mask) for mask in self.data[image]["masks"]],
            [tuple(e) for e in self.data[image]["bboxes"]],
            self.data[image]["labels"],
        )

    def get_meta_data(self, image: str) -> dict:
        return self.data[image]["meta_data"]