diff --git "a/infer.ipynb" "b/infer.ipynb"
new file mode 100644--- /dev/null
+++ "b/infer.ipynb"
@@ -0,0 +1,378 @@
+{
+ "cells": [
+ {
+ "cell_type": "code",
+ "execution_count": 27,
+ "id": "d5a1b342-1dfb-4d79-98f7-250aa3b72d93",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# coding: utf-8\n",
+ "import os\n",
+ "import torch\n",
+ "from vocos import Vocos\n",
+ "import logging\n",
+ "import py3langid as langid\n",
+ "langid.set_languages(['en', 'zh', 'ja', 'vi'])\n",
+ "\n",
+ "import pathlib\n",
+ "import platform\n",
+ "if platform.system().lower() == 'windows':\n",
+ " temp = pathlib.PosixPath\n",
+ " pathlib.PosixPath = pathlib.WindowsPath\n",
+ "else:\n",
+ " temp = pathlib.WindowsPath\n",
+ " pathlib.WindowsPath = pathlib.PosixPath\n",
+ "\n",
+ "import numpy as np\n",
+ "from data.tokenizer import (\n",
+ " AudioTokenizer,\n",
+ " tokenize_audio,\n",
+ ")\n",
+ "from data.collation import get_text_token_collater\n",
+ "from models.vallex import VALLE\n",
+ "from utils.g2p import PhonemeBpeTokenizer\n",
+ "from utils.sentence_cutter import split_text_into_sentences\n",
+ "\n",
+ "from macros import *\n",
+ "\n",
+ "device = torch.device(\"cpu\")\n",
+ "# if torch.cuda.is_available():\n",
+ "# device = torch.device(\"cuda\", 0)\n",
+ "\n",
+ "url = 'https://huggingface.co/Plachta/VALL-E-X/resolve/main/vallex-checkpoint.pt'\n",
+ "\n",
+ "checkpoints_dir = \"./checkpoints/\"\n",
+ "\n",
+ "model_checkpoint_name = \"vallex-checkpoint.pt\"\n",
+ "\n",
+ "model = None\n",
+ "\n",
+ "codec = None\n",
+ "\n",
+ "vocos = None\n",
+ "\n",
+ "text_tokenizer = PhonemeBpeTokenizer(tokenizer_path=\"./utils/g2p/bpe_69.json\")\n",
+ "text_collater = get_text_token_collater()\n",
+ "\n",
+ "def preload_models(model_path):\n",
+ " global model, codec, vocos\n",
+ " # if not os.path.exists(checkpoints_dir): os.mkdir(checkpoints_dir)\n",
+ " # if not os.path.exists(os.path.join(checkpoints_dir, model_checkpoint_name)):\n",
+ " # import wget\n",
+ " # try:\n",
+ " # logging.info(\n",
+ " # \"Downloading model from https://huggingface.co/Plachta/VALL-E-X/resolve/main/vallex-checkpoint.pt ...\")\n",
+ " # # download from https://huggingface.co/Plachta/VALL-E-X/resolve/main/vallex-checkpoint.pt to ./checkpoints/vallex-checkpoint.pt\n",
+ " # wget.download(\"https://huggingface.co/Plachta/VALL-E-X/resolve/main/vallex-checkpoint.pt\",\n",
+ " # out=\"./checkpoints/vallex-checkpoint.pt\", bar=wget.bar_adaptive)\n",
+ " # except Exception as e:\n",
+ " # logging.info(e)\n",
+ " # raise Exception(\n",
+ " # \"\\n Model weights download failed, please go to 'https://huggingface.co/Plachta/VALL-E-X/resolve/main/vallex-checkpoint.pt'\"\n",
+ " # \"\\n manually download model weights and put it to {} .\".format(os.getcwd() + \"\\checkpoints\"))\n",
+ " # # VALL-E\n",
+ " model = VALLE(\n",
+ " N_DIM,\n",
+ " NUM_HEAD,\n",
+ " NUM_LAYERS,\n",
+ " norm_first=True,\n",
+ " add_prenet=False,\n",
+ " prefix_mode=PREFIX_MODE,\n",
+ " share_embedding=True,\n",
+ " nar_scale_factor=1.0,\n",
+ " prepend_bos=True,\n",
+ " num_quantizers=NUM_QUANTIZERS,\n",
+ " ).to(device)\n",
+ " checkpoint = torch.load(model_path, map_location='cpu')\n",
+ " missing_keys, unexpected_keys = model.load_state_dict(\n",
+ " checkpoint[\"model\"], strict=True\n",
+ " )\n",
+ " assert not missing_keys\n",
+ " model.eval()\n",
+ "\n",
+ " # Encodec\n",
+ " codec = AudioTokenizer(device)\n",
+ " \n",
+ " vocos = Vocos.from_pretrained('charactr/vocos-encodec-24khz').to(device)\n",
+ "\n",
+ "@torch.no_grad()\n",
+ "def generate_audio(text, prompt=None, language='auto', accent='no-accent'):\n",
+ " global model, codec, vocos, text_tokenizer, text_collater\n",
+ " text = text.replace(\"\\n\", \"\").strip(\" \")\n",
+ " # detect language\n",
+ " if language == \"auto\":\n",
+ " language = langid.classify(text)[0]\n",
+ " lang_token = lang2token[language]\n",
+ " lang = token2lang[lang_token]\n",
+ " text = lang_token + text + lang_token\n",
+ "\n",
+ " # load prompt\n",
+ " if prompt is not None:\n",
+ " prompt_path = prompt\n",
+ " if not os.path.exists(prompt_path):\n",
+ " prompt_path = \"./presets/\" + prompt + \".npz\"\n",
+ " if not os.path.exists(prompt_path):\n",
+ " prompt_path = \"./customs/\" + prompt + \".npz\"\n",
+ " if not os.path.exists(prompt_path):\n",
+ " raise ValueError(f\"Cannot find prompt {prompt}\")\n",
+ " prompt_data = np.load(prompt_path)\n",
+ " audio_prompts = prompt_data['audio_tokens']\n",
+ " text_prompts = prompt_data['text_tokens']\n",
+ " lang_pr = prompt_data['lang_code']\n",
+ " lang_pr = code2lang[int(lang_pr)]\n",
+ "\n",
+ " # numpy to tensor\n",
+ " audio_prompts = torch.tensor(audio_prompts).type(torch.int32).to(device)\n",
+ " text_prompts = torch.tensor(text_prompts).type(torch.int32)\n",
+ " else:\n",
+ " audio_prompts = torch.zeros([1, 0, NUM_QUANTIZERS]).type(torch.int32).to(device)\n",
+ " text_prompts = torch.zeros([1, 0]).type(torch.int32)\n",
+ " lang_pr = lang if lang != 'mix' else 'en'\n",
+ "\n",
+ " enroll_x_lens = text_prompts.shape[-1]\n",
+ " logging.info(f\"synthesize text: {text}\")\n",
+ " phone_tokens, langs = text_tokenizer.tokenize(text=f\"_{text}\".strip())\n",
+ " text_tokens, text_tokens_lens = text_collater(\n",
+ " [\n",
+ " phone_tokens\n",
+ " ]\n",
+ " )\n",
+ " text_tokens = torch.cat([text_prompts, text_tokens], dim=-1)\n",
+ " text_tokens_lens += enroll_x_lens\n",
+ " # accent control\n",
+ " lang = lang if accent == \"no-accent\" else token2lang[langdropdown2token[accent]]\n",
+ " encoded_frames = model.inference(\n",
+ " text_tokens.to(device),\n",
+ " text_tokens_lens.to(device),\n",
+ " audio_prompts,\n",
+ " enroll_x_lens=enroll_x_lens,\n",
+ " top_k=-100,\n",
+ " temperature=1,\n",
+ " prompt_language=lang_pr,\n",
+ " text_language=langs if accent == \"no-accent\" else lang,\n",
+ " )\n",
+ " # Decode with Vocos\n",
+ " frames = encoded_frames.permute(2,0,1)\n",
+ " features = vocos.codes_to_features(frames)\n",
+ " samples = vocos.decode(features, bandwidth_id=torch.tensor([2], device=device))\n",
+ "\n",
+ " return samples.squeeze().cpu().numpy()\n",
+ "\n",
+ "@torch.no_grad()\n",
+ "def generate_audio_from_long_text(text, prompt=None, language='auto', accent='no-accent', mode='sliding-window'):\n",
+ " \"\"\"\n",
+ " For long audio generation, two modes are available.\n",
+ " fixed-prompt: This mode will keep using the same prompt the user has provided, and generate audio sentence by sentence.\n",
+ " sliding-window: This mode will use the last sentence as the prompt for the next sentence, but has some concern on speaker maintenance.\n",
+ " \"\"\"\n",
+ " global model, codec, vocos, text_tokenizer, text_collater\n",
+ " if prompt is None or prompt == \"\":\n",
+ " mode = 'sliding-window' # If no prompt is given, use sliding-window mode\n",
+ " sentences = split_text_into_sentences(text)\n",
+ " # detect language\n",
+ " if language == \"auto\":\n",
+ " language = langid.classify(text)[0]\n",
+ "\n",
+ " # if initial prompt is given, encode it\n",
+ " if prompt is not None and prompt != \"\":\n",
+ " prompt_path = prompt\n",
+ " if not os.path.exists(prompt_path):\n",
+ " prompt_path = \"./presets/\" + prompt + \".npz\"\n",
+ " if not os.path.exists(prompt_path):\n",
+ " prompt_path = \"./customs/\" + prompt + \".npz\"\n",
+ " if not os.path.exists(prompt_path):\n",
+ " raise ValueError(f\"Cannot find prompt {prompt}\")\n",
+ " prompt_data = np.load(prompt_path)\n",
+ " audio_prompts = prompt_data['audio_tokens']\n",
+ " text_prompts = prompt_data['text_tokens']\n",
+ " lang_pr = prompt_data['lang_code']\n",
+ " lang_pr = code2lang[int(lang_pr)]\n",
+ "\n",
+ " # numpy to tensor\n",
+ " audio_prompts = torch.tensor(audio_prompts).type(torch.int32).to(device)\n",
+ " text_prompts = torch.tensor(text_prompts).type(torch.int32)\n",
+ " else:\n",
+ " audio_prompts = torch.zeros([1, 0, NUM_QUANTIZERS]).type(torch.int32).to(device)\n",
+ " text_prompts = torch.zeros([1, 0]).type(torch.int32)\n",
+ " lang_pr = language if language != 'mix' else 'en'\n",
+ " if mode == 'fixed-prompt':\n",
+ " complete_tokens = torch.zeros([1, NUM_QUANTIZERS, 0]).type(torch.LongTensor).to(device)\n",
+ " for text in sentences:\n",
+ " text = text.replace(\"\\n\", \"\").strip(\" \")\n",
+ " if text == \"\":\n",
+ " continue\n",
+ " lang_token = lang2token[language]\n",
+ " lang = token2lang[lang_token]\n",
+ " text = lang_token + text + lang_token\n",
+ "\n",
+ " enroll_x_lens = text_prompts.shape[-1]\n",
+ " logging.info(f\"synthesize text: {text}\")\n",
+ " phone_tokens, langs = text_tokenizer.tokenize(text=f\"_{text}\".strip())\n",
+ " text_tokens, text_tokens_lens = text_collater(\n",
+ " [\n",
+ " phone_tokens\n",
+ " ]\n",
+ " )\n",
+ " text_tokens = torch.cat([text_prompts, text_tokens], dim=-1)\n",
+ " text_tokens_lens += enroll_x_lens\n",
+ " # accent control\n",
+ " lang = lang if accent == \"no-accent\" else token2lang[langdropdown2token[accent]]\n",
+ " encoded_frames = model.inference(\n",
+ " text_tokens.to(device),\n",
+ " text_tokens_lens.to(device),\n",
+ " audio_prompts,\n",
+ " enroll_x_lens=enroll_x_lens,\n",
+ " top_k=-100,\n",
+ " temperature=1,\n",
+ " prompt_language=lang_pr,\n",
+ " text_language=langs if accent == \"no-accent\" else lang,\n",
+ " )\n",
+ " complete_tokens = torch.cat([complete_tokens, encoded_frames.transpose(2, 1)], dim=-1)\n",
+ " # Decode with Vocos\n",
+ " frames = complete_tokens.permute(1,0,2)\n",
+ " features = vocos.codes_to_features(frames)\n",
+ " samples = vocos.decode(features, bandwidth_id=torch.tensor([2], device=device))\n",
+ " return samples.squeeze().cpu().numpy()\n",
+ " elif mode == \"sliding-window\":\n",
+ " complete_tokens = torch.zeros([1, NUM_QUANTIZERS, 0]).type(torch.LongTensor).to(device)\n",
+ " original_audio_prompts = audio_prompts\n",
+ " original_text_prompts = text_prompts\n",
+ " for text in sentences:\n",
+ " text = text.replace(\"\\n\", \"\").strip(\" \")\n",
+ " if text == \"\":\n",
+ " continue\n",
+ " lang_token = lang2token[language]\n",
+ " lang = token2lang[lang_token]\n",
+ " text = lang_token + text + lang_token\n",
+ "\n",
+ " enroll_x_lens = text_prompts.shape[-1]\n",
+ " logging.info(f\"synthesize text: {text}\")\n",
+ " phone_tokens, langs = text_tokenizer.tokenize(text=f\"_{text}\".strip())\n",
+ " text_tokens, text_tokens_lens = text_collater(\n",
+ " [\n",
+ " phone_tokens\n",
+ " ]\n",
+ " )\n",
+ " text_tokens = torch.cat([text_prompts, text_tokens], dim=-1)\n",
+ " text_tokens_lens += enroll_x_lens\n",
+ " # accent control\n",
+ " lang = lang if accent == \"no-accent\" else token2lang[langdropdown2token[accent]]\n",
+ " encoded_frames = model.inference(\n",
+ " text_tokens.to(device),\n",
+ " text_tokens_lens.to(device),\n",
+ " audio_prompts,\n",
+ " enroll_x_lens=enroll_x_lens,\n",
+ " top_k=-100,\n",
+ " temperature=1,\n",
+ " prompt_language=lang_pr,\n",
+ " text_language=langs if accent == \"no-accent\" else lang,\n",
+ " )\n",
+ " complete_tokens = torch.cat([complete_tokens, encoded_frames.transpose(2, 1)], dim=-1)\n",
+ " if torch.rand(1) < 0.5:\n",
+ " audio_prompts = encoded_frames[:, :, -NUM_QUANTIZERS:]\n",
+ " text_prompts = text_tokens[:, enroll_x_lens:]\n",
+ " else:\n",
+ " audio_prompts = original_audio_prompts\n",
+ " text_prompts = original_text_prompts\n",
+ " # Decode with Vocos\n",
+ " frames = complete_tokens.permute(1,0,2)\n",
+ " features = vocos.codes_to_features(frames)\n",
+ " samples = vocos.decode(features, bandwidth_id=torch.tensor([2], device=device))\n",
+ " return samples.squeeze().cpu().numpy()\n",
+ " else:\n",
+ " raise ValueError(f\"No such mode {mode}\")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 32,
+ "id": "a91524e3-407f-4baa-be5e-061d3fb97091",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# from utils.generation import SAMPLE_RATE, generate_audio, preload_models\n",
+ "from scipy.io.wavfile import write as write_wav\n",
+ "from IPython.display import Audio\n",
+ "model = 'exp/valle_dev/epoch-1.pt'\n",
+ "# download and load all models\n",
+ "preload_models(model)\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 33,
+ "id": "ce872b2f-57c4-450e-989b-95932a923b47",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "VALL-E EOS [0 -> 603]\n"
+ ]
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ " \n",
+ " "
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "execution_count": 33,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "# generate audio from text\n",
+ "text_prompt = \"\"\"\n",
+ "他整個身體是瀰漫在 空間之間的 他所有的力氣 他所有的生長的力氣 已經消耗盡了.\n",
+ "\"\"\"\n",
+ "audio_array = generate_audio(text_prompt)\n",
+ "\n",
+ "# save audio to disk\n",
+ "# write_wav(\"ep10.wav\", SAMPLE_RATE, audio_array)\n",
+ "\n",
+ "# play text in notebook\n",
+ "Audio(audio_array, rate=SAMPLE_RATE)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "1768174c-4d95-4c27-9e9f-ff720a1a4feb",
+ "metadata": {},
+ "outputs": [],
+ "source": []
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "Python 3 (ipykernel)",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.10.12"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 5
+}