import os
import sys
import torch
import gradio as gr
from pydub import AudioSegment
import mimetypes

sys.path.append('./Amphion')
import Amphion.models.vc.vevo.vevo_utils as vevo_utils
from huggingface_hub import snapshot_download

def load_model():
    print("Loading model...")
    device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
    print(f"Using device: {device}")
    
    cache_dir = "./ckpts/Vevo"
    os.makedirs(cache_dir, exist_ok=True)
    
    # Content Tokenizer
    local_dir = snapshot_download(
        repo_id="amphion/Vevo",
        repo_type="model",
        cache_dir=cache_dir,
        allow_patterns=["tokenizer/vq32/*"],
    )
    content_tokenizer_ckpt_path = os.path.join(
        local_dir, "tokenizer/vq32/hubert_large_l18_c32.pkl"
    )

    # Content-Style Tokenizer
    local_dir = snapshot_download(
        repo_id="amphion/Vevo",
        repo_type="model",
        cache_dir=cache_dir,
        allow_patterns=["tokenizer/vq8192/*"],
    )
    content_style_tokenizer_ckpt_path = os.path.join(local_dir, "tokenizer/vq8192")

    # Autoregressive Transformer
    local_dir = snapshot_download(
        repo_id="amphion/Vevo",
        repo_type="model",
        cache_dir=cache_dir,
        allow_patterns=["contentstyle_modeling/Vq32ToVq8192/*"],
    )
    ar_cfg_path = "./config/Vq32ToVq8192.json"
    ar_ckpt_path = os.path.join(local_dir, "contentstyle_modeling/Vq32ToVq8192")

    # Flow Matching Transformer
    local_dir = snapshot_download(
        repo_id="amphion/Vevo",
        repo_type="model",
        cache_dir=cache_dir,
        allow_patterns=["acoustic_modeling/Vq8192ToMels/*"],
    )
    fmt_cfg_path = "./config/Vq8192ToMels.json"
    fmt_ckpt_path = os.path.join(local_dir, "acoustic_modeling/Vq8192ToMels")

    # Vocoder
    local_dir = snapshot_download(
        repo_id="amphion/Vevo",
        repo_type="model",
        cache_dir=cache_dir,
        allow_patterns=["acoustic_modeling/Vocoder/*"],
    )
    vocoder_cfg_path = "./Amphion/models/vc/vevo/config/Vocoder.json"
    vocoder_ckpt_path = os.path.join(local_dir, "acoustic_modeling/Vocoder")

    print("Initializing pipeline...")
    pipeline = vevo_utils.VevoInferencePipeline(
        content_tokenizer_ckpt_path=content_tokenizer_ckpt_path,
        content_style_tokenizer_ckpt_path=content_style_tokenizer_ckpt_path,
        ar_cfg_path=ar_cfg_path,
        ar_ckpt_path=ar_ckpt_path,
        fmt_cfg_path=fmt_cfg_path,
        fmt_ckpt_path=fmt_ckpt_path,
        vocoder_cfg_path=vocoder_cfg_path,
        vocoder_ckpt_path=vocoder_ckpt_path,
        device=device
    )
    print("Model loaded successfully!")
    return pipeline

def convert_to_wav(audio_path):
    if audio_path is None:
        return None
        
    mime, _ = mimetypes.guess_type(audio_path)
    if mime == 'audio/wav' or mime == 'audio/x-wav':
        return audio_path
    elif mime == 'audio/mpeg':
        seg = AudioSegment.from_mp3(audio_path)
        wav_path = audio_path.rsplit('.', 1)[0] + '.wav'
        seg.export(wav_path, format="wav")
        return wav_path
    else:
        raise ValueError(f"Unsupported audio format: {mime}")

def process_audio(mode, content_audio, ref_style_audio, ref_timbre_audio, 
                 src_text, ref_text, src_language, ref_language, steps,
                 progress=gr.Progress()):
    try:
        output_dir = "outputs"
        os.makedirs(output_dir, exist_ok=True)
        output_path = os.path.join(output_dir, "output.wav")

        # Convert uploaded audio files to WAV if needed
        if content_audio:
            content_path = convert_to_wav(content_audio)
        else:
            content_path = None
            
        if ref_style_audio:
            ref_style_path = convert_to_wav(ref_style_audio)
        else:
            ref_style_path = None
            
        if ref_timbre_audio:
            ref_timbre_path = convert_to_wav(ref_timbre_audio)
        else:
            ref_timbre_path = None

        progress(0.2, "Processing audio...")

        # Run inference based on mode
        if mode == 'voice':
            if not all([content_path, ref_style_path, ref_timbre_path]):
                raise gr.Error("Voice mode requires all audio inputs")
                
            gen_audio = inference_pipeline.inference_ar_and_fm(
                src_wav_path=content_path,
                src_text=None,
                style_ref_wav_path=ref_style_path,
                timbre_ref_wav_path=ref_timbre_path,
                flow_matching_steps=steps
            )
            
        elif mode == 'timbre':
            if not all([content_path, ref_timbre_path]):
                raise gr.Error("Timbre mode requires source and timbre reference audio")
                
            gen_audio = inference_pipeline.inference_fm(
                src_wav_path=content_path,
                timbre_ref_wav_path=ref_timbre_path,
                flow_matching_steps=steps
            )
            
        elif mode == 'tts':
            if not all([ref_style_path, ref_timbre_path]) or not src_text:
                raise gr.Error("TTS mode requires style audio, timbre audio, and source text")
                
            gen_audio = inference_pipeline.inference_ar_and_fm(
                src_wav_path=None,
                src_text=src_text,
                style_ref_wav_path=ref_style_path,
                timbre_ref_wav_path=ref_timbre_path,
                style_ref_wav_text=ref_text if ref_text else None,
                src_text_language=src_language,
                style_ref_wav_text_language=ref_language
            )
        
        progress(0.8, "Saving generated audio...")
        
        # Save and return the generated audio
        vevo_utils.save_audio(gen_audio, target_sample_rate=48000, output_path=output_path)
        return output_path

    except Exception as e:
        raise gr.Error(str(e))

# Initialize the model
inference_pipeline = load_model()

# Create the Gradio interface
with gr.Blocks(title="Vevo Voice Conversion") as demo:
    gr.Markdown("# Vevo Voice Conversion")
    
    with gr.Row():
        mode = gr.Radio(
            choices=["voice", "timbre", "tts"],
            value="timbre",
            label="Inference Mode",
            interactive=True
        )

    with gr.Tabs():
        with gr.TabItem("Audio Inputs"):
            content_audio = gr.Audio(
                label="Source Audio",
                type="filepath",
                interactive=True
            )
            
            ref_style_audio = gr.Audio(
                label="Reference Style Audio",
                type="filepath",
                interactive=True
            )
            
            ref_timbre_audio = gr.Audio(
                label="Reference Timbre Audio",
                type="filepath",
                interactive=True
            )
        
        with gr.TabItem("Text Inputs (TTS Mode)"):
            src_text = gr.Textbox(
                label="Source Text",
                placeholder="Enter text for TTS mode",
                interactive=True
            )
            
            ref_text = gr.Textbox(
                label="Reference Style Text (Optional)",
                placeholder="Enter reference text",
                interactive=True
            )
            
            with gr.Row():
                src_language = gr.Dropdown(
                    choices=["en", "zh"],
                    value="en",
                    label="Source Language",
                    interactive=True
                )
                
                ref_language = gr.Dropdown(
                    choices=["en", "zh"],
                    value="en",
                    label="Reference Language",
                    interactive=True
                )
    
    with gr.Row():
        steps = gr.Slider(
            minimum=1,
            maximum=64,
            value=32,
            step=1,
            label="Flow Matching Steps"
        )
    
    with gr.Row():
        with gr.Column():
            submit_btn = gr.Button("Generate")
            error_box = gr.Textbox(label="Status", interactive=False)
        output_audio = gr.Audio(label="Generated Audio")
    
    def process_with_error_handling(*args):
        try:
            result = process_audio(*args)
            error_box.update(value="Success!")
            return [result, "Success!"]
        except Exception as e:
            error_msg = str(e)
            return [None, error_msg]
    
    submit_btn.click(
        fn=process_with_error_handling,
        inputs=[
            mode,
            content_audio,
            ref_style_audio,
            ref_timbre_audio,
            src_text,
            ref_text,
            src_language,
            ref_language,
            steps
        ],
        outputs=[output_audio, error_box]
    )

    # Example usage text
    gr.Markdown("""
    ## Quick Start Guide
    
    1. Select your mode:
       - **Voice**: Full voice conversion (style + timbre)
       - **Timbre**: Only voice timbre conversion
       - **TTS**: Text-to-speech with voice cloning
    
    2. For Voice/Timbre modes:
       - Upload source audio (what you want to convert)
       - Upload reference audio(s)
    
    3. For TTS mode:
       - Enter your text
       - Select language
       - Upload reference audio(s)
    
    4. Adjust steps slider:
       - Higher values = better quality but slower
       - Lower values = faster but lower quality
    
    5. Click Generate and wait for processing
    """)

if __name__ == "__main__":
    demo.queue().launch()