# 1. engineer_style_foreign_style_vectors.py  # for speed=1 & speed=4
# 2. tts_harvard.py                           # (call inside SHIFT repo - needs StyleTTS msinference.py)
# 3. visualize_tts_pleasantness.py            # figures & audinterface


# Visualises timeseries 11 class for mimic3 human mimic3speed
#
#
# human_770.wav
# mimic3_770.wav
# mimic3_speedup_770.wav
FULL_WAV  = [
    'english_hfullh.wav',
    'english_4x_hfullh.wav',
    'human_hfullh.wav',
    'foreign_hfullh.wav',
    'foreign_4x_hfullh.wav',
                    ]
WIN = 40
HOP = 10
import pandas as pd
import os

import json
import numpy as np
import audonnx
import audb
from pathlib import Path
import transformers
import torch
import audmodel
import audinterface
import matplotlib.pyplot as plt
import audiofile

LABELS = ['arousal', 'dominance', 'valence',
        #    'speech_synthesizer', 'synthetic_singing',
           'Angry',
           'Sad',
           'Happy',
           'Surprise', 
            'Fear', 
            'Disgust', 
            'Contempt', 
            'Neutral'
            ]


config = transformers.Wav2Vec2Config() #finetuning_task='spef2feat_reg')
config.dev = torch.device('cuda:0')
config.dev2 = torch.device('cuda:0')

# def _softmax(x):
#     '''x : (batch, num_class)'''
#     x -= x.max(1, keepdims=True)  # if all -400 then sum(exp(x)) = 0
#     x = np.minimum(-100, x)
#     x = np.exp(x)
#     x /= x.sum(1, keepdims=True)
#     return x

def _softmax(x):
    '''x : (batch, num_class)'''
    x -= x.max(1, keepdims=True)  # if all -400 then sum(exp(x)) = 0
    x = np.maximum(-100, x)
    x = np.exp(x)
    x /= x.sum(1, keepdims=True)
    return x

def _sigmoid(x):
    '''x : (batch, num_class)'''
    return 1 / (1 + np.exp(-x))    


        # --
    # ALL = anger, contempt, disgust, fear, happiness, neutral, no_agreement, other, sadness, surprise
    # plot - unplesant emo 7x emo-categories [anger, contempt, disgust, fear, sadness] for artifical/sped-up/natural
    # plot - pleasant emo [neutral, happiness, surprise]
    # plot - Cubes Natural vs spedup   4x speed
    # plot - synthesizer class audioset


    # https://arxiv.org/pdf/2407.12229
    #  https://arxiv.org/pdf/2312.05187
    # https://arxiv.org/abs/2407.05407
    # https://arxiv.org/pdf/2408.06577
    # https://arxiv.org/pdf/2309.07405

    
# wavs are generated concat and plot time-series?

# for mimic3/mimic3speed/human - concat all 77 and run timeseries with 7s hop 3s
for long_audio in FULL_WAV:
    file_interface = f'timeseries_{long_audio.replace("/", "")}.pkl'
    if not os.path.exists(file_interface):


        print('_______________________________________\nProcessing\n', file_interface, '\n___________')



        # CAT MSP

        from transformers import AutoModelForAudioClassification
        import types
        def _infer(self, x):
            '''x: (batch, audio-samples-16KHz)'''
            x = (x + self.config.mean) / self.config.std  # plus
            x = self.ssl_model(x, attention_mask=None).last_hidden_state
            # pool
            h = self.pool_model.sap_linear(x).tanh()
            w = torch.matmul(h, self.pool_model.attention)
            w = w.softmax(1)
            mu = (x * w).sum(1)
            x = torch.cat(
                [
                    mu,
                    ((x * x * w).sum(1) - mu * mu).clamp(min=1e-7).sqrt()
                ], 1)
            return self.ser_model(x)

        teacher_cat = AutoModelForAudioClassification.from_pretrained(
            '3loi/SER-Odyssey-Baseline-WavLM-Categorical-Attributes',
            trust_remote_code=True  # fun definitions see 3loi/SER-.. repo
        ).to(config.dev2).eval()
        teacher_cat.forward = types.MethodType(_infer, teacher_cat)
        

        # ===================[:]===================== Dawn
        def _prenorm(x, attention_mask=None):
            '''mean/var'''
            if attention_mask is not None:
                N = attention_mask.sum(1, keepdim=True)  # here attn msk is unprocessed just the original input
                x -= x.sum(1, keepdim=True) / N
                var = (x * x).sum(1, keepdim=True) / N

            else:
                x -= x.mean(1, keepdim=True)  # mean is an onnx operator reducemean saves some ops compared to casting integer N to float and the div
                var = (x * x).mean(1, keepdim=True)
            return x / torch.sqrt(var + 1e-7)

        from torch import nn
        from transformers.models.wav2vec2.modeling_wav2vec2 import Wav2Vec2PreTrainedModel, Wav2Vec2Model
        class RegressionHead(nn.Module):
                r"""Classification head."""

                def __init__(self, config):

                    super().__init__()

                    self.dense = nn.Linear(config.hidden_size, config.hidden_size)
                    self.dropout = nn.Dropout(config.final_dropout)
                    self.out_proj = nn.Linear(config.hidden_size, config.num_labels)

                def forward(self, features, **kwargs):

                    x = features
                    x = self.dropout(x)
                    x = self.dense(x)
                    x = torch.tanh(x)
                    x = self.dropout(x)
                    x = self.out_proj(x)

                    return x


        class Dawn(Wav2Vec2PreTrainedModel):
            r"""Speech emotion classifier."""

            def __init__(self, config):

                super().__init__(config)

                self.config = config
                self.wav2vec2 = Wav2Vec2Model(config)
                self.classifier = RegressionHead(config)
                self.init_weights()

            def forward(
                    self,
                    input_values,
                    attention_mask=None,
            ):
                x = _prenorm(input_values, attention_mask=attention_mask)
                outputs = self.wav2vec2(x, attention_mask=attention_mask)
                hidden_states = outputs[0]
                hidden_states = torch.mean(hidden_states, dim=1)
                logits = self.classifier(hidden_states)
                return logits
                # return {'hidden_states': hidden_states,
                #         'logits': logits}
        dawn = Dawn.from_pretrained('audeering/wav2vec2-large-robust-12-ft-emotion-msp-dim').to(config.dev).eval()
        # =======================================







        def process_function(x, sampling_rate, idx):
            '''run audioset ct, adv

                USE onnx teachers
                
                return [synth-speech, synth-singing, 7x, 3x adv] = 11
            '''
            
            # x = x[None , :]  ASaHSuFDCN
            #{0: 'Angry', 1: 'Sad', 2: 'Happy', 3: 'Surprise', 
            #4: 'Fear', 5: 'Disgust', 6: 'Contempt', 7: 'Neutral'}
            #tensor([[0.0015, 0.3651, 0.0593, 0.0315, 0.0600, 0.0125, 0.0319, 0.4382]])
            logits_cat = teacher_cat(torch.from_numpy(x).to(config.dev)).cpu().detach().numpy()
            # USE ALL CATEGORIES
            # --
            # logits_audioset = audioset_model(x, 16000)['logits_sounds']
            # logits_audioset = logits_audioset[:, [7, 35]]  # speech synthesizer synthetic singing
            # --
            logits_adv = dawn(torch.from_numpy(x).to(config.dev)).cpu().detach().numpy() #['logits']
            
            cat = np.concatenate([logits_adv,
                                #   _sigmoid(logits_audioset),
                                    _softmax(logits_cat)],
                                    1)
            print(cat)
            return cat #logits_adv #model(signal, sampling_rate)['logits']    


# ---------------------

        
        interface = audinterface.Feature(
            feature_names=LABELS,
            process_func=process_function,
            # process_func_args={'outputs': 'logits_scene'},
            process_func_applies_sliding_window=False,
            win_dur=WIN,
            hop_dur=HOP,
            sampling_rate=16000,
            resample=True,
            verbose=True,
        )
        df_pred = interface.process_file(long_audio)
        df_pred.to_pickle(file_interface)
    else:
        print(file_interface, 'FOUND')
        # df_pred = pd.read_pickle(file_interface)
        
        
        
# ===============================================================================
# V I S U A L S by loading all 3 pkl - mimic3 - speedup - human pd
#
# ===============================================================================


preds  = {}
SHORTEST_PD = 100000  # segments
for long_audio in FULL_WAV:
    file_interface = f'timeseries_{long_audio.replace("/", "")}.pkl'
    y = pd.read_pickle(file_interface)
    preds[long_audio] = y
    SHORTEST_PD = min(SHORTEST_PD, len(y))

# clean indexes for plot

for k,v in preds.items():
    p = v[:SHORTEST_PD]  # TRuncate extra segments - human is slower than mimic3
    # p = pd.read_pickle(student_file)
    p.reset_index(inplace= True)
    p.drop(columns=['file','start'], inplace=True)
    p.set_index('end', inplace=True)
    # p = p.filter(scene_classes) #['transport', 'indoor', 'outdoor'])
    p.index = p.index.map(mapper = (lambda x: x.total_seconds()))
    preds[k] = p

    # print(p, '\n\n\n\n \n')
    
print(preds.keys(),'p')  




# 2 PLOTS

for lang in ['english', 
             'foreign']:
            

            fig, ax = plt.subplots(nrows=8, ncols=2, figsize=(24,20.7),
                                   gridspec_kw={'hspace': 0, 'wspace': .04})


            

            time_stamp = preds['human_hfullh.wav'].index.to_numpy()
            for j, dim in enumerate(['arousal', 
                                    'dominance', 
                                    'valence']):

                # MIMIC3                      

                ax[j, 0].plot(time_stamp, preds[f'{lang}_hfullh.wav'][dim], 
                            color=(0,104/255,139/255), 
                            label='mean_1', 
                            linewidth=2)
                ax[j, 0].fill_between(time_stamp,

                                0*preds[f'{lang}_hfullh.wav'][dim],
                                preds['human_hfullh.wav'][dim],

                                color=(.2,.2,.2), 
                                alpha=0.244)
                if j == 0:
                    if lang == 'english':
                        desc = 'English'
                    else:
                        desc = 'Non-English'
                    ax[j, 0].legend([f'StyleTTS2 using Mimic-3 {desc}',
                                     f'StyleTTS2 uising EmoDB'], 
                                    prop={'size': 14},
                                    )
                ax[j, 0].set_ylabel(dim.lower(), color=(.4, .4, .4), fontsize=17)
                
                # TICK
                ax[j, 0].set_ylim([1e-7, .9999])
                # ax[j, 0].set_yticks([.25, .5,.75])
                # ax[j, 0].set_yticklabels(['0.25', '.5', '0.75'])
                ax[j, 0].set_xticklabels(['' for _ in ax[j, 0].get_xticklabels()])
                ax[j, 0].set_xlim([time_stamp[0], time_stamp[-1]])


            # MIMIC3   4x speed


                ax[j, 1].plot(time_stamp, preds[f'{lang}_4x_hfullh.wav'][dim], 
                            color=(0,104/255,139/255), 
                            label='mean_1', 
                            linewidth=2)
                ax[j, 1].fill_between(time_stamp,

                                0 * preds[f'{lang}_4x_hfullh.wav'][dim],
                                preds['human_hfullh.wav'][dim],

                                color=(.2,.2,.2),
                                alpha=0.244)
                if j == 0:
                    if lang == 'english':
                        desc = 'English'
                    else:
                        desc = 'Non-English'                 
                    ax[j, 1].legend([f'StyleTTS2 using Mimic-3 {desc} 4x speed',
                                    f'StyleTTS2 using EmoDB'], 
                                    prop={'size': 14}, 
                                    #  loc='lower right'
                                    )


                ax[j, 1].set_xlabel('720 Harvard Sentences')



                # TICK
                ax[j, 1].set_ylim([1e-7, .9999])
                # ax[j, 1].set_yticklabels(['' for _ in ax[j, 1].get_yticklabels()])
                ax[j, 1].set_xticklabels(['' for _ in ax[j, 0].get_xticklabels()])
                ax[j, 1].set_xlim([time_stamp[0], time_stamp[-1]])




                ax[j, 0].grid()
                ax[j, 1].grid()
            # CATEGORIE





            time_stamp = preds['human_hfullh.wav'].index.to_numpy()
            for j, dim in enumerate(['Angry', 
                                    'Sad',
                                    'Happy',
                                    #  'Surprise', 
                                    'Fear',
                                    'Disgust', 
                                    #  'Contempt',
                                    #  'Neutral'
                                    ]):   # ASaHSuFDCN
                j = j + 3  # skip A/D/V suplt                         

                # MIMIC3                      

                ax[j, 0].plot(time_stamp, preds[f'{lang}_hfullh.wav'][dim], 
                            color=(0,104/255,139/255), 
                            label='mean_1', 
                            linewidth=2)
                ax[j, 0].fill_between(time_stamp,

                                0*preds[f'{lang}_hfullh.wav'][dim],
                                preds['human_hfullh.wav'][dim],

                                color=(.2,.2,.2), 
                                alpha=0.244)
                # ax[j, 0].legend(['StyleTTS2 style mimic3',
                #                  'StyleTTS2 style crema-d'], 
                #                  prop={'size': 10}, 
                #                 #  loc='upper left'
                # )


                ax[j, 0].set_ylabel(dim.lower(), color=(.4, .4, .4), fontsize=17)

                # TICKS
                ax[j, 0].set_ylim([1e-7, .9999])
                ax[j, 0].set_xlim([time_stamp[0], time_stamp[-1]])
                ax[j, 0].set_xticklabels(['' for _ in ax[j, 0].get_xticklabels()])
                ax[j, 0].set_xlabel('720 Harvard Sentences', fontsize=17, color=(.2,.2,.2))


            # MIMIC3   4x speed


                ax[j, 1].plot(time_stamp, preds[f'{lang}_4x_hfullh.wav'][dim],
                            color=(0,104/255,139/255), 
                            label='mean_1', 
                            linewidth=2)
                ax[j, 1].fill_between(time_stamp,

                                0*preds[f'{lang}_4x_hfullh.wav'][dim],
                                preds['human_hfullh.wav'][dim],

                                color=(.2,.2,.2), 
                                alpha=0.244)
                # ax[j, 1].legend(['StyleTTS2 style mimic3   4x speed',
                #                  'StyleTTS2 style crema-d'], 
                #                  prop={'size': 10},
                #                 #  loc='upper left'
                # )
                ax[j, 1].set_xlabel('720 Harvard Sentences', fontsize=17, color=(.2,.2,.2))
                ax[j, 1].set_ylim([1e-7, .9999])
                # ax[j, 1].set_yticklabels(['' for _ in ax[j, 1].get_yticklabels()])
                ax[j, 1].set_xticklabels(['' for _ in ax[j, 1].get_xticklabels()])
                ax[j, 1].set_xlim([time_stamp[0], time_stamp[-1]])
                





                ax[j, 0].grid()
                ax[j, 1].grid()



            plt.savefig(f'fig_{lang}_{WIN=}_{HOP=}_HFdisc.png', bbox_inches='tight')
            plt.close()