""" Synthesize a given text using the trained DiT models. """ import json import os os.environ["NLTK_DATA"] = "nltk_data" import torch import yaml from g2p_en import G2p from vocos import Vocos from sample import sample def synthesize( text, duration_model_config, duration_model_checkpoint, acoustic_model_config, acoustic_model_checkpoint, speaker_id, cfg_scale=4.0, num_sampling_steps=1000, ): """ Synthesize speech from text using trained DiT models. Args: text (str): Input text to synthesize duration_model_config (str): Path to duration model config file duration_model_checkpoint (str): Path to duration model checkpoint acoustic_model_config (str): Path to acoustic model config file acoustic_model_checkpoint (str): Path to acoustic model checkpoint speaker_id (str): Speaker ID to use for synthesis cfg_scale (float): Classifier-free guidance scale (default: 4.0) num_sampling_steps (int): Number of sampling steps for diffusion (default: 1000) Returns: numpy.ndarray: Audio waveform array int: Sample rate (24000) """ print("Text:", text) # Read duration model config with open(duration_model_config, "r") as f: duration_config = yaml.safe_load(f) # Get data directory from data_path data_dir = os.path.dirname(duration_config["data"]["data_path"]) # Read maps.json from same directory with open(os.path.join(data_dir, "maps.json"), "r") as f: maps = json.load(f) phone_to_idx = maps["phone_to_idx"] phone_kind_to_idx = maps["phone_kind_to_idx"] speaker_id_to_idx = maps["speaker_id_to_idx"] # Step 1: Text to phonemes def text_to_phonemes(text, insert_empty=True): g2p = G2p() phonemes = g2p(text) words = [] word = [] for p in phonemes: if p == " ": if len(word) > 0: words.append(word) word = [] else: word.append(p) if len(word) > 0: words.append(word) phones = [] phone_kinds = [] for word in words: for i, p in enumerate(word): if p in [",", ".", "!", "?", ";", ":"]: p = "EMPTY" elif p in phone_to_idx: pass else: continue if p == "EMPTY": phone_kind = "EMPTY" elif len(word) == 1: phone_kind = "WORD" elif i == 0: phone_kind = "START" elif i == len(word) - 1: phone_kind = "END" else: phone_kind = "MIDDLE" phones.append(p) phone_kinds.append(phone_kind) if insert_empty: if phones[0] != "EMPTY": phones.insert(0, "EMPTY") phone_kinds.insert(0, "EMPTY") if phones[-1] != "EMPTY": phones.append("EMPTY") phone_kinds.append("EMPTY") return phones, phone_kinds phonemes, phone_kinds = text_to_phonemes(text) # Convert phonemes to indices phoneme_indices = [phone_to_idx[p] for p in phonemes] phone_kind_indices = [phone_kind_to_idx[p] for p in phone_kinds] print("Phonemes:", phonemes) # Step 2: Duration prediction device = torch.device("cuda") # if torch.cuda.is_available() else "cpu") torch_phoneme_indices = torch.tensor(phoneme_indices)[None, :].long().to(device) torch_speaker_id = torch.full_like(torch_phoneme_indices, int(speaker_id)) torch_phone_kind_indices = ( torch.tensor(phone_kind_indices)[None, :].long().to(device) ) samples = sample( duration_model_config, duration_model_checkpoint, cfg_scale=cfg_scale, num_sampling_steps=num_sampling_steps, seed=0, speaker_id=torch_speaker_id, phone=torch_phoneme_indices, phone_kind=torch_phone_kind_indices, ) phoneme_durations = samples[-1][0, 0] # Step 3: Acoustic prediction # First, we need to convert phoneme durations to number of frames per phoneme (min 1 frame) SAMPLE_RATE = 24000 HOP_LENGTH = 256 N_FFT = 1024 N_MELS = 100 time_per_frame = HOP_LENGTH / SAMPLE_RATE # convert predicted durations to raw durations using data mean and std in the config if duration_config["data"]["normalize"]: mean = duration_config["data"]["data_mean"] std = duration_config["data"]["data_std"] raw_durations = phoneme_durations * std + mean else: raw_durations = phoneme_durations raw_durations = raw_durations.clamp(min=time_per_frame, max=1.0) end_time = torch.cumsum(raw_durations, dim=0) end_frame = end_time / time_per_frame int_end_frame = end_frame.floor().int() repeated_phoneme_indices = [] repeated_phone_kind_indices = [] for i in range(len(phonemes)): repeated_phoneme_indices.extend( [phoneme_indices[i]] * (int_end_frame[i] - len(repeated_phoneme_indices)) ) repeated_phone_kind_indices.extend( [phone_kind_indices[i]] * (int_end_frame[i] - len(repeated_phone_kind_indices)) ) torch_phoneme_indices = ( torch.tensor(repeated_phoneme_indices)[None, :].long().to(device) ) torch_speaker_id = torch.full_like(torch_phoneme_indices, int(speaker_id)) torch_phone_kind_indices = ( torch.tensor(repeated_phone_kind_indices)[None, :].long().to(device) ) samples = sample( acoustic_model_config, acoustic_model_checkpoint, cfg_scale=cfg_scale, num_sampling_steps=num_sampling_steps, seed=0, speaker_id=torch_speaker_id, phone=torch_phoneme_indices, phone_kind=torch_phone_kind_indices, ) mel = samples[-1][0] # compute raw mel if acoustic model normalize is true acoustic_config = yaml.safe_load(open(acoustic_model_config, "r")) if acoustic_config["data"]["normalize"]: mean = acoustic_config["data"]["data_mean"] std = acoustic_config["data"]["data_std"] raw_mel = mel * std + mean else: raw_mel = mel # Step 4: Vocoder vocos = Vocos.from_pretrained("charactr/vocos-mel-24khz") audio = vocos.decode(raw_mel.cpu()[None, :, :]).squeeze().cpu().numpy() return audio, SAMPLE_RATE