Spaces:
Paused
Paused
import os | |
import sys | |
import torch | |
import gradio as gr | |
from pydub import AudioSegment | |
import mimetypes | |
sys.path.append('./Amphion') | |
import Amphion.models.vc.vevo.vevo_utils as vevo_utils | |
from huggingface_hub import snapshot_download | |
def load_model(): | |
print("Loading model...") | |
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") | |
print(f"Using device: {device}") | |
cache_dir = "./ckpts/Vevo" | |
os.makedirs(cache_dir, exist_ok=True) | |
# Content Tokenizer | |
local_dir = snapshot_download( | |
repo_id="amphion/Vevo", | |
repo_type="model", | |
cache_dir=cache_dir, | |
allow_patterns=["tokenizer/vq32/*"], | |
) | |
content_tokenizer_ckpt_path = os.path.join( | |
local_dir, "tokenizer/vq32/hubert_large_l18_c32.pkl" | |
) | |
# Content-Style Tokenizer | |
local_dir = snapshot_download( | |
repo_id="amphion/Vevo", | |
repo_type="model", | |
cache_dir=cache_dir, | |
allow_patterns=["tokenizer/vq8192/*"], | |
) | |
content_style_tokenizer_ckpt_path = os.path.join(local_dir, "tokenizer/vq8192") | |
# Autoregressive Transformer | |
local_dir = snapshot_download( | |
repo_id="amphion/Vevo", | |
repo_type="model", | |
cache_dir=cache_dir, | |
allow_patterns=["contentstyle_modeling/Vq32ToVq8192/*"], | |
) | |
ar_cfg_path = "./config/Vq32ToVq8192.json" | |
ar_ckpt_path = os.path.join(local_dir, "contentstyle_modeling/Vq32ToVq8192") | |
# Flow Matching Transformer | |
local_dir = snapshot_download( | |
repo_id="amphion/Vevo", | |
repo_type="model", | |
cache_dir=cache_dir, | |
allow_patterns=["acoustic_modeling/Vq8192ToMels/*"], | |
) | |
fmt_cfg_path = "./config/Vq8192ToMels.json" | |
fmt_ckpt_path = os.path.join(local_dir, "acoustic_modeling/Vq8192ToMels") | |
# Vocoder | |
local_dir = snapshot_download( | |
repo_id="amphion/Vevo", | |
repo_type="model", | |
cache_dir=cache_dir, | |
allow_patterns=["acoustic_modeling/Vocoder/*"], | |
) | |
vocoder_cfg_path = "./Amphion/models/vc/vevo/config/Vocoder.json" | |
vocoder_ckpt_path = os.path.join(local_dir, "acoustic_modeling/Vocoder") | |
print("Initializing pipeline...") | |
pipeline = vevo_utils.VevoInferencePipeline( | |
content_tokenizer_ckpt_path=content_tokenizer_ckpt_path, | |
content_style_tokenizer_ckpt_path=content_style_tokenizer_ckpt_path, | |
ar_cfg_path=ar_cfg_path, | |
ar_ckpt_path=ar_ckpt_path, | |
fmt_cfg_path=fmt_cfg_path, | |
fmt_ckpt_path=fmt_ckpt_path, | |
vocoder_cfg_path=vocoder_cfg_path, | |
vocoder_ckpt_path=vocoder_ckpt_path, | |
device=device | |
) | |
print("Model loaded successfully!") | |
return pipeline | |
def convert_to_wav(audio_path): | |
if audio_path is None: | |
return None | |
mime, _ = mimetypes.guess_type(audio_path) | |
if mime == 'audio/wav' or mime == 'audio/x-wav': | |
return audio_path | |
elif mime == 'audio/mpeg': | |
seg = AudioSegment.from_mp3(audio_path) | |
wav_path = audio_path.rsplit('.', 1)[0] + '.wav' | |
seg.export(wav_path, format="wav") | |
return wav_path | |
else: | |
raise ValueError(f"Unsupported audio format: {mime}") | |
def process_audio(mode, content_audio, ref_style_audio, ref_timbre_audio, | |
src_text, ref_text, src_language, ref_language, steps, | |
progress=gr.Progress()): | |
try: | |
output_dir = "outputs" | |
os.makedirs(output_dir, exist_ok=True) | |
output_path = os.path.join(output_dir, "output.wav") | |
# Convert uploaded audio files to WAV if needed | |
if content_audio: | |
content_path = convert_to_wav(content_audio) | |
else: | |
content_path = None | |
if ref_style_audio: | |
ref_style_path = convert_to_wav(ref_style_audio) | |
else: | |
ref_style_path = None | |
if ref_timbre_audio: | |
ref_timbre_path = convert_to_wav(ref_timbre_audio) | |
else: | |
ref_timbre_path = None | |
progress(0.2, "Processing audio...") | |
# Run inference based on mode | |
if mode == 'voice': | |
if not all([content_path, ref_style_path, ref_timbre_path]): | |
raise gr.Error("Voice mode requires all audio inputs") | |
gen_audio = inference_pipeline.inference_ar_and_fm( | |
src_wav_path=content_path, | |
src_text=None, | |
style_ref_wav_path=ref_style_path, | |
timbre_ref_wav_path=ref_timbre_path, | |
flow_matching_steps=steps | |
) | |
elif mode == 'timbre': | |
if not all([content_path, ref_timbre_path]): | |
raise gr.Error("Timbre mode requires source and timbre reference audio") | |
gen_audio = inference_pipeline.inference_fm( | |
src_wav_path=content_path, | |
timbre_ref_wav_path=ref_timbre_path, | |
flow_matching_steps=steps | |
) | |
elif mode == 'tts': | |
if not all([ref_style_path, ref_timbre_path]) or not src_text: | |
raise gr.Error("TTS mode requires style audio, timbre audio, and source text") | |
gen_audio = inference_pipeline.inference_ar_and_fm( | |
src_wav_path=None, | |
src_text=src_text, | |
style_ref_wav_path=ref_style_path, | |
timbre_ref_wav_path=ref_timbre_path, | |
style_ref_wav_text=ref_text if ref_text else None, | |
src_text_language=src_language, | |
style_ref_wav_text_language=ref_language | |
) | |
progress(0.8, "Saving generated audio...") | |
# Save and return the generated audio | |
vevo_utils.save_audio(gen_audio, target_sample_rate=48000, output_path=output_path) | |
return output_path | |
except Exception as e: | |
raise gr.Error(str(e)) | |
# Initialize the model | |
inference_pipeline = load_model() | |
# Create the Gradio interface | |
with gr.Blocks(title="Vevo Voice Conversion") as demo: | |
gr.Markdown("# Vevo Voice Conversion") | |
with gr.Row(): | |
mode = gr.Radio( | |
choices=["voice", "timbre", "tts"], | |
value="timbre", | |
label="Inference Mode", | |
interactive=True | |
) | |
with gr.Tabs(): | |
with gr.TabItem("Audio Inputs"): | |
content_audio = gr.Audio( | |
label="Source Audio", | |
type="filepath", | |
interactive=True | |
) | |
ref_style_audio = gr.Audio( | |
label="Reference Style Audio", | |
type="filepath", | |
interactive=True | |
) | |
ref_timbre_audio = gr.Audio( | |
label="Reference Timbre Audio", | |
type="filepath", | |
interactive=True | |
) | |
with gr.TabItem("Text Inputs (TTS Mode)"): | |
src_text = gr.Textbox( | |
label="Source Text", | |
placeholder="Enter text for TTS mode", | |
interactive=True | |
) | |
ref_text = gr.Textbox( | |
label="Reference Style Text (Optional)", | |
placeholder="Enter reference text", | |
interactive=True | |
) | |
with gr.Row(): | |
src_language = gr.Dropdown( | |
choices=["en", "zh"], | |
value="en", | |
label="Source Language", | |
interactive=True | |
) | |
ref_language = gr.Dropdown( | |
choices=["en", "zh"], | |
value="en", | |
label="Reference Language", | |
interactive=True | |
) | |
with gr.Row(): | |
steps = gr.Slider( | |
minimum=1, | |
maximum=64, | |
value=32, | |
step=1, | |
label="Flow Matching Steps" | |
) | |
with gr.Row(): | |
with gr.Column(): | |
submit_btn = gr.Button("Generate") | |
error_box = gr.Textbox(label="Status", interactive=False) | |
output_audio = gr.Audio(label="Generated Audio") | |
def process_with_error_handling(*args): | |
try: | |
result = process_audio(*args) | |
error_box.update(value="Success!") | |
return [result, "Success!"] | |
except Exception as e: | |
error_msg = str(e) | |
return [None, error_msg] | |
submit_btn.click( | |
fn=process_with_error_handling, | |
inputs=[ | |
mode, | |
content_audio, | |
ref_style_audio, | |
ref_timbre_audio, | |
src_text, | |
ref_text, | |
src_language, | |
ref_language, | |
steps | |
], | |
outputs=[output_audio, error_box] | |
) | |
# Example usage text | |
gr.Markdown(""" | |
## Quick Start Guide | |
1. Select your mode: | |
- **Voice**: Full voice conversion (style + timbre) | |
- **Timbre**: Only voice timbre conversion | |
- **TTS**: Text-to-speech with voice cloning | |
2. For Voice/Timbre modes: | |
- Upload source audio (what you want to convert) | |
- Upload reference audio(s) | |
3. For TTS mode: | |
- Enter your text | |
- Select language | |
- Upload reference audio(s) | |
4. Adjust steps slider: | |
- Higher values = better quality but slower | |
- Lower values = faster but lower quality | |
5. Click Generate and wait for processing | |
""") | |
if __name__ == "__main__": | |
demo.queue().launch() |