import IPython

import sys
import subprocess
import os

subprocess.check_call([sys.executable, "-m", "pip", "install", "--upgrade", "--force-reinstall", "git+https://github.com/osanseviero/tortoise-tts.git"])

# entmax could not be installed at same time as torch
subprocess.check_call([sys.executable, "-m", "pip", "install", "entmax"])

from tortoise_tts.api import TextToSpeech
from tortoise_tts.utils.audio import load_audio, get_voices
import torch 
import torchaudio
import numpy as np
import gradio as gr

device = "cuda" if torch.cuda.is_available() else "cpu"
os.environ['CUDA_LAUNCH_BLOCKING'] = "1"

# This will download all the models used by Tortoise from the HF hub
tts = TextToSpeech(device="cuda")

voices = [
  "angie",
  "daniel",
  "deniro",
  "emma",
  "freeman",
  "geralt",
  "halle",
  "jlaw",
  "lj",
  "snakes",
  "William",
]
voice_paths = get_voices()
print(voice_paths)

preset = "fast"

def inference(text, voice):
    text = text[:256]
    cond_paths = voice_paths[voice]
    conds = []
    print(voice_paths, voice, cond_paths)
    for cond_path in cond_paths:
        c = load_audio(cond_path, 22050)
        conds.append(c)
    print(text, conds, preset)
    gen = tts.tts_with_preset(text, conds, preset)
    print("gen")
    torchaudio.save('generated.wav', gen.squeeze(0).cpu(), 24000)
    return "generated.wav"

def load_audio_special(sr, data):
    if data.dtype == np.int32:
        norm_fix = 2 ** 31
    elif data.dtype == np.int16:
        norm_fix = 2 ** 15
    elif data.dtype == np.float16 or data.dtype == np.float32:
        norm_fix = 1.
    audio = torch.FloatTensor(data.astype(np.float32)) / norm_fix
    
    # Remove any channel data.
    if len(audio.shape) > 1:
        if audio.shape[0] < 5:
            audio = audio[0]
        else:
            assert audio.shape[1] < 5
            audio = audio[:, 0]

    # Check some assumptions about audio range. This should be automatically fixed in load_wav_to_torch, but might not be in some edge cases, where we should squawk.
    # '2' is arbitrarily chosen since it seems like audio will often "overdrive" the [-1,1] bounds.
    if torch.any(audio > 2) or not torch.any(audio < 0):
        print(f"Error with {audiopath}. Max={audio.max()} min={audio.min()}")
    audio.clip_(-1, 1)
    return audio.unsqueeze(0)
    
def inference_own_voice(text, voice_1, voice_2, voice_3):
    text = text[:256]
    print(voice_1)
    conds = [
        load_audio_special(voice_1[0], voice_1[1]),
        load_audio_special(voice_2[0], voice_2[1]),
        load_audio_special(voice_3[0], voice_3[1]),
    ]
    print(text, conds, preset)
    gen = tts.tts_with_preset(text, conds, preset)
    print("gen")
    torchaudio.save('generated.wav', gen.squeeze(0).cpu(), 24000)
    return "generated.wav"
 
text = "Joining two modalities results in a surprising increase in generalization! What would happen if we combined them all?"
examples = [
    [text, "angie"],
    [text, "emma"],
    ["how are you doing this day", "freeman"]
]

block = gr.Blocks(enable_queue=True)
with block:
    gr.Markdown("# TorToiSe")
    gr.Markdown("A multi-voice TTS system trained with an emphasis on quality")
    with gr.Tabs():
        with gr.TabItem("Pre-recorded voices"):
            iface = gr.Interface(
                inference,
                inputs=[
                    gr.inputs.Textbox(type="str", default=text, label="Text", lines=3),
                    gr.inputs.Dropdown(voices),
                ],
                outputs="audio",
                examples=examples,
            )
        with gr.TabItem("Record your voice (experimental, might not work well)"):
            iface = gr.Interface(
              inference_own_voice,
              inputs=[
                  gr.inputs.Textbox(type="str", default=text, label="Text", lines=3),
                  gr.inputs.Audio(source="microphone", label="Record yourself reading something out loud (audio 1)", type="numpy"),
                  gr.inputs.Audio(source="microphone", label="Record yourself reading something out loud (audio 2)", type="numpy"),
                  gr.inputs.Audio(source="microphone", label="Record yourself reading something out loud (audio 3)", type="numpy"),
              ],
              outputs="audio",
            )

    gr.Markdown("This demo shows the ultra fast option in the TorToiSe system. For more info check the <a href='https://github.com/neonbjb/tortoise-tts' target='_blank'>Repository</a>.",)

    block.launch(debug=True)