import librosa.display
import matplotlib.pyplot as plt
import json
import torch
import torchaudio
import hifigan


def manual_logging(logger, item, idx, tag, global_step, data_type, config):

    if data_type == "audio":
        audio = item[idx, ...].detach().cpu().numpy()
        logger.add_audio(
            tag,
            audio,
            global_step,
            sample_rate=config["preprocess"]["sampling_rate"],
        )
    elif data_type == "image":
        image = item[idx, ...].detach().cpu().numpy()
        fig, ax = plt.subplots()
        _ = librosa.display.specshow(
            image,
            x_axis="time",
            y_axis="linear",
            sr=config["preprocess"]["sampling_rate"],
            hop_length=config["preprocess"]["frame_shift"],
            fmax=config["preprocess"]["sampling_rate"] // 2,
            ax=ax,
        )
        logger.add_figure(tag, fig, global_step)
    else:
        raise NotImplementedError(
            "Data type given to logger should be [audio] or [image]"
        )


def load_vocoder(config):
    with open(
        "hifigan/config_{}.json".format(config["general"]["feature_type"]), "r"
    ) as f:
        config_hifigan = hifigan.AttrDict(json.load(f))
    vocoder = hifigan.Generator(config_hifigan)
    vocoder.load_state_dict(torch.load(config["general"]["hifigan_path"])["generator"])
    vocoder.remove_weight_norm()
    for param in vocoder.parameters():
        param.requires_grad = False
    return vocoder


def get_conv_padding(kernel_size, dilation=1):
    return int((kernel_size * dilation - dilation) / 2)


def plot_and_save_mels(wav, save_path, config):
    spec_module = torchaudio.transforms.MelSpectrogram(
        sample_rate=config["preprocess"]["sampling_rate"],
        n_fft=config["preprocess"]["fft_length"],
        win_length=config["preprocess"]["frame_length"],
        hop_length=config["preprocess"]["frame_shift"],
        f_min=config["preprocess"]["fmin"],
        f_max=config["preprocess"]["fmax"],
        n_mels=config["preprocess"]["n_mels"],
        power=1,
        center=True,
        norm="slaney",
        mel_scale="slaney",
    )
    spec = spec_module(wav.unsqueeze(0))
    log_spec = torch.log(
        torch.clamp_min(spec, config["preprocess"]["min_magnitude"])
        * config["preprocess"]["comp_factor"]
    )
    fig, ax = plt.subplots()
    _ = librosa.display.specshow(
        log_spec.squeeze(0).numpy(),
        x_axis="time",
        y_axis="linear",
        sr=config["preprocess"]["sampling_rate"],
        hop_length=config["preprocess"]["frame_shift"],
        fmax=config["preprocess"]["sampling_rate"] // 2,
        ax=ax,
        cmap="viridis",
    )
    fig.savefig(save_path, bbox_inches="tight", pad_inches=0)


def plot_and_save_mels_all(wavs, keys, save_path, config):
    spec_module = torchaudio.transforms.MelSpectrogram(
        sample_rate=config["preprocess"]["sampling_rate"],
        n_fft=config["preprocess"]["fft_length"],
        win_length=config["preprocess"]["frame_length"],
        hop_length=config["preprocess"]["frame_shift"],
        f_min=config["preprocess"]["fmin"],
        f_max=config["preprocess"]["fmax"],
        n_mels=config["preprocess"]["n_mels"],
        power=1,
        center=True,
        norm="slaney",
        mel_scale="slaney",
    )
    fig, ax = plt.subplots(nrows=3, ncols=3, figsize=(18, 18))
    for i, key in enumerate(keys):
        wav = wavs[key][0, ...].cpu()
        spec = spec_module(wav.unsqueeze(0))
        log_spec = torch.log(
            torch.clamp_min(spec, config["preprocess"]["min_magnitude"])
            * config["preprocess"]["comp_factor"]
        )
        ax[i // 3, i % 3].set(title=key)
        _ = librosa.display.specshow(
            log_spec.squeeze(0).numpy(),
            x_axis="time",
            y_axis="linear",
            sr=config["preprocess"]["sampling_rate"],
            hop_length=config["preprocess"]["frame_shift"],
            fmax=config["preprocess"]["sampling_rate"] // 2,
            ax=ax[i // 3, i % 3],
            cmap="viridis",
        )
    fig.savefig(save_path, bbox_inches="tight", pad_inches=0)


def configure_args(config, args):
    for key in ["stage", "corpus_type", "source_path", "aux_path", "preprocessed_path"]:
        if getattr(args, key) != None:
            config["general"][key] = str(getattr(args, key))

    for key in ["n_train", "n_val", "n_test"]:
        if getattr(args, key) != None:
            config["preprocess"][key] = getattr(args, key)

    for key in ["alpha", "beta", "learning_rate", "epoch"]:
        if getattr(args, key) != None:
            config["train"][key] = getattr(args, key)

    for key in ["load_pretrained", "early_stopping"]:
        config["train"][key] = getattr(args, key)

    if args.feature_loss_type != None:
        config["train"]["feature_loss"]["type"] = args.feature_loss_type

    for key in ["pretrained_path"]:
        if getattr(args, key) != None:
            config["train"][key] = str(getattr(args, key))

    return config, args