naonauno's picture
Update app.py
e366fb9 verified
import os
import sys
import torch
import gradio as gr
from pydub import AudioSegment
import mimetypes
sys.path.append('./Amphion')
import Amphion.models.vc.vevo.vevo_utils as vevo_utils
from huggingface_hub import snapshot_download
def load_model():
print("Loading model...")
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
print(f"Using device: {device}")
cache_dir = "./ckpts/Vevo"
os.makedirs(cache_dir, exist_ok=True)
# Content Tokenizer
local_dir = snapshot_download(
repo_id="amphion/Vevo",
repo_type="model",
cache_dir=cache_dir,
allow_patterns=["tokenizer/vq32/*"],
)
content_tokenizer_ckpt_path = os.path.join(
local_dir, "tokenizer/vq32/hubert_large_l18_c32.pkl"
)
# Content-Style Tokenizer
local_dir = snapshot_download(
repo_id="amphion/Vevo",
repo_type="model",
cache_dir=cache_dir,
allow_patterns=["tokenizer/vq8192/*"],
)
content_style_tokenizer_ckpt_path = os.path.join(local_dir, "tokenizer/vq8192")
# Autoregressive Transformer
local_dir = snapshot_download(
repo_id="amphion/Vevo",
repo_type="model",
cache_dir=cache_dir,
allow_patterns=["contentstyle_modeling/Vq32ToVq8192/*"],
)
ar_cfg_path = "./config/Vq32ToVq8192.json"
ar_ckpt_path = os.path.join(local_dir, "contentstyle_modeling/Vq32ToVq8192")
# Flow Matching Transformer
local_dir = snapshot_download(
repo_id="amphion/Vevo",
repo_type="model",
cache_dir=cache_dir,
allow_patterns=["acoustic_modeling/Vq8192ToMels/*"],
)
fmt_cfg_path = "./config/Vq8192ToMels.json"
fmt_ckpt_path = os.path.join(local_dir, "acoustic_modeling/Vq8192ToMels")
# Vocoder
local_dir = snapshot_download(
repo_id="amphion/Vevo",
repo_type="model",
cache_dir=cache_dir,
allow_patterns=["acoustic_modeling/Vocoder/*"],
)
vocoder_cfg_path = "./Amphion/models/vc/vevo/config/Vocoder.json"
vocoder_ckpt_path = os.path.join(local_dir, "acoustic_modeling/Vocoder")
print("Initializing pipeline...")
pipeline = vevo_utils.VevoInferencePipeline(
content_tokenizer_ckpt_path=content_tokenizer_ckpt_path,
content_style_tokenizer_ckpt_path=content_style_tokenizer_ckpt_path,
ar_cfg_path=ar_cfg_path,
ar_ckpt_path=ar_ckpt_path,
fmt_cfg_path=fmt_cfg_path,
fmt_ckpt_path=fmt_ckpt_path,
vocoder_cfg_path=vocoder_cfg_path,
vocoder_ckpt_path=vocoder_ckpt_path,
device=device
)
print("Model loaded successfully!")
return pipeline
def convert_to_wav(audio_path):
if audio_path is None:
return None
mime, _ = mimetypes.guess_type(audio_path)
if mime == 'audio/wav' or mime == 'audio/x-wav':
return audio_path
elif mime == 'audio/mpeg':
seg = AudioSegment.from_mp3(audio_path)
wav_path = audio_path.rsplit('.', 1)[0] + '.wav'
seg.export(wav_path, format="wav")
return wav_path
else:
raise ValueError(f"Unsupported audio format: {mime}")
def process_audio(mode, content_audio, ref_style_audio, ref_timbre_audio,
src_text, ref_text, src_language, ref_language, steps,
progress=gr.Progress()):
try:
output_dir = "outputs"
os.makedirs(output_dir, exist_ok=True)
output_path = os.path.join(output_dir, "output.wav")
# Convert uploaded audio files to WAV if needed
if content_audio:
content_path = convert_to_wav(content_audio)
else:
content_path = None
if ref_style_audio:
ref_style_path = convert_to_wav(ref_style_audio)
else:
ref_style_path = None
if ref_timbre_audio:
ref_timbre_path = convert_to_wav(ref_timbre_audio)
else:
ref_timbre_path = None
progress(0.2, "Processing audio...")
# Run inference based on mode
if mode == 'voice':
if not all([content_path, ref_style_path, ref_timbre_path]):
raise gr.Error("Voice mode requires all audio inputs")
gen_audio = inference_pipeline.inference_ar_and_fm(
src_wav_path=content_path,
src_text=None,
style_ref_wav_path=ref_style_path,
timbre_ref_wav_path=ref_timbre_path,
flow_matching_steps=steps
)
elif mode == 'timbre':
if not all([content_path, ref_timbre_path]):
raise gr.Error("Timbre mode requires source and timbre reference audio")
gen_audio = inference_pipeline.inference_fm(
src_wav_path=content_path,
timbre_ref_wav_path=ref_timbre_path,
flow_matching_steps=steps
)
elif mode == 'tts':
if not all([ref_style_path, ref_timbre_path]) or not src_text:
raise gr.Error("TTS mode requires style audio, timbre audio, and source text")
gen_audio = inference_pipeline.inference_ar_and_fm(
src_wav_path=None,
src_text=src_text,
style_ref_wav_path=ref_style_path,
timbre_ref_wav_path=ref_timbre_path,
style_ref_wav_text=ref_text if ref_text else None,
src_text_language=src_language,
style_ref_wav_text_language=ref_language
)
progress(0.8, "Saving generated audio...")
# Save and return the generated audio
vevo_utils.save_audio(gen_audio, target_sample_rate=48000, output_path=output_path)
return output_path
except Exception as e:
raise gr.Error(str(e))
# Initialize the model
inference_pipeline = load_model()
# Create the Gradio interface
with gr.Blocks(title="Vevo Voice Conversion") as demo:
gr.Markdown("# Vevo Voice Conversion")
with gr.Row():
mode = gr.Radio(
choices=["voice", "timbre", "tts"],
value="timbre",
label="Inference Mode",
interactive=True
)
with gr.Tabs():
with gr.TabItem("Audio Inputs"):
content_audio = gr.Audio(
label="Source Audio",
type="filepath",
interactive=True
)
ref_style_audio = gr.Audio(
label="Reference Style Audio",
type="filepath",
interactive=True
)
ref_timbre_audio = gr.Audio(
label="Reference Timbre Audio",
type="filepath",
interactive=True
)
with gr.TabItem("Text Inputs (TTS Mode)"):
src_text = gr.Textbox(
label="Source Text",
placeholder="Enter text for TTS mode",
interactive=True
)
ref_text = gr.Textbox(
label="Reference Style Text (Optional)",
placeholder="Enter reference text",
interactive=True
)
with gr.Row():
src_language = gr.Dropdown(
choices=["en", "zh"],
value="en",
label="Source Language",
interactive=True
)
ref_language = gr.Dropdown(
choices=["en", "zh"],
value="en",
label="Reference Language",
interactive=True
)
with gr.Row():
steps = gr.Slider(
minimum=1,
maximum=64,
value=32,
step=1,
label="Flow Matching Steps"
)
with gr.Row():
with gr.Column():
submit_btn = gr.Button("Generate")
error_box = gr.Textbox(label="Status", interactive=False)
output_audio = gr.Audio(label="Generated Audio")
def process_with_error_handling(*args):
try:
result = process_audio(*args)
error_box.update(value="Success!")
return [result, "Success!"]
except Exception as e:
error_msg = str(e)
return [None, error_msg]
submit_btn.click(
fn=process_with_error_handling,
inputs=[
mode,
content_audio,
ref_style_audio,
ref_timbre_audio,
src_text,
ref_text,
src_language,
ref_language,
steps
],
outputs=[output_audio, error_box]
)
# Example usage text
gr.Markdown("""
## Quick Start Guide
1. Select your mode:
- **Voice**: Full voice conversion (style + timbre)
- **Timbre**: Only voice timbre conversion
- **TTS**: Text-to-speech with voice cloning
2. For Voice/Timbre modes:
- Upload source audio (what you want to convert)
- Upload reference audio(s)
3. For TTS mode:
- Enter your text
- Select language
- Upload reference audio(s)
4. Adjust steps slider:
- Higher values = better quality but slower
- Lower values = faster but lower quality
5. Click Generate and wait for processing
""")
if __name__ == "__main__":
demo.queue().launch()