xtts_awesome / utils /core_functions.py
awesome-paulw's picture
Upload folder using huggingface_hub
1207342 verified
raw
history blame
12.2 kB
import os
import sys
# 获取当前文件所在目录的上一级目录
root_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
# 将根目录添加到系统路径
sys.path.append(root_dir)
import tempfile
import logging
from pathlib import Path
from datetime import datetime
from pydub import AudioSegment
import pysrt
import torch
import torchaudio
import traceback
from .utils.formatter import format_audio_list, find_latest_best_model
from .utils.gpt_train import train_gpt
from TTS.tts.configs.xtts_config import XttsConfig
from TTS.tts.models.xtts import Xtts
from .openvoice_cli.downloader import download_checkpoint
from .openvoice_cli.api import ToneColorConverter
import .openvoice_cli.se_extractor as se_extractor
from logging_utils import setup_logger, read_logs
# 设置日志处理器
setup_logger("logs/core_functions.log")
logger = logging.getLogger(__name__)
def clear_gpu_cache():
# clear the GPU cache
if torch.cuda.is_available():
torch.cuda.empty_cache()
XTTS_MODEL = None
def load_model(xtts_checkpoint, xtts_config, xtts_vocab,xtts_speaker):
global XTTS_MODEL
clear_gpu_cache()
if not xtts_checkpoint or not xtts_config or not xtts_vocab:
return "You need to run the previous steps or manually set the `XTTS checkpoint path`, `XTTS config path`, and `XTTS vocab path` fields !!"
config = XttsConfig()
config.load_json(xtts_config)
XTTS_MODEL = Xtts.init_from_config(config)
print("Loading XTTS model! ")
XTTS_MODEL.load_checkpoint(config, checkpoint_path=xtts_checkpoint, vocab_path=xtts_vocab,speaker_file_path=xtts_speaker, use_deepspeed=False)
if torch.cuda.is_available():
XTTS_MODEL.cuda()
print("Model Loaded!")
return "Model Loaded!"
def run_tts(lang, tts_text, speaker_audio_file, output_file_path, temperature, length_penalty, repetition_penalty, top_k, top_p, speed, sentence_split, use_config):
if XTTS_MODEL is None:
raise Exception("XTTS_MODEL is not loaded. Please load the model before running TTS.")
if not tts_text.strip():
raise ValueError("Text for TTS is empty.")
if not os.path.exists(speaker_audio_file):
raise FileNotFoundError(f"Speaker audio file not found: {speaker_audio_file}")
gpt_cond_latent, speaker_embedding = XTTS_MODEL.get_conditioning_latents(audio_path=speaker_audio_file, gpt_cond_len=XTTS_MODEL.config.gpt_cond_len, max_ref_length=XTTS_MODEL.config.max_ref_len, sound_norm_refs=XTTS_MODEL.config.sound_norm_refs)
if use_config:
out = XTTS_MODEL.inference(
text=tts_text,
language=lang,
gpt_cond_latent=gpt_cond_latent,
speaker_embedding=speaker_embedding,
temperature=XTTS_MODEL.config.temperature, # Add custom parameters here
length_penalty=XTTS_MODEL.config.length_penalty,
repetition_penalty=XTTS_MODEL.config.repetition_penalty,
top_k=XTTS_MODEL.config.top_k,
top_p=XTTS_MODEL.config.top_p,
speed=speed,
enable_text_splitting = True
)
else:
out = XTTS_MODEL.inference(
text=tts_text,
language=lang,
gpt_cond_latent=gpt_cond_latent,
speaker_embedding=speaker_embedding,
temperature=temperature, # Add custom parameters here
length_penalty=length_penalty,
repetition_penalty=float(repetition_penalty),
top_k=top_k,
top_p=top_p,
speed=speed,
enable_text_splitting = sentence_split
)
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as fp:
out["wav"] = torch.tensor(out["wav"]).unsqueeze(0)
out_path = fp.name
torchaudio.save(out_path, out["wav"], 24000)
return "Speech generated !", out_path, speaker_audio_file
def load_params_tts(out_path,version):
out_path = Path(out_path)
# base_model_path = Path.cwd() / "models" / version
# if not base_model_path.exists():
# return "Base model not found !","","",""
ready_model_path = out_path / "ready"
vocab_path = ready_model_path / "vocab.json"
config_path = ready_model_path / "config.json"
speaker_path = ready_model_path / "speakers_xtts.pth"
reference_path = ready_model_path / "reference.wav"
model_path = ready_model_path / "model.pth"
if not model_path.exists():
model_path = ready_model_path / "unoptimize_model.pth"
if not model_path.exists():
return "Params for TTS not found", "", "", ""
return "Params for TTS loaded", model_path, config_path, vocab_path,speaker_path, reference_path
def process_srt_and_generate_audio(
srt_file,
lang,
speaker_reference_audio,
temperature,
length_penalty,
repetition_penalty,
top_k,
top_p,
speed,
sentence_split,
use_config ):
try:
subtitles = pysrt.open(srt_file)
audio_files = []
output_dir = create_output_dir(parent_dir='/content/drive/MyDrive/Voice Conversion Result')
for index, subtitle in enumerate(subtitles):
audio_filename = f"audio_{index+1:03d}.wav"
audio_file_path = os.path.join(output_dir, audio_filename)
subtitle_text=remove_endperiod(subtitle.text)
run_tts(lang, subtitle_text, speaker_reference_audio, audio_file_path,
temperature, length_penalty, repetition_penalty, top_k, top_p,
speed, sentence_split, use_config)
logger.info(f"Generated audio file: {audio_file_path}")
audio_files.append(audio_file_path)
output_audio_path = merge_audio_with_srt_timing(subtitles, audio_files, output_dir)
return output_audio_path
except Exception as e:
logger.error(f"Error in process_srt_and_generate_audio: {e}")
raise
def create_output_dir(parent_dir):
try:
# 定义一个基于当前日期和时间的文件夹名称
folder_name = datetime.now().strftime("audio_outputs_%Y-%m-%d_%H-%M-%S")
# 定义父目录,这里假设在Colab的根目录
#parent_dir = "/content/drive/MyDrive/Voice Conversion Result"
# 完整的文件夹路径
output_dir = os.path.join(parent_dir, folder_name)
# 创建文件夹
if not os.path.exists(output_dir):
os.makedirs(output_dir)
logger.info(f"Created output directory at: {output_dir}")
return output_dir
except Exception as e:
logger.error(f"Failed to create output directory: {e}")
raise
def srt_time_to_ms(srt_time):
return (srt_time.hours * 3600 + srt_time.minutes * 60 + srt_time.seconds) * 1000 + srt_time.milliseconds
def merge_audio_with_srt_timing(subtitles, audio_files, output_dir):
try:
combined = AudioSegment.silent(duration=0)
last_position_ms = 0
for subtitle, audio_file in zip(subtitles, audio_files):
start_time_ms = srt_time_to_ms(subtitle.start)
if last_position_ms < start_time_ms:
silence_duration = start_time_ms - last_position_ms
combined += AudioSegment.silent(duration=silence_duration)
last_position_ms = start_time_ms
audio = AudioSegment.from_file(audio_file, format="wav")
combined += audio
last_position_ms += len(audio)
output_path = os.path.join(output_dir, "combined_audio_with_timing.wav")
#combined_with_set_frame_rate = combined.set_frame_rate(24000)
#combined_with_set_frame_rate.export(output_path, format="wav")
combined.export(output_path, format="wav")
logger.info(f"Exported combined audio to: {output_path}")
return output_path
except Exception as e:
logger.error(f"Error merging audio files: {e}")
raise
def remove_endperiod(subtitle):
"""Removes the period (.) at the end of a subtitle.
"""
if subtitle.endswith('.'):
subtitle = subtitle[:-1]
return subtitle
def convert_voice(reference_audio, audio_to_convert):
device = "cuda:0" if torch.cuda.is_available() else "cpu"
# 定义输入和输出音频路径
#input_audio_path = audio_to_convert
base_name, ext = os.path.splitext(os.path.basename(audio_to_convert))
new_file_name = base_name + 'convertedvoice' + ext
output_path = os.path.join(os.path.dirname(audio_to_convert), new_file_name)
tune_one(input_file=audio_to_convert, ref_file=reference_audio, output_file=output_path, device=device)
return output_path
def tune_one(input_file,ref_file,output_file,device):
current_dir = os.path.dirname(os.path.realpath(__file__))
checkpoints_dir = os.path.join(current_dir, 'checkpoints')
ckpt_converter = os.path.join(checkpoints_dir, 'converter')
if not os.path.exists(ckpt_converter):
os.makedirs(ckpt_converter, exist_ok=True)
download_checkpoint(ckpt_converter)
device = device
tone_color_converter = ToneColorConverter(os.path.join(ckpt_converter, 'config.json'), device=device)
tone_color_converter.load_ckpt(os.path.join(ckpt_converter, 'checkpoint.pth'))
source_se, _ = se_extractor.get_se(input_file, tone_color_converter, vad=True)
target_se, _ = se_extractor.get_se(ref_file, tone_color_converter, vad=True)
# Ensure output directory exists and is writable
output_dir = os.path.dirname(output_file)
if output_dir:
if not os.path.exists(output_dir):
os.makedirs(output_dir, exist_ok=True)
# Run the tone color converter
tone_color_converter.convert(
audio_src_path=input_file,
src_se=source_se,
tgt_se=target_se,
output_path=output_file,
)
'''
def tune_batch(input_dir, ref_file, output_dir=None, device='cpu', output_format='.wav'):
current_dir = os.path.dirname(os.path.realpath(__file__))
checkpoints_dir = os.path.join(current_dir, 'checkpoints')
ckpt_converter = os.path.join(checkpoints_dir, 'converter')
if not os.path.exists(ckpt_converter):
os.makedirs(ckpt_converter, exist_ok=True)
download_checkpoint(ckpt_converter)
tone_color_converter = ToneColorConverter(os.path.join(ckpt_converter, 'config.json'), device=device)
tone_color_converter.load_ckpt(os.path.join(ckpt_converter, 'checkpoint.pth'))
target_se, _ = se_extractor.get_se(ref_file, tone_color_converter, vad=True)
# Use default output directory 'out' if not provided
if output_dir is None:
output_dir = os.path.join(current_dir, 'out')
os.makedirs(output_dir, exist_ok=True)
# Check for any audio files in the input directory (wav, mp3, flac) using glob
audio_extensions = ('*.wav', '*.mp3', '*.flac')
audio_files = []
for extension in audio_extensions:
audio_files.extend(glob.glob(os.path.join(input_dir, extension)))
for audio_file in tqdm(audio_files,"Tune file",len(audio_files)):
# Extract source SE from audio file
source_se, _ = se_extractor.get_se(audio_file, tone_color_converter, vad=True)
# Run the tone color converter
filename_without_extension = os.path.splitext(os.path.basename(audio_file))[0]
output_filename = f"{filename_without_extension}_tuned{output_format}"
output_file = os.path.join(output_dir, output_filename)
tone_color_converter.convert(
audio_src_path=audio_file,
src_se=source_se,
tgt_se=target_se,
output_path=output_file,
)
print(f"Converted {audio_file} to {output_file}")
return output_dir
def main_single(args):
tune_one(input_file=args.input, ref_file=args.ref, output_file=args.output, device=args.device)
def main_batch(args):
output_dir = tune_batch(
input_dir=args.input_dir,
ref_file=args.ref_file,
output_dir=args.output_dir,
device=args.device,
output_format=args.output_format
)
print(f"Batch processing complete. Converted files are saved in {output_dir}")
'''