# Synthesize all Harvard Lists 77x lists of 10x sentences to single .wav

# 1. using mimic3 english 1x/4x non-english 1x/4x
# Call visualize_tts_plesantness.py for 4figs [eng 1x/4x vs human,  non-eng 1x/4x vs human-libri]

import soundfile
import json
import numpy as np
import audb
from pathlib import Path

LABELS = ['arousal', 'dominance', 'valence']




def load_speech(split=None):
    DB = [
        # [dataset, version, table, has_timdeltas_or_is_full_wavfile]
          #  ['crema-d', '1.1.1', 'emotion.voice.test', False],
        #['librispeech', '3.1.0', 'test-clean', False],
            ['emodb',  '1.2.0', 'emotion.categories.train.gold_standard', False],
  #          ['entertain-playtestcloud', '1.1.0', 'emotion.categories.train.gold_standard', True],
   #         ['erik', '2.2.0', 'emotion.categories.train.gold_standard', True],
    #        ['meld', '1.3.1', 'emotion.categories.train.gold_standard', False],
            # ['msppodcast', '5.0.0', 'emotion.categories.train.gold_standard', False],  # tandalone bucket because it has gt labels?
     #       ['myai', '1.0.1', 'emotion.categories.train.gold_standard', False],
      #      ['casia', None, 'emotion.categories.gold_standard', False],
            # ['switchboard-1', None, 'sentiment', True],
            # ['swiss-parliament', None, 'segments', True], 
            # ['argentinian-parliament', None, 'segments', True],
            # ['austrian-parliament', None, 'segments', True],
            # #'german', --> bundestag
            # ['brazilian-parliament', None, 'segments', True],
            # ['mexican-parliament', None, 'segments', True],
            # ['portuguese-parliament', None, 'segments', True],
       #     ['spanish-parliament', None, 'segments', True],
        #    ['chinese-vocal-emotions-liu-pell', None, 'emotion.categories.desired', False],
            # peoples-speech slow
         #   ['peoples-speech', None, 'train-initial', False]
    ]

    output_list = []
    for database_name, ver, table, has_timedeltas in DB:

        a = audb.load(database_name,
                        sampling_rate=16000,
                        format='wav',
                        mixdown=True,
                        version=ver,
                        cache_root='/cache/audb/')
        a = a[table].get()
        if has_timedeltas:
            print(f'{has_timedeltas=}')
            # a = a.reset_index()[['file', 'start', 'end']]
            # output_list += [[*t] for t
            #         in zip(a.file.values, a.start.dt.total_seconds().values, a.end.dt.total_seconds().values)]
        else:
            output_list += [f for f in a.index]  # use file (no timedeltas)
    return output_list





    




    
natural_wav_paths = load_speech()


# SYNTHESIZE mimic mimicx4 crema-d
import msinference
import os
from random import shuffle
import audiofile
with open('harvard.json', 'r') as f:
    harvard_individual_sentences = json.load(f)['sentences']



synthetic_wav_paths = ['./enslow/' + i for i in 
                       os.listdir('./enslow/')]
synthetic_wav_paths_4x = ['./style_vector_v2/' + i for i in 
                    os.listdir('./style_vector_v2/')]
synthetic_wav_paths_foreign = ['./mimic3_foreign/' + i for i in os.listdir('./mimic3_foreign/') if 'en_U' not in i]
synthetic_wav_paths_foreign_4x = ['./mimic3_foreign_4x/' + i for i in os.listdir('./mimic3_foreign_4x/') if 'en_U' not in i]  # very short segments

# filter very short styles
synthetic_wav_paths_foreign = [i for i in synthetic_wav_paths_foreign if audiofile.duration(i) > 2]
synthetic_wav_paths_foreign_4x = [i for i in synthetic_wav_paths_foreign_4x if audiofile.duration(i) > 2]
synthetic_wav_paths = [i for i in synthetic_wav_paths if audiofile.duration(i) > 2]
synthetic_wav_pathsn_4x = [i for i in synthetic_wav_paths_4x if audiofile.duration(i) > 2]

shuffle(synthetic_wav_paths_foreign_4x)
shuffle(synthetic_wav_paths_foreign)
shuffle(synthetic_wav_paths)
shuffle(synthetic_wav_paths_4x)
print(len(synthetic_wav_paths_foreign_4x), len(synthetic_wav_paths_foreign),
      len(synthetic_wav_paths), len(synthetic_wav_paths_4x))  # 134 204 134 204
for audio_prompt in ['english', 
                     'english_4x', 
                     'human', 
                     'foreign', 
                     'foreign_4x']:
    OUT_FILE = f'{audio_prompt}_hfullh.wav'
    if not os.path.isfile(OUT_FILE):
                    total_audio = []
                    total_style = []
                    ix = 0
                    for list_of_10 in harvard_individual_sentences[:1000]:
                        # long_sentence = ' '.join(list_of_10['sentences'])
                        # harvard.append(long_sentence.replace('.', ' '))
                        for text in list_of_10['sentences']:
                            if audio_prompt == 'english':
                                _p = synthetic_wav_paths[ix % len(synthetic_wav_paths)] #134]
                                style_vec = msinference.compute_style(_p)
                            elif audio_prompt == 'english_4x':
                                _p = synthetic_wav_paths_4x[ix % len(synthetic_wav_paths_4x)] # 134]
                                style_vec = msinference.compute_style(_p)
                            elif audio_prompt == 'human':
                                _p = natural_wav_paths[ix % len(natural_wav_paths)]
                                style_vec = msinference.compute_style(_p)
                            elif audio_prompt == 'foreign':
                                _p = synthetic_wav_paths_foreign[ix % len(synthetic_wav_paths_foreign)] #179]  # 204 some short styles are discarded
                                style_vec = msinference.compute_style(_p)
                            elif audio_prompt == 'foreign_4x':
                                _p = synthetic_wav_paths_foreign_4x[ix % len(synthetic_wav_paths_foreign_4x)] #179]  # 204 
                                style_vec = msinference.compute_style(_p)
                            else:
                                print('unknonw list of style vector')
                            print(ix, text)
                            ix += 1
                            x = msinference.inference(text,
                                                        style_vec,
                                                        alpha=0.3,
                                                        beta=0.7,
                                                        diffusion_steps=7,
                                                        embedding_scale=1)
                            
                            total_audio.append(x)
                            _st, fsr = audiofile.read(_p)
                            total_style.append(_st[:len(x)])
                            # concat before write
                        # -- for 10x sentenctes
                        print('_____________________')
                    # -- for 77x lists
                    total_audio = np.concatenate(total_audio)
                    soundfile.write(OUT_FILE, total_audio, 24000)
                    total_style = np.concatenate(total_style)
                    soundfile.write('_st_' + OUT_FILE, total_style, fsr)  # take this fs from the loading
                    
    else:
        print('\nALREADY EXISTS\n')