import pickle
import pathlib
import torch
from torch.utils.data.dataloader import DataLoader
import pytorch_lightning as pl
import numpy as np
import yaml
import torchaudio
import pyworld
import pysptk
import random


class DataModule(pl.LightningDataModule):
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.batchsize = config["train"]["batchsize"]
        self.preprocessed_dir = pathlib.Path(config["general"]["preprocessed_path"])

    def setup(self, stage):

        if not self.preprocessed_dir.exists():
            raise RuntimeError("Preprocessed directory was not be found")

        if "dual" in self.config:
            if self.config["dual"]["enable"]:
                task_config = yaml.load(
                    open(self.config["dual"]["config_path"], "r"),
                    Loader=yaml.FullLoader,
                )
                task_preprocessed_dir = (
                    self.preprocessed_dir.parent
                    / pathlib.Path(task_config["general"]["preprocessed_path"]).name
                )
                if not task_preprocessed_dir.exists():
                    raise RuntimeError(
                        "Preprocessed directory for multi-task learning was not be found"
                    )

        self.flnames = {
            "train": "train.txt",
            "val": "val.txt",
            "test": "test.txt",
        }

    def get_ds(self, phase):
        ds = Dataset(self.flnames[phase], self.config)
        return ds

    def get_loader(self, phase):
        ds = self.get_ds(phase)
        dl = DataLoader(
            ds,
            self.batchsize,
            shuffle=True if phase == "train" else False,
            num_workers=self.config["train"]["num_workers"],
            drop_last=True,
        )
        return dl

    def train_dataloader(self):
        return self.get_loader(phase="train")

    def val_dataloader(self):
        return self.get_loader(phase="val")

    def test_dataloader(self):
        return self.get_loader(phase="test")


class Dataset(torch.utils.data.Dataset):
    def __init__(self, filetxt, config):

        self.preprocessed_dir = pathlib.Path(config["general"]["preprocessed_path"])
        self.config = config
        self.spec_module = torchaudio.transforms.MelSpectrogram(
            sample_rate=config["preprocess"]["sampling_rate"],
            n_fft=config["preprocess"]["fft_length"],
            win_length=config["preprocess"]["frame_length"],
            hop_length=config["preprocess"]["frame_shift"],
            f_min=config["preprocess"]["fmin"],
            f_max=config["preprocess"]["fmax"],
            n_mels=config["preprocess"]["n_mels"],
            power=1,
            center=True,
            norm="slaney",
            mel_scale="slaney",
        )
        self.resample_candidate = [8000, 11025, 12000, 16000]
        self.quantization_candidate = range(2 ** 6, 2 ** 10 + 2, 2)
        self.segment_length = config["preprocess"]["segment_length"]

        with open(self.preprocessed_dir / filetxt, "r") as fr:
            self.filelist = [pathlib.Path(path.strip("\n")) for path in fr]

        self.d_out = dict()
        for item in ["wavs", "wavsaux"]:
            self.d_out[item] = []

        for wp in self.filelist:

            if config["general"]["corpus_type"] == "single":
                basename = str(wp.stem)
            else:
                basename = str(wp.parent.name) + "-" + str(wp.stem)

            with open(self.preprocessed_dir / "{}.pickle".format(basename), "rb") as fw:
                d_preprocessed = pickle.load(fw)

            for item in ["wavs", "wavsaux"]:
                try:
                    self.d_out[item].extend(d_preprocessed[item])
                except:
                    pass

        for item in ["wavs", "wavsaux"]:
            if self.d_out[item] != None:
                self.d_out[item] = np.asarray(self.d_out[item])

        if "dual" in self.config:
            if self.config["dual"]["enable"]:
                task_config = yaml.load(
                    open(config["dual"]["config_path"], "r"),
                    Loader=yaml.FullLoader,
                )
                task_preprocessed_dir = (
                    self.preprocessed_dir.parent
                    / pathlib.Path(task_config["general"]["preprocessed_path"]).name
                )
                with open(task_preprocessed_dir / filetxt, "r") as fr:
                    task_filelist = [pathlib.Path(path.strip("\n")) for path in fr]
                self.d_out["wavstask"] = []
                for wp in task_filelist:
                    if task_config["general"]["corpus_type"] == "single":
                        basename = str(wp.stem)
                    else:
                        basename = str(wp.parent.name) + "-" + str(wp.stem)
                    with open(
                        task_preprocessed_dir / "{}.pickle".format(basename), "rb"
                    ) as fw:
                        d_preprocessed = pickle.load(fw)
                    self.d_out["wavstask"].extend(d_preprocessed["wavs"])
                self.d_out["wavstask"] = np.asarray(self.d_out["wavstask"])

    def __len__(self):
        return len(self.d_out["wavs"])

    def __getitem__(self, idx):

        d_batch = {}

        if self.d_out["wavs"].size > 0:
            d_batch["wavs"] = torch.from_numpy(self.d_out["wavs"][idx])
            if self.segment_length > 0:
                d_batch["wavs"] = self.get_segment(d_batch["wavs"], self.segment_length)

        if self.d_out["wavsaux"].size > 0:
            d_batch["wavsaux"] = torch.from_numpy(self.d_out["wavsaux"][idx])
            if self.segment_length > 0:
                d_batch["wavsaux"] = self.get_segment(
                    d_batch["wavsaux"], self.segment_length
                )

        if self.config["general"]["stage"] == "pretrain":
            if self.config["train"]["augment"]:
                d_batch["wavs"] = self.augmentation(d_batch["wavsaux"])
            d_batch["wavs"] = self.normalize_waveform(d_batch["wavs"], db=-3)
            d_batch["wavsaux"] = self.normalize_waveform(d_batch["wavsaux"], db=-3)
            if len(d_batch["wavs"]) != len(d_batch["wavsaux"]):
                min_seq_len = min(len(d_batch["wavs"]), len(d_batch["wavsaux"]))
                d_batch["wavs"] = d_batch["wavs"][:min_seq_len]
                d_batch["wavsaux"] = d_batch["wavsaux"][:min_seq_len]
            d_batch["melspecs"] = self.calc_spectrogram(d_batch["wavs"])
            if self.config["general"]["feature_type"] == "melspec":
                d_batch["melspecsaux"] = self.calc_spectrogram(d_batch["wavsaux"])
            elif self.config["general"]["feature_type"] == "vocfeats":
                d_batch["melceps"] = self.calc_melcep(d_batch["wavsaux"])
                d_batch["f0s"] = self.calc_f0(d_batch["wavs"])
                d_batch["melcepssrc"] = self.calc_melcep(d_batch["wavs"])
            else:
                raise NotImplementedError()

        elif self.config["general"]["stage"].startswith("ssl"):
            d_batch["wavs"] = self.normalize_waveform(d_batch["wavs"], db=-3)
            d_batch["melspecs"] = self.calc_spectrogram(d_batch["wavs"])
            if self.config["general"]["feature_type"] == "vocfeats":
                d_batch["f0s"] = self.calc_f0(d_batch["wavs"])
                d_batch["melcepssrc"] = self.calc_melcep(d_batch["wavs"])
            if self.d_out["wavsaux"].size > 0:
                d_batch["wavsaux"] = self.normalize_waveform(d_batch["wavsaux"], db=-3)
                if self.config["general"]["feature_type"] == "melspec":
                    d_batch["melspecsaux"] = self.calc_spectrogram(d_batch["wavsaux"])
                elif self.config["general"]["feature_type"] == "vocfeats":
                    d_batch["melceps"] = self.calc_melcep(d_batch["wavsaux"])
            if "dual" in self.config:
                if self.config["dual"]["enable"]:
                    d_batch["wavstask"] = torch.from_numpy(self.d_out["wavstask"][idx])
                    d_batch["wavstask"] = self.get_segment(
                        d_batch["wavstask"], self.segment_length
                    )
                    d_batch["wavstask"] = self.normalize_waveform(
                        d_batch["wavstask"], db=-3
                    )
                    if self.config["general"]["feature_type"] == "melspec":
                        d_batch["melspecstask"] = self.calc_spectrogram(
                            d_batch["wavstask"]
                        )
                    elif self.config["general"]["feature_type"] == "vocfeats":
                        d_batch["melcepstask"] = self.calc_melcep(d_batch["wavstask"])
                    else:
                        raise NotImplementedError()
        else:
            raise NotImplementedError()

        return d_batch

    def calc_spectrogram(self, wav):
        specs = self.spec_module(wav)
        log_spec = torch.log(
            torch.clamp_min(specs, self.config["preprocess"]["min_magnitude"])
            * self.config["preprocess"]["comp_factor"]
        ).to(torch.float32)
        return log_spec

    def calc_melcep(self, wav):
        wav = wav.numpy()
        _, sp, _ = pyworld.wav2world(
            wav.astype(np.float64),
            self.config["preprocess"]["sampling_rate"],
            fft_size=self.config["preprocess"]["fft_length"],
            frame_period=(
                self.config["preprocess"]["frame_shift"]
                / self.config["preprocess"]["sampling_rate"]
                * 1000
            ),
        )
        melcep = pysptk.sp2mc(
            sp,
            order=self.config["preprocess"]["cep_order"],
            alpha=pysptk.util.mcepalpha(self.config["preprocess"]["sampling_rate"]),
        ).transpose(1, 0)
        melcep = torch.from_numpy(melcep).to(torch.float32)
        return melcep

    def calc_f0(self, wav):
        if self.config["preprocess"]["f0_extractor"] == "dio":
            return self.calc_f0_dio(wav)
        elif self.config["preprocess"]["f0_extractor"] == "harvest":
            return self.calc_f0_harvest(wav)
        elif self.config["preprocess"]["f0_extractor"] == "swipe":
            return self.calc_f0_swipe(wav)
        else:
            raise NotImplementedError()

    def calc_f0_dio(self, wav):
        wav = wav.numpy()
        _f0, _t = pyworld.dio(
            wav.astype(np.float64),
            self.config["preprocess"]["sampling_rate"],
            frame_period=(
                self.config["preprocess"]["frame_shift"]
                / self.config["preprocess"]["sampling_rate"]
                * 1000
            ),
        )
        f0 = pyworld.stonemask(
            wav.astype(np.float64), _f0, _t, self.config["preprocess"]["sampling_rate"]
        )
        f0 = torch.from_numpy(f0).to(torch.float32)
        return f0

    def calc_f0_harvest(self, wav):
        wav = wav.numpy()
        _f0, _t = pyworld.harvest(
            wav.astype(np.float64),
            self.config["preprocess"]["sampling_rate"],
            frame_period=(
                self.config["preprocess"]["frame_shift"]
                / self.config["preprocess"]["sampling_rate"]
                * 1000
            ),
        )
        f0 = pyworld.stonemask(
            wav.astype(np.float64), _f0, _t, self.config["preprocess"]["sampling_rate"]
        )
        f0 = torch.from_numpy(f0).to(torch.float32)
        return f0

    def calc_f0_swipe(self, wav):
        wav = wav.numpy()
        f0 = pysptk.sptk.swipe(
            wav.astype(np.float64),
            fs=self.config["preprocess"]["sampling_rate"],
            min=71,
            max=800,
            hopsize=self.config["preprocess"]["frame_shift"],
            otype="f0",
        )
        f0 = torch.from_numpy(f0).to(torch.float32)
        return f0

    def augmentation(self, wav):
        wav /= torch.max(torch.abs(wav))
        new_freq = random.choice(self.resample_candidate)
        new_quantization = random.choice(self.quantization_candidate)
        mulaw_encoder = torchaudio.transforms.MuLawEncoding(
            quantization_channels=new_quantization
        )
        wav_quantized = mulaw_encoder(wav) / new_quantization * 2.0 - 1.0
        downsampler = torchaudio.transforms.Resample(
            orig_freq=self.config["preprocess"]["sampling_rate"],
            new_freq=new_freq,
            resampling_method="sinc_interpolation",
            lowpass_filter_width=6,
            dtype=torch.float32,
        )
        upsampler = torchaudio.transforms.Resample(
            orig_freq=new_freq,
            new_freq=self.config["preprocess"]["sampling_rate"],
            resampling_method="sinc_interpolation",
            lowpass_filter_width=6,
            dtype=torch.float32,
        )
        wav_processed = upsampler(downsampler(wav_quantized))
        return wav_processed

    def normalize_waveform(self, wav, db=-3):
        wav, _ = torchaudio.sox_effects.apply_effects_tensor(
            wav.unsqueeze(0),
            self.config["preprocess"]["sampling_rate"],
            [["norm", "{}".format(db)]],
        )
        return wav.squeeze(0)

    def get_segment(self, wav, segment_length):
        seg_size = self.config["preprocess"]["sampling_rate"] * segment_length
        if len(wav) >= seg_size:
            max_wav_start = len(wav) - seg_size
            wav_start = random.randint(0, max_wav_start)
            wav = wav[wav_start : wav_start + seg_size]
        else:
            wav = torch.nn.functional.pad(wav, (0, seg_size - len(wav)), "constant")
        return wav