import librosa
import pytorch_lightning as pl
import torch
from auraloss.freq import STFTLoss, MultiResolutionSTFTLoss, apply_reduction, SpectralConvergenceLoss, STFTMagnitudeLoss

from config import CONFIG


class STFTLossDDP(STFTLoss):
    def __init__(self,
                 fft_size=1024,
                 hop_size=256,
                 win_length=1024,
                 window="hann_window",
                 w_sc=1.0,
                 w_log_mag=1.0,
                 w_lin_mag=0.0,
                 w_phs=0.0,
                 sample_rate=None,
                 scale=None,
                 n_bins=None,
                 scale_invariance=False,
                 eps=1e-8,
                 output="loss",
                 reduction="mean",
                 device=None):
        super(STFTLoss, self).__init__()
        self.fft_size = fft_size
        self.hop_size = hop_size
        self.win_length = win_length
        self.window = getattr(torch, window)(win_length)
        self.w_sc = w_sc
        self.w_log_mag = w_log_mag
        self.w_lin_mag = w_lin_mag
        self.w_phs = w_phs
        self.sample_rate = sample_rate
        self.scale = scale
        self.n_bins = n_bins
        self.scale_invariance = scale_invariance
        self.eps = eps
        self.output = output
        self.reduction = reduction
        self.device = device

        self.spectralconv = SpectralConvergenceLoss()
        self.logstft = STFTMagnitudeLoss(log=True, reduction=reduction)
        self.linstft = STFTMagnitudeLoss(log=False, reduction=reduction)

        # setup mel filterbank
        if self.scale == "mel":
            assert (sample_rate is not None)  # Must set sample rate to use mel scale
            assert (n_bins <= fft_size)  # Must be more FFT bins than Mel bins
            fb = librosa.filters.mel(sample_rate, fft_size, n_mels=n_bins)
            self.fb = torch.tensor(fb).unsqueeze(0)
        elif self.scale == "chroma":
            assert (sample_rate is not None)  # Must set sample rate to use chroma scale
            assert (n_bins <= fft_size)  # Must be more FFT bins than chroma bins
            fb = librosa.filters.chroma(sample_rate, fft_size, n_chroma=n_bins)
            self.fb = torch.tensor(fb).unsqueeze(0)

        if scale is not None and device is not None:
            self.fb = self.fb.to(self.device)  # move filterbank to device

    def compressed_loss(self, x, y, alpha=None):
        self.window = self.window.to(x.device)
        x_mag, x_phs = self.stft(x.view(-1, x.size(-1)))
        y_mag, y_phs = self.stft(y.view(-1, y.size(-1)))

        if alpha is not None:
            x_mag = x_mag ** alpha
            y_mag = y_mag ** alpha

        # apply relevant transforms
        if self.scale is not None:
            x_mag = torch.matmul(self.fb.to(x_mag.device), x_mag)
            y_mag = torch.matmul(self.fb.to(y_mag.device), y_mag)

        # normalize scales
        if self.scale_invariance:
            alpha = (x_mag * y_mag).sum([-2, -1]) / ((y_mag ** 2).sum([-2, -1]))
            y_mag = y_mag * alpha.unsqueeze(-1)

        # compute loss terms
        sc_loss = self.spectralconv(x_mag, y_mag) if self.w_sc else 0.0
        mag_loss = self.logstft(x_mag, y_mag) if self.w_log_mag else 0.0
        lin_loss = self.linstft(x_mag, y_mag) if self.w_lin_mag else 0.0

        # combine loss terms
        loss = (self.w_sc * sc_loss) + (self.w_log_mag * mag_loss) + (self.w_lin_mag * lin_loss)
        loss = apply_reduction(loss, reduction=self.reduction)
        return loss

    def forward(self, x, y):
        return self.compressed_loss(x, y, 0.3)


class MRSTFTLossDDP(MultiResolutionSTFTLoss):
    def __init__(self,
                 fft_sizes=(1024, 2048, 512),
                 hop_sizes=(120, 240, 50),
                 win_lengths=(600, 1200, 240),
                 window="hann_window",
                 w_sc=1.0,
                 w_log_mag=1.0,
                 w_lin_mag=0.0,
                 w_phs=0.0,
                 sample_rate=None,
                 scale=None,
                 n_bins=None,
                 scale_invariance=False,
                 **kwargs):
        super(MultiResolutionSTFTLoss, self).__init__()
        assert len(fft_sizes) == len(hop_sizes) == len(win_lengths)  # must define all
        self.stft_losses = torch.nn.ModuleList()
        for fs, ss, wl in zip(fft_sizes, hop_sizes, win_lengths):
            self.stft_losses += [STFTLossDDP(fs,
                                             ss,
                                             wl,
                                             window,
                                             w_sc,
                                             w_log_mag,
                                             w_lin_mag,
                                             w_phs,
                                             sample_rate,
                                             scale,
                                             n_bins,
                                             scale_invariance,
                                             **kwargs)]


class Loss(pl.LightningModule):
    def __init__(self):
        super(Loss, self).__init__()
        self.stft_loss = MRSTFTLossDDP(sample_rate=CONFIG.DATA.sr, device="cpu", w_log_mag=0.0, w_lin_mag=1.0)
        self.window = torch.sqrt(torch.hann_window(CONFIG.DATA.window_size))

    def forward(self, x, y):
        x = x.permute(0, 2, 3, 1)
        y = y.permute(0, 2, 3, 1)
        wave_x = torch.istft(torch.view_as_complex(x.contiguous()), CONFIG.DATA.window_size, CONFIG.DATA.stride,
                             window=self.window.to(x.device))
        wave_y = torch.istft(torch.view_as_complex(y.contiguous()), CONFIG.DATA.window_size, CONFIG.DATA.stride,
                             window=self.window.to(y.device))
        loss = self.stft_loss(wave_x, wave_y)
        return loss