diff --git "a/infer.ipynb" "b/infer.ipynb" new file mode 100644--- /dev/null +++ "b/infer.ipynb" @@ -0,0 +1,378 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 27, + "id": "d5a1b342-1dfb-4d79-98f7-250aa3b72d93", + "metadata": {}, + "outputs": [], + "source": [ + "# coding: utf-8\n", + "import os\n", + "import torch\n", + "from vocos import Vocos\n", + "import logging\n", + "import py3langid as langid\n", + "langid.set_languages(['en', 'zh', 'ja', 'vi'])\n", + "\n", + "import pathlib\n", + "import platform\n", + "if platform.system().lower() == 'windows':\n", + " temp = pathlib.PosixPath\n", + " pathlib.PosixPath = pathlib.WindowsPath\n", + "else:\n", + " temp = pathlib.WindowsPath\n", + " pathlib.WindowsPath = pathlib.PosixPath\n", + "\n", + "import numpy as np\n", + "from data.tokenizer import (\n", + " AudioTokenizer,\n", + " tokenize_audio,\n", + ")\n", + "from data.collation import get_text_token_collater\n", + "from models.vallex import VALLE\n", + "from utils.g2p import PhonemeBpeTokenizer\n", + "from utils.sentence_cutter import split_text_into_sentences\n", + "\n", + "from macros import *\n", + "\n", + "device = torch.device(\"cpu\")\n", + "# if torch.cuda.is_available():\n", + "# device = torch.device(\"cuda\", 0)\n", + "\n", + "url = 'https://huggingface.co/Plachta/VALL-E-X/resolve/main/vallex-checkpoint.pt'\n", + "\n", + "checkpoints_dir = \"./checkpoints/\"\n", + "\n", + "model_checkpoint_name = \"vallex-checkpoint.pt\"\n", + "\n", + "model = None\n", + "\n", + "codec = None\n", + "\n", + "vocos = None\n", + "\n", + "text_tokenizer = PhonemeBpeTokenizer(tokenizer_path=\"./utils/g2p/bpe_69.json\")\n", + "text_collater = get_text_token_collater()\n", + "\n", + "def preload_models(model_path):\n", + " global model, codec, vocos\n", + " # if not os.path.exists(checkpoints_dir): os.mkdir(checkpoints_dir)\n", + " # if not os.path.exists(os.path.join(checkpoints_dir, model_checkpoint_name)):\n", + " # import wget\n", + " # try:\n", + " # logging.info(\n", + " # \"Downloading model from https://huggingface.co/Plachta/VALL-E-X/resolve/main/vallex-checkpoint.pt ...\")\n", + " # # download from https://huggingface.co/Plachta/VALL-E-X/resolve/main/vallex-checkpoint.pt to ./checkpoints/vallex-checkpoint.pt\n", + " # wget.download(\"https://huggingface.co/Plachta/VALL-E-X/resolve/main/vallex-checkpoint.pt\",\n", + " # out=\"./checkpoints/vallex-checkpoint.pt\", bar=wget.bar_adaptive)\n", + " # except Exception as e:\n", + " # logging.info(e)\n", + " # raise Exception(\n", + " # \"\\n Model weights download failed, please go to 'https://huggingface.co/Plachta/VALL-E-X/resolve/main/vallex-checkpoint.pt'\"\n", + " # \"\\n manually download model weights and put it to {} .\".format(os.getcwd() + \"\\checkpoints\"))\n", + " # # VALL-E\n", + " model = VALLE(\n", + " N_DIM,\n", + " NUM_HEAD,\n", + " NUM_LAYERS,\n", + " norm_first=True,\n", + " add_prenet=False,\n", + " prefix_mode=PREFIX_MODE,\n", + " share_embedding=True,\n", + " nar_scale_factor=1.0,\n", + " prepend_bos=True,\n", + " num_quantizers=NUM_QUANTIZERS,\n", + " ).to(device)\n", + " checkpoint = torch.load(model_path, map_location='cpu')\n", + " missing_keys, unexpected_keys = model.load_state_dict(\n", + " checkpoint[\"model\"], strict=True\n", + " )\n", + " assert not missing_keys\n", + " model.eval()\n", + "\n", + " # Encodec\n", + " codec = AudioTokenizer(device)\n", + " \n", + " vocos = Vocos.from_pretrained('charactr/vocos-encodec-24khz').to(device)\n", + "\n", + "@torch.no_grad()\n", + "def generate_audio(text, prompt=None, language='auto', accent='no-accent'):\n", + " global model, codec, vocos, text_tokenizer, text_collater\n", + " text = text.replace(\"\\n\", \"\").strip(\" \")\n", + " # detect language\n", + " if language == \"auto\":\n", + " language = langid.classify(text)[0]\n", + " lang_token = lang2token[language]\n", + " lang = token2lang[lang_token]\n", + " text = lang_token + text + lang_token\n", + "\n", + " # load prompt\n", + " if prompt is not None:\n", + " prompt_path = prompt\n", + " if not os.path.exists(prompt_path):\n", + " prompt_path = \"./presets/\" + prompt + \".npz\"\n", + " if not os.path.exists(prompt_path):\n", + " prompt_path = \"./customs/\" + prompt + \".npz\"\n", + " if not os.path.exists(prompt_path):\n", + " raise ValueError(f\"Cannot find prompt {prompt}\")\n", + " prompt_data = np.load(prompt_path)\n", + " audio_prompts = prompt_data['audio_tokens']\n", + " text_prompts = prompt_data['text_tokens']\n", + " lang_pr = prompt_data['lang_code']\n", + " lang_pr = code2lang[int(lang_pr)]\n", + "\n", + " # numpy to tensor\n", + " audio_prompts = torch.tensor(audio_prompts).type(torch.int32).to(device)\n", + " text_prompts = torch.tensor(text_prompts).type(torch.int32)\n", + " else:\n", + " audio_prompts = torch.zeros([1, 0, NUM_QUANTIZERS]).type(torch.int32).to(device)\n", + " text_prompts = torch.zeros([1, 0]).type(torch.int32)\n", + " lang_pr = lang if lang != 'mix' else 'en'\n", + "\n", + " enroll_x_lens = text_prompts.shape[-1]\n", + " logging.info(f\"synthesize text: {text}\")\n", + " phone_tokens, langs = text_tokenizer.tokenize(text=f\"_{text}\".strip())\n", + " text_tokens, text_tokens_lens = text_collater(\n", + " [\n", + " phone_tokens\n", + " ]\n", + " )\n", + " text_tokens = torch.cat([text_prompts, text_tokens], dim=-1)\n", + " text_tokens_lens += enroll_x_lens\n", + " # accent control\n", + " lang = lang if accent == \"no-accent\" else token2lang[langdropdown2token[accent]]\n", + " encoded_frames = model.inference(\n", + " text_tokens.to(device),\n", + " text_tokens_lens.to(device),\n", + " audio_prompts,\n", + " enroll_x_lens=enroll_x_lens,\n", + " top_k=-100,\n", + " temperature=1,\n", + " prompt_language=lang_pr,\n", + " text_language=langs if accent == \"no-accent\" else lang,\n", + " )\n", + " # Decode with Vocos\n", + " frames = encoded_frames.permute(2,0,1)\n", + " features = vocos.codes_to_features(frames)\n", + " samples = vocos.decode(features, bandwidth_id=torch.tensor([2], device=device))\n", + "\n", + " return samples.squeeze().cpu().numpy()\n", + "\n", + "@torch.no_grad()\n", + "def generate_audio_from_long_text(text, prompt=None, language='auto', accent='no-accent', mode='sliding-window'):\n", + " \"\"\"\n", + " For long audio generation, two modes are available.\n", + " fixed-prompt: This mode will keep using the same prompt the user has provided, and generate audio sentence by sentence.\n", + " sliding-window: This mode will use the last sentence as the prompt for the next sentence, but has some concern on speaker maintenance.\n", + " \"\"\"\n", + " global model, codec, vocos, text_tokenizer, text_collater\n", + " if prompt is None or prompt == \"\":\n", + " mode = 'sliding-window' # If no prompt is given, use sliding-window mode\n", + " sentences = split_text_into_sentences(text)\n", + " # detect language\n", + " if language == \"auto\":\n", + " language = langid.classify(text)[0]\n", + "\n", + " # if initial prompt is given, encode it\n", + " if prompt is not None and prompt != \"\":\n", + " prompt_path = prompt\n", + " if not os.path.exists(prompt_path):\n", + " prompt_path = \"./presets/\" + prompt + \".npz\"\n", + " if not os.path.exists(prompt_path):\n", + " prompt_path = \"./customs/\" + prompt + \".npz\"\n", + " if not os.path.exists(prompt_path):\n", + " raise ValueError(f\"Cannot find prompt {prompt}\")\n", + " prompt_data = np.load(prompt_path)\n", + " audio_prompts = prompt_data['audio_tokens']\n", + " text_prompts = prompt_data['text_tokens']\n", + " lang_pr = prompt_data['lang_code']\n", + " lang_pr = code2lang[int(lang_pr)]\n", + "\n", + " # numpy to tensor\n", + " audio_prompts = torch.tensor(audio_prompts).type(torch.int32).to(device)\n", + " text_prompts = torch.tensor(text_prompts).type(torch.int32)\n", + " else:\n", + " audio_prompts = torch.zeros([1, 0, NUM_QUANTIZERS]).type(torch.int32).to(device)\n", + " text_prompts = torch.zeros([1, 0]).type(torch.int32)\n", + " lang_pr = language if language != 'mix' else 'en'\n", + " if mode == 'fixed-prompt':\n", + " complete_tokens = torch.zeros([1, NUM_QUANTIZERS, 0]).type(torch.LongTensor).to(device)\n", + " for text in sentences:\n", + " text = text.replace(\"\\n\", \"\").strip(\" \")\n", + " if text == \"\":\n", + " continue\n", + " lang_token = lang2token[language]\n", + " lang = token2lang[lang_token]\n", + " text = lang_token + text + lang_token\n", + "\n", + " enroll_x_lens = text_prompts.shape[-1]\n", + " logging.info(f\"synthesize text: {text}\")\n", + " phone_tokens, langs = text_tokenizer.tokenize(text=f\"_{text}\".strip())\n", + " text_tokens, text_tokens_lens = text_collater(\n", + " [\n", + " phone_tokens\n", + " ]\n", + " )\n", + " text_tokens = torch.cat([text_prompts, text_tokens], dim=-1)\n", + " text_tokens_lens += enroll_x_lens\n", + " # accent control\n", + " lang = lang if accent == \"no-accent\" else token2lang[langdropdown2token[accent]]\n", + " encoded_frames = model.inference(\n", + " text_tokens.to(device),\n", + " text_tokens_lens.to(device),\n", + " audio_prompts,\n", + " enroll_x_lens=enroll_x_lens,\n", + " top_k=-100,\n", + " temperature=1,\n", + " prompt_language=lang_pr,\n", + " text_language=langs if accent == \"no-accent\" else lang,\n", + " )\n", + " complete_tokens = torch.cat([complete_tokens, encoded_frames.transpose(2, 1)], dim=-1)\n", + " # Decode with Vocos\n", + " frames = complete_tokens.permute(1,0,2)\n", + " features = vocos.codes_to_features(frames)\n", + " samples = vocos.decode(features, bandwidth_id=torch.tensor([2], device=device))\n", + " return samples.squeeze().cpu().numpy()\n", + " elif mode == \"sliding-window\":\n", + " complete_tokens = torch.zeros([1, NUM_QUANTIZERS, 0]).type(torch.LongTensor).to(device)\n", + " original_audio_prompts = audio_prompts\n", + " original_text_prompts = text_prompts\n", + " for text in sentences:\n", + " text = text.replace(\"\\n\", \"\").strip(\" \")\n", + " if text == \"\":\n", + " continue\n", + " lang_token = lang2token[language]\n", + " lang = token2lang[lang_token]\n", + " text = lang_token + text + lang_token\n", + "\n", + " enroll_x_lens = text_prompts.shape[-1]\n", + " logging.info(f\"synthesize text: {text}\")\n", + " phone_tokens, langs = text_tokenizer.tokenize(text=f\"_{text}\".strip())\n", + " text_tokens, text_tokens_lens = text_collater(\n", + " [\n", + " phone_tokens\n", + " ]\n", + " )\n", + " text_tokens = torch.cat([text_prompts, text_tokens], dim=-1)\n", + " text_tokens_lens += enroll_x_lens\n", + " # accent control\n", + " lang = lang if accent == \"no-accent\" else token2lang[langdropdown2token[accent]]\n", + " encoded_frames = model.inference(\n", + " text_tokens.to(device),\n", + " text_tokens_lens.to(device),\n", + " audio_prompts,\n", + " enroll_x_lens=enroll_x_lens,\n", + " top_k=-100,\n", + " temperature=1,\n", + " prompt_language=lang_pr,\n", + " text_language=langs if accent == \"no-accent\" else lang,\n", + " )\n", + " complete_tokens = torch.cat([complete_tokens, encoded_frames.transpose(2, 1)], dim=-1)\n", + " if torch.rand(1) < 0.5:\n", + " audio_prompts = encoded_frames[:, :, -NUM_QUANTIZERS:]\n", + " text_prompts = text_tokens[:, enroll_x_lens:]\n", + " else:\n", + " audio_prompts = original_audio_prompts\n", + " text_prompts = original_text_prompts\n", + " # Decode with Vocos\n", + " frames = complete_tokens.permute(1,0,2)\n", + " features = vocos.codes_to_features(frames)\n", + " samples = vocos.decode(features, bandwidth_id=torch.tensor([2], device=device))\n", + " return samples.squeeze().cpu().numpy()\n", + " else:\n", + " raise ValueError(f\"No such mode {mode}\")" + ] + }, + { + "cell_type": "code", + "execution_count": 32, + "id": "a91524e3-407f-4baa-be5e-061d3fb97091", + "metadata": {}, + "outputs": [], + "source": [ + "# from utils.generation import SAMPLE_RATE, generate_audio, preload_models\n", + "from scipy.io.wavfile import write as write_wav\n", + "from IPython.display import Audio\n", + "model = 'exp/valle_dev/epoch-1.pt'\n", + "# download and load all models\n", + "preload_models(model)\n" + ] + }, + { + "cell_type": "code", + "execution_count": 33, + "id": "ce872b2f-57c4-450e-989b-95932a923b47", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "VALL-E EOS [0 -> 603]\n" + ] + }, + { + "data": { + "text/html": [ + "\n", + " \n", + " " + ], + "text/plain": [ + "" + ] + }, + "execution_count": 33, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# generate audio from text\n", + "text_prompt = \"\"\"\n", + "他整個身體是瀰漫在 空間之間的 他所有的力氣 他所有的生長的力氣 已經消耗盡了.\n", + "\"\"\"\n", + "audio_array = generate_audio(text_prompt)\n", + "\n", + "# save audio to disk\n", + "# write_wav(\"ep10.wav\", SAMPLE_RATE, audio_array)\n", + "\n", + "# play text in notebook\n", + "Audio(audio_array, rate=SAMPLE_RATE)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1768174c-4d95-4c27-9e9f-ff720a1a4feb", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.12" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +}