{
"cells": [
{
"cell_type": "code",
"execution_count": 27,
"id": "d5a1b342-1dfb-4d79-98f7-250aa3b72d93",
"metadata": {},
"outputs": [],
"source": [
"# coding: utf-8\n",
"import os\n",
"import torch\n",
"from vocos import Vocos\n",
"import logging\n",
"import py3langid as langid\n",
"langid.set_languages(['en', 'zh', 'ja', 'vi'])\n",
"\n",
"import pathlib\n",
"import platform\n",
"if platform.system().lower() == 'windows':\n",
" temp = pathlib.PosixPath\n",
" pathlib.PosixPath = pathlib.WindowsPath\n",
"else:\n",
" temp = pathlib.WindowsPath\n",
" pathlib.WindowsPath = pathlib.PosixPath\n",
"\n",
"import numpy as np\n",
"from data.tokenizer import (\n",
" AudioTokenizer,\n",
" tokenize_audio,\n",
")\n",
"from data.collation import get_text_token_collater\n",
"from models.vallex import VALLE\n",
"from utils.g2p import PhonemeBpeTokenizer\n",
"from utils.sentence_cutter import split_text_into_sentences\n",
"\n",
"from macros import *\n",
"\n",
"device = torch.device(\"cpu\")\n",
"# if torch.cuda.is_available():\n",
"# device = torch.device(\"cuda\", 0)\n",
"\n",
"url = 'https://huggingface.co/Plachta/VALL-E-X/resolve/main/vallex-checkpoint.pt'\n",
"\n",
"checkpoints_dir = \"./checkpoints/\"\n",
"\n",
"model_checkpoint_name = \"vallex-checkpoint.pt\"\n",
"\n",
"model = None\n",
"\n",
"codec = None\n",
"\n",
"vocos = None\n",
"\n",
"text_tokenizer = PhonemeBpeTokenizer(tokenizer_path=\"./utils/g2p/bpe_69.json\")\n",
"text_collater = get_text_token_collater()\n",
"\n",
"def preload_models(model_path):\n",
" global model, codec, vocos\n",
" # if not os.path.exists(checkpoints_dir): os.mkdir(checkpoints_dir)\n",
" # if not os.path.exists(os.path.join(checkpoints_dir, model_checkpoint_name)):\n",
" # import wget\n",
" # try:\n",
" # logging.info(\n",
" # \"Downloading model from https://huggingface.co/Plachta/VALL-E-X/resolve/main/vallex-checkpoint.pt ...\")\n",
" # # download from https://huggingface.co/Plachta/VALL-E-X/resolve/main/vallex-checkpoint.pt to ./checkpoints/vallex-checkpoint.pt\n",
" # wget.download(\"https://huggingface.co/Plachta/VALL-E-X/resolve/main/vallex-checkpoint.pt\",\n",
" # out=\"./checkpoints/vallex-checkpoint.pt\", bar=wget.bar_adaptive)\n",
" # except Exception as e:\n",
" # logging.info(e)\n",
" # raise Exception(\n",
" # \"\\n Model weights download failed, please go to 'https://huggingface.co/Plachta/VALL-E-X/resolve/main/vallex-checkpoint.pt'\"\n",
" # \"\\n manually download model weights and put it to {} .\".format(os.getcwd() + \"\\checkpoints\"))\n",
" # # VALL-E\n",
" model = VALLE(\n",
" N_DIM,\n",
" NUM_HEAD,\n",
" NUM_LAYERS,\n",
" norm_first=True,\n",
" add_prenet=False,\n",
" prefix_mode=PREFIX_MODE,\n",
" share_embedding=True,\n",
" nar_scale_factor=1.0,\n",
" prepend_bos=True,\n",
" num_quantizers=NUM_QUANTIZERS,\n",
" ).to(device)\n",
" checkpoint = torch.load(model_path, map_location='cpu')\n",
" missing_keys, unexpected_keys = model.load_state_dict(\n",
" checkpoint[\"model\"], strict=True\n",
" )\n",
" assert not missing_keys\n",
" model.eval()\n",
"\n",
" # Encodec\n",
" codec = AudioTokenizer(device)\n",
" \n",
" vocos = Vocos.from_pretrained('charactr/vocos-encodec-24khz').to(device)\n",
"\n",
"@torch.no_grad()\n",
"def generate_audio(text, prompt=None, language='auto', accent='no-accent'):\n",
" global model, codec, vocos, text_tokenizer, text_collater\n",
" text = text.replace(\"\\n\", \"\").strip(\" \")\n",
" # detect language\n",
" if language == \"auto\":\n",
" language = langid.classify(text)[0]\n",
" lang_token = lang2token[language]\n",
" lang = token2lang[lang_token]\n",
" text = lang_token + text + lang_token\n",
"\n",
" # load prompt\n",
" if prompt is not None:\n",
" prompt_path = prompt\n",
" if not os.path.exists(prompt_path):\n",
" prompt_path = \"./presets/\" + prompt + \".npz\"\n",
" if not os.path.exists(prompt_path):\n",
" prompt_path = \"./customs/\" + prompt + \".npz\"\n",
" if not os.path.exists(prompt_path):\n",
" raise ValueError(f\"Cannot find prompt {prompt}\")\n",
" prompt_data = np.load(prompt_path)\n",
" audio_prompts = prompt_data['audio_tokens']\n",
" text_prompts = prompt_data['text_tokens']\n",
" lang_pr = prompt_data['lang_code']\n",
" lang_pr = code2lang[int(lang_pr)]\n",
"\n",
" # numpy to tensor\n",
" audio_prompts = torch.tensor(audio_prompts).type(torch.int32).to(device)\n",
" text_prompts = torch.tensor(text_prompts).type(torch.int32)\n",
" else:\n",
" audio_prompts = torch.zeros([1, 0, NUM_QUANTIZERS]).type(torch.int32).to(device)\n",
" text_prompts = torch.zeros([1, 0]).type(torch.int32)\n",
" lang_pr = lang if lang != 'mix' else 'en'\n",
"\n",
" enroll_x_lens = text_prompts.shape[-1]\n",
" logging.info(f\"synthesize text: {text}\")\n",
" phone_tokens, langs = text_tokenizer.tokenize(text=f\"_{text}\".strip())\n",
" text_tokens, text_tokens_lens = text_collater(\n",
" [\n",
" phone_tokens\n",
" ]\n",
" )\n",
" text_tokens = torch.cat([text_prompts, text_tokens], dim=-1)\n",
" text_tokens_lens += enroll_x_lens\n",
" # accent control\n",
" lang = lang if accent == \"no-accent\" else token2lang[langdropdown2token[accent]]\n",
" encoded_frames = model.inference(\n",
" text_tokens.to(device),\n",
" text_tokens_lens.to(device),\n",
" audio_prompts,\n",
" enroll_x_lens=enroll_x_lens,\n",
" top_k=-100,\n",
" temperature=1,\n",
" prompt_language=lang_pr,\n",
" text_language=langs if accent == \"no-accent\" else lang,\n",
" )\n",
" # Decode with Vocos\n",
" frames = encoded_frames.permute(2,0,1)\n",
" features = vocos.codes_to_features(frames)\n",
" samples = vocos.decode(features, bandwidth_id=torch.tensor([2], device=device))\n",
"\n",
" return samples.squeeze().cpu().numpy()\n",
"\n",
"@torch.no_grad()\n",
"def generate_audio_from_long_text(text, prompt=None, language='auto', accent='no-accent', mode='sliding-window'):\n",
" \"\"\"\n",
" For long audio generation, two modes are available.\n",
" fixed-prompt: This mode will keep using the same prompt the user has provided, and generate audio sentence by sentence.\n",
" sliding-window: This mode will use the last sentence as the prompt for the next sentence, but has some concern on speaker maintenance.\n",
" \"\"\"\n",
" global model, codec, vocos, text_tokenizer, text_collater\n",
" if prompt is None or prompt == \"\":\n",
" mode = 'sliding-window' # If no prompt is given, use sliding-window mode\n",
" sentences = split_text_into_sentences(text)\n",
" # detect language\n",
" if language == \"auto\":\n",
" language = langid.classify(text)[0]\n",
"\n",
" # if initial prompt is given, encode it\n",
" if prompt is not None and prompt != \"\":\n",
" prompt_path = prompt\n",
" if not os.path.exists(prompt_path):\n",
" prompt_path = \"./presets/\" + prompt + \".npz\"\n",
" if not os.path.exists(prompt_path):\n",
" prompt_path = \"./customs/\" + prompt + \".npz\"\n",
" if not os.path.exists(prompt_path):\n",
" raise ValueError(f\"Cannot find prompt {prompt}\")\n",
" prompt_data = np.load(prompt_path)\n",
" audio_prompts = prompt_data['audio_tokens']\n",
" text_prompts = prompt_data['text_tokens']\n",
" lang_pr = prompt_data['lang_code']\n",
" lang_pr = code2lang[int(lang_pr)]\n",
"\n",
" # numpy to tensor\n",
" audio_prompts = torch.tensor(audio_prompts).type(torch.int32).to(device)\n",
" text_prompts = torch.tensor(text_prompts).type(torch.int32)\n",
" else:\n",
" audio_prompts = torch.zeros([1, 0, NUM_QUANTIZERS]).type(torch.int32).to(device)\n",
" text_prompts = torch.zeros([1, 0]).type(torch.int32)\n",
" lang_pr = language if language != 'mix' else 'en'\n",
" if mode == 'fixed-prompt':\n",
" complete_tokens = torch.zeros([1, NUM_QUANTIZERS, 0]).type(torch.LongTensor).to(device)\n",
" for text in sentences:\n",
" text = text.replace(\"\\n\", \"\").strip(\" \")\n",
" if text == \"\":\n",
" continue\n",
" lang_token = lang2token[language]\n",
" lang = token2lang[lang_token]\n",
" text = lang_token + text + lang_token\n",
"\n",
" enroll_x_lens = text_prompts.shape[-1]\n",
" logging.info(f\"synthesize text: {text}\")\n",
" phone_tokens, langs = text_tokenizer.tokenize(text=f\"_{text}\".strip())\n",
" text_tokens, text_tokens_lens = text_collater(\n",
" [\n",
" phone_tokens\n",
" ]\n",
" )\n",
" text_tokens = torch.cat([text_prompts, text_tokens], dim=-1)\n",
" text_tokens_lens += enroll_x_lens\n",
" # accent control\n",
" lang = lang if accent == \"no-accent\" else token2lang[langdropdown2token[accent]]\n",
" encoded_frames = model.inference(\n",
" text_tokens.to(device),\n",
" text_tokens_lens.to(device),\n",
" audio_prompts,\n",
" enroll_x_lens=enroll_x_lens,\n",
" top_k=-100,\n",
" temperature=1,\n",
" prompt_language=lang_pr,\n",
" text_language=langs if accent == \"no-accent\" else lang,\n",
" )\n",
" complete_tokens = torch.cat([complete_tokens, encoded_frames.transpose(2, 1)], dim=-1)\n",
" # Decode with Vocos\n",
" frames = complete_tokens.permute(1,0,2)\n",
" features = vocos.codes_to_features(frames)\n",
" samples = vocos.decode(features, bandwidth_id=torch.tensor([2], device=device))\n",
" return samples.squeeze().cpu().numpy()\n",
" elif mode == \"sliding-window\":\n",
" complete_tokens = torch.zeros([1, NUM_QUANTIZERS, 0]).type(torch.LongTensor).to(device)\n",
" original_audio_prompts = audio_prompts\n",
" original_text_prompts = text_prompts\n",
" for text in sentences:\n",
" text = text.replace(\"\\n\", \"\").strip(\" \")\n",
" if text == \"\":\n",
" continue\n",
" lang_token = lang2token[language]\n",
" lang = token2lang[lang_token]\n",
" text = lang_token + text + lang_token\n",
"\n",
" enroll_x_lens = text_prompts.shape[-1]\n",
" logging.info(f\"synthesize text: {text}\")\n",
" phone_tokens, langs = text_tokenizer.tokenize(text=f\"_{text}\".strip())\n",
" text_tokens, text_tokens_lens = text_collater(\n",
" [\n",
" phone_tokens\n",
" ]\n",
" )\n",
" text_tokens = torch.cat([text_prompts, text_tokens], dim=-1)\n",
" text_tokens_lens += enroll_x_lens\n",
" # accent control\n",
" lang = lang if accent == \"no-accent\" else token2lang[langdropdown2token[accent]]\n",
" encoded_frames = model.inference(\n",
" text_tokens.to(device),\n",
" text_tokens_lens.to(device),\n",
" audio_prompts,\n",
" enroll_x_lens=enroll_x_lens,\n",
" top_k=-100,\n",
" temperature=1,\n",
" prompt_language=lang_pr,\n",
" text_language=langs if accent == \"no-accent\" else lang,\n",
" )\n",
" complete_tokens = torch.cat([complete_tokens, encoded_frames.transpose(2, 1)], dim=-1)\n",
" if torch.rand(1) < 0.5:\n",
" audio_prompts = encoded_frames[:, :, -NUM_QUANTIZERS:]\n",
" text_prompts = text_tokens[:, enroll_x_lens:]\n",
" else:\n",
" audio_prompts = original_audio_prompts\n",
" text_prompts = original_text_prompts\n",
" # Decode with Vocos\n",
" frames = complete_tokens.permute(1,0,2)\n",
" features = vocos.codes_to_features(frames)\n",
" samples = vocos.decode(features, bandwidth_id=torch.tensor([2], device=device))\n",
" return samples.squeeze().cpu().numpy()\n",
" else:\n",
" raise ValueError(f\"No such mode {mode}\")"
]
},
{
"cell_type": "code",
"execution_count": 32,
"id": "a91524e3-407f-4baa-be5e-061d3fb97091",
"metadata": {},
"outputs": [],
"source": [
"# from utils.generation import SAMPLE_RATE, generate_audio, preload_models\n",
"from scipy.io.wavfile import write as write_wav\n",
"from IPython.display import Audio\n",
"model = 'exp/valle_dev/epoch-1.pt'\n",
"# download and load all models\n",
"preload_models(model)\n"
]
},
{
"cell_type": "code",
"execution_count": 33,
"id": "ce872b2f-57c4-450e-989b-95932a923b47",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"VALL-E EOS [0 -> 603]\n"
]
},
{
"data": {
"text/html": [
"\n",
" \n",
" "
],
"text/plain": [
""
]
},
"execution_count": 33,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# generate audio from text\n",
"text_prompt = \"\"\"\n",
"他整個身體是瀰漫在 空間之間的 他所有的力氣 他所有的生長的力氣 已經消耗盡了.\n",
"\"\"\"\n",
"audio_array = generate_audio(text_prompt)\n",
"\n",
"# save audio to disk\n",
"# write_wav(\"ep10.wav\", SAMPLE_RATE, audio_array)\n",
"\n",
"# play text in notebook\n",
"Audio(audio_array, rate=SAMPLE_RATE)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "1768174c-4d95-4c27-9e9f-ff720a1a4feb",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.12"
}
},
"nbformat": 4,
"nbformat_minor": 5
}