{ "cells": [ { "cell_type": "code", "execution_count": 27, "id": "d5a1b342-1dfb-4d79-98f7-250aa3b72d93", "metadata": {}, "outputs": [], "source": [ "# coding: utf-8\n", "import os\n", "import torch\n", "from vocos import Vocos\n", "import logging\n", "import py3langid as langid\n", "langid.set_languages(['en', 'zh', 'ja', 'vi'])\n", "\n", "import pathlib\n", "import platform\n", "if platform.system().lower() == 'windows':\n", " temp = pathlib.PosixPath\n", " pathlib.PosixPath = pathlib.WindowsPath\n", "else:\n", " temp = pathlib.WindowsPath\n", " pathlib.WindowsPath = pathlib.PosixPath\n", "\n", "import numpy as np\n", "from data.tokenizer import (\n", " AudioTokenizer,\n", " tokenize_audio,\n", ")\n", "from data.collation import get_text_token_collater\n", "from models.vallex import VALLE\n", "from utils.g2p import PhonemeBpeTokenizer\n", "from utils.sentence_cutter import split_text_into_sentences\n", "\n", "from macros import *\n", "\n", "device = torch.device(\"cpu\")\n", "# if torch.cuda.is_available():\n", "# device = torch.device(\"cuda\", 0)\n", "\n", "url = 'https://huggingface.co/Plachta/VALL-E-X/resolve/main/vallex-checkpoint.pt'\n", "\n", "checkpoints_dir = \"./checkpoints/\"\n", "\n", "model_checkpoint_name = \"vallex-checkpoint.pt\"\n", "\n", "model = None\n", "\n", "codec = None\n", "\n", "vocos = None\n", "\n", "text_tokenizer = PhonemeBpeTokenizer(tokenizer_path=\"./utils/g2p/bpe_69.json\")\n", "text_collater = get_text_token_collater()\n", "\n", "def preload_models(model_path):\n", " global model, codec, vocos\n", " # if not os.path.exists(checkpoints_dir): os.mkdir(checkpoints_dir)\n", " # if not os.path.exists(os.path.join(checkpoints_dir, model_checkpoint_name)):\n", " # import wget\n", " # try:\n", " # logging.info(\n", " # \"Downloading model from https://huggingface.co/Plachta/VALL-E-X/resolve/main/vallex-checkpoint.pt ...\")\n", " # # download from https://huggingface.co/Plachta/VALL-E-X/resolve/main/vallex-checkpoint.pt to ./checkpoints/vallex-checkpoint.pt\n", " # wget.download(\"https://huggingface.co/Plachta/VALL-E-X/resolve/main/vallex-checkpoint.pt\",\n", " # out=\"./checkpoints/vallex-checkpoint.pt\", bar=wget.bar_adaptive)\n", " # except Exception as e:\n", " # logging.info(e)\n", " # raise Exception(\n", " # \"\\n Model weights download failed, please go to 'https://huggingface.co/Plachta/VALL-E-X/resolve/main/vallex-checkpoint.pt'\"\n", " # \"\\n manually download model weights and put it to {} .\".format(os.getcwd() + \"\\checkpoints\"))\n", " # # VALL-E\n", " model = VALLE(\n", " N_DIM,\n", " NUM_HEAD,\n", " NUM_LAYERS,\n", " norm_first=True,\n", " add_prenet=False,\n", " prefix_mode=PREFIX_MODE,\n", " share_embedding=True,\n", " nar_scale_factor=1.0,\n", " prepend_bos=True,\n", " num_quantizers=NUM_QUANTIZERS,\n", " ).to(device)\n", " checkpoint = torch.load(model_path, map_location='cpu')\n", " missing_keys, unexpected_keys = model.load_state_dict(\n", " checkpoint[\"model\"], strict=True\n", " )\n", " assert not missing_keys\n", " model.eval()\n", "\n", " # Encodec\n", " codec = AudioTokenizer(device)\n", " \n", " vocos = Vocos.from_pretrained('charactr/vocos-encodec-24khz').to(device)\n", "\n", "@torch.no_grad()\n", "def generate_audio(text, prompt=None, language='auto', accent='no-accent'):\n", " global model, codec, vocos, text_tokenizer, text_collater\n", " text = text.replace(\"\\n\", \"\").strip(\" \")\n", " # detect language\n", " if language == \"auto\":\n", " language = langid.classify(text)[0]\n", " lang_token = lang2token[language]\n", " lang = token2lang[lang_token]\n", " text = lang_token + text + lang_token\n", "\n", " # load prompt\n", " if prompt is not None:\n", " prompt_path = prompt\n", " if not os.path.exists(prompt_path):\n", " prompt_path = \"./presets/\" + prompt + \".npz\"\n", " if not os.path.exists(prompt_path):\n", " prompt_path = \"./customs/\" + prompt + \".npz\"\n", " if not os.path.exists(prompt_path):\n", " raise ValueError(f\"Cannot find prompt {prompt}\")\n", " prompt_data = np.load(prompt_path)\n", " audio_prompts = prompt_data['audio_tokens']\n", " text_prompts = prompt_data['text_tokens']\n", " lang_pr = prompt_data['lang_code']\n", " lang_pr = code2lang[int(lang_pr)]\n", "\n", " # numpy to tensor\n", " audio_prompts = torch.tensor(audio_prompts).type(torch.int32).to(device)\n", " text_prompts = torch.tensor(text_prompts).type(torch.int32)\n", " else:\n", " audio_prompts = torch.zeros([1, 0, NUM_QUANTIZERS]).type(torch.int32).to(device)\n", " text_prompts = torch.zeros([1, 0]).type(torch.int32)\n", " lang_pr = lang if lang != 'mix' else 'en'\n", "\n", " enroll_x_lens = text_prompts.shape[-1]\n", " logging.info(f\"synthesize text: {text}\")\n", " phone_tokens, langs = text_tokenizer.tokenize(text=f\"_{text}\".strip())\n", " text_tokens, text_tokens_lens = text_collater(\n", " [\n", " phone_tokens\n", " ]\n", " )\n", " text_tokens = torch.cat([text_prompts, text_tokens], dim=-1)\n", " text_tokens_lens += enroll_x_lens\n", " # accent control\n", " lang = lang if accent == \"no-accent\" else token2lang[langdropdown2token[accent]]\n", " encoded_frames = model.inference(\n", " text_tokens.to(device),\n", " text_tokens_lens.to(device),\n", " audio_prompts,\n", " enroll_x_lens=enroll_x_lens,\n", " top_k=-100,\n", " temperature=1,\n", " prompt_language=lang_pr,\n", " text_language=langs if accent == \"no-accent\" else lang,\n", " )\n", " # Decode with Vocos\n", " frames = encoded_frames.permute(2,0,1)\n", " features = vocos.codes_to_features(frames)\n", " samples = vocos.decode(features, bandwidth_id=torch.tensor([2], device=device))\n", "\n", " return samples.squeeze().cpu().numpy()\n", "\n", "@torch.no_grad()\n", "def generate_audio_from_long_text(text, prompt=None, language='auto', accent='no-accent', mode='sliding-window'):\n", " \"\"\"\n", " For long audio generation, two modes are available.\n", " fixed-prompt: This mode will keep using the same prompt the user has provided, and generate audio sentence by sentence.\n", " sliding-window: This mode will use the last sentence as the prompt for the next sentence, but has some concern on speaker maintenance.\n", " \"\"\"\n", " global model, codec, vocos, text_tokenizer, text_collater\n", " if prompt is None or prompt == \"\":\n", " mode = 'sliding-window' # If no prompt is given, use sliding-window mode\n", " sentences = split_text_into_sentences(text)\n", " # detect language\n", " if language == \"auto\":\n", " language = langid.classify(text)[0]\n", "\n", " # if initial prompt is given, encode it\n", " if prompt is not None and prompt != \"\":\n", " prompt_path = prompt\n", " if not os.path.exists(prompt_path):\n", " prompt_path = \"./presets/\" + prompt + \".npz\"\n", " if not os.path.exists(prompt_path):\n", " prompt_path = \"./customs/\" + prompt + \".npz\"\n", " if not os.path.exists(prompt_path):\n", " raise ValueError(f\"Cannot find prompt {prompt}\")\n", " prompt_data = np.load(prompt_path)\n", " audio_prompts = prompt_data['audio_tokens']\n", " text_prompts = prompt_data['text_tokens']\n", " lang_pr = prompt_data['lang_code']\n", " lang_pr = code2lang[int(lang_pr)]\n", "\n", " # numpy to tensor\n", " audio_prompts = torch.tensor(audio_prompts).type(torch.int32).to(device)\n", " text_prompts = torch.tensor(text_prompts).type(torch.int32)\n", " else:\n", " audio_prompts = torch.zeros([1, 0, NUM_QUANTIZERS]).type(torch.int32).to(device)\n", " text_prompts = torch.zeros([1, 0]).type(torch.int32)\n", " lang_pr = language if language != 'mix' else 'en'\n", " if mode == 'fixed-prompt':\n", " complete_tokens = torch.zeros([1, NUM_QUANTIZERS, 0]).type(torch.LongTensor).to(device)\n", " for text in sentences:\n", " text = text.replace(\"\\n\", \"\").strip(\" \")\n", " if text == \"\":\n", " continue\n", " lang_token = lang2token[language]\n", " lang = token2lang[lang_token]\n", " text = lang_token + text + lang_token\n", "\n", " enroll_x_lens = text_prompts.shape[-1]\n", " logging.info(f\"synthesize text: {text}\")\n", " phone_tokens, langs = text_tokenizer.tokenize(text=f\"_{text}\".strip())\n", " text_tokens, text_tokens_lens = text_collater(\n", " [\n", " phone_tokens\n", " ]\n", " )\n", " text_tokens = torch.cat([text_prompts, text_tokens], dim=-1)\n", " text_tokens_lens += enroll_x_lens\n", " # accent control\n", " lang = lang if accent == \"no-accent\" else token2lang[langdropdown2token[accent]]\n", " encoded_frames = model.inference(\n", " text_tokens.to(device),\n", " text_tokens_lens.to(device),\n", " audio_prompts,\n", " enroll_x_lens=enroll_x_lens,\n", " top_k=-100,\n", " temperature=1,\n", " prompt_language=lang_pr,\n", " text_language=langs if accent == \"no-accent\" else lang,\n", " )\n", " complete_tokens = torch.cat([complete_tokens, encoded_frames.transpose(2, 1)], dim=-1)\n", " # Decode with Vocos\n", " frames = complete_tokens.permute(1,0,2)\n", " features = vocos.codes_to_features(frames)\n", " samples = vocos.decode(features, bandwidth_id=torch.tensor([2], device=device))\n", " return samples.squeeze().cpu().numpy()\n", " elif mode == \"sliding-window\":\n", " complete_tokens = torch.zeros([1, NUM_QUANTIZERS, 0]).type(torch.LongTensor).to(device)\n", " original_audio_prompts = audio_prompts\n", " original_text_prompts = text_prompts\n", " for text in sentences:\n", " text = text.replace(\"\\n\", \"\").strip(\" \")\n", " if text == \"\":\n", " continue\n", " lang_token = lang2token[language]\n", " lang = token2lang[lang_token]\n", " text = lang_token + text + lang_token\n", "\n", " enroll_x_lens = text_prompts.shape[-1]\n", " logging.info(f\"synthesize text: {text}\")\n", " phone_tokens, langs = text_tokenizer.tokenize(text=f\"_{text}\".strip())\n", " text_tokens, text_tokens_lens = text_collater(\n", " [\n", " phone_tokens\n", " ]\n", " )\n", " text_tokens = torch.cat([text_prompts, text_tokens], dim=-1)\n", " text_tokens_lens += enroll_x_lens\n", " # accent control\n", " lang = lang if accent == \"no-accent\" else token2lang[langdropdown2token[accent]]\n", " encoded_frames = model.inference(\n", " text_tokens.to(device),\n", " text_tokens_lens.to(device),\n", " audio_prompts,\n", " enroll_x_lens=enroll_x_lens,\n", " top_k=-100,\n", " temperature=1,\n", " prompt_language=lang_pr,\n", " text_language=langs if accent == \"no-accent\" else lang,\n", " )\n", " complete_tokens = torch.cat([complete_tokens, encoded_frames.transpose(2, 1)], dim=-1)\n", " if torch.rand(1) < 0.5:\n", " audio_prompts = encoded_frames[:, :, -NUM_QUANTIZERS:]\n", " text_prompts = text_tokens[:, enroll_x_lens:]\n", " else:\n", " audio_prompts = original_audio_prompts\n", " text_prompts = original_text_prompts\n", " # Decode with Vocos\n", " frames = complete_tokens.permute(1,0,2)\n", " features = vocos.codes_to_features(frames)\n", " samples = vocos.decode(features, bandwidth_id=torch.tensor([2], device=device))\n", " return samples.squeeze().cpu().numpy()\n", " else:\n", " raise ValueError(f\"No such mode {mode}\")" ] }, { "cell_type": "code", "execution_count": 32, "id": "a91524e3-407f-4baa-be5e-061d3fb97091", "metadata": {}, "outputs": [], "source": [ "# from utils.generation import SAMPLE_RATE, generate_audio, preload_models\n", "from scipy.io.wavfile import write as write_wav\n", "from IPython.display import Audio\n", "model = 'exp/valle_dev/epoch-1.pt'\n", "# download and load all models\n", "preload_models(model)\n" ] }, { "cell_type": "code", "execution_count": 33, "id": "ce872b2f-57c4-450e-989b-95932a923b47", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "VALL-E EOS [0 -> 603]\n" ] }, { "data": { "text/html": [ "\n", " \n", " " ], "text/plain": [ "" ] }, "execution_count": 33, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# generate audio from text\n", "text_prompt = \"\"\"\n", "他整個身體是瀰漫在 空間之間的 他所有的力氣 他所有的生長的力氣 已經消耗盡了.\n", "\"\"\"\n", "audio_array = generate_audio(text_prompt)\n", "\n", "# save audio to disk\n", "# write_wav(\"ep10.wav\", SAMPLE_RATE, audio_array)\n", "\n", "# play text in notebook\n", "Audio(audio_array, rate=SAMPLE_RATE)" ] }, { "cell_type": "code", "execution_count": null, "id": "1768174c-4d95-4c27-9e9f-ff720a1a4feb", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.12" } }, "nbformat": 4, "nbformat_minor": 5 }