import datetime
import glob
import os
import random
import re

import numpy as np
from scipy import signal

from TTS.encoder.models.lstm import LSTMSpeakerEncoder
from TTS.encoder.models.resnet import ResNetSpeakerEncoder
from TTS.utils.io import save_fsspec


class AugmentWAV(object):
    def __init__(self, ap, augmentation_config):
        self.ap = ap
        self.use_additive_noise = False

        if "additive" in augmentation_config.keys():
            self.additive_noise_config = augmentation_config["additive"]
            additive_path = self.additive_noise_config["sounds_path"]
            if additive_path:
                self.use_additive_noise = True
                # get noise types
                self.additive_noise_types = []
                for key in self.additive_noise_config.keys():
                    if isinstance(self.additive_noise_config[key], dict):
                        self.additive_noise_types.append(key)

                additive_files = glob.glob(os.path.join(additive_path, "**/*.wav"), recursive=True)

                self.noise_list = {}

                for wav_file in additive_files:
                    noise_dir = wav_file.replace(additive_path, "").split(os.sep)[0]
                    # ignore not listed directories
                    if noise_dir not in self.additive_noise_types:
                        continue
                    if not noise_dir in self.noise_list:
                        self.noise_list[noise_dir] = []
                    self.noise_list[noise_dir].append(wav_file)

                print(
                    f" | > Using Additive Noise Augmentation: with {len(additive_files)} audios instances from {self.additive_noise_types}"
                )

        self.use_rir = False

        if "rir" in augmentation_config.keys():
            self.rir_config = augmentation_config["rir"]
            if self.rir_config["rir_path"]:
                self.rir_files = glob.glob(os.path.join(self.rir_config["rir_path"], "**/*.wav"), recursive=True)
                self.use_rir = True

            print(f" | > Using RIR Noise Augmentation: with {len(self.rir_files)} audios instances")

        self.create_augmentation_global_list()

    def create_augmentation_global_list(self):
        if self.use_additive_noise:
            self.global_noise_list = self.additive_noise_types
        else:
            self.global_noise_list = []
        if self.use_rir:
            self.global_noise_list.append("RIR_AUG")

    def additive_noise(self, noise_type, audio):
        clean_db = 10 * np.log10(np.mean(audio**2) + 1e-4)

        noise_list = random.sample(
            self.noise_list[noise_type],
            random.randint(
                self.additive_noise_config[noise_type]["min_num_noises"],
                self.additive_noise_config[noise_type]["max_num_noises"],
            ),
        )

        audio_len = audio.shape[0]
        noises_wav = None
        for noise in noise_list:
            noiseaudio = self.ap.load_wav(noise, sr=self.ap.sample_rate)[:audio_len]

            if noiseaudio.shape[0] < audio_len:
                continue

            noise_snr = random.uniform(
                self.additive_noise_config[noise_type]["min_snr_in_db"],
                self.additive_noise_config[noise_type]["max_num_noises"],
            )
            noise_db = 10 * np.log10(np.mean(noiseaudio**2) + 1e-4)
            noise_wav = np.sqrt(10 ** ((clean_db - noise_db - noise_snr) / 10)) * noiseaudio

            if noises_wav is None:
                noises_wav = noise_wav
            else:
                noises_wav += noise_wav

        # if all possible files is less than audio, choose other files
        if noises_wav is None:
            return self.additive_noise(noise_type, audio)

        return audio + noises_wav

    def reverberate(self, audio):
        audio_len = audio.shape[0]

        rir_file = random.choice(self.rir_files)
        rir = self.ap.load_wav(rir_file, sr=self.ap.sample_rate)
        rir = rir / np.sqrt(np.sum(rir**2))
        return signal.convolve(audio, rir, mode=self.rir_config["conv_mode"])[:audio_len]

    def apply_one(self, audio):
        noise_type = random.choice(self.global_noise_list)
        if noise_type == "RIR_AUG":
            return self.reverberate(audio)

        return self.additive_noise(noise_type, audio)


def to_camel(text):
    text = text.capitalize()
    return re.sub(r"(?!^)_([a-zA-Z])", lambda m: m.group(1).upper(), text)


def setup_encoder_model(config: "Coqpit"):
    if config.model_params["model_name"].lower() == "lstm":
        model = LSTMSpeakerEncoder(
            config.model_params["input_dim"],
            config.model_params["proj_dim"],
            config.model_params["lstm_dim"],
            config.model_params["num_lstm_layers"],
            use_torch_spec=config.model_params.get("use_torch_spec", False),
            audio_config=config.audio,
        )
    elif config.model_params["model_name"].lower() == "resnet":
        model = ResNetSpeakerEncoder(
            input_dim=config.model_params["input_dim"],
            proj_dim=config.model_params["proj_dim"],
            log_input=config.model_params.get("log_input", False),
            use_torch_spec=config.model_params.get("use_torch_spec", False),
            audio_config=config.audio,
        )
    return model


def save_checkpoint(model, optimizer, criterion, model_loss, out_path, current_step, epoch):
    checkpoint_path = "checkpoint_{}.pth".format(current_step)
    checkpoint_path = os.path.join(out_path, checkpoint_path)
    print(" | | > Checkpoint saving : {}".format(checkpoint_path))

    new_state_dict = model.state_dict()
    state = {
        "model": new_state_dict,
        "optimizer": optimizer.state_dict() if optimizer is not None else None,
        "criterion": criterion.state_dict(),
        "step": current_step,
        "epoch": epoch,
        "loss": model_loss,
        "date": datetime.date.today().strftime("%B %d, %Y"),
    }
    save_fsspec(state, checkpoint_path)


def save_best_model(model, optimizer, criterion, model_loss, best_loss, out_path, current_step, epoch):
    if model_loss < best_loss:
        new_state_dict = model.state_dict()
        state = {
            "model": new_state_dict,
            "optimizer": optimizer.state_dict(),
            "criterion": criterion.state_dict(),
            "step": current_step,
            "epoch": epoch,
            "loss": model_loss,
            "date": datetime.date.today().strftime("%B %d, %Y"),
        }
        best_loss = model_loss
        bestmodel_path = "best_model.pth"
        bestmodel_path = os.path.join(out_path, bestmodel_path)
        print("\n > BEST MODEL ({0:.5f}) : {1:}".format(model_loss, bestmodel_path))
        save_fsspec(state, bestmodel_path)
    return best_loss