# This file is copied from https://github.com/rnwzd/FSPBT-Image-Translation/blob/master/data.py

# MIT License

# Copyright (c) 2022 Lorenzo Breschi

# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:

# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.

# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.

from typing import Callable, Dict

import torch

from torch.utils.data import Dataset

import torchvision.transforms.functional as F
from torchvision import transforms
import pytorch_lightning as pl

from collections.abc import Iterable


# image reader writer
from pathlib import Path
from PIL import Image
from typing import Tuple


def read_image(filepath: Path, mode: str = None) -> Image:
    with open(filepath, 'rb') as file:
        image = Image.open(file)
        return image.convert(mode)


image2tensor = transforms.ToTensor()
tensor2image = transforms.ToPILImage()


def write_image(image: Image, filepath: Path):
    filepath.parent.mkdir(parents=True, exist_ok=True)
    image.save(str(filepath))


def read_image_tensor(filepath: Path, mode: str = 'RGB') -> torch.Tensor:
    return image2tensor(read_image(filepath, mode))


def write_image_tensor(input: torch.Tensor, filepath: Path):
    write_image(tensor2image(input), filepath)


def get_valid_indices(H: int, W: int, patch_size: int, random_overlap: int = 0):

    vih = torch.arange(random_overlap, H-patch_size -
                       random_overlap+1, patch_size)
    viw = torch.arange(random_overlap, W-patch_size -
                       random_overlap+1, patch_size)
    if random_overlap > 0:
        rih = torch.randint_like(vih, -random_overlap, random_overlap)
        riw = torch.randint_like(viw, -random_overlap, random_overlap)
        vih += rih
        viw += riw
    vi = torch.stack(torch.meshgrid(vih, viw)).view(2, -1).t()
    return vi


def cut_patches(input: torch.Tensor, indices: Tuple[Tuple[int, int]], patch_size: int, padding: int = 0):
    # TODO use slices to get all patches at the same time ?

    patches_l = []
    for n in range(len(indices)):

        patch = F.crop(input, *(indices[n]-padding),
                       *(patch_size+padding*2,)*2)
        patches_l.append(patch)
    patches = torch.cat(patches_l, dim=0)

    return patches


def prepare_data(data_path: Path, read_func: Callable = read_image_tensor) -> Dict:
    """
    Takes a data_path of a folder which contains subfolders with input, target, etc.
    lablelled by the same names.
    :param data_path: Path of the folder containing data
    :param read_func: function that reads data and returns a tensor
    """
    data_dict = {}

    subdir_names = ["target", "input", "mask"]  # ,"helper"

    # checks only files for which there is an target
    # TODO check for images
    name_ls = [file.name for file in (
        data_path / "target").iterdir() if file.is_file()]

    subdirs = [data_path / sdn for sdn in subdir_names]
    for sd in subdirs:
        if sd.is_dir():
            data_ls = []
            files = [sd / name for name in name_ls]
            for file in files:
                tensor = read_func(file)
                H, W = tensor.shape[-2:]
                data_ls.append(tensor)
            # TODO check that all sizes match
            data_dict[sd.name] = torch.stack(data_ls, dim=0)

    data_dict['name'] = name_ls
    data_dict['len'] = len(data_dict['name'])
    data_dict['H'] = H
    data_dict['W'] = W
    return data_dict


# TODO an image is loaded whenever a patch is needed, this may be a bottleneck
class DataDictLoader():
    def __init__(self, data_dict: Dict,
                 batch_size: int = 16,
                 max_length: int = 128,
                 shuffle: bool = False):
        """
        """

        self.batch_size = batch_size
        self.shuffle = shuffle

        self.batch_size = batch_size

        self.data_dict = data_dict
        self.dataset_len = data_dict['len']
        self.len = self.dataset_len if max_length is None else min(
            self.dataset_len, max_length)
        # Calculate # batches
        num_batches, remainder = divmod(self.len, self.batch_size)
        if remainder > 0:
            num_batches += 1
        self.num_batches = num_batches

    def __iter__(self):
        if self.shuffle:
            r = torch.randperm(self.dataset_len)
            self.data_dict = {k: v[r] if isinstance(
                v, Iterable) else v for k, v in self.data_dict.items()}
        self.i = 0
        return self

    def __next__(self):
        if self.i >= self.len:
            raise StopIteration
        batch = {k: v[self.i:self.i+self.batch_size]
                 if isinstance(v, Iterable) else v for k, v in self.data_dict.items()}

        self.i += self.batch_size
        return batch

    def __len__(self):
        return self.num_batches


class PatchDataModule(pl.LightningDataModule):

    def __init__(self, data_dict,
                 patch_size: int = 2**5,
                 batch_size: int = 2**4,
                 patch_num: int = 2**6):
        super().__init__()
        self.data_dict = data_dict
        self.H, self.W = data_dict['H'], data_dict['W']
        self.len = data_dict['len']

        self.batch_size = batch_size
        self.patch_size = patch_size
        self.patch_num = patch_num

    def dataloader(self, data_dict,  **kwargs):
        return DataDictLoader(data_dict, **kwargs)

    def train_dataloader(self):
        patches = self.cut_patches()
        return self.dataloader(patches, batch_size=self.batch_size, shuffle=True,
                               max_length=self.patch_num)

    def val_dataloader(self):
        return self.dataloader(self.data_dict, batch_size=1)

    def test_dataloader(self):
        return self.dataloader(self.data_dict)  # TODO batch size

    def cut_patches(self):
        # TODO cycle once
        patch_indices = get_valid_indices(
            self.H, self.W, self.patch_size, self.patch_size//4)
        dd = {k: cut_patches(
            v, patch_indices, self.patch_size) for k, v in self.data_dict.items()
            if isinstance(v, torch.Tensor)
        }
        threshold = 0.1
        mask_p = torch.mean(
            dd.get('mask', torch.ones_like(dd['input'])), dim=(-1, -2, -3))
        masked_idx = (mask_p > threshold).nonzero(as_tuple=True)[0]
        dd = {k: v[masked_idx] for k, v in dd.items()}
        dd['len'] = len(masked_idx)
        dd['H'], dd['W'] = (self.patch_size,)*2

        return dd


class ImageDataset(Dataset):
    def __init__(self, file_paths: Iterable, read_func: Callable = read_image_tensor):
        self.file_paths = file_paths

    def __getitem__(self, idx: int) -> dict:
        file = self.file_paths[idx]
        return read_image_tensor(file), file.name

    def __len__(self) -> int:
        return len(self.file_paths)