import argparse import os from pathlib import Path import gradio as gr #import torch from functions.core_functions1 import clear_gpu_cache, load_model, run_tts, load_params_tts, process_srt_and_generate_audio, convert_voice # preprocess_dataset, load_params, train_model, optimize_model, from functions.logging_utils import remove_log_file, read_logs from functions.slice_utils import open_slice, close_slice, kill_process from utils.formatter import format_audio_list from utils.gpt_train import train_gpt import traceback import shutil from tools.i18n.i18n import I18nAuto from tools import my_utils from multiprocessing import cpu_count from subprocess import Popen from config import python_exec, is_share, webui_port_main if __name__ == "__main__": # 清除旧的日志文件 remove_log_file("logs/main.log") parser = argparse.ArgumentParser( description="""XTTS fine-tuning demo\n\n""" """ Example runs: python3 TTS/demos/xtts_ft_demo/xtts_demo.py --port """, formatter_class=argparse.RawTextHelpFormatter, ) parser.add_argument( "--port", type=int, help="Port to run the gradio demo. Default: 5003", default=5003, ) parser.add_argument( "--out_path", type=str, help="Output path (where data and checkpoints will be saved) Default: output/", default=str(Path.cwd() / "finetune_models"), ) parser.add_argument( "--num_epochs", type=int, help="Number of epochs to train. Default: 6", default=6, ) parser.add_argument( "--batch_size", type=int, help="Batch size. Default: 2", default=2, ) parser.add_argument( "--grad_acumm", type=int, help="Grad accumulation steps. Default: 1", default=1, ) parser.add_argument( "--max_audio_length", type=int, help="Max permitted audio size in seconds. Default: 11", default=11, ) args = parser.parse_args() i18n = I18nAuto() n_cpu=cpu_count() ''' ngpu = torch.cuda.device_count() gpu_infos = [] mem = [] if_gpu_ok = False ''' with gr.Blocks() as demo: with gr.Tab("0 - Audio Slicing"): gr.Markdown(value=i18n("0b-语音切分工具")) with gr.Row(): slice_inp_path = gr.Textbox(label=i18n("音频自动切分输入路径,可文件可文件夹"), value="") slice_opt_root = gr.Textbox(label=i18n("切分后的子音频的输出根目录"), value="output/slicer_opt") threshold = gr.Textbox(label=i18n("threshold:音量小于这个值视作静音的备选切割点"), value="-34") min_length = gr.Textbox(label=i18n("min_length:每段最小多长,如果第一段太短一直和后面段连起来直到超过这个值"), value="4000") min_interval = gr.Textbox(label=i18n("min_interval:最短切割间隔"), value="300") hop_size = gr.Textbox(label=i18n("hop_size:怎么算音量曲线,越小精度越大计算量越高(不是精度越大效果越好)"), value="10") max_sil_kept = gr.Textbox(label=i18n("max_sil_kept:切完后静音最多留多长"), value="500") with gr.Row(): open_slicer_button = gr.Button(i18n("开启语音切割"), variant="primary", visible=True) close_slicer_button = gr.Button(i18n("终止语音切割"), variant="primary", visible=False) _max = gr.Slider(minimum=0, maximum=1, step=0.05, label=i18n("max:归一化后最大值多少"), value=0.9, interactive=True) alpha = gr.Slider(minimum=0, maximum=1, step=0.05, label=i18n("alpha_mix:混多少比例归一化后音频进来"), value=0.25, interactive=True) n_process = gr.Slider(minimum=1, maximum=n_cpu, step=1, label=i18n("切割使用的进程数"), value=4, interactive=True) slicer_info = gr.Textbox(label=i18n("语音切割进程输出信息")) open_slicer_button.click(open_slice, [slice_inp_path, slice_opt_root, threshold, min_length, min_interval, hop_size, max_sil_kept, _max, alpha, n_process], [slicer_info, open_slicer_button, close_slicer_button]) close_slicer_button.click(close_slice, [], [slicer_info, open_slicer_button, close_slicer_button]) with gr.Tab("1 - Data processing"): out_path = gr.Textbox(label="Output path (where data and checkpoints will be saved):", value=args.out_path) upload_file = gr.File(file_count="multiple", label="Select here the audio files that you want to use for XTTS trainining (Supported formats: wav, mp3, and flac)") folder_path = gr.Textbox(label="Or input the path of a folder containing audio files") whisper_model = gr.Dropdown(label="Whisper Model", value="large-v3", choices=["large-v3", "large-v2", "large", "medium", "small"]) lang = gr.Dropdown(label="Dataset Language", value="en", choices=["en", "es", "fr", "de", "it", "pt", "pl", "tr", "ru", "nl", "cs", "ar", "zh", "hu", "ko", "ja"]) progress_data = gr.Label(label="Progress:") prompt_compute_btn = gr.Button(value="Step 1 - Create dataset") def get_audio_files_from_folder(folder_path): audio_files = [] for root, dirs, files in os.walk(folder_path): for file in files: if file.endswith(".wav") or file.endswith(".mp3") or file.endswith(".flac") or file.endswith(".m4a") or file.endswith(".webm"): audio_files.append(os.path.join(root, file)) return audio_files def preprocess_dataset(audio_path, audio_folder, language, whisper_model, out_path, train_csv, eval_csv, progress=gr.Progress(track_tqdm=True)): clear_gpu_cache() train_csv = "" eval_csv = "" out_path = os.path.join(out_path, "dataset") os.makedirs(out_path, exist_ok=True) # 检测输入是单个文件、多个文件还是文件夹 if audio_path is not None and audio_path != []: # 处理单个文件或多个文件 try: train_meta, eval_meta, audio_total_size = format_audio_list(audio_path, whisper_model=whisper_model, target_language=language, out_path=out_path, gradio_progress=progress) except: traceback.print_exc() error = traceback.format_exc() return f"The data processing was interrupted due to an error! Please check the console to verify the full error message! \n Error summary: {error}", "", "" elif audio_folder is not None: # 处理文件夹 audio_files = get_audio_files_from_folder(audio_folder) try: train_meta, eval_meta, audio_total_size = format_audio_list(audio_files, whisper_model=whisper_model, target_language=language, out_path=out_path, gradio_progress=progress) except: traceback.print_exc() error = traceback.format_exc() return f"The data processing was interrupted due to an error! Please check the console to verify the full error message! \n Error summary: {error}", "", "" else: return "You should provide either audio files or a folder containing audio files!", "", "" # if audio total len is less than 2 minutes raise an error if audio_total_size < 120: message = "The sum of the duration of the audios that you provided should be at least 2 minutes!" print(message) return message, "", "" print("Dataset Processed!") return "Dataset Processed!", train_meta, eval_meta #prompt_compute_btn.click(preprocess_dataset, inputs=[upload_file, upload_folder, lang, whisper_model, out_path, train_csv, eval_csv], outputs=[progress_data, train_csv, eval_csv]) ''' def preprocess_dataset(audio_path, language, whisper_model, out_path,train_csv,eval_csv, progress=gr.Progress(track_tqdm=True)): clear_gpu_cache() train_csv = "" eval_csv = "" out_path = os.path.join(out_path, "dataset") os.makedirs(out_path, exist_ok=True) if audio_path is None: return "You should provide one or multiple audio files! If you provided it, probably the upload of the files is not finished yet!", "", "" else: try: train_meta, eval_meta, audio_total_size = format_audio_list(audio_path, whisper_model = whisper_model, target_language=language, out_path=out_path, gradio_progress=progress) except: traceback.print_exc() error = traceback.format_exc() return f"The data processing was interrupted due an error !! Please check the console to verify the full error message! \n Error summary: {error}", "", "" # clear_gpu_cache() # if audio total len is less than 2 minutes raise an error if audio_total_size < 120: message = "The sum of the duration of the audios that you provided should be at least 2 minutes!" print(message) return message, "", "" print("Dataset Processed!") return "Dataset Processed!", train_meta, eval_meta ''' with gr.Tab("2 - Fine-tuning XTTS Encoder"): load_params_btn = gr.Button(value="Load Params from output folder") version = gr.Dropdown( label="XTTS base version", value="v2.0.2", choices=[ "v2.0.3", "v2.0.2", "v2.0.1", "v2.0.0", "main" ], ) train_csv = gr.Textbox( label="Train CSV:", ) eval_csv = gr.Textbox( label="Eval CSV:", ) custom_model = gr.Textbox( label="(Optional) Custom model.pth file , leave blank if you want to use the base file.", value="", ) num_epochs = gr.Slider( label="Number of epochs:", minimum=1, maximum=100, step=1, value=args.num_epochs, ) batch_size = gr.Slider( label="Batch size:", minimum=2, maximum=512, step=1, value=args.batch_size, ) grad_acumm = gr.Slider( label="Grad accumulation steps:", minimum=2, maximum=128, step=1, value=args.grad_acumm, ) max_audio_length = gr.Slider( label="Max permitted audio size in seconds:", minimum=2, maximum=20, step=1, value=args.max_audio_length, ) clear_train_data = gr.Dropdown( label="Clear train data, you will delete selected folder, after optimizing", value="run", choices=[ "none", "run", "dataset", "all" ]) progress_train = gr.Label( label="Progress:" ) # demo.load(read_logs, None, logs_tts_train, every=1) train_btn = gr.Button(value="Step 2 - Run the training") optimize_model_btn = gr.Button(value="Step 2.5 - Optimize the model") def train_model(custom_model,version,language, train_csv, eval_csv, num_epochs, batch_size, grad_acumm, output_path, max_audio_length): clear_gpu_cache() run_dir = Path(output_path) / "run" # # Remove train dir if run_dir.exists(): os.remove(run_dir) # Check if the dataset language matches the language you specified lang_file_path = Path(output_path) / "dataset" / "lang.txt" # Check if lang.txt already exists and contains a different language current_language = None if lang_file_path.exists(): with open(lang_file_path, 'r', encoding='utf-8') as existing_lang_file: current_language = existing_lang_file.read().strip() if current_language != language: print("The language that was prepared for the dataset does not match the specified language. Change the language to the one specified in the dataset") language = current_language if not train_csv or not eval_csv: return "You need to run the data processing step or manually set `Train CSV` and `Eval CSV` fields !", "", "", "", "" try: # convert seconds to waveform frames max_audio_length = int(max_audio_length * 22050) speaker_xtts_path,config_path, original_xtts_checkpoint, vocab_file, exp_path, speaker_wav = train_gpt(custom_model,version,language, num_epochs, batch_size, grad_acumm, train_csv, eval_csv, output_path=output_path, max_audio_length=max_audio_length) except: traceback.print_exc() error = traceback.format_exc() return f"The training was interrupted due an error !! Please check the console to check the full error message! \n Error summary: {error}", "", "", "", "" # copy original files to avoid parameters changes issues # os.system(f"cp {config_path} {exp_path}") # os.system(f"cp {vocab_file} {exp_path}") ready_dir = Path(output_path) / "ready" ft_xtts_checkpoint = os.path.join(exp_path, "best_model.pth") shutil.copy(ft_xtts_checkpoint, ready_dir / "unoptimize_model.pth") # os.remove(ft_xtts_checkpoint) ft_xtts_checkpoint = os.path.join(ready_dir, "unoptimize_model.pth") # Reference # Move reference audio to output folder and rename it speaker_reference_path = Path(speaker_wav) speaker_reference_new_path = ready_dir / "reference.wav" shutil.copy(speaker_reference_path, speaker_reference_new_path) print("Model training done!") # clear_gpu_cache() return "Model training done!", config_path, vocab_file, ft_xtts_checkpoint,speaker_xtts_path, speaker_reference_new_path def optimize_model(out_path, clear_train_data): # print(out_path) out_path = Path(out_path) # Ensure that out_path is a Path object. ready_dir = out_path / "ready" run_dir = out_path / "run" dataset_dir = out_path / "dataset" # Clear specified training data directories. if clear_train_data in {"run", "all"} and run_dir.exists(): try: shutil.rmtree(run_dir) except PermissionError as e: print(f"An error occurred while deleting {run_dir}: {e}") if clear_train_data in {"dataset", "all"} and dataset_dir.exists(): try: shutil.rmtree(dataset_dir) except PermissionError as e: print(f"An error occurred while deleting {dataset_dir}: {e}") # Get full path to model model_path = ready_dir / "unoptimize_model.pth" if not model_path.is_file(): return "Unoptimized model not found in ready folder", "" # Load the checkpoint and remove unnecessary parts. checkpoint = torch.load(model_path, map_location=torch.device("cpu")) del checkpoint["optimizer"] for key in list(checkpoint["model"].keys()): if "dvae" in key: del checkpoint["model"][key] # Make sure out_path is a Path object or convert it to Path os.remove(model_path) # Save the optimized model. optimized_model_file_name="model.pth" optimized_model=ready_dir/optimized_model_file_name torch.save(checkpoint, optimized_model) ft_xtts_checkpoint=str(optimized_model) clear_gpu_cache() return f"Model optimized and saved at {ft_xtts_checkpoint}!", ft_xtts_checkpoint def load_params(out_path): path_output = Path(out_path) dataset_path = path_output / "dataset" if not dataset_path.exists(): return "The output folder does not exist!", "", "" eval_train = dataset_path / "metadata_train.csv" eval_csv = dataset_path / "metadata_eval.csv" # Write the target language to lang.txt in the output directory lang_file_path = dataset_path / "lang.txt" # Check if lang.txt already exists and contains a different language current_language = None if os.path.exists(lang_file_path): with open(lang_file_path, 'r', encoding='utf-8') as existing_lang_file: current_language = existing_lang_file.read().strip() clear_gpu_cache() print(current_language) return "The data has been updated", eval_train, eval_csv, current_language with gr.Tab("3 - Inference"): with gr.Row(): with gr.Column() as col1: load_params_tts_btn = gr.Button(value="Load params for TTS from output folder") xtts_checkpoint = gr.Textbox( label="XTTS checkpoint path:", value="", ) xtts_config = gr.Textbox( label="XTTS config path:", value="", ) xtts_vocab = gr.Textbox( label="XTTS vocab path:", value="", ) xtts_speaker = gr.Textbox( label="XTTS speaker path:", value="", ) progress_load = gr.Label( label="Progress:" ) load_btn = gr.Button(value="Step 3 - Load Fine-tuned XTTS model") with gr.Column() as col2: speaker_reference_audio = gr.Textbox( label="Speaker reference audio:", value="", ) tts_language = gr.Dropdown( label="Language", value="en", choices=[ "en", "es", "fr", "de", "it", "pt", "pl", "tr", "ru", "nl", "cs", "ar", "zh", "hu", "ko", "ja", ] ) tts_text = gr.Textbox( label="Input Text.", value="This model sounds really good and above all, it's reasonably fast.", ) with gr.Accordion("Advanced settings", open=False) as acr: temperature = gr.Slider( label="temperature", minimum=0, maximum=1, step=0.05, value=0.75, ) length_penalty = gr.Slider( label="length_penalty", minimum=-10.0, maximum=10.0, step=0.5, value=1, ) repetition_penalty = gr.Slider( label="repetition penalty", minimum=1, maximum=10, step=0.5, value=5, ) top_k = gr.Slider( label="top_k", minimum=1, maximum=100, step=1, value=50, ) top_p = gr.Slider( label="top_p", minimum=0, maximum=1, step=0.05, value=0.85, ) speed = gr.Slider( label="speed", minimum=0.2, maximum=4.0, step=0.05, value=1.0, ) sentence_split = gr.Checkbox( label="Enable text splitting", value=True, ) use_config = gr.Checkbox( label="Use Inference settings from config, if disabled use the settings above", value=False, ) tts_btn = gr.Button(value="Step 4 - Inference") with gr.Column() as col3: progress_gen = gr.Label( label="Progress:" ) tts_output_audio = gr.Audio(label="Generated Audio.") reference_audio = gr.Audio(label="Reference audio used.") with gr.Column() as col4: srt_upload = gr.File(label="Upload SRT File") generate_srt_audio_btn = gr.Button(value="Generate Audio from SRT") srt_output_audio = gr.Audio(label="Combined Audio from SRT") error_message = gr.Textbox(label="Error Message", visible=False) # 错误消息组件,默认不显示 generate_srt_audio_btn.click( fn=process_srt_and_generate_audio, inputs=[ srt_upload, tts_language, speaker_reference_audio, temperature, length_penalty, repetition_penalty, top_k, top_p, speed, sentence_split, use_config ], outputs=[srt_output_audio] ) prompt_compute_btn.click( fn=preprocess_dataset, inputs=[ upload_file, lang, whisper_model, out_path, train_csv, eval_csv ], outputs=[ progress_data, train_csv, eval_csv, ], ) load_params_btn.click( fn=load_params, inputs=[out_path], outputs=[ progress_train, train_csv, eval_csv, lang ] ) train_btn.click( fn=train_model, inputs=[ custom_model, version, lang, train_csv, eval_csv, num_epochs, batch_size, grad_acumm, out_path, max_audio_length, ], outputs=[progress_train, xtts_config, xtts_vocab, xtts_checkpoint,xtts_speaker, speaker_reference_audio], ) optimize_model_btn.click( fn=optimize_model, inputs=[ out_path, clear_train_data ], outputs=[progress_train,xtts_checkpoint], ) load_btn.click( fn=load_model, inputs=[ xtts_checkpoint, xtts_config, xtts_vocab, xtts_speaker ], outputs=[progress_load], ) tts_btn.click( fn=run_tts, inputs=[ tts_language, tts_text, speaker_reference_audio, temperature, length_penalty, repetition_penalty, top_k, top_p, speed, sentence_split, use_config ], outputs=[progress_gen, tts_output_audio, reference_audio], ) load_params_tts_btn.click( fn=load_params_tts, inputs=[ out_path, version ], outputs=[progress_load,xtts_checkpoint,xtts_config,xtts_vocab,xtts_speaker,speaker_reference_audio], ) with gr.Tab("4 - Voice conversion"): with gr.Column() as col0: gr.Markdown("## OpenVoice Conversion Tool") voice_convert_seed = gr.File(label="Upload Reference Speaker Audio being generated") #pitch_shift_slider = gr.Slider(minimum=-12, maximum=12, step=1, value=0, label="Pitch Shift (Semitones)") audio_to_convert = gr.Textbox( label="Input the to-be-convert audio location", value="", ) convert_button = gr.Button("Convert Voice") converted_audio = gr.Audio(label="Converted Audio") convert_button.click( convert_voice, inputs=[voice_convert_seed, audio_to_convert], #, pitch_shift_slider], outputs=[converted_audio] ) with gr.Tab("5 - Logs"): # 添加一个按钮来读取日志 read_logs_btn = gr.Button("Read Logs") log_output = gr.Textbox(label="Log Output") read_logs_btn.click(fn=read_logs, inputs=None, outputs=log_output) demo.launch( #share=False, share=True, debug=False, server_port=args.port, #server_name="localhost" server_name="0.0.0.0" )