xtts_awesome / gradio_utils.py
awesome-paulw's picture
Upload folder using huggingface_hub
1207342 verified
raw
history blame
7.21 kB
import os
import shutil
import torch
import traceback
from pathlib import Path
from multiprocessing import cpu_count
from functions.core_functions1 import clear_gpu_cache
from functions.logging_utils import remove_log_file
from functions.slice_utils import open_slice, close_slice
from utils.formatter import format_audio_list
from utils.gpt_train import train_gpt
def get_audio_files_from_folder(folder_path):
audio_files = []
for root, dirs, files in os.walk(folder_path):
for file in files:
if file.endswith(".wav") or file.endswith(".mp3") or file.endswith(".flac") or file.endswith(".m4a") or file.endswith(".webm"):
audio_files.append(os.path.join(root, file))
return audio_files
def preprocess_dataset(audio_path, audio_folder, language, whisper_model, out_path, train_csv, eval_csv, progress):
out_path = os.path.join(out_path, "dataset")
os.makedirs(out_path, exist_ok=True)
if audio_path is not None and audio_path != []:
try:
train_meta, eval_meta, audio_total_size = format_audio_list(audio_path, whisper_model=whisper_model, target_language=language, out_path=out_path, gradio_progress=progress)
except:
traceback.print_exc()
error = traceback.format_exc()
return f"The data processing was interrupted due to an error! Please check the console to verify the full error message! \n Error summary: {error}", "", ""
elif audio_folder is not None:
audio_files = get_audio_files_from_folder(audio_folder)
try:
train_meta, eval_meta, audio_total_size = format_audio_list(audio_files, whisper_model=whisper_model, target_language=language, out_path=out_path, gradio_progress=progress)
except:
traceback.print_exc()
error = traceback.format_exc()
return f"The data processing was interrupted due to an error! Please check the console to verify the full error message! \n Error summary: {error}", "", ""
else:
return "You should provide either audio files or a folder containing audio files!", "", ""
if audio_total_size < 120:
message = "The sum of the duration of the audios that you provided should be at least 2 minutes!"
print(message)
return message, "", ""
print("Dataset Processed!")
return "Dataset Processed!", train_meta, eval_meta
def train_model(custom_model, version, language, train_csv, eval_csv, num_epochs, batch_size, grad_accum, output_path, max_audio_length):
run_dir = Path(output_path) / "run"
if run_dir.exists():
os.remove(run_dir)
if not train_csv or not eval_csv:
return "You need to run the data processing step or manually set `Train CSV` and `Eval CSV` fields!", "", "", "", "", ""
try:
max_audio_length = int(max_audio_length * 22050)
speaker_xtts_path, config_path, original_xtts_checkpoint, vocab_file, exp_path, speaker_wav = train_gpt(custom_model, version, language, num_epochs, batch_size, grad_accum, train_csv, eval_csv, output_path=output_path, max_audio_length=max_audio_length)
except:
traceback.print_exc()
error = traceback.format_exc()
return f"The training was interrupted due to an error! Please check the console to check the full error message! \n Error summary: {error}", "", "", "", "", ""
ready_dir = Path(output_path) / "ready"
ft_xtts_checkpoint = os.path.join(exp_path, "best_model.pth")
shutil.copy(ft_xtts_checkpoint, ready_dir / "unoptimize_model.pth")
ft_xtts_checkpoint = os.path.join(ready_dir, "unoptimize_model.pth")
speaker_reference_path = Path(speaker_wav)
speaker_reference_new_path = ready_dir / "reference.wav"
shutil.copy(speaker_reference_path, speaker_reference_new_path)
print("Model training done!")
return "Model training done!", config_path, vocab_file, ft_xtts_checkpoint, speaker_xtts_path, speaker_reference_new_path
def optimize_model(out_path, clear_train_data):
out_path = Path(out_path)
ready_dir = out_path / "ready"
run_dir = out_path / "run"
dataset_dir = out_path / "dataset"
if clear_train_data in {"run", "all"} and run_dir.exists():
try:
shutil.rmtree(run_dir)
except PermissionError as e:
print(f"An error occurred while deleting {run_dir}: {e}")
if clear_train_data in {"dataset", "all"} and dataset_dir.exists():
try:
shutil.rmtree(dataset_dir)
except PermissionError as e:
print(f"An error occurred while deleting {dataset_dir}: {e}")
model_path = ready_dir / "unoptimize_model.pth"
if not model_path.is_file():
return "Unoptimized model not found in ready folder", ""
checkpoint = torch.load(model_path, map_location=torch.device("cpu"))
del checkpoint["optimizer"]
for key in list(checkpoint["model"].keys()):
if "dvae" in key:
del checkpoint["model"][key]
os.remove(model_path)
optimized_model_file_name = "model.pth"
optimized_model = ready_dir / optimized_model_file_name
torch.save(checkpoint, optimized_model)
ft_xtts_checkpoint = str(optimized_model)
return f"Model optimized and saved at {ft_xtts_checkpoint}!", ft_xtts_checkpoint
def load_params(out_path):
path_output = Path(out_path)
dataset_path = path_output / "dataset"
if not dataset_path.exists():
return "The output folder does not exist!", "", "", ""
eval_train = dataset_path / "metadata_train.csv"
eval_csv = dataset_path / "metadata_eval.csv"
lang_file_path = dataset_path / "lang.txt"
current_language = None
if os.path.exists(lang_file_path):
with open(lang_file_path, 'r', encoding='utf-8') as existing_lang_file:
current_language = existing_lang_file.read().strip()
print(current_language)
return "The data has been updated", eval_train, eval_csv, current_language
def load_params_tts(out_path, version):
path_output = Path(out_path)
ready_dir = path_output / "ready"
xtts_config_path = ready_dir / "config.json"
xtts_vocab_path = ready_dir / "vocab.json"
xtts_checkpoint_path = ready_dir / "model.pth"
xtts_speaker_path = ready_dir / "speaker.pth"
speaker_reference_path = ready_dir / "reference.wav"
missing_files = []
if not xtts_config_path.exists():
missing_files.append(str(xtts_config_path))
if not xtts_vocab_path.exists():
missing_files.append(str(xtts_vocab_path))
if not xtts_checkpoint_path.exists():
missing_files.append(str(xtts_checkpoint_path))
if not xtts_speaker_path.exists():
missing_files.append(str(xtts_speaker_path))
if not speaker_reference_path.exists():
missing_files.append(str(speaker_reference_path))
if missing_files:
return f"The following files are missing from the ready folder: {', '.join(missing_files)}", "", "", "", "", ""
print("Loaded parameters for TTS.")
return "Loaded parameters for TTS.", str(xtts_checkpoint_path), str(xtts_config_path), str(xtts_vocab_path), str(xtts_speaker_path), str(speaker_reference_path)