mrfakename commited on
Commit
dd217c7
·
verified ·
1 Parent(s): 07a600b

Upload folder using huggingface_hub

Browse files
.gitignore ADDED
@@ -0,0 +1,173 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Customed
2
+ .vscode/
3
+ tests/
4
+ runs/
5
+ data/
6
+ ckpts/
7
+ wandb/
8
+ results/
9
+
10
+
11
+
12
+ # Byte-compiled / optimized / DLL files
13
+ __pycache__/
14
+ *.py[cod]
15
+ *$py.class
16
+
17
+ # C extensions
18
+ *.so
19
+
20
+ # Distribution / packaging
21
+ .Python
22
+ build/
23
+ develop-eggs/
24
+ dist/
25
+ downloads/
26
+ eggs/
27
+ .eggs/
28
+ lib/
29
+ lib64/
30
+ parts/
31
+ sdist/
32
+ var/
33
+ wheels/
34
+ share/python-wheels/
35
+ *.egg-info/
36
+ .installed.cfg
37
+ *.egg
38
+ MANIFEST
39
+
40
+ # PyInstaller
41
+ # Usually these files are written by a python script from a template
42
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
43
+ *.manifest
44
+ *.spec
45
+
46
+ # Installer logs
47
+ pip-log.txt
48
+ pip-delete-this-directory.txt
49
+
50
+ # Unit test / coverage reports
51
+ htmlcov/
52
+ .tox/
53
+ .nox/
54
+ .coverage
55
+ .coverage.*
56
+ .cache
57
+ nosetests.xml
58
+ coverage.xml
59
+ *.cover
60
+ *.py,cover
61
+ .hypothesis/
62
+ .pytest_cache/
63
+ cover/
64
+
65
+ # Translations
66
+ *.mo
67
+ *.pot
68
+
69
+ # Django stuff:
70
+ *.log
71
+ local_settings.py
72
+ db.sqlite3
73
+ db.sqlite3-journal
74
+
75
+ # Flask stuff:
76
+ instance/
77
+ .webassets-cache
78
+
79
+ # Scrapy stuff:
80
+ .scrapy
81
+
82
+ # Sphinx documentation
83
+ docs/_build/
84
+
85
+ # PyBuilder
86
+ .pybuilder/
87
+ target/
88
+
89
+ # Jupyter Notebook
90
+ .ipynb_checkpoints
91
+
92
+ # IPython
93
+ profile_default/
94
+ ipython_config.py
95
+
96
+ # pyenv
97
+ # For a library or package, you might want to ignore these files since the code is
98
+ # intended to run in multiple environments; otherwise, check them in:
99
+ # .python-version
100
+
101
+ # pipenv
102
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
103
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
104
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
105
+ # install all needed dependencies.
106
+ #Pipfile.lock
107
+
108
+ # poetry
109
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
110
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
111
+ # commonly ignored for libraries.
112
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
113
+ #poetry.lock
114
+
115
+ # pdm
116
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
117
+ #pdm.lock
118
+ # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
119
+ # in version control.
120
+ # https://pdm.fming.dev/latest/usage/project/#working-with-version-control
121
+ .pdm.toml
122
+ .pdm-python
123
+ .pdm-build/
124
+
125
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
126
+ __pypackages__/
127
+
128
+ # Celery stuff
129
+ celerybeat-schedule
130
+ celerybeat.pid
131
+
132
+ # SageMath parsed files
133
+ *.sage.py
134
+
135
+ # Environments
136
+ .env
137
+ .venv
138
+ env/
139
+ venv/
140
+ ENV/
141
+ env.bak/
142
+ venv.bak/
143
+
144
+ # Spyder project settings
145
+ .spyderproject
146
+ .spyproject
147
+
148
+ # Rope project settings
149
+ .ropeproject
150
+
151
+ # mkdocs documentation
152
+ /site
153
+
154
+ # mypy
155
+ .mypy_cache/
156
+ .dmypy.json
157
+ dmypy.json
158
+
159
+ # Pyre type checker
160
+ .pyre/
161
+
162
+ # pytype static type analyzer
163
+ .pytype/
164
+
165
+ # Cython debug symbols
166
+ cython_debug/
167
+
168
+ # PyCharm
169
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
170
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
171
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
172
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
173
+ #.idea/
LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2024 Yushen CHEN
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
README.md CHANGED
@@ -1,12 +1,13 @@
1
  ---
2
- title: F5 TTS Test
3
- emoji: 🏢
4
- colorFrom: indigo
5
- colorTo: yellow
6
  sdk: gradio
7
- sdk_version: 5.1.0
8
  app_file: app.py
9
- pinned: false
 
 
10
  ---
11
 
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
  ---
2
+ title: F5-TTS
3
+ emoji: 🗣️
4
+ colorFrom: green
5
+ colorTo: green
6
  sdk: gradio
 
7
  app_file: app.py
8
+ pinned: true
9
+ short_description: 'F5-TTS & E2-TTS: Zero-Shot Voice Cloning (Unofficial Demo)'
10
+ sdk_version: 5.1.0
11
  ---
12
 
13
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
README_REPO.md ADDED
@@ -0,0 +1,196 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # F5-TTS: A Fairytaler that Fakes Fluent and Faithful Speech with Flow Matching
2
+
3
+ [![python](https://img.shields.io/badge/Python-3.10-brightgreen)](https://github.com/SWivid/F5-TTS)
4
+ [![arXiv](https://img.shields.io/badge/arXiv-2410.06885-b31b1b.svg?logo=arXiv)](https://arxiv.org/abs/2410.06885)
5
+ [![demo](https://img.shields.io/badge/GitHub-Demo%20page-blue.svg)](https://swivid.github.io/F5-TTS/)
6
+ [![space](https://img.shields.io/badge/🤗-Space%20demo-yellow)](https://huggingface.co/spaces/mrfakename/E2-F5-TTS)
7
+
8
+ **F5-TTS**: Diffusion Transformer with ConvNeXt V2, faster trained and inference.
9
+
10
+ **E2 TTS**: Flat-UNet Transformer, closest reproduction.
11
+
12
+ **Sway Sampling**: Inference-time flow step sampling strategy, greatly improves performance
13
+
14
+ ## Installation
15
+
16
+ Clone the repository:
17
+
18
+ ```bash
19
+ git clone https://github.com/SWivid/F5-TTS.git
20
+ cd F5-TTS
21
+ ```
22
+
23
+ Install torch with your CUDA version, e.g. :
24
+
25
+ ```bash
26
+ pip install torch==2.3.0+cu118 --extra-index-url https://download.pytorch.org/whl/cu118
27
+ pip install torchaudio==2.3.0+cu118 --extra-index-url https://download.pytorch.org/whl/cu118
28
+ ```
29
+
30
+ Install other packages:
31
+
32
+ ```bash
33
+ pip install -r requirements.txt
34
+ ```
35
+
36
+ ## Prepare Dataset
37
+
38
+ Example data processing scripts for Emilia and Wenetspeech4TTS, and you may tailor your own one along with a Dataset class in `model/dataset.py`.
39
+
40
+ ```bash
41
+ # prepare custom dataset up to your need
42
+ # download corresponding dataset first, and fill in the path in scripts
43
+
44
+ # Prepare the Emilia dataset
45
+ python scripts/prepare_emilia.py
46
+
47
+ # Prepare the Wenetspeech4TTS dataset
48
+ python scripts/prepare_wenetspeech4tts.py
49
+ ```
50
+
51
+ ## Training
52
+
53
+ Once your datasets are prepared, you can start the training process.
54
+
55
+ ```bash
56
+ # setup accelerate config, e.g. use multi-gpu ddp, fp16
57
+ # will be to: ~/.cache/huggingface/accelerate/default_config.yaml
58
+ accelerate config
59
+ accelerate launch train.py
60
+ ```
61
+ An initial guidance on Finetuning [#57](https://github.com/SWivid/F5-TTS/discussions/57).
62
+
63
+ ## Inference
64
+
65
+ To run inference with pretrained models, download the checkpoints from [🤗 Hugging Face](https://huggingface.co/SWivid/F5-TTS), or automatically downloaded with `inference-cli` and `gradio_app`.
66
+
67
+ Currently support 30s for a single generation, which is the **TOTAL** length of prompt audio and the generated. Batch inference with chunks is supported by `inference-cli` and `gradio_app`.
68
+ - To avoid possible inference failures, make sure you have seen through the following instructions.
69
+ - A longer prompt audio allows shorter generated output. The part longer than 30s cannot be generated properly. Consider using a prompt audio <15s.
70
+ - Uppercased letters will be uttered letter by letter, so use lowercased letters for normal words.
71
+ - Add some spaces (blank: " ") or punctuations (e.g. "," ".") to explicitly introduce some pauses. If first few words skipped in code-switched generation (cuz different speed with different languages), this might help.
72
+
73
+ ### CLI Inference
74
+
75
+ Either you can specify everything in `inference-cli.toml` or override with flags. Leave `--ref_text ""` will have ASR model transcribe the reference audio automatically (use extra GPU memory). If encounter network error, consider use local ckpt, just set `ckpt_path` in `inference-cli.py`
76
+
77
+ ```bash
78
+ python inference-cli.py \
79
+ --model "F5-TTS" \
80
+ --ref_audio "tests/ref_audio/test_en_1_ref_short.wav" \
81
+ --ref_text "Some call me nature, others call me mother nature." \
82
+ --gen_text "I don't really care what you call me. I've been a silent spectator, watching species evolve, empires rise and fall. But always remember, I am mighty and enduring. Respect me and I'll nurture you; ignore me and you shall face the consequences."
83
+
84
+ python inference-cli.py \
85
+ --model "E2-TTS" \
86
+ --ref_audio "tests/ref_audio/test_zh_1_ref_short.wav" \
87
+ --ref_text "对,这就是我,万人敬仰的太乙真人。" \
88
+ --gen_text "突然,身边一阵笑声。我看着他们,意气风发地挺直了胸膛,甩了甩那稍显肉感的双臂,轻笑道,我身上的肉,是为了掩饰我爆棚的魅力,否则,岂不吓坏了你们呢?"
89
+ ```
90
+
91
+ ### Gradio App
92
+ Currently supported features:
93
+ - Chunk inference
94
+ - Podcast Generation
95
+ - Multiple Speech-Type Generation
96
+
97
+ You can launch a Gradio app (web interface) to launch a GUI for inference (will load ckpt from Huggingface, you may set `ckpt_path` to local file in `gradio_app.py`). Currently load ASR model, F5-TTS and E2 TTS all in once, thus use more GPU memory than `inference-cli`.
98
+
99
+ ```bash
100
+ python gradio_app.py
101
+ ```
102
+
103
+ You can specify the port/host:
104
+
105
+ ```bash
106
+ python gradio_app.py --port 7860 --host 0.0.0.0
107
+ ```
108
+
109
+ Or launch a share link:
110
+
111
+ ```bash
112
+ python gradio_app.py --share
113
+ ```
114
+
115
+ ### Speech Editing
116
+
117
+ To test speech editing capabilities, use the following command.
118
+
119
+ ```bash
120
+ python speech_edit.py
121
+ ```
122
+
123
+ ## Evaluation
124
+
125
+ ### Prepare Test Datasets
126
+
127
+ 1. Seed-TTS test set: Download from [seed-tts-eval](https://github.com/BytedanceSpeech/seed-tts-eval).
128
+ 2. LibriSpeech test-clean: Download from [OpenSLR](http://www.openslr.org/12/).
129
+ 3. Unzip the downloaded datasets and place them in the data/ directory.
130
+ 4. Update the path for the test-clean data in `scripts/eval_infer_batch.py`
131
+ 5. Our filtered LibriSpeech-PC 4-10s subset is already under data/ in this repo
132
+
133
+ ### Batch Inference for Test Set
134
+
135
+ To run batch inference for evaluations, execute the following commands:
136
+
137
+ ```bash
138
+ # batch inference for evaluations
139
+ accelerate config # if not set before
140
+ bash scripts/eval_infer_batch.sh
141
+ ```
142
+
143
+ ### Download Evaluation Model Checkpoints
144
+
145
+ 1. Chinese ASR Model: [Paraformer-zh](https://huggingface.co/funasr/paraformer-zh)
146
+ 2. English ASR Model: [Faster-Whisper](https://huggingface.co/Systran/faster-whisper-large-v3)
147
+ 3. WavLM Model: Download from [Google Drive](https://drive.google.com/file/d/1-aE1NfzpRCLxA4GUxX9ITI3F9LlbtEGP/view).
148
+
149
+ ### Objective Evaluation
150
+
151
+ **Some Notes**
152
+
153
+ For faster-whisper with CUDA 11:
154
+
155
+ ```bash
156
+ pip install --force-reinstall ctranslate2==3.24.0
157
+ ```
158
+
159
+ (Recommended) To avoid possible ASR failures, such as abnormal repetitions in output:
160
+
161
+ ```bash
162
+ pip install faster-whisper==0.10.1
163
+ ```
164
+
165
+ Update the path with your batch-inferenced results, and carry out WER / SIM evaluations:
166
+ ```bash
167
+ # Evaluation for Seed-TTS test set
168
+ python scripts/eval_seedtts_testset.py
169
+
170
+ # Evaluation for LibriSpeech-PC test-clean (cross-sentence)
171
+ python scripts/eval_librispeech_test_clean.py
172
+ ```
173
+
174
+ ## Acknowledgements
175
+
176
+ - [E2-TTS](https://arxiv.org/abs/2406.18009) brilliant work, simple and effective
177
+ - [Emilia](https://arxiv.org/abs/2407.05361), [WenetSpeech4TTS](https://arxiv.org/abs/2406.05763) valuable datasets
178
+ - [lucidrains](https://github.com/lucidrains) initial CFM structure with also [bfs18](https://github.com/bfs18) for discussion
179
+ - [SD3](https://arxiv.org/abs/2403.03206) & [Hugging Face diffusers](https://github.com/huggingface/diffusers) DiT and MMDiT code structure
180
+ - [torchdiffeq](https://github.com/rtqichen/torchdiffeq) as ODE solver, [Vocos](https://huggingface.co/charactr/vocos-mel-24khz) as vocoder
181
+ - [mrfakename](https://x.com/realmrfakename) huggingface space demo ~
182
+ - [FunASR](https://github.com/modelscope/FunASR), [faster-whisper](https://github.com/SYSTRAN/faster-whisper), [UniSpeech](https://github.com/microsoft/UniSpeech) for evaluation tools
183
+ - [ctc-forced-aligner](https://github.com/MahmoudAshraf97/ctc-forced-aligner) for speech edit test
184
+
185
+ ## Citation
186
+ ```
187
+ @article{chen-etal-2024-f5tts,
188
+ title={F5-TTS: A Fairytaler that Fakes Fluent and Faithful Speech with Flow Matching},
189
+ author={Yushen Chen and Zhikang Niu and Ziyang Ma and Keqi Deng and Chunhui Wang and Jian Zhao and Kai Yu and Xie Chen},
190
+ journal={arXiv preprint arXiv:2410.06885},
191
+ year={2024},
192
+ }
193
+ ```
194
+ ## License
195
+
196
+ Our code is released under MIT License.
app.py ADDED
@@ -0,0 +1,824 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import re
3
+ import torch
4
+ import torchaudio
5
+ import gradio as gr
6
+ import numpy as np
7
+ import tempfile
8
+ from einops import rearrange
9
+ from vocos import Vocos
10
+ from pydub import AudioSegment, silence
11
+ from model import CFM, UNetT, DiT, MMDiT
12
+ from cached_path import cached_path
13
+ from model.utils import (
14
+ load_checkpoint,
15
+ get_tokenizer,
16
+ convert_char_to_pinyin,
17
+ save_spectrogram,
18
+ )
19
+ from transformers import pipeline
20
+ import librosa
21
+ import click
22
+ import soundfile as sf
23
+
24
+ try:
25
+ import spaces
26
+ USING_SPACES = True
27
+ except ImportError:
28
+ USING_SPACES = False
29
+
30
+ def gpu_decorator(func):
31
+ if USING_SPACES:
32
+ return spaces.GPU(func)
33
+ else:
34
+ return func
35
+
36
+
37
+
38
+ SPLIT_WORDS = [
39
+ "but", "however", "nevertheless", "yet", "still",
40
+ "therefore", "thus", "hence", "consequently",
41
+ "moreover", "furthermore", "additionally",
42
+ "meanwhile", "alternatively", "otherwise",
43
+ "namely", "specifically", "for example", "such as",
44
+ "in fact", "indeed", "notably",
45
+ "in contrast", "on the other hand", "conversely",
46
+ "in conclusion", "to summarize", "finally"
47
+ ]
48
+
49
+ device = (
50
+ "cuda"
51
+ if torch.cuda.is_available()
52
+ else "mps" if torch.backends.mps.is_available() else "cpu"
53
+ )
54
+
55
+ print(f"Using {device} device")
56
+
57
+ pipe = pipeline(
58
+ "automatic-speech-recognition",
59
+ model="openai/whisper-large-v3-turbo",
60
+ torch_dtype=torch.float16,
61
+ device=device,
62
+ )
63
+ vocos = Vocos.from_pretrained("charactr/vocos-mel-24khz")
64
+
65
+ # --------------------- Settings -------------------- #
66
+
67
+ target_sample_rate = 24000
68
+ n_mel_channels = 100
69
+ hop_length = 256
70
+ target_rms = 0.1
71
+ nfe_step = 32 # 16, 32
72
+ cfg_strength = 2.0
73
+ ode_method = "euler"
74
+ sway_sampling_coef = -1.0
75
+ speed = 1.0
76
+ # fix_duration = 27 # None or float (duration in seconds)
77
+ fix_duration = None
78
+
79
+
80
+ def load_model(repo_name, exp_name, model_cls, model_cfg, ckpt_step):
81
+ ckpt_path = str(cached_path(f"hf://SWivid/{repo_name}/{exp_name}/model_{ckpt_step}.safetensors"))
82
+ # ckpt_path = f"ckpts/{exp_name}/model_{ckpt_step}.pt" # .pt | .safetensors
83
+ vocab_char_map, vocab_size = get_tokenizer("Emilia_ZH_EN", "pinyin")
84
+ model = CFM(
85
+ transformer=model_cls(
86
+ **model_cfg, text_num_embeds=vocab_size, mel_dim=n_mel_channels
87
+ ),
88
+ mel_spec_kwargs=dict(
89
+ target_sample_rate=target_sample_rate,
90
+ n_mel_channels=n_mel_channels,
91
+ hop_length=hop_length,
92
+ ),
93
+ odeint_kwargs=dict(
94
+ method=ode_method,
95
+ ),
96
+ vocab_char_map=vocab_char_map,
97
+ ).to(device)
98
+
99
+ model = load_checkpoint(model, ckpt_path, device, use_ema = True)
100
+
101
+ return model
102
+
103
+
104
+ # load models
105
+ F5TTS_model_cfg = dict(
106
+ dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4
107
+ )
108
+ E2TTS_model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4)
109
+
110
+ F5TTS_ema_model = load_model(
111
+ "F5-TTS", "F5TTS_Base", DiT, F5TTS_model_cfg, 1200000
112
+ )
113
+ E2TTS_ema_model = load_model(
114
+ "E2-TTS", "E2TTS_Base", UNetT, E2TTS_model_cfg, 1200000
115
+ )
116
+
117
+ def split_text_into_batches(text, max_chars=200, split_words=SPLIT_WORDS):
118
+ if len(text.encode('utf-8')) <= max_chars:
119
+ return [text]
120
+ if text[-1] not in ['。', '.', '!', '!', '?', '?']:
121
+ text += '.'
122
+
123
+ sentences = re.split('([。.!?!?])', text)
124
+ sentences = [''.join(i) for i in zip(sentences[0::2], sentences[1::2])]
125
+
126
+ batches = []
127
+ current_batch = ""
128
+
129
+ def split_by_words(text):
130
+ words = text.split()
131
+ current_word_part = ""
132
+ word_batches = []
133
+ for word in words:
134
+ if len(current_word_part.encode('utf-8')) + len(word.encode('utf-8')) + 1 <= max_chars:
135
+ current_word_part += word + ' '
136
+ else:
137
+ if current_word_part:
138
+ # Try to find a suitable split word
139
+ for split_word in split_words:
140
+ split_index = current_word_part.rfind(' ' + split_word + ' ')
141
+ if split_index != -1:
142
+ word_batches.append(current_word_part[:split_index].strip())
143
+ current_word_part = current_word_part[split_index:].strip() + ' '
144
+ break
145
+ else:
146
+ # If no suitable split word found, just append the current part
147
+ word_batches.append(current_word_part.strip())
148
+ current_word_part = ""
149
+ current_word_part += word + ' '
150
+ if current_word_part:
151
+ word_batches.append(current_word_part.strip())
152
+ return word_batches
153
+
154
+ for sentence in sentences:
155
+ if len(current_batch.encode('utf-8')) + len(sentence.encode('utf-8')) <= max_chars:
156
+ current_batch += sentence
157
+ else:
158
+ # If adding this sentence would exceed the limit
159
+ if current_batch:
160
+ batches.append(current_batch)
161
+ current_batch = ""
162
+
163
+ # If the sentence itself is longer than max_chars, split it
164
+ if len(sentence.encode('utf-8')) > max_chars:
165
+ # First, try to split by colon
166
+ colon_parts = sentence.split(':')
167
+ if len(colon_parts) > 1:
168
+ for part in colon_parts:
169
+ if len(part.encode('utf-8')) <= max_chars:
170
+ batches.append(part)
171
+ else:
172
+ # If colon part is still too long, split by comma
173
+ comma_parts = re.split('[,,]', part)
174
+ if len(comma_parts) > 1:
175
+ current_comma_part = ""
176
+ for comma_part in comma_parts:
177
+ if len(current_comma_part.encode('utf-8')) + len(comma_part.encode('utf-8')) <= max_chars:
178
+ current_comma_part += comma_part + ','
179
+ else:
180
+ if current_comma_part:
181
+ batches.append(current_comma_part.rstrip(','))
182
+ current_comma_part = comma_part + ','
183
+ if current_comma_part:
184
+ batches.append(current_comma_part.rstrip(','))
185
+ else:
186
+ # If no comma, split by words
187
+ batches.extend(split_by_words(part))
188
+ else:
189
+ # If no colon, split by comma
190
+ comma_parts = re.split('[,,]', sentence)
191
+ if len(comma_parts) > 1:
192
+ current_comma_part = ""
193
+ for comma_part in comma_parts:
194
+ if len(current_comma_part.encode('utf-8')) + len(comma_part.encode('utf-8')) <= max_chars:
195
+ current_comma_part += comma_part + ','
196
+ else:
197
+ if current_comma_part:
198
+ batches.append(current_comma_part.rstrip(','))
199
+ current_comma_part = comma_part + ','
200
+ if current_comma_part:
201
+ batches.append(current_comma_part.rstrip(','))
202
+ else:
203
+ # If no comma, split by words
204
+ batches.extend(split_by_words(sentence))
205
+ else:
206
+ current_batch = sentence
207
+
208
+ if current_batch:
209
+ batches.append(current_batch)
210
+
211
+ return batches
212
+
213
+ def infer_batch(ref_audio, ref_text, gen_text_batches, exp_name, remove_silence, progress=gr.Progress()):
214
+ if exp_name == "F5-TTS":
215
+ ema_model = F5TTS_ema_model
216
+ elif exp_name == "E2-TTS":
217
+ ema_model = E2TTS_ema_model
218
+
219
+ audio, sr = ref_audio
220
+ if audio.shape[0] > 1:
221
+ audio = torch.mean(audio, dim=0, keepdim=True)
222
+
223
+ rms = torch.sqrt(torch.mean(torch.square(audio)))
224
+ if rms < target_rms:
225
+ audio = audio * target_rms / rms
226
+ if sr != target_sample_rate:
227
+ resampler = torchaudio.transforms.Resample(sr, target_sample_rate)
228
+ audio = resampler(audio)
229
+ audio = audio.to(device)
230
+
231
+ generated_waves = []
232
+ spectrograms = []
233
+
234
+ for i, gen_text in enumerate(progress.tqdm(gen_text_batches)):
235
+ # Prepare the text
236
+ if len(ref_text[-1].encode('utf-8')) == 1:
237
+ ref_text = ref_text + " "
238
+ text_list = [ref_text + gen_text]
239
+ final_text_list = convert_char_to_pinyin(text_list)
240
+
241
+ # Calculate duration
242
+ ref_audio_len = audio.shape[-1] // hop_length
243
+ zh_pause_punc = r"。,、;:?!"
244
+ ref_text_len = len(ref_text.encode('utf-8')) + 3 * len(re.findall(zh_pause_punc, ref_text))
245
+ gen_text_len = len(gen_text.encode('utf-8')) + 3 * len(re.findall(zh_pause_punc, gen_text))
246
+ duration = ref_audio_len + int(ref_audio_len / ref_text_len * gen_text_len / speed)
247
+
248
+ # inference
249
+ with torch.inference_mode():
250
+ generated, _ = ema_model.sample(
251
+ cond=audio,
252
+ text=final_text_list,
253
+ duration=duration,
254
+ steps=nfe_step,
255
+ cfg_strength=cfg_strength,
256
+ sway_sampling_coef=sway_sampling_coef,
257
+ )
258
+
259
+ generated = generated[:, ref_audio_len:, :]
260
+ generated_mel_spec = rearrange(generated, "1 n d -> 1 d n")
261
+ generated_wave = vocos.decode(generated_mel_spec.cpu())
262
+ if rms < target_rms:
263
+ generated_wave = generated_wave * rms / target_rms
264
+
265
+ # wav -> numpy
266
+ generated_wave = generated_wave.squeeze().cpu().numpy()
267
+
268
+ generated_waves.append(generated_wave)
269
+ spectrograms.append(generated_mel_spec[0].cpu().numpy())
270
+
271
+ # Combine all generated waves
272
+ final_wave = np.concatenate(generated_waves)
273
+
274
+ # Remove silence
275
+ if remove_silence:
276
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as f:
277
+ sf.write(f.name, final_wave, target_sample_rate)
278
+ aseg = AudioSegment.from_file(f.name)
279
+ non_silent_segs = silence.split_on_silence(aseg, min_silence_len=1000, silence_thresh=-50, keep_silence=500)
280
+ non_silent_wave = AudioSegment.silent(duration=0)
281
+ for non_silent_seg in non_silent_segs:
282
+ non_silent_wave += non_silent_seg
283
+ aseg = non_silent_wave
284
+ aseg.export(f.name, format="wav")
285
+ final_wave, _ = torchaudio.load(f.name)
286
+ final_wave = final_wave.squeeze().cpu().numpy()
287
+
288
+ # Create a combined spectrogram
289
+ combined_spectrogram = np.concatenate(spectrograms, axis=1)
290
+
291
+ with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmp_spectrogram:
292
+ spectrogram_path = tmp_spectrogram.name
293
+ save_spectrogram(combined_spectrogram, spectrogram_path)
294
+
295
+ return (target_sample_rate, final_wave), spectrogram_path
296
+
297
+ def infer(ref_audio_orig, ref_text, gen_text, exp_name, remove_silence, custom_split_words=''):
298
+ if not custom_split_words.strip():
299
+ custom_words = [word.strip() for word in custom_split_words.split(',')]
300
+ global SPLIT_WORDS
301
+ SPLIT_WORDS = custom_words
302
+
303
+ print(gen_text)
304
+
305
+ gr.Info("Converting audio...")
306
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as f:
307
+ aseg = AudioSegment.from_file(ref_audio_orig)
308
+
309
+ non_silent_segs = silence.split_on_silence(aseg, min_silence_len=1000, silence_thresh=-50, keep_silence=500)
310
+ non_silent_wave = AudioSegment.silent(duration=0)
311
+ for non_silent_seg in non_silent_segs:
312
+ non_silent_wave += non_silent_seg
313
+ aseg = non_silent_wave
314
+
315
+ audio_duration = len(aseg)
316
+ if audio_duration > 15000:
317
+ gr.Warning("Audio is over 15s, clipping to only first 15s.")
318
+ aseg = aseg[:15000]
319
+ aseg.export(f.name, format="wav")
320
+ ref_audio = f.name
321
+
322
+ if not ref_text.strip():
323
+ gr.Info("No reference text provided, transcribing reference audio...")
324
+ ref_text = pipe(
325
+ ref_audio,
326
+ chunk_length_s=30,
327
+ batch_size=128,
328
+ generate_kwargs={"task": "transcribe"},
329
+ return_timestamps=False,
330
+ )["text"].strip()
331
+ gr.Info("Finished transcription")
332
+ else:
333
+ gr.Info("Using custom reference text...")
334
+
335
+ # Split the input text into batches
336
+ audio, sr = torchaudio.load(ref_audio)
337
+ max_chars = int(len(ref_text.encode('utf-8')) / (audio.shape[-1] / sr) * (30 - audio.shape[-1] / sr))
338
+ gen_text_batches = split_text_into_batches(gen_text, max_chars=max_chars)
339
+ print('ref_text', ref_text)
340
+ for i, gen_text in enumerate(gen_text_batches):
341
+ print(f'gen_text {i}', gen_text)
342
+
343
+ gr.Info(f"Generating audio using {exp_name} in {len(gen_text_batches)} batches")
344
+ return infer_batch((audio, sr), ref_text, gen_text_batches, exp_name, remove_silence)
345
+
346
+ def generate_podcast(script, speaker1_name, ref_audio1, ref_text1, speaker2_name, ref_audio2, ref_text2, exp_name, remove_silence):
347
+ # Split the script into speaker blocks
348
+ speaker_pattern = re.compile(f"^({re.escape(speaker1_name)}|{re.escape(speaker2_name)}):", re.MULTILINE)
349
+ speaker_blocks = speaker_pattern.split(script)[1:] # Skip the first empty element
350
+
351
+ generated_audio_segments = []
352
+
353
+ for i in range(0, len(speaker_blocks), 2):
354
+ speaker = speaker_blocks[i]
355
+ text = speaker_blocks[i+1].strip()
356
+
357
+ # Determine which speaker is talking
358
+ if speaker == speaker1_name:
359
+ ref_audio = ref_audio1
360
+ ref_text = ref_text1
361
+ elif speaker == speaker2_name:
362
+ ref_audio = ref_audio2
363
+ ref_text = ref_text2
364
+ else:
365
+ continue # Skip if the speaker is neither speaker1 nor speaker2
366
+
367
+ # Generate audio for this block
368
+ audio, _ = infer(ref_audio, ref_text, text, exp_name, remove_silence)
369
+
370
+ # Convert the generated audio to a numpy array
371
+ sr, audio_data = audio
372
+
373
+ # Save the audio data as a WAV file
374
+ with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as temp_file:
375
+ sf.write(temp_file.name, audio_data, sr)
376
+ audio_segment = AudioSegment.from_wav(temp_file.name)
377
+
378
+ generated_audio_segments.append(audio_segment)
379
+
380
+ # Add a short pause between speakers
381
+ pause = AudioSegment.silent(duration=500) # 500ms pause
382
+ generated_audio_segments.append(pause)
383
+
384
+ # Concatenate all audio segments
385
+ final_podcast = sum(generated_audio_segments)
386
+
387
+ # Export the final podcast
388
+ with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as temp_file:
389
+ podcast_path = temp_file.name
390
+ final_podcast.export(podcast_path, format="wav")
391
+
392
+ return podcast_path
393
+
394
+ def parse_speechtypes_text(gen_text):
395
+ # Pattern to find (Emotion)
396
+ pattern = r'\((.*?)\)'
397
+
398
+ # Split the text by the pattern
399
+ tokens = re.split(pattern, gen_text)
400
+
401
+ segments = []
402
+
403
+ current_emotion = 'Regular'
404
+
405
+ for i in range(len(tokens)):
406
+ if i % 2 == 0:
407
+ # This is text
408
+ text = tokens[i].strip()
409
+ if text:
410
+ segments.append({'emotion': current_emotion, 'text': text})
411
+ else:
412
+ # This is emotion
413
+ emotion = tokens[i].strip()
414
+ current_emotion = emotion
415
+
416
+ return segments
417
+
418
+ def update_speed(new_speed):
419
+ global speed
420
+ speed = new_speed
421
+ return f"Speed set to: {speed}"
422
+
423
+ with gr.Blocks() as app_credits:
424
+ gr.Markdown("""
425
+ # Credits
426
+
427
+ * [mrfakename](https://github.com/fakerybakery) for the original [online demo](https://huggingface.co/spaces/mrfakename/E2-F5-TTS)
428
+ * [RootingInLoad](https://github.com/RootingInLoad) for the podcast generation
429
+ """)
430
+ with gr.Blocks() as app_tts:
431
+ gr.Markdown("# Batched TTS")
432
+ ref_audio_input = gr.Audio(label="Reference Audio", type="filepath")
433
+ gen_text_input = gr.Textbox(label="Text to Generate", lines=10)
434
+ model_choice = gr.Radio(
435
+ choices=["F5-TTS", "E2-TTS"], label="Choose TTS Model", value="F5-TTS"
436
+ )
437
+ generate_btn = gr.Button("Synthesize", variant="primary")
438
+ with gr.Accordion("Advanced Settings", open=False):
439
+ ref_text_input = gr.Textbox(
440
+ label="Reference Text",
441
+ info="Leave blank to automatically transcribe the reference audio. If you enter text it will override automatic transcription.",
442
+ lines=2,
443
+ )
444
+ remove_silence = gr.Checkbox(
445
+ label="Remove Silences",
446
+ info="The model tends to produce silences, especially on longer audio. We can manually remove silences if needed. Note that this is an experimental feature and may produce strange results. This will also increase generation time.",
447
+ value=True,
448
+ )
449
+ split_words_input = gr.Textbox(
450
+ label="Custom Split Words",
451
+ info="Enter custom words to split on, separated by commas. Leave blank to use default list.",
452
+ lines=2,
453
+ )
454
+ speed_slider = gr.Slider(
455
+ label="Speed",
456
+ minimum=0.3,
457
+ maximum=2.0,
458
+ value=speed,
459
+ step=0.1,
460
+ info="Adjust the speed of the audio.",
461
+ )
462
+ speed_slider.change(update_speed, inputs=speed_slider)
463
+
464
+ audio_output = gr.Audio(label="Synthesized Audio")
465
+ spectrogram_output = gr.Image(label="Spectrogram")
466
+
467
+ generate_btn.click(
468
+ infer,
469
+ inputs=[
470
+ ref_audio_input,
471
+ ref_text_input,
472
+ gen_text_input,
473
+ model_choice,
474
+ remove_silence,
475
+ split_words_input,
476
+ ],
477
+ outputs=[audio_output, spectrogram_output],
478
+ )
479
+
480
+ with gr.Blocks() as app_podcast:
481
+ gr.Markdown("# Podcast Generation")
482
+ speaker1_name = gr.Textbox(label="Speaker 1 Name")
483
+ ref_audio_input1 = gr.Audio(label="Reference Audio (Speaker 1)", type="filepath")
484
+ ref_text_input1 = gr.Textbox(label="Reference Text (Speaker 1)", lines=2)
485
+
486
+ speaker2_name = gr.Textbox(label="Speaker 2 Name")
487
+ ref_audio_input2 = gr.Audio(label="Reference Audio (Speaker 2)", type="filepath")
488
+ ref_text_input2 = gr.Textbox(label="Reference Text (Speaker 2)", lines=2)
489
+
490
+ script_input = gr.Textbox(label="Podcast Script", lines=10,
491
+ placeholder="Enter the script with speaker names at the start of each block, e.g.:\nSean: How did you start studying...\n\nMeghan: I came to my interest in technology...\nIt was a long journey...\n\nSean: That's fascinating. Can you elaborate...")
492
+
493
+ podcast_model_choice = gr.Radio(
494
+ choices=["F5-TTS", "E2-TTS"], label="Choose TTS Model", value="F5-TTS"
495
+ )
496
+ podcast_remove_silence = gr.Checkbox(
497
+ label="Remove Silences",
498
+ value=True,
499
+ )
500
+ generate_podcast_btn = gr.Button("Generate Podcast", variant="primary")
501
+ podcast_output = gr.Audio(label="Generated Podcast")
502
+
503
+ def podcast_generation(script, speaker1, ref_audio1, ref_text1, speaker2, ref_audio2, ref_text2, model, remove_silence):
504
+ return generate_podcast(script, speaker1, ref_audio1, ref_text1, speaker2, ref_audio2, ref_text2, model, remove_silence)
505
+
506
+ generate_podcast_btn.click(
507
+ podcast_generation,
508
+ inputs=[
509
+ script_input,
510
+ speaker1_name,
511
+ ref_audio_input1,
512
+ ref_text_input1,
513
+ speaker2_name,
514
+ ref_audio_input2,
515
+ ref_text_input2,
516
+ podcast_model_choice,
517
+ podcast_remove_silence,
518
+ ],
519
+ outputs=podcast_output,
520
+ )
521
+
522
+ def parse_emotional_text(gen_text):
523
+ # Pattern to find (Emotion)
524
+ pattern = r'\((.*?)\)'
525
+
526
+ # Split the text by the pattern
527
+ tokens = re.split(pattern, gen_text)
528
+
529
+ segments = []
530
+
531
+ current_emotion = 'Regular'
532
+
533
+ for i in range(len(tokens)):
534
+ if i % 2 == 0:
535
+ # This is text
536
+ text = tokens[i].strip()
537
+ if text:
538
+ segments.append({'emotion': current_emotion, 'text': text})
539
+ else:
540
+ # This is emotion
541
+ emotion = tokens[i].strip()
542
+ current_emotion = emotion
543
+
544
+ return segments
545
+
546
+ with gr.Blocks() as app_emotional:
547
+ # New section for emotional generation
548
+ gr.Markdown(
549
+ """
550
+ # Multiple Speech-Type Generation
551
+
552
+ This section allows you to upload different audio clips for each speech type. 'Regular' emotion is mandatory. You can add additional speech types by clicking the "Add Speech Type" button. Enter your text in the format shown below, and the system will generate speech using the appropriate emotions. If unspecified, the model will use the regular speech type. The current speech type will be used until the next speech type is specified.
553
+
554
+ **Example Input:**
555
+
556
+ (Regular) Hello, I'd like to order a sandwich please. (Surprised) What do you mean you're out of bread? (Sad) I really wanted a sandwich though... (Angry) You know what, darn you and your little shop, you suck! (Whisper) I'll just go back home and cry now. (Shouting) Why me?!
557
+ """
558
+ )
559
+
560
+ gr.Markdown("Upload different audio clips for each speech type. 'Regular' emotion is mandatory. You can add additional speech types by clicking the 'Add Speech Type' button.")
561
+
562
+ # Regular speech type (mandatory)
563
+ with gr.Row():
564
+ regular_name = gr.Textbox(value='Regular', label='Speech Type Name', interactive=False)
565
+ regular_audio = gr.Audio(label='Regular Reference Audio', type='filepath')
566
+ regular_ref_text = gr.Textbox(label='Reference Text (Regular)', lines=2)
567
+
568
+ # Additional speech types (up to 9 more)
569
+ max_speech_types = 10
570
+ speech_type_names = []
571
+ speech_type_audios = []
572
+ speech_type_ref_texts = []
573
+ speech_type_delete_btns = []
574
+
575
+ for i in range(max_speech_types - 1):
576
+ with gr.Row():
577
+ name_input = gr.Textbox(label='Speech Type Name', visible=False)
578
+ audio_input = gr.Audio(label='Reference Audio', type='filepath', visible=False)
579
+ ref_text_input = gr.Textbox(label='Reference Text', lines=2, visible=False)
580
+ delete_btn = gr.Button("Delete", variant="secondary", visible=False)
581
+ speech_type_names.append(name_input)
582
+ speech_type_audios.append(audio_input)
583
+ speech_type_ref_texts.append(ref_text_input)
584
+ speech_type_delete_btns.append(delete_btn)
585
+
586
+ # Button to add speech type
587
+ add_speech_type_btn = gr.Button("Add Speech Type")
588
+
589
+ # Keep track of current number of speech types
590
+ speech_type_count = gr.State(value=0)
591
+
592
+ # Function to add a speech type
593
+ def add_speech_type_fn(speech_type_count):
594
+ if speech_type_count < max_speech_types - 1:
595
+ speech_type_count += 1
596
+ # Prepare updates for the components
597
+ name_updates = []
598
+ audio_updates = []
599
+ ref_text_updates = []
600
+ delete_btn_updates = []
601
+ for i in range(max_speech_types - 1):
602
+ if i < speech_type_count:
603
+ name_updates.append(gr.update(visible=True))
604
+ audio_updates.append(gr.update(visible=True))
605
+ ref_text_updates.append(gr.update(visible=True))
606
+ delete_btn_updates.append(gr.update(visible=True))
607
+ else:
608
+ name_updates.append(gr.update())
609
+ audio_updates.append(gr.update())
610
+ ref_text_updates.append(gr.update())
611
+ delete_btn_updates.append(gr.update())
612
+ else:
613
+ # Optionally, show a warning
614
+ # gr.Warning("Maximum number of speech types reached.")
615
+ name_updates = [gr.update() for _ in range(max_speech_types - 1)]
616
+ audio_updates = [gr.update() for _ in range(max_speech_types - 1)]
617
+ ref_text_updates = [gr.update() for _ in range(max_speech_types - 1)]
618
+ delete_btn_updates = [gr.update() for _ in range(max_speech_types - 1)]
619
+ return [speech_type_count] + name_updates + audio_updates + ref_text_updates + delete_btn_updates
620
+
621
+ add_speech_type_btn.click(
622
+ add_speech_type_fn,
623
+ inputs=speech_type_count,
624
+ outputs=[speech_type_count] + speech_type_names + speech_type_audios + speech_type_ref_texts + speech_type_delete_btns
625
+ )
626
+
627
+ # Function to delete a speech type
628
+ def make_delete_speech_type_fn(index):
629
+ def delete_speech_type_fn(speech_type_count):
630
+ # Prepare updates
631
+ name_updates = []
632
+ audio_updates = []
633
+ ref_text_updates = []
634
+ delete_btn_updates = []
635
+
636
+ for i in range(max_speech_types - 1):
637
+ if i == index:
638
+ name_updates.append(gr.update(visible=False, value=''))
639
+ audio_updates.append(gr.update(visible=False, value=None))
640
+ ref_text_updates.append(gr.update(visible=False, value=''))
641
+ delete_btn_updates.append(gr.update(visible=False))
642
+ else:
643
+ name_updates.append(gr.update())
644
+ audio_updates.append(gr.update())
645
+ ref_text_updates.append(gr.update())
646
+ delete_btn_updates.append(gr.update())
647
+
648
+ speech_type_count = max(0, speech_type_count - 1)
649
+
650
+ return [speech_type_count] + name_updates + audio_updates + ref_text_updates + delete_btn_updates
651
+
652
+ return delete_speech_type_fn
653
+
654
+ for i, delete_btn in enumerate(speech_type_delete_btns):
655
+ delete_fn = make_delete_speech_type_fn(i)
656
+ delete_btn.click(
657
+ delete_fn,
658
+ inputs=speech_type_count,
659
+ outputs=[speech_type_count] + speech_type_names + speech_type_audios + speech_type_ref_texts + speech_type_delete_btns
660
+ )
661
+
662
+ # Text input for the prompt
663
+ gen_text_input_emotional = gr.Textbox(label="Text to Generate", lines=10)
664
+
665
+ # Model choice
666
+ model_choice_emotional = gr.Radio(
667
+ choices=["F5-TTS", "E2-TTS"], label="Choose TTS Model", value="F5-TTS"
668
+ )
669
+
670
+ with gr.Accordion("Advanced Settings", open=False):
671
+ remove_silence_emotional = gr.Checkbox(
672
+ label="Remove Silences",
673
+ value=True,
674
+ )
675
+
676
+ # Generate button
677
+ generate_emotional_btn = gr.Button("Generate Emotional Speech", variant="primary")
678
+
679
+ # Output audio
680
+ audio_output_emotional = gr.Audio(label="Synthesized Audio")
681
+
682
+ def generate_emotional_speech(
683
+ regular_audio,
684
+ regular_ref_text,
685
+ gen_text,
686
+ *args,
687
+ ):
688
+ num_additional_speech_types = max_speech_types - 1
689
+ speech_type_names_list = args[:num_additional_speech_types]
690
+ speech_type_audios_list = args[num_additional_speech_types:2 * num_additional_speech_types]
691
+ speech_type_ref_texts_list = args[2 * num_additional_speech_types:3 * num_additional_speech_types]
692
+ model_choice = args[3 * num_additional_speech_types]
693
+ remove_silence = args[3 * num_additional_speech_types + 1]
694
+
695
+ # Collect the speech types and their audios into a dict
696
+ speech_types = {'Regular': {'audio': regular_audio, 'ref_text': regular_ref_text}}
697
+
698
+ for name_input, audio_input, ref_text_input in zip(speech_type_names_list, speech_type_audios_list, speech_type_ref_texts_list):
699
+ if name_input and audio_input:
700
+ speech_types[name_input] = {'audio': audio_input, 'ref_text': ref_text_input}
701
+
702
+ # Parse the gen_text into segments
703
+ segments = parse_speechtypes_text(gen_text)
704
+
705
+ # For each segment, generate speech
706
+ generated_audio_segments = []
707
+ current_emotion = 'Regular'
708
+
709
+ for segment in segments:
710
+ emotion = segment['emotion']
711
+ text = segment['text']
712
+
713
+ if emotion in speech_types:
714
+ current_emotion = emotion
715
+ else:
716
+ # If emotion not available, default to Regular
717
+ current_emotion = 'Regular'
718
+
719
+ ref_audio = speech_types[current_emotion]['audio']
720
+ ref_text = speech_types[current_emotion].get('ref_text', '')
721
+
722
+ # Generate speech for this segment
723
+ audio, _ = infer(ref_audio, ref_text, text, model_choice, remove_silence, "")
724
+ sr, audio_data = audio
725
+
726
+ generated_audio_segments.append(audio_data)
727
+
728
+ # Concatenate all audio segments
729
+ if generated_audio_segments:
730
+ final_audio_data = np.concatenate(generated_audio_segments)
731
+ return (sr, final_audio_data)
732
+ else:
733
+ gr.Warning("No audio generated.")
734
+ return None
735
+
736
+ generate_emotional_btn.click(
737
+ generate_emotional_speech,
738
+ inputs=[
739
+ regular_audio,
740
+ regular_ref_text,
741
+ gen_text_input_emotional,
742
+ ] + speech_type_names + speech_type_audios + speech_type_ref_texts + [
743
+ model_choice_emotional,
744
+ remove_silence_emotional,
745
+ ],
746
+ outputs=audio_output_emotional,
747
+ )
748
+
749
+ # Validation function to disable Generate button if speech types are missing
750
+ def validate_speech_types(
751
+ gen_text,
752
+ regular_name,
753
+ *args
754
+ ):
755
+ num_additional_speech_types = max_speech_types - 1
756
+ speech_type_names_list = args[:num_additional_speech_types]
757
+
758
+ # Collect the speech types names
759
+ speech_types_available = set()
760
+ if regular_name:
761
+ speech_types_available.add(regular_name)
762
+ for name_input in speech_type_names_list:
763
+ if name_input:
764
+ speech_types_available.add(name_input)
765
+
766
+ # Parse the gen_text to get the speech types used
767
+ segments = parse_emotional_text(gen_text)
768
+ speech_types_in_text = set(segment['emotion'] for segment in segments)
769
+
770
+ # Check if all speech types in text are available
771
+ missing_speech_types = speech_types_in_text - speech_types_available
772
+
773
+ if missing_speech_types:
774
+ # Disable the generate button
775
+ return gr.update(interactive=False)
776
+ else:
777
+ # Enable the generate button
778
+ return gr.update(interactive=True)
779
+
780
+ gen_text_input_emotional.change(
781
+ validate_speech_types,
782
+ inputs=[gen_text_input_emotional, regular_name] + speech_type_names,
783
+ outputs=generate_emotional_btn
784
+ )
785
+ with gr.Blocks() as app:
786
+ gr.Markdown(
787
+ """
788
+ # E2/F5 TTS
789
+
790
+ This is a local web UI for F5 TTS with advanced batch processing support. This app supports the following TTS models:
791
+
792
+ * [F5-TTS](https://arxiv.org/abs/2410.06885) (A Fairytaler that Fakes Fluent and Faithful Speech with Flow Matching)
793
+ * [E2 TTS](https://arxiv.org/abs/2406.18009) (Embarrassingly Easy Fully Non-Autoregressive Zero-Shot TTS)
794
+
795
+ The checkpoints support English and Chinese.
796
+
797
+ If you're having issues, try converting your reference audio to WAV or MP3, clipping it to 15s, and shortening your prompt.
798
+
799
+ **NOTE: Reference text will be automatically transcribed with Whisper if not provided. For best results, keep your reference clips short (<15s). Ensure the audio is fully uploaded before generating.**
800
+ """
801
+ )
802
+ gr.TabbedInterface([app_tts, app_podcast, app_emotional, app_credits], ["TTS", "Podcast", "Multi-Style", "Credits"])
803
+
804
+ @click.command()
805
+ @click.option("--port", "-p", default=None, type=int, help="Port to run the app on")
806
+ @click.option("--host", "-H", default=None, help="Host to run the app on")
807
+ @click.option(
808
+ "--share",
809
+ "-s",
810
+ default=False,
811
+ is_flag=True,
812
+ help="Share the app via Gradio share link",
813
+ )
814
+ @click.option("--api", "-a", default=True, is_flag=True, help="Allow API access")
815
+ def main(port, host, share, api):
816
+ global app
817
+ print(f"Starting app...")
818
+ app.queue(api_open=api).launch(
819
+ server_name=host, server_port=port, share=share, show_api=api
820
+ )
821
+
822
+
823
+ if __name__ == "__main__":
824
+ main()
inference-cli.py ADDED
@@ -0,0 +1,378 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ import torch
3
+ import torchaudio
4
+ import numpy as np
5
+ import tempfile
6
+ from einops import rearrange
7
+ from vocos import Vocos
8
+ from pydub import AudioSegment, silence
9
+ from model import CFM, UNetT, DiT, MMDiT
10
+ from cached_path import cached_path
11
+ from model.utils import (
12
+ load_checkpoint,
13
+ get_tokenizer,
14
+ convert_char_to_pinyin,
15
+ save_spectrogram,
16
+ )
17
+ from transformers import pipeline
18
+ import soundfile as sf
19
+ import tomli
20
+ import argparse
21
+ import tqdm
22
+ from pathlib import Path
23
+
24
+ parser = argparse.ArgumentParser(
25
+ prog="python3 inference-cli.py",
26
+ description="Commandline interface for E2/F5 TTS with Advanced Batch Processing.",
27
+ epilog="Specify options above to override one or more settings from config.",
28
+ )
29
+ parser.add_argument(
30
+ "-c",
31
+ "--config",
32
+ help="Configuration file. Default=cli-config.toml",
33
+ default="inference-cli.toml",
34
+ )
35
+ parser.add_argument(
36
+ "-m",
37
+ "--model",
38
+ help="F5-TTS | E2-TTS",
39
+ )
40
+ parser.add_argument(
41
+ "-r",
42
+ "--ref_audio",
43
+ type=str,
44
+ help="Reference audio file < 15 seconds."
45
+ )
46
+ parser.add_argument(
47
+ "-s",
48
+ "--ref_text",
49
+ type=str,
50
+ default="666",
51
+ help="Subtitle for the reference audio."
52
+ )
53
+ parser.add_argument(
54
+ "-t",
55
+ "--gen_text",
56
+ type=str,
57
+ help="Text to generate.",
58
+ )
59
+ parser.add_argument(
60
+ "-o",
61
+ "--output_dir",
62
+ type=str,
63
+ help="Path to output folder..",
64
+ )
65
+ parser.add_argument(
66
+ "--remove_silence",
67
+ help="Remove silence.",
68
+ )
69
+ args = parser.parse_args()
70
+
71
+ config = tomli.load(open(args.config, "rb"))
72
+
73
+ ref_audio = args.ref_audio if args.ref_audio else config["ref_audio"]
74
+ ref_text = args.ref_text if args.ref_text != "666" else config["ref_text"]
75
+ gen_text = args.gen_text if args.gen_text else config["gen_text"]
76
+ output_dir = args.output_dir if args.output_dir else config["output_dir"]
77
+ model = args.model if args.model else config["model"]
78
+ remove_silence = args.remove_silence if args.remove_silence else config["remove_silence"]
79
+ wave_path = Path(output_dir)/"out.wav"
80
+ spectrogram_path = Path(output_dir)/"out.png"
81
+
82
+ SPLIT_WORDS = [
83
+ "but", "however", "nevertheless", "yet", "still",
84
+ "therefore", "thus", "hence", "consequently",
85
+ "moreover", "furthermore", "additionally",
86
+ "meanwhile", "alternatively", "otherwise",
87
+ "namely", "specifically", "for example", "such as",
88
+ "in fact", "indeed", "notably",
89
+ "in contrast", "on the other hand", "conversely",
90
+ "in conclusion", "to summarize", "finally"
91
+ ]
92
+
93
+ device = (
94
+ "cuda"
95
+ if torch.cuda.is_available()
96
+ else "mps" if torch.backends.mps.is_available() else "cpu"
97
+ )
98
+ vocos = Vocos.from_pretrained("charactr/vocos-mel-24khz")
99
+
100
+ print(f"Using {device} device")
101
+
102
+ # --------------------- Settings -------------------- #
103
+
104
+ target_sample_rate = 24000
105
+ n_mel_channels = 100
106
+ hop_length = 256
107
+ target_rms = 0.1
108
+ nfe_step = 32 # 16, 32
109
+ cfg_strength = 2.0
110
+ ode_method = "euler"
111
+ sway_sampling_coef = -1.0
112
+ speed = 1.0
113
+ # fix_duration = 27 # None or float (duration in seconds)
114
+ fix_duration = None
115
+
116
+ def load_model(repo_name, exp_name, model_cls, model_cfg, ckpt_step):
117
+ ckpt_path = str(cached_path(f"hf://SWivid/{repo_name}/{exp_name}/model_{ckpt_step}.safetensors"))
118
+ # ckpt_path = f"ckpts/{exp_name}/model_{ckpt_step}.pt" # .pt | .safetensors
119
+ vocab_char_map, vocab_size = get_tokenizer("Emilia_ZH_EN", "pinyin")
120
+ model = CFM(
121
+ transformer=model_cls(
122
+ **model_cfg, text_num_embeds=vocab_size, mel_dim=n_mel_channels
123
+ ),
124
+ mel_spec_kwargs=dict(
125
+ target_sample_rate=target_sample_rate,
126
+ n_mel_channels=n_mel_channels,
127
+ hop_length=hop_length,
128
+ ),
129
+ odeint_kwargs=dict(
130
+ method=ode_method,
131
+ ),
132
+ vocab_char_map=vocab_char_map,
133
+ ).to(device)
134
+
135
+ model = load_checkpoint(model, ckpt_path, device, use_ema = True)
136
+
137
+ return model
138
+
139
+
140
+ # load models
141
+ F5TTS_model_cfg = dict(
142
+ dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4
143
+ )
144
+ E2TTS_model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4)
145
+
146
+ def split_text_into_batches(text, max_chars=200, split_words=SPLIT_WORDS):
147
+ if len(text.encode('utf-8')) <= max_chars:
148
+ return [text]
149
+ if text[-1] not in ['。', '.', '!', '!', '?', '?']:
150
+ text += '.'
151
+
152
+ sentences = re.split('([。.!?!?])', text)
153
+ sentences = [''.join(i) for i in zip(sentences[0::2], sentences[1::2])]
154
+
155
+ batches = []
156
+ current_batch = ""
157
+
158
+ def split_by_words(text):
159
+ words = text.split()
160
+ current_word_part = ""
161
+ word_batches = []
162
+ for word in words:
163
+ if len(current_word_part.encode('utf-8')) + len(word.encode('utf-8')) + 1 <= max_chars:
164
+ current_word_part += word + ' '
165
+ else:
166
+ if current_word_part:
167
+ # Try to find a suitable split word
168
+ for split_word in split_words:
169
+ split_index = current_word_part.rfind(' ' + split_word + ' ')
170
+ if split_index != -1:
171
+ word_batches.append(current_word_part[:split_index].strip())
172
+ current_word_part = current_word_part[split_index:].strip() + ' '
173
+ break
174
+ else:
175
+ # If no suitable split word found, just append the current part
176
+ word_batches.append(current_word_part.strip())
177
+ current_word_part = ""
178
+ current_word_part += word + ' '
179
+ if current_word_part:
180
+ word_batches.append(current_word_part.strip())
181
+ return word_batches
182
+
183
+ for sentence in sentences:
184
+ if len(current_batch.encode('utf-8')) + len(sentence.encode('utf-8')) <= max_chars:
185
+ current_batch += sentence
186
+ else:
187
+ # If adding this sentence would exceed the limit
188
+ if current_batch:
189
+ batches.append(current_batch)
190
+ current_batch = ""
191
+
192
+ # If the sentence itself is longer than max_chars, split it
193
+ if len(sentence.encode('utf-8')) > max_chars:
194
+ # First, try to split by colon
195
+ colon_parts = sentence.split(':')
196
+ if len(colon_parts) > 1:
197
+ for part in colon_parts:
198
+ if len(part.encode('utf-8')) <= max_chars:
199
+ batches.append(part)
200
+ else:
201
+ # If colon part is still too long, split by comma
202
+ comma_parts = re.split('[,,]', part)
203
+ if len(comma_parts) > 1:
204
+ current_comma_part = ""
205
+ for comma_part in comma_parts:
206
+ if len(current_comma_part.encode('utf-8')) + len(comma_part.encode('utf-8')) <= max_chars:
207
+ current_comma_part += comma_part + ','
208
+ else:
209
+ if current_comma_part:
210
+ batches.append(current_comma_part.rstrip(','))
211
+ current_comma_part = comma_part + ','
212
+ if current_comma_part:
213
+ batches.append(current_comma_part.rstrip(','))
214
+ else:
215
+ # If no comma, split by words
216
+ batches.extend(split_by_words(part))
217
+ else:
218
+ # If no colon, split by comma
219
+ comma_parts = re.split('[,,]', sentence)
220
+ if len(comma_parts) > 1:
221
+ current_comma_part = ""
222
+ for comma_part in comma_parts:
223
+ if len(current_comma_part.encode('utf-8')) + len(comma_part.encode('utf-8')) <= max_chars:
224
+ current_comma_part += comma_part + ','
225
+ else:
226
+ if current_comma_part:
227
+ batches.append(current_comma_part.rstrip(','))
228
+ current_comma_part = comma_part + ','
229
+ if current_comma_part:
230
+ batches.append(current_comma_part.rstrip(','))
231
+ else:
232
+ # If no comma, split by words
233
+ batches.extend(split_by_words(sentence))
234
+ else:
235
+ current_batch = sentence
236
+
237
+ if current_batch:
238
+ batches.append(current_batch)
239
+
240
+ return batches
241
+
242
+ def infer_batch(ref_audio, ref_text, gen_text_batches, model, remove_silence):
243
+ if model == "F5-TTS":
244
+ ema_model = load_model(model, "F5TTS_Base", DiT, F5TTS_model_cfg, 1200000)
245
+ elif model == "E2-TTS":
246
+ ema_model = load_model(model, "E2TTS_Base", UNetT, E2TTS_model_cfg, 1200000)
247
+
248
+ audio, sr = ref_audio
249
+ if audio.shape[0] > 1:
250
+ audio = torch.mean(audio, dim=0, keepdim=True)
251
+
252
+ rms = torch.sqrt(torch.mean(torch.square(audio)))
253
+ if rms < target_rms:
254
+ audio = audio * target_rms / rms
255
+ if sr != target_sample_rate:
256
+ resampler = torchaudio.transforms.Resample(sr, target_sample_rate)
257
+ audio = resampler(audio)
258
+ audio = audio.to(device)
259
+
260
+ generated_waves = []
261
+ spectrograms = []
262
+
263
+ for i, gen_text in enumerate(tqdm.tqdm(gen_text_batches)):
264
+ # Prepare the text
265
+ if len(ref_text[-1].encode('utf-8')) == 1:
266
+ ref_text = ref_text + " "
267
+ text_list = [ref_text + gen_text]
268
+ final_text_list = convert_char_to_pinyin(text_list)
269
+
270
+ # Calculate duration
271
+ ref_audio_len = audio.shape[-1] // hop_length
272
+ zh_pause_punc = r"。,、;:?!"
273
+ ref_text_len = len(ref_text.encode('utf-8')) + 3 * len(re.findall(zh_pause_punc, ref_text))
274
+ gen_text_len = len(gen_text.encode('utf-8')) + 3 * len(re.findall(zh_pause_punc, gen_text))
275
+ duration = ref_audio_len + int(ref_audio_len / ref_text_len * gen_text_len / speed)
276
+
277
+ # inference
278
+ with torch.inference_mode():
279
+ generated, _ = ema_model.sample(
280
+ cond=audio,
281
+ text=final_text_list,
282
+ duration=duration,
283
+ steps=nfe_step,
284
+ cfg_strength=cfg_strength,
285
+ sway_sampling_coef=sway_sampling_coef,
286
+ )
287
+
288
+ generated = generated[:, ref_audio_len:, :]
289
+ generated_mel_spec = rearrange(generated, "1 n d -> 1 d n")
290
+ generated_wave = vocos.decode(generated_mel_spec.cpu())
291
+ if rms < target_rms:
292
+ generated_wave = generated_wave * rms / target_rms
293
+
294
+ # wav -> numpy
295
+ generated_wave = generated_wave.squeeze().cpu().numpy()
296
+
297
+ generated_waves.append(generated_wave)
298
+ spectrograms.append(generated_mel_spec[0].cpu().numpy())
299
+
300
+ # Combine all generated waves
301
+ final_wave = np.concatenate(generated_waves)
302
+
303
+ with open(wave_path, "wb") as f:
304
+ sf.write(f.name, final_wave, target_sample_rate)
305
+ # Remove silence
306
+ if remove_silence:
307
+ aseg = AudioSegment.from_file(f.name)
308
+ non_silent_segs = silence.split_on_silence(aseg, min_silence_len=1000, silence_thresh=-50, keep_silence=500)
309
+ non_silent_wave = AudioSegment.silent(duration=0)
310
+ for non_silent_seg in non_silent_segs:
311
+ non_silent_wave += non_silent_seg
312
+ aseg = non_silent_wave
313
+ aseg.export(f.name, format="wav")
314
+ print(f.name)
315
+
316
+ # Create a combined spectrogram
317
+ combined_spectrogram = np.concatenate(spectrograms, axis=1)
318
+ save_spectrogram(combined_spectrogram, spectrogram_path)
319
+ print(spectrogram_path)
320
+
321
+
322
+ def infer(ref_audio_orig, ref_text, gen_text, model, remove_silence, custom_split_words):
323
+ if not custom_split_words.strip():
324
+ custom_words = [word.strip() for word in custom_split_words.split(',')]
325
+ global SPLIT_WORDS
326
+ SPLIT_WORDS = custom_words
327
+
328
+ print(gen_text)
329
+
330
+ print("Converting audio...")
331
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as f:
332
+ aseg = AudioSegment.from_file(ref_audio_orig)
333
+
334
+ non_silent_segs = silence.split_on_silence(aseg, min_silence_len=1000, silence_thresh=-50, keep_silence=500)
335
+ non_silent_wave = AudioSegment.silent(duration=0)
336
+ for non_silent_seg in non_silent_segs:
337
+ non_silent_wave += non_silent_seg
338
+ aseg = non_silent_wave
339
+
340
+ audio_duration = len(aseg)
341
+ if audio_duration > 15000:
342
+ print("Audio is over 15s, clipping to only first 15s.")
343
+ aseg = aseg[:15000]
344
+ aseg.export(f.name, format="wav")
345
+ ref_audio = f.name
346
+
347
+ if not ref_text.strip():
348
+ print("No reference text provided, transcribing reference audio...")
349
+ pipe = pipeline(
350
+ "automatic-speech-recognition",
351
+ model="openai/whisper-large-v3-turbo",
352
+ torch_dtype=torch.float16,
353
+ device=device,
354
+ )
355
+ ref_text = pipe(
356
+ ref_audio,
357
+ chunk_length_s=30,
358
+ batch_size=128,
359
+ generate_kwargs={"task": "transcribe"},
360
+ return_timestamps=False,
361
+ )["text"].strip()
362
+ print("Finished transcription")
363
+ else:
364
+ print("Using custom reference text...")
365
+
366
+ # Split the input text into batches
367
+ audio, sr = torchaudio.load(ref_audio)
368
+ max_chars = int(len(ref_text.encode('utf-8')) / (audio.shape[-1] / sr) * (30 - audio.shape[-1] / sr))
369
+ gen_text_batches = split_text_into_batches(gen_text, max_chars=max_chars)
370
+ print('ref_text', ref_text)
371
+ for i, gen_text in enumerate(gen_text_batches):
372
+ print(f'gen_text {i}', gen_text)
373
+
374
+ print(f"Generating audio using {model} in {len(gen_text_batches)} batches, loading models...")
375
+ return infer_batch((audio, sr), ref_text, gen_text_batches, model, remove_silence)
376
+
377
+
378
+ infer(ref_audio, ref_text, gen_text, model, remove_silence, ",".join(SPLIT_WORDS))
inference-cli.toml ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ # F5-TTS | E2-TTS
2
+ model = "F5-TTS"
3
+ ref_audio = "tests/ref_audio/test_en_1_ref_short.wav"
4
+ # If an empty "", transcribes the reference audio automatically.
5
+ ref_text = "Some call me nature, others call me mother nature."
6
+ gen_text = "I don't really care what you call me. I've been a silent spectator, watching species evolve, empires rise and fall. But always remember, I am mighty and enduring. Respect me and I'll nurture you; ignore me and you shall face the consequences."
7
+ remove_silence = true
8
+ output_dir = "tests"
model/__init__.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ from model.cfm import CFM
2
+
3
+ from model.backbones.unett import UNetT
4
+ from model.backbones.dit import DiT
5
+ from model.backbones.mmdit import MMDiT
6
+
7
+ from model.trainer import Trainer
model/backbones/README.md ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ## Backbones quick introduction
2
+
3
+
4
+ ### unett.py
5
+ - flat unet transformer
6
+ - structure same as in e2-tts & voicebox paper except using rotary pos emb
7
+ - update: allow possible abs pos emb & convnextv2 blocks for embedded text before concat
8
+
9
+ ### dit.py
10
+ - adaln-zero dit
11
+ - embedded timestep as condition
12
+ - concatted noised_input + masked_cond + embedded_text, linear proj in
13
+ - possible abs pos emb & convnextv2 blocks for embedded text before concat
14
+ - possible long skip connection (first layer to last layer)
15
+
16
+ ### mmdit.py
17
+ - sd3 structure
18
+ - timestep as condition
19
+ - left stream: text embedded and applied a abs pos emb
20
+ - right stream: masked_cond & noised_input concatted and with same conv pos emb as unett
model/backbones/dit.py ADDED
@@ -0,0 +1,158 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ ein notation:
3
+ b - batch
4
+ n - sequence
5
+ nt - text sequence
6
+ nw - raw wave length
7
+ d - dimension
8
+ """
9
+
10
+ from __future__ import annotations
11
+
12
+ import torch
13
+ from torch import nn
14
+ import torch.nn.functional as F
15
+
16
+ from einops import repeat
17
+
18
+ from x_transformers.x_transformers import RotaryEmbedding
19
+
20
+ from model.modules import (
21
+ TimestepEmbedding,
22
+ ConvNeXtV2Block,
23
+ ConvPositionEmbedding,
24
+ DiTBlock,
25
+ AdaLayerNormZero_Final,
26
+ precompute_freqs_cis, get_pos_embed_indices,
27
+ )
28
+
29
+
30
+ # Text embedding
31
+
32
+ class TextEmbedding(nn.Module):
33
+ def __init__(self, text_num_embeds, text_dim, conv_layers = 0, conv_mult = 2):
34
+ super().__init__()
35
+ self.text_embed = nn.Embedding(text_num_embeds + 1, text_dim) # use 0 as filler token
36
+
37
+ if conv_layers > 0:
38
+ self.extra_modeling = True
39
+ self.precompute_max_pos = 4096 # ~44s of 24khz audio
40
+ self.register_buffer("freqs_cis", precompute_freqs_cis(text_dim, self.precompute_max_pos), persistent=False)
41
+ self.text_blocks = nn.Sequential(*[ConvNeXtV2Block(text_dim, text_dim * conv_mult) for _ in range(conv_layers)])
42
+ else:
43
+ self.extra_modeling = False
44
+
45
+ def forward(self, text: int['b nt'], seq_len, drop_text = False):
46
+ batch, text_len = text.shape[0], text.shape[1]
47
+ text = text + 1 # use 0 as filler token. preprocess of batch pad -1, see list_str_to_idx()
48
+ text = text[:, :seq_len] # curtail if character tokens are more than the mel spec tokens
49
+ text = F.pad(text, (0, seq_len - text_len), value = 0)
50
+
51
+ if drop_text: # cfg for text
52
+ text = torch.zeros_like(text)
53
+
54
+ text = self.text_embed(text) # b n -> b n d
55
+
56
+ # possible extra modeling
57
+ if self.extra_modeling:
58
+ # sinus pos emb
59
+ batch_start = torch.zeros((batch,), dtype=torch.long)
60
+ pos_idx = get_pos_embed_indices(batch_start, seq_len, max_pos=self.precompute_max_pos)
61
+ text_pos_embed = self.freqs_cis[pos_idx]
62
+ text = text + text_pos_embed
63
+
64
+ # convnextv2 blocks
65
+ text = self.text_blocks(text)
66
+
67
+ return text
68
+
69
+
70
+ # noised input audio and context mixing embedding
71
+
72
+ class InputEmbedding(nn.Module):
73
+ def __init__(self, mel_dim, text_dim, out_dim):
74
+ super().__init__()
75
+ self.proj = nn.Linear(mel_dim * 2 + text_dim, out_dim)
76
+ self.conv_pos_embed = ConvPositionEmbedding(dim = out_dim)
77
+
78
+ def forward(self, x: float['b n d'], cond: float['b n d'], text_embed: float['b n d'], drop_audio_cond = False):
79
+ if drop_audio_cond: # cfg for cond audio
80
+ cond = torch.zeros_like(cond)
81
+
82
+ x = self.proj(torch.cat((x, cond, text_embed), dim = -1))
83
+ x = self.conv_pos_embed(x) + x
84
+ return x
85
+
86
+
87
+ # Transformer backbone using DiT blocks
88
+
89
+ class DiT(nn.Module):
90
+ def __init__(self, *,
91
+ dim, depth = 8, heads = 8, dim_head = 64, dropout = 0.1, ff_mult = 4,
92
+ mel_dim = 100, text_num_embeds = 256, text_dim = None, conv_layers = 0,
93
+ long_skip_connection = False,
94
+ ):
95
+ super().__init__()
96
+
97
+ self.time_embed = TimestepEmbedding(dim)
98
+ if text_dim is None:
99
+ text_dim = mel_dim
100
+ self.text_embed = TextEmbedding(text_num_embeds, text_dim, conv_layers = conv_layers)
101
+ self.input_embed = InputEmbedding(mel_dim, text_dim, dim)
102
+
103
+ self.rotary_embed = RotaryEmbedding(dim_head)
104
+
105
+ self.dim = dim
106
+ self.depth = depth
107
+
108
+ self.transformer_blocks = nn.ModuleList(
109
+ [
110
+ DiTBlock(
111
+ dim = dim,
112
+ heads = heads,
113
+ dim_head = dim_head,
114
+ ff_mult = ff_mult,
115
+ dropout = dropout
116
+ )
117
+ for _ in range(depth)
118
+ ]
119
+ )
120
+ self.long_skip_connection = nn.Linear(dim * 2, dim, bias = False) if long_skip_connection else None
121
+
122
+ self.norm_out = AdaLayerNormZero_Final(dim) # final modulation
123
+ self.proj_out = nn.Linear(dim, mel_dim)
124
+
125
+ def forward(
126
+ self,
127
+ x: float['b n d'], # nosied input audio
128
+ cond: float['b n d'], # masked cond audio
129
+ text: int['b nt'], # text
130
+ time: float['b'] | float[''], # time step
131
+ drop_audio_cond, # cfg for cond audio
132
+ drop_text, # cfg for text
133
+ mask: bool['b n'] | None = None,
134
+ ):
135
+ batch, seq_len = x.shape[0], x.shape[1]
136
+ if time.ndim == 0:
137
+ time = repeat(time, ' -> b', b = batch)
138
+
139
+ # t: conditioning time, c: context (text + masked cond audio), x: noised input audio
140
+ t = self.time_embed(time)
141
+ text_embed = self.text_embed(text, seq_len, drop_text = drop_text)
142
+ x = self.input_embed(x, cond, text_embed, drop_audio_cond = drop_audio_cond)
143
+
144
+ rope = self.rotary_embed.forward_from_seq_len(seq_len)
145
+
146
+ if self.long_skip_connection is not None:
147
+ residual = x
148
+
149
+ for block in self.transformer_blocks:
150
+ x = block(x, t, mask = mask, rope = rope)
151
+
152
+ if self.long_skip_connection is not None:
153
+ x = self.long_skip_connection(torch.cat((x, residual), dim = -1))
154
+
155
+ x = self.norm_out(x, t)
156
+ output = self.proj_out(x)
157
+
158
+ return output
model/backbones/mmdit.py ADDED
@@ -0,0 +1,136 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ ein notation:
3
+ b - batch
4
+ n - sequence
5
+ nt - text sequence
6
+ nw - raw wave length
7
+ d - dimension
8
+ """
9
+
10
+ from __future__ import annotations
11
+
12
+ import torch
13
+ from torch import nn
14
+
15
+ from einops import repeat
16
+
17
+ from x_transformers.x_transformers import RotaryEmbedding
18
+
19
+ from model.modules import (
20
+ TimestepEmbedding,
21
+ ConvPositionEmbedding,
22
+ MMDiTBlock,
23
+ AdaLayerNormZero_Final,
24
+ precompute_freqs_cis, get_pos_embed_indices,
25
+ )
26
+
27
+
28
+ # text embedding
29
+
30
+ class TextEmbedding(nn.Module):
31
+ def __init__(self, out_dim, text_num_embeds):
32
+ super().__init__()
33
+ self.text_embed = nn.Embedding(text_num_embeds + 1, out_dim) # will use 0 as filler token
34
+
35
+ self.precompute_max_pos = 1024
36
+ self.register_buffer("freqs_cis", precompute_freqs_cis(out_dim, self.precompute_max_pos), persistent=False)
37
+
38
+ def forward(self, text: int['b nt'], drop_text = False) -> int['b nt d']:
39
+ text = text + 1
40
+ if drop_text:
41
+ text = torch.zeros_like(text)
42
+ text = self.text_embed(text)
43
+
44
+ # sinus pos emb
45
+ batch_start = torch.zeros((text.shape[0],), dtype=torch.long)
46
+ batch_text_len = text.shape[1]
47
+ pos_idx = get_pos_embed_indices(batch_start, batch_text_len, max_pos=self.precompute_max_pos)
48
+ text_pos_embed = self.freqs_cis[pos_idx]
49
+
50
+ text = text + text_pos_embed
51
+
52
+ return text
53
+
54
+
55
+ # noised input & masked cond audio embedding
56
+
57
+ class AudioEmbedding(nn.Module):
58
+ def __init__(self, in_dim, out_dim):
59
+ super().__init__()
60
+ self.linear = nn.Linear(2 * in_dim, out_dim)
61
+ self.conv_pos_embed = ConvPositionEmbedding(out_dim)
62
+
63
+ def forward(self, x: float['b n d'], cond: float['b n d'], drop_audio_cond = False):
64
+ if drop_audio_cond:
65
+ cond = torch.zeros_like(cond)
66
+ x = torch.cat((x, cond), dim = -1)
67
+ x = self.linear(x)
68
+ x = self.conv_pos_embed(x) + x
69
+ return x
70
+
71
+
72
+ # Transformer backbone using MM-DiT blocks
73
+
74
+ class MMDiT(nn.Module):
75
+ def __init__(self, *,
76
+ dim, depth = 8, heads = 8, dim_head = 64, dropout = 0.1, ff_mult = 4,
77
+ text_num_embeds = 256, mel_dim = 100,
78
+ ):
79
+ super().__init__()
80
+
81
+ self.time_embed = TimestepEmbedding(dim)
82
+ self.text_embed = TextEmbedding(dim, text_num_embeds)
83
+ self.audio_embed = AudioEmbedding(mel_dim, dim)
84
+
85
+ self.rotary_embed = RotaryEmbedding(dim_head)
86
+
87
+ self.dim = dim
88
+ self.depth = depth
89
+
90
+ self.transformer_blocks = nn.ModuleList(
91
+ [
92
+ MMDiTBlock(
93
+ dim = dim,
94
+ heads = heads,
95
+ dim_head = dim_head,
96
+ dropout = dropout,
97
+ ff_mult = ff_mult,
98
+ context_pre_only = i == depth - 1,
99
+ )
100
+ for i in range(depth)
101
+ ]
102
+ )
103
+ self.norm_out = AdaLayerNormZero_Final(dim) # final modulation
104
+ self.proj_out = nn.Linear(dim, mel_dim)
105
+
106
+ def forward(
107
+ self,
108
+ x: float['b n d'], # nosied input audio
109
+ cond: float['b n d'], # masked cond audio
110
+ text: int['b nt'], # text
111
+ time: float['b'] | float[''], # time step
112
+ drop_audio_cond, # cfg for cond audio
113
+ drop_text, # cfg for text
114
+ mask: bool['b n'] | None = None,
115
+ ):
116
+ batch = x.shape[0]
117
+ if time.ndim == 0:
118
+ time = repeat(time, ' -> b', b = batch)
119
+
120
+ # t: conditioning (time), c: context (text + masked cond audio), x: noised input audio
121
+ t = self.time_embed(time)
122
+ c = self.text_embed(text, drop_text = drop_text)
123
+ x = self.audio_embed(x, cond, drop_audio_cond = drop_audio_cond)
124
+
125
+ seq_len = x.shape[1]
126
+ text_len = text.shape[1]
127
+ rope_audio = self.rotary_embed.forward_from_seq_len(seq_len)
128
+ rope_text = self.rotary_embed.forward_from_seq_len(text_len)
129
+
130
+ for block in self.transformer_blocks:
131
+ c, x = block(x, c, t, mask = mask, rope = rope_audio, c_rope = rope_text)
132
+
133
+ x = self.norm_out(x, t)
134
+ output = self.proj_out(x)
135
+
136
+ return output
model/backbones/unett.py ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ ein notation:
3
+ b - batch
4
+ n - sequence
5
+ nt - text sequence
6
+ nw - raw wave length
7
+ d - dimension
8
+ """
9
+
10
+ from __future__ import annotations
11
+ from typing import Literal
12
+
13
+ import torch
14
+ from torch import nn
15
+ import torch.nn.functional as F
16
+
17
+ from einops import repeat, pack, unpack
18
+
19
+ from x_transformers import RMSNorm
20
+ from x_transformers.x_transformers import RotaryEmbedding
21
+
22
+ from model.modules import (
23
+ TimestepEmbedding,
24
+ ConvNeXtV2Block,
25
+ ConvPositionEmbedding,
26
+ Attention,
27
+ AttnProcessor,
28
+ FeedForward,
29
+ precompute_freqs_cis, get_pos_embed_indices,
30
+ )
31
+
32
+
33
+ # Text embedding
34
+
35
+ class TextEmbedding(nn.Module):
36
+ def __init__(self, text_num_embeds, text_dim, conv_layers = 0, conv_mult = 2):
37
+ super().__init__()
38
+ self.text_embed = nn.Embedding(text_num_embeds + 1, text_dim) # use 0 as filler token
39
+
40
+ if conv_layers > 0:
41
+ self.extra_modeling = True
42
+ self.precompute_max_pos = 4096 # ~44s of 24khz audio
43
+ self.register_buffer("freqs_cis", precompute_freqs_cis(text_dim, self.precompute_max_pos), persistent=False)
44
+ self.text_blocks = nn.Sequential(*[ConvNeXtV2Block(text_dim, text_dim * conv_mult) for _ in range(conv_layers)])
45
+ else:
46
+ self.extra_modeling = False
47
+
48
+ def forward(self, text: int['b nt'], seq_len, drop_text = False):
49
+ batch, text_len = text.shape[0], text.shape[1]
50
+ text = text + 1 # use 0 as filler token. preprocess of batch pad -1, see list_str_to_idx()
51
+ text = text[:, :seq_len] # curtail if character tokens are more than the mel spec tokens
52
+ text = F.pad(text, (0, seq_len - text_len), value = 0)
53
+
54
+ if drop_text: # cfg for text
55
+ text = torch.zeros_like(text)
56
+
57
+ text = self.text_embed(text) # b n -> b n d
58
+
59
+ # possible extra modeling
60
+ if self.extra_modeling:
61
+ # sinus pos emb
62
+ batch_start = torch.zeros((batch,), dtype=torch.long)
63
+ pos_idx = get_pos_embed_indices(batch_start, seq_len, max_pos=self.precompute_max_pos)
64
+ text_pos_embed = self.freqs_cis[pos_idx]
65
+ text = text + text_pos_embed
66
+
67
+ # convnextv2 blocks
68
+ text = self.text_blocks(text)
69
+
70
+ return text
71
+
72
+
73
+ # noised input audio and context mixing embedding
74
+
75
+ class InputEmbedding(nn.Module):
76
+ def __init__(self, mel_dim, text_dim, out_dim):
77
+ super().__init__()
78
+ self.proj = nn.Linear(mel_dim * 2 + text_dim, out_dim)
79
+ self.conv_pos_embed = ConvPositionEmbedding(dim = out_dim)
80
+
81
+ def forward(self, x: float['b n d'], cond: float['b n d'], text_embed: float['b n d'], drop_audio_cond = False):
82
+ if drop_audio_cond: # cfg for cond audio
83
+ cond = torch.zeros_like(cond)
84
+
85
+ x = self.proj(torch.cat((x, cond, text_embed), dim = -1))
86
+ x = self.conv_pos_embed(x) + x
87
+ return x
88
+
89
+
90
+ # Flat UNet Transformer backbone
91
+
92
+ class UNetT(nn.Module):
93
+ def __init__(self, *,
94
+ dim, depth = 8, heads = 8, dim_head = 64, dropout = 0.1, ff_mult = 4,
95
+ mel_dim = 100, text_num_embeds = 256, text_dim = None, conv_layers = 0,
96
+ skip_connect_type: Literal['add', 'concat', 'none'] = 'concat',
97
+ ):
98
+ super().__init__()
99
+ assert depth % 2 == 0, "UNet-Transformer's depth should be even."
100
+
101
+ self.time_embed = TimestepEmbedding(dim)
102
+ if text_dim is None:
103
+ text_dim = mel_dim
104
+ self.text_embed = TextEmbedding(text_num_embeds, text_dim, conv_layers = conv_layers)
105
+ self.input_embed = InputEmbedding(mel_dim, text_dim, dim)
106
+
107
+ self.rotary_embed = RotaryEmbedding(dim_head)
108
+
109
+ # transformer layers & skip connections
110
+
111
+ self.dim = dim
112
+ self.skip_connect_type = skip_connect_type
113
+ needs_skip_proj = skip_connect_type == 'concat'
114
+
115
+ self.depth = depth
116
+ self.layers = nn.ModuleList([])
117
+
118
+ for idx in range(depth):
119
+ is_later_half = idx >= (depth // 2)
120
+
121
+ attn_norm = RMSNorm(dim)
122
+ attn = Attention(
123
+ processor = AttnProcessor(),
124
+ dim = dim,
125
+ heads = heads,
126
+ dim_head = dim_head,
127
+ dropout = dropout,
128
+ )
129
+
130
+ ff_norm = RMSNorm(dim)
131
+ ff = FeedForward(dim = dim, mult = ff_mult, dropout = dropout, approximate = "tanh")
132
+
133
+ skip_proj = nn.Linear(dim * 2, dim, bias = False) if needs_skip_proj and is_later_half else None
134
+
135
+ self.layers.append(nn.ModuleList([
136
+ skip_proj,
137
+ attn_norm,
138
+ attn,
139
+ ff_norm,
140
+ ff,
141
+ ]))
142
+
143
+ self.norm_out = RMSNorm(dim)
144
+ self.proj_out = nn.Linear(dim, mel_dim)
145
+
146
+ def forward(
147
+ self,
148
+ x: float['b n d'], # nosied input audio
149
+ cond: float['b n d'], # masked cond audio
150
+ text: int['b nt'], # text
151
+ time: float['b'] | float[''], # time step
152
+ drop_audio_cond, # cfg for cond audio
153
+ drop_text, # cfg for text
154
+ mask: bool['b n'] | None = None,
155
+ ):
156
+ batch, seq_len = x.shape[0], x.shape[1]
157
+ if time.ndim == 0:
158
+ time = repeat(time, ' -> b', b = batch)
159
+
160
+ # t: conditioning time, c: context (text + masked cond audio), x: noised input audio
161
+ t = self.time_embed(time)
162
+ text_embed = self.text_embed(text, seq_len, drop_text = drop_text)
163
+ x = self.input_embed(x, cond, text_embed, drop_audio_cond = drop_audio_cond)
164
+
165
+ # postfix time t to input x, [b n d] -> [b n+1 d]
166
+ x, ps = pack((t, x), 'b * d')
167
+ if mask is not None:
168
+ mask = F.pad(mask, (1, 0), value=1)
169
+
170
+ rope = self.rotary_embed.forward_from_seq_len(seq_len + 1)
171
+
172
+ # flat unet transformer
173
+ skip_connect_type = self.skip_connect_type
174
+ skips = []
175
+ for idx, (maybe_skip_proj, attn_norm, attn, ff_norm, ff) in enumerate(self.layers):
176
+ layer = idx + 1
177
+
178
+ # skip connection logic
179
+ is_first_half = layer <= (self.depth // 2)
180
+ is_later_half = not is_first_half
181
+
182
+ if is_first_half:
183
+ skips.append(x)
184
+
185
+ if is_later_half:
186
+ skip = skips.pop()
187
+ if skip_connect_type == 'concat':
188
+ x = torch.cat((x, skip), dim = -1)
189
+ x = maybe_skip_proj(x)
190
+ elif skip_connect_type == 'add':
191
+ x = x + skip
192
+
193
+ # attention and feedforward blocks
194
+ x = attn(attn_norm(x), rope = rope, mask = mask) + x
195
+ x = ff(ff_norm(x)) + x
196
+
197
+ assert len(skips) == 0
198
+
199
+ _, x = unpack(self.norm_out(x), ps, 'b * d')
200
+
201
+ return self.proj_out(x)
model/cfm.py ADDED
@@ -0,0 +1,279 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ ein notation:
3
+ b - batch
4
+ n - sequence
5
+ nt - text sequence
6
+ nw - raw wave length
7
+ d - dimension
8
+ """
9
+
10
+ from __future__ import annotations
11
+ from typing import Callable
12
+ from random import random
13
+
14
+ import torch
15
+ from torch import nn
16
+ import torch.nn.functional as F
17
+ from torch.nn.utils.rnn import pad_sequence
18
+
19
+ from torchdiffeq import odeint
20
+
21
+ from einops import rearrange
22
+
23
+ from model.modules import MelSpec
24
+
25
+ from model.utils import (
26
+ default, exists,
27
+ list_str_to_idx, list_str_to_tensor,
28
+ lens_to_mask, mask_from_frac_lengths,
29
+ )
30
+
31
+
32
+ class CFM(nn.Module):
33
+ def __init__(
34
+ self,
35
+ transformer: nn.Module,
36
+ sigma = 0.,
37
+ odeint_kwargs: dict = dict(
38
+ # atol = 1e-5,
39
+ # rtol = 1e-5,
40
+ method = 'euler' # 'midpoint'
41
+ ),
42
+ audio_drop_prob = 0.3,
43
+ cond_drop_prob = 0.2,
44
+ num_channels = None,
45
+ mel_spec_module: nn.Module | None = None,
46
+ mel_spec_kwargs: dict = dict(),
47
+ frac_lengths_mask: tuple[float, float] = (0.7, 1.),
48
+ vocab_char_map: dict[str: int] | None = None
49
+ ):
50
+ super().__init__()
51
+
52
+ self.frac_lengths_mask = frac_lengths_mask
53
+
54
+ # mel spec
55
+ self.mel_spec = default(mel_spec_module, MelSpec(**mel_spec_kwargs))
56
+ num_channels = default(num_channels, self.mel_spec.n_mel_channels)
57
+ self.num_channels = num_channels
58
+
59
+ # classifier-free guidance
60
+ self.audio_drop_prob = audio_drop_prob
61
+ self.cond_drop_prob = cond_drop_prob
62
+
63
+ # transformer
64
+ self.transformer = transformer
65
+ dim = transformer.dim
66
+ self.dim = dim
67
+
68
+ # conditional flow related
69
+ self.sigma = sigma
70
+
71
+ # sampling related
72
+ self.odeint_kwargs = odeint_kwargs
73
+
74
+ # vocab map for tokenization
75
+ self.vocab_char_map = vocab_char_map
76
+
77
+ @property
78
+ def device(self):
79
+ return next(self.parameters()).device
80
+
81
+ @torch.no_grad()
82
+ def sample(
83
+ self,
84
+ cond: float['b n d'] | float['b nw'],
85
+ text: int['b nt'] | list[str],
86
+ duration: int | int['b'],
87
+ *,
88
+ lens: int['b'] | None = None,
89
+ steps = 32,
90
+ cfg_strength = 1.,
91
+ sway_sampling_coef = None,
92
+ seed: int | None = None,
93
+ max_duration = 4096,
94
+ vocoder: Callable[[float['b d n']], float['b nw']] | None = None,
95
+ no_ref_audio = False,
96
+ duplicate_test = False,
97
+ t_inter = 0.1,
98
+ edit_mask = None,
99
+ ):
100
+ self.eval()
101
+
102
+ # raw wave
103
+
104
+ if cond.ndim == 2:
105
+ cond = self.mel_spec(cond)
106
+ cond = rearrange(cond, 'b d n -> b n d')
107
+ assert cond.shape[-1] == self.num_channels
108
+
109
+ batch, cond_seq_len, device = *cond.shape[:2], cond.device
110
+ if not exists(lens):
111
+ lens = torch.full((batch,), cond_seq_len, device = device, dtype = torch.long)
112
+
113
+ # text
114
+
115
+ if isinstance(text, list):
116
+ if exists(self.vocab_char_map):
117
+ text = list_str_to_idx(text, self.vocab_char_map).to(device)
118
+ else:
119
+ text = list_str_to_tensor(text).to(device)
120
+ assert text.shape[0] == batch
121
+
122
+ if exists(text):
123
+ text_lens = (text != -1).sum(dim = -1)
124
+ lens = torch.maximum(text_lens, lens) # make sure lengths are at least those of the text characters
125
+
126
+ # duration
127
+
128
+ cond_mask = lens_to_mask(lens)
129
+ if edit_mask is not None:
130
+ cond_mask = cond_mask & edit_mask
131
+
132
+ if isinstance(duration, int):
133
+ duration = torch.full((batch,), duration, device = device, dtype = torch.long)
134
+
135
+ duration = torch.maximum(lens + 1, duration) # just add one token so something is generated
136
+ duration = duration.clamp(max = max_duration)
137
+ max_duration = duration.amax()
138
+
139
+ # duplicate test corner for inner time step oberservation
140
+ if duplicate_test:
141
+ test_cond = F.pad(cond, (0, 0, cond_seq_len, max_duration - 2*cond_seq_len), value = 0.)
142
+
143
+ cond = F.pad(cond, (0, 0, 0, max_duration - cond_seq_len), value = 0.)
144
+ cond_mask = F.pad(cond_mask, (0, max_duration - cond_mask.shape[-1]), value = False)
145
+ cond_mask = rearrange(cond_mask, '... -> ... 1')
146
+ step_cond = torch.where(cond_mask, cond, torch.zeros_like(cond)) # allow direct control (cut cond audio) with lens passed in
147
+
148
+ if batch > 1:
149
+ mask = lens_to_mask(duration)
150
+ else: # save memory and speed up, as single inference need no mask currently
151
+ mask = None
152
+
153
+ # test for no ref audio
154
+ if no_ref_audio:
155
+ cond = torch.zeros_like(cond)
156
+
157
+ # neural ode
158
+
159
+ def fn(t, x):
160
+ # at each step, conditioning is fixed
161
+ # step_cond = torch.where(cond_mask, cond, torch.zeros_like(cond))
162
+
163
+ # predict flow
164
+ pred = self.transformer(x = x, cond = step_cond, text = text, time = t, mask = mask, drop_audio_cond = False, drop_text = False)
165
+ if cfg_strength < 1e-5:
166
+ return pred
167
+
168
+ null_pred = self.transformer(x = x, cond = step_cond, text = text, time = t, mask = mask, drop_audio_cond = True, drop_text = True)
169
+ return pred + (pred - null_pred) * cfg_strength
170
+
171
+ # noise input
172
+ # to make sure batch inference result is same with different batch size, and for sure single inference
173
+ # still some difference maybe due to convolutional layers
174
+ y0 = []
175
+ for dur in duration:
176
+ if exists(seed):
177
+ torch.manual_seed(seed)
178
+ y0.append(torch.randn(dur, self.num_channels, device = self.device))
179
+ y0 = pad_sequence(y0, padding_value = 0, batch_first = True)
180
+
181
+ t_start = 0
182
+
183
+ # duplicate test corner for inner time step oberservation
184
+ if duplicate_test:
185
+ t_start = t_inter
186
+ y0 = (1 - t_start) * y0 + t_start * test_cond
187
+ steps = int(steps * (1 - t_start))
188
+
189
+ t = torch.linspace(t_start, 1, steps, device = self.device)
190
+ if sway_sampling_coef is not None:
191
+ t = t + sway_sampling_coef * (torch.cos(torch.pi / 2 * t) - 1 + t)
192
+
193
+ trajectory = odeint(fn, y0, t, **self.odeint_kwargs)
194
+
195
+ sampled = trajectory[-1]
196
+ out = sampled
197
+ out = torch.where(cond_mask, cond, out)
198
+
199
+ if exists(vocoder):
200
+ out = rearrange(out, 'b n d -> b d n')
201
+ out = vocoder(out)
202
+
203
+ return out, trajectory
204
+
205
+ def forward(
206
+ self,
207
+ inp: float['b n d'] | float['b nw'], # mel or raw wave
208
+ text: int['b nt'] | list[str],
209
+ *,
210
+ lens: int['b'] | None = None,
211
+ noise_scheduler: str | None = None,
212
+ ):
213
+ # handle raw wave
214
+ if inp.ndim == 2:
215
+ inp = self.mel_spec(inp)
216
+ inp = rearrange(inp, 'b d n -> b n d')
217
+ assert inp.shape[-1] == self.num_channels
218
+
219
+ batch, seq_len, dtype, device, σ1 = *inp.shape[:2], inp.dtype, self.device, self.sigma
220
+
221
+ # handle text as string
222
+ if isinstance(text, list):
223
+ if exists(self.vocab_char_map):
224
+ text = list_str_to_idx(text, self.vocab_char_map).to(device)
225
+ else:
226
+ text = list_str_to_tensor(text).to(device)
227
+ assert text.shape[0] == batch
228
+
229
+ # lens and mask
230
+ if not exists(lens):
231
+ lens = torch.full((batch,), seq_len, device = device)
232
+
233
+ mask = lens_to_mask(lens, length = seq_len) # useless here, as collate_fn will pad to max length in batch
234
+
235
+ # get a random span to mask out for training conditionally
236
+ frac_lengths = torch.zeros((batch,), device = self.device).float().uniform_(*self.frac_lengths_mask)
237
+ rand_span_mask = mask_from_frac_lengths(lens, frac_lengths)
238
+
239
+ if exists(mask):
240
+ rand_span_mask &= mask
241
+
242
+ # mel is x1
243
+ x1 = inp
244
+
245
+ # x0 is gaussian noise
246
+ x0 = torch.randn_like(x1)
247
+
248
+ # time step
249
+ time = torch.rand((batch,), dtype = dtype, device = self.device)
250
+ # TODO. noise_scheduler
251
+
252
+ # sample xt (φ_t(x) in the paper)
253
+ t = rearrange(time, 'b -> b 1 1')
254
+ φ = (1 - t) * x0 + t * x1
255
+ flow = x1 - x0
256
+
257
+ # only predict what is within the random mask span for infilling
258
+ cond = torch.where(
259
+ rand_span_mask[..., None],
260
+ torch.zeros_like(x1), x1
261
+ )
262
+
263
+ # transformer and cfg training with a drop rate
264
+ drop_audio_cond = random() < self.audio_drop_prob # p_drop in voicebox paper
265
+ if random() < self.cond_drop_prob: # p_uncond in voicebox paper
266
+ drop_audio_cond = True
267
+ drop_text = True
268
+ else:
269
+ drop_text = False
270
+
271
+ # if want rigourously mask out padding, record in collate_fn in dataset.py, and pass in here
272
+ # adding mask will use more memory, thus also need to adjust batchsampler with scaled down threshold for long sequences
273
+ pred = self.transformer(x = φ, cond = cond, text = text, time = time, drop_audio_cond = drop_audio_cond, drop_text = drop_text)
274
+
275
+ # flow matching loss
276
+ loss = F.mse_loss(pred, flow, reduction = 'none')
277
+ loss = loss[rand_span_mask]
278
+
279
+ return loss.mean(), cond, pred
model/dataset.py ADDED
@@ -0,0 +1,242 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import random
3
+ from tqdm import tqdm
4
+
5
+ import torch
6
+ import torch.nn.functional as F
7
+ from torch.utils.data import Dataset, Sampler
8
+ import torchaudio
9
+ from datasets import load_dataset, load_from_disk
10
+ from datasets import Dataset as Dataset_
11
+
12
+ from einops import rearrange
13
+
14
+ from model.modules import MelSpec
15
+
16
+
17
+ class HFDataset(Dataset):
18
+ def __init__(
19
+ self,
20
+ hf_dataset: Dataset,
21
+ target_sample_rate = 24_000,
22
+ n_mel_channels = 100,
23
+ hop_length = 256,
24
+ ):
25
+ self.data = hf_dataset
26
+ self.target_sample_rate = target_sample_rate
27
+ self.hop_length = hop_length
28
+ self.mel_spectrogram = MelSpec(target_sample_rate=target_sample_rate, n_mel_channels=n_mel_channels, hop_length=hop_length)
29
+
30
+ def get_frame_len(self, index):
31
+ row = self.data[index]
32
+ audio = row['audio']['array']
33
+ sample_rate = row['audio']['sampling_rate']
34
+ return audio.shape[-1] / sample_rate * self.target_sample_rate / self.hop_length
35
+
36
+ def __len__(self):
37
+ return len(self.data)
38
+
39
+ def __getitem__(self, index):
40
+ row = self.data[index]
41
+ audio = row['audio']['array']
42
+
43
+ # logger.info(f"Audio shape: {audio.shape}")
44
+
45
+ sample_rate = row['audio']['sampling_rate']
46
+ duration = audio.shape[-1] / sample_rate
47
+
48
+ if duration > 30 or duration < 0.3:
49
+ return self.__getitem__((index + 1) % len(self.data))
50
+
51
+ audio_tensor = torch.from_numpy(audio).float()
52
+
53
+ if sample_rate != self.target_sample_rate:
54
+ resampler = torchaudio.transforms.Resample(sample_rate, self.target_sample_rate)
55
+ audio_tensor = resampler(audio_tensor)
56
+
57
+ audio_tensor = rearrange(audio_tensor, 't -> 1 t')
58
+
59
+ mel_spec = self.mel_spectrogram(audio_tensor)
60
+
61
+ mel_spec = rearrange(mel_spec, '1 d t -> d t')
62
+
63
+ text = row['text']
64
+
65
+ return dict(
66
+ mel_spec = mel_spec,
67
+ text = text,
68
+ )
69
+
70
+
71
+ class CustomDataset(Dataset):
72
+ def __init__(
73
+ self,
74
+ custom_dataset: Dataset,
75
+ durations = None,
76
+ target_sample_rate = 24_000,
77
+ hop_length = 256,
78
+ n_mel_channels = 100,
79
+ preprocessed_mel = False,
80
+ ):
81
+ self.data = custom_dataset
82
+ self.durations = durations
83
+ self.target_sample_rate = target_sample_rate
84
+ self.hop_length = hop_length
85
+ self.preprocessed_mel = preprocessed_mel
86
+ if not preprocessed_mel:
87
+ self.mel_spectrogram = MelSpec(target_sample_rate=target_sample_rate, hop_length=hop_length, n_mel_channels=n_mel_channels)
88
+
89
+ def get_frame_len(self, index):
90
+ if self.durations is not None: # Please make sure the separately provided durations are correct, otherwise 99.99% OOM
91
+ return self.durations[index] * self.target_sample_rate / self.hop_length
92
+ return self.data[index]["duration"] * self.target_sample_rate / self.hop_length
93
+
94
+ def __len__(self):
95
+ return len(self.data)
96
+
97
+ def __getitem__(self, index):
98
+ row = self.data[index]
99
+ audio_path = row["audio_path"]
100
+ text = row["text"]
101
+ duration = row["duration"]
102
+
103
+ if self.preprocessed_mel:
104
+ mel_spec = torch.tensor(row["mel_spec"])
105
+
106
+ else:
107
+ audio, source_sample_rate = torchaudio.load(audio_path)
108
+
109
+ if duration > 30 or duration < 0.3:
110
+ return self.__getitem__((index + 1) % len(self.data))
111
+
112
+ if source_sample_rate != self.target_sample_rate:
113
+ resampler = torchaudio.transforms.Resample(source_sample_rate, self.target_sample_rate)
114
+ audio = resampler(audio)
115
+
116
+ mel_spec = self.mel_spectrogram(audio)
117
+ mel_spec = rearrange(mel_spec, '1 d t -> d t')
118
+
119
+ return dict(
120
+ mel_spec = mel_spec,
121
+ text = text,
122
+ )
123
+
124
+
125
+ # Dynamic Batch Sampler
126
+
127
+ class DynamicBatchSampler(Sampler[list[int]]):
128
+ """ Extension of Sampler that will do the following:
129
+ 1. Change the batch size (essentially number of sequences)
130
+ in a batch to ensure that the total number of frames are less
131
+ than a certain threshold.
132
+ 2. Make sure the padding efficiency in the batch is high.
133
+ """
134
+
135
+ def __init__(self, sampler: Sampler[int], frames_threshold: int, max_samples=0, random_seed=None, drop_last: bool = False):
136
+ self.sampler = sampler
137
+ self.frames_threshold = frames_threshold
138
+ self.max_samples = max_samples
139
+
140
+ indices, batches = [], []
141
+ data_source = self.sampler.data_source
142
+
143
+ for idx in tqdm(self.sampler, desc=f"Sorting with sampler... if slow, check whether dataset is provided with duration"):
144
+ indices.append((idx, data_source.get_frame_len(idx)))
145
+ indices.sort(key=lambda elem : elem[1])
146
+
147
+ batch = []
148
+ batch_frames = 0
149
+ for idx, frame_len in tqdm(indices, desc=f"Creating dynamic batches with {frames_threshold} audio frames per gpu"):
150
+ if batch_frames + frame_len <= self.frames_threshold and (max_samples == 0 or len(batch) < max_samples):
151
+ batch.append(idx)
152
+ batch_frames += frame_len
153
+ else:
154
+ if len(batch) > 0:
155
+ batches.append(batch)
156
+ if frame_len <= self.frames_threshold:
157
+ batch = [idx]
158
+ batch_frames = frame_len
159
+ else:
160
+ batch = []
161
+ batch_frames = 0
162
+
163
+ if not drop_last and len(batch) > 0:
164
+ batches.append(batch)
165
+
166
+ del indices
167
+
168
+ # if want to have different batches between epochs, may just set a seed and log it in ckpt
169
+ # cuz during multi-gpu training, although the batch on per gpu not change between epochs, the formed general minibatch is different
170
+ # e.g. for epoch n, use (random_seed + n)
171
+ random.seed(random_seed)
172
+ random.shuffle(batches)
173
+
174
+ self.batches = batches
175
+
176
+ def __iter__(self):
177
+ return iter(self.batches)
178
+
179
+ def __len__(self):
180
+ return len(self.batches)
181
+
182
+
183
+ # Load dataset
184
+
185
+ def load_dataset(
186
+ dataset_name: str,
187
+ tokenizer: str,
188
+ dataset_type: str = "CustomDataset",
189
+ audio_type: str = "raw",
190
+ mel_spec_kwargs: dict = dict()
191
+ ) -> CustomDataset:
192
+
193
+ print("Loading dataset ...")
194
+
195
+ if dataset_type == "CustomDataset":
196
+ if audio_type == "raw":
197
+ try:
198
+ train_dataset = load_from_disk(f"data/{dataset_name}_{tokenizer}/raw")
199
+ except:
200
+ train_dataset = Dataset_.from_file(f"data/{dataset_name}_{tokenizer}/raw.arrow")
201
+ preprocessed_mel = False
202
+ elif audio_type == "mel":
203
+ train_dataset = Dataset_.from_file(f"data/{dataset_name}_{tokenizer}/mel.arrow")
204
+ preprocessed_mel = True
205
+ with open(f"data/{dataset_name}_{tokenizer}/duration.json", 'r', encoding='utf-8') as f:
206
+ data_dict = json.load(f)
207
+ durations = data_dict["duration"]
208
+ train_dataset = CustomDataset(train_dataset, durations=durations, preprocessed_mel=preprocessed_mel, **mel_spec_kwargs)
209
+
210
+ elif dataset_type == "HFDataset":
211
+ print("Should manually modify the path of huggingface dataset to your need.\n" +
212
+ "May also the corresponding script cuz different dataset may have different format.")
213
+ pre, post = dataset_name.split("_")
214
+ train_dataset = HFDataset(load_dataset(f"{pre}/{pre}", split=f"train.{post}", cache_dir="./data"),)
215
+
216
+ return train_dataset
217
+
218
+
219
+ # collation
220
+
221
+ def collate_fn(batch):
222
+ mel_specs = [item['mel_spec'].squeeze(0) for item in batch]
223
+ mel_lengths = torch.LongTensor([spec.shape[-1] for spec in mel_specs])
224
+ max_mel_length = mel_lengths.amax()
225
+
226
+ padded_mel_specs = []
227
+ for spec in mel_specs: # TODO. maybe records mask for attention here
228
+ padding = (0, max_mel_length - spec.size(-1))
229
+ padded_spec = F.pad(spec, padding, value = 0)
230
+ padded_mel_specs.append(padded_spec)
231
+
232
+ mel_specs = torch.stack(padded_mel_specs)
233
+
234
+ text = [item['text'] for item in batch]
235
+ text_lengths = torch.LongTensor([len(item) for item in text])
236
+
237
+ return dict(
238
+ mel = mel_specs,
239
+ mel_lengths = mel_lengths,
240
+ text = text,
241
+ text_lengths = text_lengths,
242
+ )
model/ecapa_tdnn.py ADDED
@@ -0,0 +1,268 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # just for speaker similarity evaluation, third-party code
2
+
3
+ # From https://github.com/microsoft/UniSpeech/blob/main/downstreams/speaker_verification/models/
4
+ # part of the code is borrowed from https://github.com/lawlict/ECAPA-TDNN
5
+
6
+ import os
7
+ import torch
8
+ import torch.nn as nn
9
+ import torch.nn.functional as F
10
+
11
+
12
+ ''' Res2Conv1d + BatchNorm1d + ReLU
13
+ '''
14
+
15
+ class Res2Conv1dReluBn(nn.Module):
16
+ '''
17
+ in_channels == out_channels == channels
18
+ '''
19
+
20
+ def __init__(self, channels, kernel_size=1, stride=1, padding=0, dilation=1, bias=True, scale=4):
21
+ super().__init__()
22
+ assert channels % scale == 0, "{} % {} != 0".format(channels, scale)
23
+ self.scale = scale
24
+ self.width = channels // scale
25
+ self.nums = scale if scale == 1 else scale - 1
26
+
27
+ self.convs = []
28
+ self.bns = []
29
+ for i in range(self.nums):
30
+ self.convs.append(nn.Conv1d(self.width, self.width, kernel_size, stride, padding, dilation, bias=bias))
31
+ self.bns.append(nn.BatchNorm1d(self.width))
32
+ self.convs = nn.ModuleList(self.convs)
33
+ self.bns = nn.ModuleList(self.bns)
34
+
35
+ def forward(self, x):
36
+ out = []
37
+ spx = torch.split(x, self.width, 1)
38
+ for i in range(self.nums):
39
+ if i == 0:
40
+ sp = spx[i]
41
+ else:
42
+ sp = sp + spx[i]
43
+ # Order: conv -> relu -> bn
44
+ sp = self.convs[i](sp)
45
+ sp = self.bns[i](F.relu(sp))
46
+ out.append(sp)
47
+ if self.scale != 1:
48
+ out.append(spx[self.nums])
49
+ out = torch.cat(out, dim=1)
50
+
51
+ return out
52
+
53
+
54
+ ''' Conv1d + BatchNorm1d + ReLU
55
+ '''
56
+
57
+ class Conv1dReluBn(nn.Module):
58
+ def __init__(self, in_channels, out_channels, kernel_size=1, stride=1, padding=0, dilation=1, bias=True):
59
+ super().__init__()
60
+ self.conv = nn.Conv1d(in_channels, out_channels, kernel_size, stride, padding, dilation, bias=bias)
61
+ self.bn = nn.BatchNorm1d(out_channels)
62
+
63
+ def forward(self, x):
64
+ return self.bn(F.relu(self.conv(x)))
65
+
66
+
67
+ ''' The SE connection of 1D case.
68
+ '''
69
+
70
+ class SE_Connect(nn.Module):
71
+ def __init__(self, channels, se_bottleneck_dim=128):
72
+ super().__init__()
73
+ self.linear1 = nn.Linear(channels, se_bottleneck_dim)
74
+ self.linear2 = nn.Linear(se_bottleneck_dim, channels)
75
+
76
+ def forward(self, x):
77
+ out = x.mean(dim=2)
78
+ out = F.relu(self.linear1(out))
79
+ out = torch.sigmoid(self.linear2(out))
80
+ out = x * out.unsqueeze(2)
81
+
82
+ return out
83
+
84
+
85
+ ''' SE-Res2Block of the ECAPA-TDNN architecture.
86
+ '''
87
+
88
+ # def SE_Res2Block(channels, kernel_size, stride, padding, dilation, scale):
89
+ # return nn.Sequential(
90
+ # Conv1dReluBn(channels, 512, kernel_size=1, stride=1, padding=0),
91
+ # Res2Conv1dReluBn(512, kernel_size, stride, padding, dilation, scale=scale),
92
+ # Conv1dReluBn(512, channels, kernel_size=1, stride=1, padding=0),
93
+ # SE_Connect(channels)
94
+ # )
95
+
96
+ class SE_Res2Block(nn.Module):
97
+ def __init__(self, in_channels, out_channels, kernel_size, stride, padding, dilation, scale, se_bottleneck_dim):
98
+ super().__init__()
99
+ self.Conv1dReluBn1 = Conv1dReluBn(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
100
+ self.Res2Conv1dReluBn = Res2Conv1dReluBn(out_channels, kernel_size, stride, padding, dilation, scale=scale)
101
+ self.Conv1dReluBn2 = Conv1dReluBn(out_channels, out_channels, kernel_size=1, stride=1, padding=0)
102
+ self.SE_Connect = SE_Connect(out_channels, se_bottleneck_dim)
103
+
104
+ self.shortcut = None
105
+ if in_channels != out_channels:
106
+ self.shortcut = nn.Conv1d(
107
+ in_channels=in_channels,
108
+ out_channels=out_channels,
109
+ kernel_size=1,
110
+ )
111
+
112
+ def forward(self, x):
113
+ residual = x
114
+ if self.shortcut:
115
+ residual = self.shortcut(x)
116
+
117
+ x = self.Conv1dReluBn1(x)
118
+ x = self.Res2Conv1dReluBn(x)
119
+ x = self.Conv1dReluBn2(x)
120
+ x = self.SE_Connect(x)
121
+
122
+ return x + residual
123
+
124
+
125
+ ''' Attentive weighted mean and standard deviation pooling.
126
+ '''
127
+
128
+ class AttentiveStatsPool(nn.Module):
129
+ def __init__(self, in_dim, attention_channels=128, global_context_att=False):
130
+ super().__init__()
131
+ self.global_context_att = global_context_att
132
+
133
+ # Use Conv1d with stride == 1 rather than Linear, then we don't need to transpose inputs.
134
+ if global_context_att:
135
+ self.linear1 = nn.Conv1d(in_dim * 3, attention_channels, kernel_size=1) # equals W and b in the paper
136
+ else:
137
+ self.linear1 = nn.Conv1d(in_dim, attention_channels, kernel_size=1) # equals W and b in the paper
138
+ self.linear2 = nn.Conv1d(attention_channels, in_dim, kernel_size=1) # equals V and k in the paper
139
+
140
+ def forward(self, x):
141
+
142
+ if self.global_context_att:
143
+ context_mean = torch.mean(x, dim=-1, keepdim=True).expand_as(x)
144
+ context_std = torch.sqrt(torch.var(x, dim=-1, keepdim=True) + 1e-10).expand_as(x)
145
+ x_in = torch.cat((x, context_mean, context_std), dim=1)
146
+ else:
147
+ x_in = x
148
+
149
+ # DON'T use ReLU here! In experiments, I find ReLU hard to converge.
150
+ alpha = torch.tanh(self.linear1(x_in))
151
+ # alpha = F.relu(self.linear1(x_in))
152
+ alpha = torch.softmax(self.linear2(alpha), dim=2)
153
+ mean = torch.sum(alpha * x, dim=2)
154
+ residuals = torch.sum(alpha * (x ** 2), dim=2) - mean ** 2
155
+ std = torch.sqrt(residuals.clamp(min=1e-9))
156
+ return torch.cat([mean, std], dim=1)
157
+
158
+
159
+ class ECAPA_TDNN(nn.Module):
160
+ def __init__(self, feat_dim=80, channels=512, emb_dim=192, global_context_att=False,
161
+ feat_type='wavlm_large', sr=16000, feature_selection="hidden_states", update_extract=False, config_path=None):
162
+ super().__init__()
163
+
164
+ self.feat_type = feat_type
165
+ self.feature_selection = feature_selection
166
+ self.update_extract = update_extract
167
+ self.sr = sr
168
+
169
+ torch.hub._validate_not_a_forked_repo=lambda a,b,c: True
170
+ try:
171
+ local_s3prl_path = os.path.expanduser("~/.cache/torch/hub/s3prl_s3prl_main")
172
+ self.feature_extract = torch.hub.load(local_s3prl_path, feat_type, source='local', config_path=config_path)
173
+ except:
174
+ self.feature_extract = torch.hub.load('s3prl/s3prl', feat_type)
175
+
176
+ if len(self.feature_extract.model.encoder.layers) == 24 and hasattr(self.feature_extract.model.encoder.layers[23].self_attn, "fp32_attention"):
177
+ self.feature_extract.model.encoder.layers[23].self_attn.fp32_attention = False
178
+ if len(self.feature_extract.model.encoder.layers) == 24 and hasattr(self.feature_extract.model.encoder.layers[11].self_attn, "fp32_attention"):
179
+ self.feature_extract.model.encoder.layers[11].self_attn.fp32_attention = False
180
+
181
+ self.feat_num = self.get_feat_num()
182
+ self.feature_weight = nn.Parameter(torch.zeros(self.feat_num))
183
+
184
+ if feat_type != 'fbank' and feat_type != 'mfcc':
185
+ freeze_list = ['final_proj', 'label_embs_concat', 'mask_emb', 'project_q', 'quantizer']
186
+ for name, param in self.feature_extract.named_parameters():
187
+ for freeze_val in freeze_list:
188
+ if freeze_val in name:
189
+ param.requires_grad = False
190
+ break
191
+
192
+ if not self.update_extract:
193
+ for param in self.feature_extract.parameters():
194
+ param.requires_grad = False
195
+
196
+ self.instance_norm = nn.InstanceNorm1d(feat_dim)
197
+ # self.channels = [channels] * 4 + [channels * 3]
198
+ self.channels = [channels] * 4 + [1536]
199
+
200
+ self.layer1 = Conv1dReluBn(feat_dim, self.channels[0], kernel_size=5, padding=2)
201
+ self.layer2 = SE_Res2Block(self.channels[0], self.channels[1], kernel_size=3, stride=1, padding=2, dilation=2, scale=8, se_bottleneck_dim=128)
202
+ self.layer3 = SE_Res2Block(self.channels[1], self.channels[2], kernel_size=3, stride=1, padding=3, dilation=3, scale=8, se_bottleneck_dim=128)
203
+ self.layer4 = SE_Res2Block(self.channels[2], self.channels[3], kernel_size=3, stride=1, padding=4, dilation=4, scale=8, se_bottleneck_dim=128)
204
+
205
+ # self.conv = nn.Conv1d(self.channels[-1], self.channels[-1], kernel_size=1)
206
+ cat_channels = channels * 3
207
+ self.conv = nn.Conv1d(cat_channels, self.channels[-1], kernel_size=1)
208
+ self.pooling = AttentiveStatsPool(self.channels[-1], attention_channels=128, global_context_att=global_context_att)
209
+ self.bn = nn.BatchNorm1d(self.channels[-1] * 2)
210
+ self.linear = nn.Linear(self.channels[-1] * 2, emb_dim)
211
+
212
+
213
+ def get_feat_num(self):
214
+ self.feature_extract.eval()
215
+ wav = [torch.randn(self.sr).to(next(self.feature_extract.parameters()).device)]
216
+ with torch.no_grad():
217
+ features = self.feature_extract(wav)
218
+ select_feature = features[self.feature_selection]
219
+ if isinstance(select_feature, (list, tuple)):
220
+ return len(select_feature)
221
+ else:
222
+ return 1
223
+
224
+ def get_feat(self, x):
225
+ if self.update_extract:
226
+ x = self.feature_extract([sample for sample in x])
227
+ else:
228
+ with torch.no_grad():
229
+ if self.feat_type == 'fbank' or self.feat_type == 'mfcc':
230
+ x = self.feature_extract(x) + 1e-6 # B x feat_dim x time_len
231
+ else:
232
+ x = self.feature_extract([sample for sample in x])
233
+
234
+ if self.feat_type == 'fbank':
235
+ x = x.log()
236
+
237
+ if self.feat_type != "fbank" and self.feat_type != "mfcc":
238
+ x = x[self.feature_selection]
239
+ if isinstance(x, (list, tuple)):
240
+ x = torch.stack(x, dim=0)
241
+ else:
242
+ x = x.unsqueeze(0)
243
+ norm_weights = F.softmax(self.feature_weight, dim=-1).unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
244
+ x = (norm_weights * x).sum(dim=0)
245
+ x = torch.transpose(x, 1, 2) + 1e-6
246
+
247
+ x = self.instance_norm(x)
248
+ return x
249
+
250
+ def forward(self, x):
251
+ x = self.get_feat(x)
252
+
253
+ out1 = self.layer1(x)
254
+ out2 = self.layer2(out1)
255
+ out3 = self.layer3(out2)
256
+ out4 = self.layer4(out3)
257
+
258
+ out = torch.cat([out2, out3, out4], dim=1)
259
+ out = F.relu(self.conv(out))
260
+ out = self.bn(self.pooling(out))
261
+ out = self.linear(out)
262
+
263
+ return out
264
+
265
+
266
+ def ECAPA_TDNN_SMALL(feat_dim, emb_dim=256, feat_type='wavlm_large', sr=16000, feature_selection="hidden_states", update_extract=False, config_path=None):
267
+ return ECAPA_TDNN(feat_dim=feat_dim, channels=512, emb_dim=emb_dim,
268
+ feat_type=feat_type, sr=sr, feature_selection=feature_selection, update_extract=update_extract, config_path=config_path)
model/modules.py ADDED
@@ -0,0 +1,575 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ ein notation:
3
+ b - batch
4
+ n - sequence
5
+ nt - text sequence
6
+ nw - raw wave length
7
+ d - dimension
8
+ """
9
+
10
+ from __future__ import annotations
11
+ from typing import Optional
12
+ import math
13
+
14
+ import torch
15
+ from torch import nn
16
+ import torch.nn.functional as F
17
+ import torchaudio
18
+
19
+ from einops import rearrange
20
+ from x_transformers.x_transformers import apply_rotary_pos_emb
21
+
22
+
23
+ # raw wav to mel spec
24
+
25
+ class MelSpec(nn.Module):
26
+ def __init__(
27
+ self,
28
+ filter_length = 1024,
29
+ hop_length = 256,
30
+ win_length = 1024,
31
+ n_mel_channels = 100,
32
+ target_sample_rate = 24_000,
33
+ normalize = False,
34
+ power = 1,
35
+ norm = None,
36
+ center = True,
37
+ ):
38
+ super().__init__()
39
+ self.n_mel_channels = n_mel_channels
40
+
41
+ self.mel_stft = torchaudio.transforms.MelSpectrogram(
42
+ sample_rate = target_sample_rate,
43
+ n_fft = filter_length,
44
+ win_length = win_length,
45
+ hop_length = hop_length,
46
+ n_mels = n_mel_channels,
47
+ power = power,
48
+ center = center,
49
+ normalized = normalize,
50
+ norm = norm,
51
+ )
52
+
53
+ self.register_buffer('dummy', torch.tensor(0), persistent = False)
54
+
55
+ def forward(self, inp):
56
+ if len(inp.shape) == 3:
57
+ inp = rearrange(inp, 'b 1 nw -> b nw')
58
+
59
+ assert len(inp.shape) == 2
60
+
61
+ if self.dummy.device != inp.device:
62
+ self.to(inp.device)
63
+
64
+ mel = self.mel_stft(inp)
65
+ mel = mel.clamp(min = 1e-5).log()
66
+ return mel
67
+
68
+
69
+ # sinusoidal position embedding
70
+
71
+ class SinusPositionEmbedding(nn.Module):
72
+ def __init__(self, dim):
73
+ super().__init__()
74
+ self.dim = dim
75
+
76
+ def forward(self, x, scale=1000):
77
+ device = x.device
78
+ half_dim = self.dim // 2
79
+ emb = math.log(10000) / (half_dim - 1)
80
+ emb = torch.exp(torch.arange(half_dim, device=device).float() * -emb)
81
+ emb = scale * x.unsqueeze(1) * emb.unsqueeze(0)
82
+ emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
83
+ return emb
84
+
85
+
86
+ # convolutional position embedding
87
+
88
+ class ConvPositionEmbedding(nn.Module):
89
+ def __init__(self, dim, kernel_size = 31, groups = 16):
90
+ super().__init__()
91
+ assert kernel_size % 2 != 0
92
+ self.conv1d = nn.Sequential(
93
+ nn.Conv1d(dim, dim, kernel_size, groups = groups, padding = kernel_size // 2),
94
+ nn.Mish(),
95
+ nn.Conv1d(dim, dim, kernel_size, groups = groups, padding = kernel_size // 2),
96
+ nn.Mish(),
97
+ )
98
+
99
+ def forward(self, x: float['b n d'], mask: bool['b n'] | None = None):
100
+ if mask is not None:
101
+ mask = mask[..., None]
102
+ x = x.masked_fill(~mask, 0.)
103
+
104
+ x = rearrange(x, 'b n d -> b d n')
105
+ x = self.conv1d(x)
106
+ out = rearrange(x, 'b d n -> b n d')
107
+
108
+ if mask is not None:
109
+ out = out.masked_fill(~mask, 0.)
110
+
111
+ return out
112
+
113
+
114
+ # rotary positional embedding related
115
+
116
+ def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0, theta_rescale_factor=1.):
117
+ # proposed by reddit user bloc97, to rescale rotary embeddings to longer sequence length without fine-tuning
118
+ # has some connection to NTK literature
119
+ # https://www.reddit.com/r/LocalLLaMA/comments/14lz7j5/ntkaware_scaled_rope_allows_llama_models_to_have/
120
+ # https://github.com/lucidrains/rotary-embedding-torch/blob/main/rotary_embedding_torch/rotary_embedding_torch.py
121
+ theta *= theta_rescale_factor ** (dim / (dim - 2))
122
+ freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
123
+ t = torch.arange(end, device=freqs.device) # type: ignore
124
+ freqs = torch.outer(t, freqs).float() # type: ignore
125
+ freqs_cos = torch.cos(freqs) # real part
126
+ freqs_sin = torch.sin(freqs) # imaginary part
127
+ return torch.cat([freqs_cos, freqs_sin], dim=-1)
128
+
129
+ def get_pos_embed_indices(start, length, max_pos, scale=1.):
130
+ # length = length if isinstance(length, int) else length.max()
131
+ scale = scale * torch.ones_like(start, dtype=torch.float32) # in case scale is a scalar
132
+ pos = start.unsqueeze(1) + (
133
+ torch.arange(length, device=start.device, dtype=torch.float32).unsqueeze(0) *
134
+ scale.unsqueeze(1)).long()
135
+ # avoid extra long error.
136
+ pos = torch.where(pos < max_pos, pos, max_pos - 1)
137
+ return pos
138
+
139
+
140
+ # Global Response Normalization layer (Instance Normalization ?)
141
+
142
+ class GRN(nn.Module):
143
+ def __init__(self, dim):
144
+ super().__init__()
145
+ self.gamma = nn.Parameter(torch.zeros(1, 1, dim))
146
+ self.beta = nn.Parameter(torch.zeros(1, 1, dim))
147
+
148
+ def forward(self, x):
149
+ Gx = torch.norm(x, p=2, dim=1, keepdim=True)
150
+ Nx = Gx / (Gx.mean(dim=-1, keepdim=True) + 1e-6)
151
+ return self.gamma * (x * Nx) + self.beta + x
152
+
153
+
154
+ # ConvNeXt-V2 Block https://github.com/facebookresearch/ConvNeXt-V2/blob/main/models/convnextv2.py
155
+ # ref: https://github.com/bfs18/e2_tts/blob/main/rfwave/modules.py#L108
156
+
157
+ class ConvNeXtV2Block(nn.Module):
158
+ def __init__(
159
+ self,
160
+ dim: int,
161
+ intermediate_dim: int,
162
+ dilation: int = 1,
163
+ ):
164
+ super().__init__()
165
+ padding = (dilation * (7 - 1)) // 2
166
+ self.dwconv = nn.Conv1d(dim, dim, kernel_size=7, padding=padding, groups=dim, dilation=dilation) # depthwise conv
167
+ self.norm = nn.LayerNorm(dim, eps=1e-6)
168
+ self.pwconv1 = nn.Linear(dim, intermediate_dim) # pointwise/1x1 convs, implemented with linear layers
169
+ self.act = nn.GELU()
170
+ self.grn = GRN(intermediate_dim)
171
+ self.pwconv2 = nn.Linear(intermediate_dim, dim)
172
+
173
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
174
+ residual = x
175
+ x = x.transpose(1, 2) # b n d -> b d n
176
+ x = self.dwconv(x)
177
+ x = x.transpose(1, 2) # b d n -> b n d
178
+ x = self.norm(x)
179
+ x = self.pwconv1(x)
180
+ x = self.act(x)
181
+ x = self.grn(x)
182
+ x = self.pwconv2(x)
183
+ return residual + x
184
+
185
+
186
+ # AdaLayerNormZero
187
+ # return with modulated x for attn input, and params for later mlp modulation
188
+
189
+ class AdaLayerNormZero(nn.Module):
190
+ def __init__(self, dim):
191
+ super().__init__()
192
+
193
+ self.silu = nn.SiLU()
194
+ self.linear = nn.Linear(dim, dim * 6)
195
+
196
+ self.norm = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
197
+
198
+ def forward(self, x, emb = None):
199
+ emb = self.linear(self.silu(emb))
200
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = torch.chunk(emb, 6, dim=1)
201
+
202
+ x = self.norm(x) * (1 + scale_msa[:, None]) + shift_msa[:, None]
203
+ return x, gate_msa, shift_mlp, scale_mlp, gate_mlp
204
+
205
+
206
+ # AdaLayerNormZero for final layer
207
+ # return only with modulated x for attn input, cuz no more mlp modulation
208
+
209
+ class AdaLayerNormZero_Final(nn.Module):
210
+ def __init__(self, dim):
211
+ super().__init__()
212
+
213
+ self.silu = nn.SiLU()
214
+ self.linear = nn.Linear(dim, dim * 2)
215
+
216
+ self.norm = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
217
+
218
+ def forward(self, x, emb):
219
+ emb = self.linear(self.silu(emb))
220
+ scale, shift = torch.chunk(emb, 2, dim=1)
221
+
222
+ x = self.norm(x) * (1 + scale)[:, None, :] + shift[:, None, :]
223
+ return x
224
+
225
+
226
+ # FeedForward
227
+
228
+ class FeedForward(nn.Module):
229
+ def __init__(self, dim, dim_out = None, mult = 4, dropout = 0., approximate: str = 'none'):
230
+ super().__init__()
231
+ inner_dim = int(dim * mult)
232
+ dim_out = dim_out if dim_out is not None else dim
233
+
234
+ activation = nn.GELU(approximate=approximate)
235
+ project_in = nn.Sequential(
236
+ nn.Linear(dim, inner_dim),
237
+ activation
238
+ )
239
+ self.ff = nn.Sequential(
240
+ project_in,
241
+ nn.Dropout(dropout),
242
+ nn.Linear(inner_dim, dim_out)
243
+ )
244
+
245
+ def forward(self, x):
246
+ return self.ff(x)
247
+
248
+
249
+ # Attention with possible joint part
250
+ # modified from diffusers/src/diffusers/models/attention_processor.py
251
+
252
+ class Attention(nn.Module):
253
+ def __init__(
254
+ self,
255
+ processor: JointAttnProcessor | AttnProcessor,
256
+ dim: int,
257
+ heads: int = 8,
258
+ dim_head: int = 64,
259
+ dropout: float = 0.0,
260
+ context_dim: Optional[int] = None, # if not None -> joint attention
261
+ context_pre_only = None,
262
+ ):
263
+ super().__init__()
264
+
265
+ if not hasattr(F, "scaled_dot_product_attention"):
266
+ raise ImportError("Attention equires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
267
+
268
+ self.processor = processor
269
+
270
+ self.dim = dim
271
+ self.heads = heads
272
+ self.inner_dim = dim_head * heads
273
+ self.dropout = dropout
274
+
275
+ self.context_dim = context_dim
276
+ self.context_pre_only = context_pre_only
277
+
278
+ self.to_q = nn.Linear(dim, self.inner_dim)
279
+ self.to_k = nn.Linear(dim, self.inner_dim)
280
+ self.to_v = nn.Linear(dim, self.inner_dim)
281
+
282
+ if self.context_dim is not None:
283
+ self.to_k_c = nn.Linear(context_dim, self.inner_dim)
284
+ self.to_v_c = nn.Linear(context_dim, self.inner_dim)
285
+ if self.context_pre_only is not None:
286
+ self.to_q_c = nn.Linear(context_dim, self.inner_dim)
287
+
288
+ self.to_out = nn.ModuleList([])
289
+ self.to_out.append(nn.Linear(self.inner_dim, dim))
290
+ self.to_out.append(nn.Dropout(dropout))
291
+
292
+ if self.context_pre_only is not None and not self.context_pre_only:
293
+ self.to_out_c = nn.Linear(self.inner_dim, dim)
294
+
295
+ def forward(
296
+ self,
297
+ x: float['b n d'], # noised input x
298
+ c: float['b n d'] = None, # context c
299
+ mask: bool['b n'] | None = None,
300
+ rope = None, # rotary position embedding for x
301
+ c_rope = None, # rotary position embedding for c
302
+ ) -> torch.Tensor:
303
+ if c is not None:
304
+ return self.processor(self, x, c = c, mask = mask, rope = rope, c_rope = c_rope)
305
+ else:
306
+ return self.processor(self, x, mask = mask, rope = rope)
307
+
308
+
309
+ # Attention processor
310
+
311
+ class AttnProcessor:
312
+ def __init__(self):
313
+ pass
314
+
315
+ def __call__(
316
+ self,
317
+ attn: Attention,
318
+ x: float['b n d'], # noised input x
319
+ mask: bool['b n'] | None = None,
320
+ rope = None, # rotary position embedding
321
+ ) -> torch.FloatTensor:
322
+
323
+ batch_size = x.shape[0]
324
+
325
+ # `sample` projections.
326
+ query = attn.to_q(x)
327
+ key = attn.to_k(x)
328
+ value = attn.to_v(x)
329
+
330
+ # apply rotary position embedding
331
+ if rope is not None:
332
+ freqs, xpos_scale = rope
333
+ q_xpos_scale, k_xpos_scale = (xpos_scale, xpos_scale ** -1.) if xpos_scale is not None else (1., 1.)
334
+
335
+ query = apply_rotary_pos_emb(query, freqs, q_xpos_scale)
336
+ key = apply_rotary_pos_emb(key, freqs, k_xpos_scale)
337
+
338
+ # attention
339
+ inner_dim = key.shape[-1]
340
+ head_dim = inner_dim // attn.heads
341
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
342
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
343
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
344
+
345
+ # mask. e.g. inference got a batch with different target durations, mask out the padding
346
+ if mask is not None:
347
+ attn_mask = mask
348
+ attn_mask = rearrange(attn_mask, 'b n -> b 1 1 n')
349
+ attn_mask = attn_mask.expand(batch_size, attn.heads, query.shape[-2], key.shape[-2])
350
+ else:
351
+ attn_mask = None
352
+
353
+ x = F.scaled_dot_product_attention(query, key, value, attn_mask=attn_mask, dropout_p=0.0, is_causal=False)
354
+ x = x.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
355
+ x = x.to(query.dtype)
356
+
357
+ # linear proj
358
+ x = attn.to_out[0](x)
359
+ # dropout
360
+ x = attn.to_out[1](x)
361
+
362
+ if mask is not None:
363
+ mask = rearrange(mask, 'b n -> b n 1')
364
+ x = x.masked_fill(~mask, 0.)
365
+
366
+ return x
367
+
368
+
369
+ # Joint Attention processor for MM-DiT
370
+ # modified from diffusers/src/diffusers/models/attention_processor.py
371
+
372
+ class JointAttnProcessor:
373
+ def __init__(self):
374
+ pass
375
+
376
+ def __call__(
377
+ self,
378
+ attn: Attention,
379
+ x: float['b n d'], # noised input x
380
+ c: float['b nt d'] = None, # context c, here text
381
+ mask: bool['b n'] | None = None,
382
+ rope = None, # rotary position embedding for x
383
+ c_rope = None, # rotary position embedding for c
384
+ ) -> torch.FloatTensor:
385
+ residual = x
386
+
387
+ batch_size = c.shape[0]
388
+
389
+ # `sample` projections.
390
+ query = attn.to_q(x)
391
+ key = attn.to_k(x)
392
+ value = attn.to_v(x)
393
+
394
+ # `context` projections.
395
+ c_query = attn.to_q_c(c)
396
+ c_key = attn.to_k_c(c)
397
+ c_value = attn.to_v_c(c)
398
+
399
+ # apply rope for context and noised input independently
400
+ if rope is not None:
401
+ freqs, xpos_scale = rope
402
+ q_xpos_scale, k_xpos_scale = (xpos_scale, xpos_scale ** -1.) if xpos_scale is not None else (1., 1.)
403
+ query = apply_rotary_pos_emb(query, freqs, q_xpos_scale)
404
+ key = apply_rotary_pos_emb(key, freqs, k_xpos_scale)
405
+ if c_rope is not None:
406
+ freqs, xpos_scale = c_rope
407
+ q_xpos_scale, k_xpos_scale = (xpos_scale, xpos_scale ** -1.) if xpos_scale is not None else (1., 1.)
408
+ c_query = apply_rotary_pos_emb(c_query, freqs, q_xpos_scale)
409
+ c_key = apply_rotary_pos_emb(c_key, freqs, k_xpos_scale)
410
+
411
+ # attention
412
+ query = torch.cat([query, c_query], dim=1)
413
+ key = torch.cat([key, c_key], dim=1)
414
+ value = torch.cat([value, c_value], dim=1)
415
+
416
+ inner_dim = key.shape[-1]
417
+ head_dim = inner_dim // attn.heads
418
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
419
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
420
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
421
+
422
+ # mask. e.g. inference got a batch with different target durations, mask out the padding
423
+ if mask is not None:
424
+ attn_mask = F.pad(mask, (0, c.shape[1]), value = True) # no mask for c (text)
425
+ attn_mask = rearrange(attn_mask, 'b n -> b 1 1 n')
426
+ attn_mask = attn_mask.expand(batch_size, attn.heads, query.shape[-2], key.shape[-2])
427
+ else:
428
+ attn_mask = None
429
+
430
+ x = F.scaled_dot_product_attention(query, key, value, attn_mask=attn_mask, dropout_p=0.0, is_causal=False)
431
+ x = x.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
432
+ x = x.to(query.dtype)
433
+
434
+ # Split the attention outputs.
435
+ x, c = (
436
+ x[:, :residual.shape[1]],
437
+ x[:, residual.shape[1]:],
438
+ )
439
+
440
+ # linear proj
441
+ x = attn.to_out[0](x)
442
+ # dropout
443
+ x = attn.to_out[1](x)
444
+ if not attn.context_pre_only:
445
+ c = attn.to_out_c(c)
446
+
447
+ if mask is not None:
448
+ mask = rearrange(mask, 'b n -> b n 1')
449
+ x = x.masked_fill(~mask, 0.)
450
+ # c = c.masked_fill(~mask, 0.) # no mask for c (text)
451
+
452
+ return x, c
453
+
454
+
455
+ # DiT Block
456
+
457
+ class DiTBlock(nn.Module):
458
+
459
+ def __init__(self, dim, heads, dim_head, ff_mult = 4, dropout = 0.1):
460
+ super().__init__()
461
+
462
+ self.attn_norm = AdaLayerNormZero(dim)
463
+ self.attn = Attention(
464
+ processor = AttnProcessor(),
465
+ dim = dim,
466
+ heads = heads,
467
+ dim_head = dim_head,
468
+ dropout = dropout,
469
+ )
470
+
471
+ self.ff_norm = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
472
+ self.ff = FeedForward(dim = dim, mult = ff_mult, dropout = dropout, approximate = "tanh")
473
+
474
+ def forward(self, x, t, mask = None, rope = None): # x: noised input, t: time embedding
475
+ # pre-norm & modulation for attention input
476
+ norm, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.attn_norm(x, emb=t)
477
+
478
+ # attention
479
+ attn_output = self.attn(x=norm, mask=mask, rope=rope)
480
+
481
+ # process attention output for input x
482
+ x = x + gate_msa.unsqueeze(1) * attn_output
483
+
484
+ norm = self.ff_norm(x) * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
485
+ ff_output = self.ff(norm)
486
+ x = x + gate_mlp.unsqueeze(1) * ff_output
487
+
488
+ return x
489
+
490
+
491
+ # MMDiT Block https://arxiv.org/abs/2403.03206
492
+
493
+ class MMDiTBlock(nn.Module):
494
+ r"""
495
+ modified from diffusers/src/diffusers/models/attention.py
496
+
497
+ notes.
498
+ _c: context related. text, cond, etc. (left part in sd3 fig2.b)
499
+ _x: noised input related. (right part)
500
+ context_pre_only: last layer only do prenorm + modulation cuz no more ffn
501
+ """
502
+
503
+ def __init__(self, dim, heads, dim_head, ff_mult = 4, dropout = 0.1, context_pre_only = False):
504
+ super().__init__()
505
+
506
+ self.context_pre_only = context_pre_only
507
+
508
+ self.attn_norm_c = AdaLayerNormZero_Final(dim) if context_pre_only else AdaLayerNormZero(dim)
509
+ self.attn_norm_x = AdaLayerNormZero(dim)
510
+ self.attn = Attention(
511
+ processor = JointAttnProcessor(),
512
+ dim = dim,
513
+ heads = heads,
514
+ dim_head = dim_head,
515
+ dropout = dropout,
516
+ context_dim = dim,
517
+ context_pre_only = context_pre_only,
518
+ )
519
+
520
+ if not context_pre_only:
521
+ self.ff_norm_c = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
522
+ self.ff_c = FeedForward(dim = dim, mult = ff_mult, dropout = dropout, approximate = "tanh")
523
+ else:
524
+ self.ff_norm_c = None
525
+ self.ff_c = None
526
+ self.ff_norm_x = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
527
+ self.ff_x = FeedForward(dim = dim, mult = ff_mult, dropout = dropout, approximate = "tanh")
528
+
529
+ def forward(self, x, c, t, mask = None, rope = None, c_rope = None): # x: noised input, c: context, t: time embedding
530
+ # pre-norm & modulation for attention input
531
+ if self.context_pre_only:
532
+ norm_c = self.attn_norm_c(c, t)
533
+ else:
534
+ norm_c, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.attn_norm_c(c, emb=t)
535
+ norm_x, x_gate_msa, x_shift_mlp, x_scale_mlp, x_gate_mlp = self.attn_norm_x(x, emb=t)
536
+
537
+ # attention
538
+ x_attn_output, c_attn_output = self.attn(x=norm_x, c=norm_c, mask=mask, rope=rope, c_rope=c_rope)
539
+
540
+ # process attention output for context c
541
+ if self.context_pre_only:
542
+ c = None
543
+ else: # if not last layer
544
+ c = c + c_gate_msa.unsqueeze(1) * c_attn_output
545
+
546
+ norm_c = self.ff_norm_c(c) * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None]
547
+ c_ff_output = self.ff_c(norm_c)
548
+ c = c + c_gate_mlp.unsqueeze(1) * c_ff_output
549
+
550
+ # process attention output for input x
551
+ x = x + x_gate_msa.unsqueeze(1) * x_attn_output
552
+
553
+ norm_x = self.ff_norm_x(x) * (1 + x_scale_mlp[:, None]) + x_shift_mlp[:, None]
554
+ x_ff_output = self.ff_x(norm_x)
555
+ x = x + x_gate_mlp.unsqueeze(1) * x_ff_output
556
+
557
+ return c, x
558
+
559
+
560
+ # time step conditioning embedding
561
+
562
+ class TimestepEmbedding(nn.Module):
563
+ def __init__(self, dim, freq_embed_dim=256):
564
+ super().__init__()
565
+ self.time_embed = SinusPositionEmbedding(freq_embed_dim)
566
+ self.time_mlp = nn.Sequential(
567
+ nn.Linear(freq_embed_dim, dim),
568
+ nn.SiLU(),
569
+ nn.Linear(dim, dim)
570
+ )
571
+
572
+ def forward(self, timestep: float['b']):
573
+ time_hidden = self.time_embed(timestep)
574
+ time = self.time_mlp(time_hidden) # b d
575
+ return time
model/trainer.py ADDED
@@ -0,0 +1,250 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import os
4
+ import gc
5
+ from tqdm import tqdm
6
+ import wandb
7
+
8
+ import torch
9
+ from torch.optim import AdamW
10
+ from torch.utils.data import DataLoader, Dataset, SequentialSampler
11
+ from torch.optim.lr_scheduler import LinearLR, SequentialLR
12
+
13
+ from einops import rearrange
14
+
15
+ from accelerate import Accelerator
16
+ from accelerate.utils import DistributedDataParallelKwargs
17
+
18
+ from ema_pytorch import EMA
19
+
20
+ from model import CFM
21
+ from model.utils import exists, default
22
+ from model.dataset import DynamicBatchSampler, collate_fn
23
+
24
+
25
+ # trainer
26
+
27
+ class Trainer:
28
+ def __init__(
29
+ self,
30
+ model: CFM,
31
+ epochs,
32
+ learning_rate,
33
+ num_warmup_updates = 20000,
34
+ save_per_updates = 1000,
35
+ checkpoint_path = None,
36
+ batch_size = 32,
37
+ batch_size_type: str = "sample",
38
+ max_samples = 32,
39
+ grad_accumulation_steps = 1,
40
+ max_grad_norm = 1.0,
41
+ noise_scheduler: str | None = None,
42
+ duration_predictor: torch.nn.Module | None = None,
43
+ wandb_project = "test_e2-tts",
44
+ wandb_run_name = "test_run",
45
+ wandb_resume_id: str = None,
46
+ last_per_steps = None,
47
+ accelerate_kwargs: dict = dict(),
48
+ ema_kwargs: dict = dict()
49
+ ):
50
+
51
+ ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters = True)
52
+
53
+ self.accelerator = Accelerator(
54
+ log_with = "wandb",
55
+ kwargs_handlers = [ddp_kwargs],
56
+ gradient_accumulation_steps = grad_accumulation_steps,
57
+ **accelerate_kwargs
58
+ )
59
+
60
+ if exists(wandb_resume_id):
61
+ init_kwargs={"wandb": {"resume": "allow", "name": wandb_run_name, 'id': wandb_resume_id}}
62
+ else:
63
+ init_kwargs={"wandb": {"resume": "allow", "name": wandb_run_name}}
64
+ self.accelerator.init_trackers(
65
+ project_name = wandb_project,
66
+ init_kwargs=init_kwargs,
67
+ config={"epochs": epochs,
68
+ "learning_rate": learning_rate,
69
+ "num_warmup_updates": num_warmup_updates,
70
+ "batch_size": batch_size,
71
+ "batch_size_type": batch_size_type,
72
+ "max_samples": max_samples,
73
+ "grad_accumulation_steps": grad_accumulation_steps,
74
+ "max_grad_norm": max_grad_norm,
75
+ "gpus": self.accelerator.num_processes,
76
+ "noise_scheduler": noise_scheduler}
77
+ )
78
+
79
+ self.model = model
80
+
81
+ if self.is_main:
82
+ self.ema_model = EMA(
83
+ model,
84
+ include_online_model = False,
85
+ **ema_kwargs
86
+ )
87
+
88
+ self.ema_model.to(self.accelerator.device)
89
+
90
+ self.epochs = epochs
91
+ self.num_warmup_updates = num_warmup_updates
92
+ self.save_per_updates = save_per_updates
93
+ self.last_per_steps = default(last_per_steps, save_per_updates * grad_accumulation_steps)
94
+ self.checkpoint_path = default(checkpoint_path, 'ckpts/test_e2-tts')
95
+
96
+ self.batch_size = batch_size
97
+ self.batch_size_type = batch_size_type
98
+ self.max_samples = max_samples
99
+ self.grad_accumulation_steps = grad_accumulation_steps
100
+ self.max_grad_norm = max_grad_norm
101
+
102
+ self.noise_scheduler = noise_scheduler
103
+
104
+ self.duration_predictor = duration_predictor
105
+
106
+ self.optimizer = AdamW(model.parameters(), lr=learning_rate)
107
+ self.model, self.optimizer = self.accelerator.prepare(
108
+ self.model, self.optimizer
109
+ )
110
+
111
+ @property
112
+ def is_main(self):
113
+ return self.accelerator.is_main_process
114
+
115
+ def save_checkpoint(self, step, last=False):
116
+ self.accelerator.wait_for_everyone()
117
+ if self.is_main:
118
+ checkpoint = dict(
119
+ model_state_dict = self.accelerator.unwrap_model(self.model).state_dict(),
120
+ optimizer_state_dict = self.accelerator.unwrap_model(self.optimizer).state_dict(),
121
+ ema_model_state_dict = self.ema_model.state_dict(),
122
+ scheduler_state_dict = self.scheduler.state_dict(),
123
+ step = step
124
+ )
125
+ if not os.path.exists(self.checkpoint_path):
126
+ os.makedirs(self.checkpoint_path)
127
+ if last == True:
128
+ self.accelerator.save(checkpoint, f"{self.checkpoint_path}/model_last.pt")
129
+ print(f"Saved last checkpoint at step {step}")
130
+ else:
131
+ self.accelerator.save(checkpoint, f"{self.checkpoint_path}/model_{step}.pt")
132
+
133
+ def load_checkpoint(self):
134
+ if not exists(self.checkpoint_path) or not os.path.exists(self.checkpoint_path) or not os.listdir(self.checkpoint_path):
135
+ return 0
136
+
137
+ self.accelerator.wait_for_everyone()
138
+ if "model_last.pt" in os.listdir(self.checkpoint_path):
139
+ latest_checkpoint = "model_last.pt"
140
+ else:
141
+ latest_checkpoint = sorted([f for f in os.listdir(self.checkpoint_path) if f.endswith('.pt')], key=lambda x: int(''.join(filter(str.isdigit, x))))[-1]
142
+ # checkpoint = torch.load(f"{self.checkpoint_path}/{latest_checkpoint}", map_location=self.accelerator.device) # rather use accelerator.load_state ಥ_ಥ
143
+ checkpoint = torch.load(f"{self.checkpoint_path}/{latest_checkpoint}", map_location="cpu")
144
+
145
+ if self.is_main:
146
+ self.ema_model.load_state_dict(checkpoint['ema_model_state_dict'])
147
+
148
+ if 'step' in checkpoint:
149
+ self.accelerator.unwrap_model(self.model).load_state_dict(checkpoint['model_state_dict'])
150
+ self.accelerator.unwrap_model(self.optimizer).load_state_dict(checkpoint['optimizer_state_dict'])
151
+ if self.scheduler:
152
+ self.scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
153
+ step = checkpoint['step']
154
+ else:
155
+ checkpoint['model_state_dict'] = {k.replace("ema_model.", ""): v for k, v in checkpoint['ema_model_state_dict'].items() if k not in ["initted", "step"]}
156
+ self.accelerator.unwrap_model(self.model).load_state_dict(checkpoint['model_state_dict'])
157
+ step = 0
158
+
159
+ del checkpoint; gc.collect()
160
+ return step
161
+
162
+ def train(self, train_dataset: Dataset, num_workers=16, resumable_with_seed: int = None):
163
+
164
+ if exists(resumable_with_seed):
165
+ generator = torch.Generator()
166
+ generator.manual_seed(resumable_with_seed)
167
+ else:
168
+ generator = None
169
+
170
+ if self.batch_size_type == "sample":
171
+ train_dataloader = DataLoader(train_dataset, collate_fn=collate_fn, num_workers=num_workers, pin_memory=True, persistent_workers=True,
172
+ batch_size=self.batch_size, shuffle=True, generator=generator)
173
+ elif self.batch_size_type == "frame":
174
+ self.accelerator.even_batches = False
175
+ sampler = SequentialSampler(train_dataset)
176
+ batch_sampler = DynamicBatchSampler(sampler, self.batch_size, max_samples=self.max_samples, random_seed=resumable_with_seed, drop_last=False)
177
+ train_dataloader = DataLoader(train_dataset, collate_fn=collate_fn, num_workers=num_workers, pin_memory=True, persistent_workers=True,
178
+ batch_sampler=batch_sampler)
179
+ else:
180
+ raise ValueError(f"batch_size_type must be either 'sample' or 'frame', but received {self.batch_size_type}")
181
+
182
+ # accelerator.prepare() dispatches batches to devices;
183
+ # which means the length of dataloader calculated before, should consider the number of devices
184
+ warmup_steps = self.num_warmup_updates * self.accelerator.num_processes # consider a fixed warmup steps while using accelerate multi-gpu ddp
185
+ # otherwise by default with split_batches=False, warmup steps change with num_processes
186
+ total_steps = len(train_dataloader) * self.epochs / self.grad_accumulation_steps
187
+ decay_steps = total_steps - warmup_steps
188
+ warmup_scheduler = LinearLR(self.optimizer, start_factor=1e-8, end_factor=1.0, total_iters=warmup_steps)
189
+ decay_scheduler = LinearLR(self.optimizer, start_factor=1.0, end_factor=1e-8, total_iters=decay_steps)
190
+ self.scheduler = SequentialLR(self.optimizer,
191
+ schedulers=[warmup_scheduler, decay_scheduler],
192
+ milestones=[warmup_steps])
193
+ train_dataloader, self.scheduler = self.accelerator.prepare(train_dataloader, self.scheduler) # actual steps = 1 gpu steps / gpus
194
+ start_step = self.load_checkpoint()
195
+ global_step = start_step
196
+
197
+ if exists(resumable_with_seed):
198
+ orig_epoch_step = len(train_dataloader)
199
+ skipped_epoch = int(start_step // orig_epoch_step)
200
+ skipped_batch = start_step % orig_epoch_step
201
+ skipped_dataloader = self.accelerator.skip_first_batches(train_dataloader, num_batches=skipped_batch)
202
+ else:
203
+ skipped_epoch = 0
204
+
205
+ for epoch in range(skipped_epoch, self.epochs):
206
+ self.model.train()
207
+ if exists(resumable_with_seed) and epoch == skipped_epoch:
208
+ progress_bar = tqdm(skipped_dataloader, desc=f"Epoch {epoch+1}/{self.epochs}", unit="step", disable=not self.accelerator.is_local_main_process,
209
+ initial=skipped_batch, total=orig_epoch_step)
210
+ else:
211
+ progress_bar = tqdm(train_dataloader, desc=f"Epoch {epoch+1}/{self.epochs}", unit="step", disable=not self.accelerator.is_local_main_process)
212
+
213
+ for batch in progress_bar:
214
+ with self.accelerator.accumulate(self.model):
215
+ text_inputs = batch['text']
216
+ mel_spec = rearrange(batch['mel'], 'b d n -> b n d')
217
+ mel_lengths = batch["mel_lengths"]
218
+
219
+ # TODO. add duration predictor training
220
+ if self.duration_predictor is not None and self.accelerator.is_local_main_process:
221
+ dur_loss = self.duration_predictor(mel_spec, lens=batch.get('durations'))
222
+ self.accelerator.log({"duration loss": dur_loss.item()}, step=global_step)
223
+
224
+ loss, cond, pred = self.model(mel_spec, text=text_inputs, lens=mel_lengths, noise_scheduler=self.noise_scheduler)
225
+ self.accelerator.backward(loss)
226
+
227
+ if self.max_grad_norm > 0 and self.accelerator.sync_gradients:
228
+ self.accelerator.clip_grad_norm_(self.model.parameters(), self.max_grad_norm)
229
+
230
+ self.optimizer.step()
231
+ self.scheduler.step()
232
+ self.optimizer.zero_grad()
233
+
234
+ if self.is_main:
235
+ self.ema_model.update()
236
+
237
+ global_step += 1
238
+
239
+ if self.accelerator.is_local_main_process:
240
+ self.accelerator.log({"loss": loss.item(), "lr": self.scheduler.get_last_lr()[0]}, step=global_step)
241
+
242
+ progress_bar.set_postfix(step=str(global_step), loss=loss.item())
243
+
244
+ if global_step % (self.save_per_updates * self.grad_accumulation_steps) == 0:
245
+ self.save_checkpoint(global_step)
246
+
247
+ if global_step % self.last_per_steps == 0:
248
+ self.save_checkpoint(global_step, last=True)
249
+
250
+ self.accelerator.end_training()
model/utils.py ADDED
@@ -0,0 +1,574 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import os
4
+ import re
5
+ import math
6
+ import random
7
+ import string
8
+ from tqdm import tqdm
9
+ from collections import defaultdict
10
+
11
+ import matplotlib
12
+ matplotlib.use("Agg")
13
+ import matplotlib.pylab as plt
14
+
15
+ import torch
16
+ import torch.nn.functional as F
17
+ from torch.nn.utils.rnn import pad_sequence
18
+ import torchaudio
19
+
20
+ import einx
21
+ from einops import rearrange, reduce
22
+
23
+ import jieba
24
+ from pypinyin import lazy_pinyin, Style
25
+ import zhconv
26
+ from zhon.hanzi import punctuation
27
+ from jiwer import compute_measures
28
+
29
+ from funasr import AutoModel
30
+ from faster_whisper import WhisperModel
31
+
32
+ from model.ecapa_tdnn import ECAPA_TDNN_SMALL
33
+ from model.modules import MelSpec
34
+
35
+
36
+ # seed everything
37
+
38
+ def seed_everything(seed = 0):
39
+ random.seed(seed)
40
+ os.environ['PYTHONHASHSEED'] = str(seed)
41
+ torch.manual_seed(seed)
42
+ torch.cuda.manual_seed(seed)
43
+ torch.cuda.manual_seed_all(seed)
44
+ torch.backends.cudnn.deterministic = True
45
+ torch.backends.cudnn.benchmark = False
46
+
47
+ # helpers
48
+
49
+ def exists(v):
50
+ return v is not None
51
+
52
+ def default(v, d):
53
+ return v if exists(v) else d
54
+
55
+ # tensor helpers
56
+
57
+ def lens_to_mask(
58
+ t: int['b'],
59
+ length: int | None = None
60
+ ) -> bool['b n']:
61
+
62
+ if not exists(length):
63
+ length = t.amax()
64
+
65
+ seq = torch.arange(length, device = t.device)
66
+ return einx.less('n, b -> b n', seq, t)
67
+
68
+ def mask_from_start_end_indices(
69
+ seq_len: int['b'],
70
+ start: int['b'],
71
+ end: int['b']
72
+ ):
73
+ max_seq_len = seq_len.max().item()
74
+ seq = torch.arange(max_seq_len, device = start.device).long()
75
+ return einx.greater_equal('n, b -> b n', seq, start) & einx.less('n, b -> b n', seq, end)
76
+
77
+ def mask_from_frac_lengths(
78
+ seq_len: int['b'],
79
+ frac_lengths: float['b']
80
+ ):
81
+ lengths = (frac_lengths * seq_len).long()
82
+ max_start = seq_len - lengths
83
+
84
+ rand = torch.rand_like(frac_lengths)
85
+ start = (max_start * rand).long().clamp(min = 0)
86
+ end = start + lengths
87
+
88
+ return mask_from_start_end_indices(seq_len, start, end)
89
+
90
+ def maybe_masked_mean(
91
+ t: float['b n d'],
92
+ mask: bool['b n'] = None
93
+ ) -> float['b d']:
94
+
95
+ if not exists(mask):
96
+ return t.mean(dim = 1)
97
+
98
+ t = einx.where('b n, b n d, -> b n d', mask, t, 0.)
99
+ num = reduce(t, 'b n d -> b d', 'sum')
100
+ den = reduce(mask.float(), 'b n -> b', 'sum')
101
+
102
+ return einx.divide('b d, b -> b d', num, den.clamp(min = 1.))
103
+
104
+
105
+ # simple utf-8 tokenizer, since paper went character based
106
+ def list_str_to_tensor(
107
+ text: list[str],
108
+ padding_value = -1
109
+ ) -> int['b nt']:
110
+ list_tensors = [torch.tensor([*bytes(t, 'UTF-8')]) for t in text] # ByT5 style
111
+ text = pad_sequence(list_tensors, padding_value = padding_value, batch_first = True)
112
+ return text
113
+
114
+ # char tokenizer, based on custom dataset's extracted .txt file
115
+ def list_str_to_idx(
116
+ text: list[str] | list[list[str]],
117
+ vocab_char_map: dict[str, int], # {char: idx}
118
+ padding_value = -1
119
+ ) -> int['b nt']:
120
+ list_idx_tensors = [torch.tensor([vocab_char_map.get(c, 0) for c in t]) for t in text] # pinyin or char style
121
+ text = pad_sequence(list_idx_tensors, padding_value = padding_value, batch_first = True)
122
+ return text
123
+
124
+
125
+ # Get tokenizer
126
+
127
+ def get_tokenizer(dataset_name, tokenizer: str = "pinyin"):
128
+ '''
129
+ tokenizer - "pinyin" do g2p for only chinese characters, need .txt vocab_file
130
+ - "char" for char-wise tokenizer, need .txt vocab_file
131
+ - "byte" for utf-8 tokenizer
132
+ vocab_size - if use "pinyin", all available pinyin types, common alphabets (also those with accent) and symbols
133
+ - if use "char", derived from unfiltered character & symbol counts of custom dataset
134
+ - if use "byte", set to 256 (unicode byte range)
135
+ '''
136
+ if tokenizer in ["pinyin", "char"]:
137
+ with open (f"data/{dataset_name}_{tokenizer}/vocab.txt", "r", encoding="utf-8") as f:
138
+ vocab_char_map = {}
139
+ for i, char in enumerate(f):
140
+ vocab_char_map[char[:-1]] = i
141
+ vocab_size = len(vocab_char_map)
142
+ assert vocab_char_map[" "] == 0, "make sure space is of idx 0 in vocab.txt, cuz 0 is used for unknown char"
143
+
144
+ elif tokenizer == "byte":
145
+ vocab_char_map = None
146
+ vocab_size = 256
147
+
148
+ return vocab_char_map, vocab_size
149
+
150
+
151
+ # convert char to pinyin
152
+
153
+ def convert_char_to_pinyin(text_list, polyphone = True):
154
+ final_text_list = []
155
+ god_knows_why_en_testset_contains_zh_quote = str.maketrans({'“': '"', '”': '"', '‘': "'", '’': "'"}) # in case librispeech (orig no-pc) test-clean
156
+ custom_trans = str.maketrans({';': ','}) # add custom trans here, to address oov
157
+ for text in text_list:
158
+ char_list = []
159
+ text = text.translate(god_knows_why_en_testset_contains_zh_quote)
160
+ text = text.translate(custom_trans)
161
+ for seg in jieba.cut(text):
162
+ seg_byte_len = len(bytes(seg, 'UTF-8'))
163
+ if seg_byte_len == len(seg): # if pure alphabets and symbols
164
+ if char_list and seg_byte_len > 1 and char_list[-1] not in " :'\"":
165
+ char_list.append(" ")
166
+ char_list.extend(seg)
167
+ elif polyphone and seg_byte_len == 3 * len(seg): # if pure chinese characters
168
+ seg = lazy_pinyin(seg, style=Style.TONE3, tone_sandhi=True)
169
+ for c in seg:
170
+ if c not in "。,、;:?!《》【】—…":
171
+ char_list.append(" ")
172
+ char_list.append(c)
173
+ else: # if mixed chinese characters, alphabets and symbols
174
+ for c in seg:
175
+ if ord(c) < 256:
176
+ char_list.extend(c)
177
+ else:
178
+ if c not in "。,、;:?!《》【】—…":
179
+ char_list.append(" ")
180
+ char_list.extend(lazy_pinyin(c, style=Style.TONE3, tone_sandhi=True))
181
+ else: # if is zh punc
182
+ char_list.append(c)
183
+ final_text_list.append(char_list)
184
+
185
+ return final_text_list
186
+
187
+
188
+ # save spectrogram
189
+ def save_spectrogram(spectrogram, path):
190
+ plt.figure(figsize=(12, 4))
191
+ plt.imshow(spectrogram, origin='lower', aspect='auto')
192
+ plt.colorbar()
193
+ plt.savefig(path)
194
+ plt.close()
195
+
196
+
197
+ # seedtts testset metainfo: utt, prompt_text, prompt_wav, gt_text, gt_wav
198
+ def get_seedtts_testset_metainfo(metalst):
199
+ f = open(metalst); lines = f.readlines(); f.close()
200
+ metainfo = []
201
+ for line in lines:
202
+ if len(line.strip().split('|')) == 5:
203
+ utt, prompt_text, prompt_wav, gt_text, gt_wav = line.strip().split('|')
204
+ elif len(line.strip().split('|')) == 4:
205
+ utt, prompt_text, prompt_wav, gt_text = line.strip().split('|')
206
+ gt_wav = os.path.join(os.path.dirname(metalst), "wavs", utt + ".wav")
207
+ if not os.path.isabs(prompt_wav):
208
+ prompt_wav = os.path.join(os.path.dirname(metalst), prompt_wav)
209
+ metainfo.append((utt, prompt_text, prompt_wav, gt_text, gt_wav))
210
+ return metainfo
211
+
212
+
213
+ # librispeech test-clean metainfo: gen_utt, ref_txt, ref_wav, gen_txt, gen_wav
214
+ def get_librispeech_test_clean_metainfo(metalst, librispeech_test_clean_path):
215
+ f = open(metalst); lines = f.readlines(); f.close()
216
+ metainfo = []
217
+ for line in lines:
218
+ ref_utt, ref_dur, ref_txt, gen_utt, gen_dur, gen_txt = line.strip().split('\t')
219
+
220
+ # ref_txt = ref_txt[0] + ref_txt[1:].lower() + '.' # if use librispeech test-clean (no-pc)
221
+ ref_spk_id, ref_chaptr_id, _ = ref_utt.split('-')
222
+ ref_wav = os.path.join(librispeech_test_clean_path, ref_spk_id, ref_chaptr_id, ref_utt + '.flac')
223
+
224
+ # gen_txt = gen_txt[0] + gen_txt[1:].lower() + '.' # if use librispeech test-clean (no-pc)
225
+ gen_spk_id, gen_chaptr_id, _ = gen_utt.split('-')
226
+ gen_wav = os.path.join(librispeech_test_clean_path, gen_spk_id, gen_chaptr_id, gen_utt + '.flac')
227
+
228
+ metainfo.append((gen_utt, ref_txt, ref_wav, " " + gen_txt, gen_wav))
229
+
230
+ return metainfo
231
+
232
+
233
+ # padded to max length mel batch
234
+ def padded_mel_batch(ref_mels):
235
+ max_mel_length = torch.LongTensor([mel.shape[-1] for mel in ref_mels]).amax()
236
+ padded_ref_mels = []
237
+ for mel in ref_mels:
238
+ padded_ref_mel = F.pad(mel, (0, max_mel_length - mel.shape[-1]), value = 0)
239
+ padded_ref_mels.append(padded_ref_mel)
240
+ padded_ref_mels = torch.stack(padded_ref_mels)
241
+ padded_ref_mels = rearrange(padded_ref_mels, 'b d n -> b n d')
242
+ return padded_ref_mels
243
+
244
+
245
+ # get prompts from metainfo containing: utt, prompt_text, prompt_wav, gt_text, gt_wav
246
+
247
+ def get_inference_prompt(
248
+ metainfo,
249
+ speed = 1., tokenizer = "pinyin", polyphone = True,
250
+ target_sample_rate = 24000, n_mel_channels = 100, hop_length = 256, target_rms = 0.1,
251
+ use_truth_duration = False,
252
+ infer_batch_size = 1, num_buckets = 200, min_secs = 3, max_secs = 40,
253
+ ):
254
+ prompts_all = []
255
+
256
+ min_tokens = min_secs * target_sample_rate // hop_length
257
+ max_tokens = max_secs * target_sample_rate // hop_length
258
+
259
+ batch_accum = [0] * num_buckets
260
+ utts, ref_rms_list, ref_mels, ref_mel_lens, total_mel_lens, final_text_list = \
261
+ ([[] for _ in range(num_buckets)] for _ in range(6))
262
+
263
+ mel_spectrogram = MelSpec(target_sample_rate=target_sample_rate, n_mel_channels=n_mel_channels, hop_length=hop_length)
264
+
265
+ for utt, prompt_text, prompt_wav, gt_text, gt_wav in tqdm(metainfo, desc="Processing prompts..."):
266
+
267
+ # Audio
268
+ ref_audio, ref_sr = torchaudio.load(prompt_wav)
269
+ ref_rms = torch.sqrt(torch.mean(torch.square(ref_audio)))
270
+ if ref_rms < target_rms:
271
+ ref_audio = ref_audio * target_rms / ref_rms
272
+ assert ref_audio.shape[-1] > 5000, f"Empty prompt wav: {prompt_wav}, or torchaudio backend issue."
273
+ if ref_sr != target_sample_rate:
274
+ resampler = torchaudio.transforms.Resample(ref_sr, target_sample_rate)
275
+ ref_audio = resampler(ref_audio)
276
+
277
+ # Text
278
+ if len(prompt_text[-1].encode('utf-8')) == 1:
279
+ prompt_text = prompt_text + " "
280
+ text = [prompt_text + gt_text]
281
+ if tokenizer == "pinyin":
282
+ text_list = convert_char_to_pinyin(text, polyphone = polyphone)
283
+ else:
284
+ text_list = text
285
+
286
+ # Duration, mel frame length
287
+ ref_mel_len = ref_audio.shape[-1] // hop_length
288
+ if use_truth_duration:
289
+ gt_audio, gt_sr = torchaudio.load(gt_wav)
290
+ if gt_sr != target_sample_rate:
291
+ resampler = torchaudio.transforms.Resample(gt_sr, target_sample_rate)
292
+ gt_audio = resampler(gt_audio)
293
+ total_mel_len = ref_mel_len + int(gt_audio.shape[-1] / hop_length / speed)
294
+
295
+ # # test vocoder resynthesis
296
+ # ref_audio = gt_audio
297
+ else:
298
+ zh_pause_punc = r"。,、;:?!"
299
+ ref_text_len = len(prompt_text.encode('utf-8')) + 3 * len(re.findall(zh_pause_punc, prompt_text))
300
+ gen_text_len = len(gt_text.encode('utf-8')) + 3 * len(re.findall(zh_pause_punc, gt_text))
301
+ total_mel_len = ref_mel_len + int(ref_mel_len / ref_text_len * gen_text_len / speed)
302
+
303
+ # to mel spectrogram
304
+ ref_mel = mel_spectrogram(ref_audio)
305
+ ref_mel = rearrange(ref_mel, '1 d n -> d n')
306
+
307
+ # deal with batch
308
+ assert infer_batch_size > 0, "infer_batch_size should be greater than 0."
309
+ assert min_tokens <= total_mel_len <= max_tokens, \
310
+ f"Audio {utt} has duration {total_mel_len*hop_length//target_sample_rate}s out of range [{min_secs}, {max_secs}]."
311
+ bucket_i = math.floor((total_mel_len - min_tokens) / (max_tokens - min_tokens + 1) * num_buckets)
312
+
313
+ utts[bucket_i].append(utt)
314
+ ref_rms_list[bucket_i].append(ref_rms)
315
+ ref_mels[bucket_i].append(ref_mel)
316
+ ref_mel_lens[bucket_i].append(ref_mel_len)
317
+ total_mel_lens[bucket_i].append(total_mel_len)
318
+ final_text_list[bucket_i].extend(text_list)
319
+
320
+ batch_accum[bucket_i] += total_mel_len
321
+
322
+ if batch_accum[bucket_i] >= infer_batch_size:
323
+ # print(f"\n{len(ref_mels[bucket_i][0][0])}\n{ref_mel_lens[bucket_i]}\n{total_mel_lens[bucket_i]}")
324
+ prompts_all.append((
325
+ utts[bucket_i],
326
+ ref_rms_list[bucket_i],
327
+ padded_mel_batch(ref_mels[bucket_i]),
328
+ ref_mel_lens[bucket_i],
329
+ total_mel_lens[bucket_i],
330
+ final_text_list[bucket_i]
331
+ ))
332
+ batch_accum[bucket_i] = 0
333
+ utts[bucket_i], ref_rms_list[bucket_i], ref_mels[bucket_i], ref_mel_lens[bucket_i], total_mel_lens[bucket_i], final_text_list[bucket_i] = [], [], [], [], [], []
334
+
335
+ # add residual
336
+ for bucket_i, bucket_frames in enumerate(batch_accum):
337
+ if bucket_frames > 0:
338
+ prompts_all.append((
339
+ utts[bucket_i],
340
+ ref_rms_list[bucket_i],
341
+ padded_mel_batch(ref_mels[bucket_i]),
342
+ ref_mel_lens[bucket_i],
343
+ total_mel_lens[bucket_i],
344
+ final_text_list[bucket_i]
345
+ ))
346
+ # not only leave easy work for last workers
347
+ random.seed(666)
348
+ random.shuffle(prompts_all)
349
+
350
+ return prompts_all
351
+
352
+
353
+ # get wav_res_ref_text of seed-tts test metalst
354
+ # https://github.com/BytedanceSpeech/seed-tts-eval
355
+
356
+ def get_seed_tts_test(metalst, gen_wav_dir, gpus):
357
+ f = open(metalst)
358
+ lines = f.readlines()
359
+ f.close()
360
+
361
+ test_set_ = []
362
+ for line in tqdm(lines):
363
+ if len(line.strip().split('|')) == 5:
364
+ utt, prompt_text, prompt_wav, gt_text, gt_wav = line.strip().split('|')
365
+ elif len(line.strip().split('|')) == 4:
366
+ utt, prompt_text, prompt_wav, gt_text = line.strip().split('|')
367
+
368
+ if not os.path.exists(os.path.join(gen_wav_dir, utt + '.wav')):
369
+ continue
370
+ gen_wav = os.path.join(gen_wav_dir, utt + '.wav')
371
+ if not os.path.isabs(prompt_wav):
372
+ prompt_wav = os.path.join(os.path.dirname(metalst), prompt_wav)
373
+
374
+ test_set_.append((gen_wav, prompt_wav, gt_text))
375
+
376
+ num_jobs = len(gpus)
377
+ if num_jobs == 1:
378
+ return [(gpus[0], test_set_)]
379
+
380
+ wav_per_job = len(test_set_) // num_jobs + 1
381
+ test_set = []
382
+ for i in range(num_jobs):
383
+ test_set.append((gpus[i], test_set_[i*wav_per_job:(i+1)*wav_per_job]))
384
+
385
+ return test_set
386
+
387
+
388
+ # get librispeech test-clean cross sentence test
389
+
390
+ def get_librispeech_test(metalst, gen_wav_dir, gpus, librispeech_test_clean_path, eval_ground_truth = False):
391
+ f = open(metalst)
392
+ lines = f.readlines()
393
+ f.close()
394
+
395
+ test_set_ = []
396
+ for line in tqdm(lines):
397
+ ref_utt, ref_dur, ref_txt, gen_utt, gen_dur, gen_txt = line.strip().split('\t')
398
+
399
+ if eval_ground_truth:
400
+ gen_spk_id, gen_chaptr_id, _ = gen_utt.split('-')
401
+ gen_wav = os.path.join(librispeech_test_clean_path, gen_spk_id, gen_chaptr_id, gen_utt + '.flac')
402
+ else:
403
+ if not os.path.exists(os.path.join(gen_wav_dir, gen_utt + '.wav')):
404
+ raise FileNotFoundError(f"Generated wav not found: {gen_utt}")
405
+ gen_wav = os.path.join(gen_wav_dir, gen_utt + '.wav')
406
+
407
+ ref_spk_id, ref_chaptr_id, _ = ref_utt.split('-')
408
+ ref_wav = os.path.join(librispeech_test_clean_path, ref_spk_id, ref_chaptr_id, ref_utt + '.flac')
409
+
410
+ test_set_.append((gen_wav, ref_wav, gen_txt))
411
+
412
+ num_jobs = len(gpus)
413
+ if num_jobs == 1:
414
+ return [(gpus[0], test_set_)]
415
+
416
+ wav_per_job = len(test_set_) // num_jobs + 1
417
+ test_set = []
418
+ for i in range(num_jobs):
419
+ test_set.append((gpus[i], test_set_[i*wav_per_job:(i+1)*wav_per_job]))
420
+
421
+ return test_set
422
+
423
+
424
+ # load asr model
425
+
426
+ def load_asr_model(lang, ckpt_dir = ""):
427
+ if lang == "zh":
428
+ model = AutoModel(
429
+ model = os.path.join(ckpt_dir, "paraformer-zh"),
430
+ # vad_model = os.path.join(ckpt_dir, "fsmn-vad"),
431
+ # punc_model = os.path.join(ckpt_dir, "ct-punc"),
432
+ # spk_model = os.path.join(ckpt_dir, "cam++"),
433
+ disable_update=True,
434
+ ) # following seed-tts setting
435
+ elif lang == "en":
436
+ model_size = "large-v3" if ckpt_dir == "" else ckpt_dir
437
+ model = WhisperModel(model_size, device="cuda", compute_type="float16")
438
+ return model
439
+
440
+
441
+ # WER Evaluation, the way Seed-TTS does
442
+
443
+ def run_asr_wer(args):
444
+ rank, lang, test_set, ckpt_dir = args
445
+
446
+ if lang == "zh":
447
+ torch.cuda.set_device(rank)
448
+ elif lang == "en":
449
+ os.environ["CUDA_VISIBLE_DEVICES"] = str(rank)
450
+ else:
451
+ raise NotImplementedError("lang support only 'zh' (funasr paraformer-zh), 'en' (faster-whisper-large-v3), for now.")
452
+
453
+ asr_model = load_asr_model(lang, ckpt_dir = ckpt_dir)
454
+
455
+ punctuation_all = punctuation + string.punctuation
456
+ wers = []
457
+
458
+ for gen_wav, prompt_wav, truth in tqdm(test_set):
459
+ if lang == "zh":
460
+ res = asr_model.generate(input=gen_wav, batch_size_s=300, disable_pbar=True)
461
+ hypo = res[0]["text"]
462
+ hypo = zhconv.convert(hypo, 'zh-cn')
463
+ elif lang == "en":
464
+ segments, _ = asr_model.transcribe(gen_wav, beam_size=5, language="en")
465
+ hypo = ''
466
+ for segment in segments:
467
+ hypo = hypo + ' ' + segment.text
468
+
469
+ # raw_truth = truth
470
+ # raw_hypo = hypo
471
+
472
+ for x in punctuation_all:
473
+ truth = truth.replace(x, '')
474
+ hypo = hypo.replace(x, '')
475
+
476
+ truth = truth.replace(' ', ' ')
477
+ hypo = hypo.replace(' ', ' ')
478
+
479
+ if lang == "zh":
480
+ truth = " ".join([x for x in truth])
481
+ hypo = " ".join([x for x in hypo])
482
+ elif lang == "en":
483
+ truth = truth.lower()
484
+ hypo = hypo.lower()
485
+
486
+ measures = compute_measures(truth, hypo)
487
+ wer = measures["wer"]
488
+
489
+ # ref_list = truth.split(" ")
490
+ # subs = measures["substitutions"] / len(ref_list)
491
+ # dele = measures["deletions"] / len(ref_list)
492
+ # inse = measures["insertions"] / len(ref_list)
493
+
494
+ wers.append(wer)
495
+
496
+ return wers
497
+
498
+
499
+ # SIM Evaluation
500
+
501
+ def run_sim(args):
502
+ rank, test_set, ckpt_dir = args
503
+ device = f"cuda:{rank}"
504
+
505
+ model = ECAPA_TDNN_SMALL(feat_dim=1024, feat_type='wavlm_large', config_path=None)
506
+ state_dict = torch.load(ckpt_dir, map_location=lambda storage, loc: storage)
507
+ model.load_state_dict(state_dict['model'], strict=False)
508
+
509
+ use_gpu=True if torch.cuda.is_available() else False
510
+ if use_gpu:
511
+ model = model.cuda(device)
512
+ model.eval()
513
+
514
+ sim_list = []
515
+ for wav1, wav2, truth in tqdm(test_set):
516
+
517
+ wav1, sr1 = torchaudio.load(wav1)
518
+ wav2, sr2 = torchaudio.load(wav2)
519
+
520
+ resample1 = torchaudio.transforms.Resample(orig_freq=sr1, new_freq=16000)
521
+ resample2 = torchaudio.transforms.Resample(orig_freq=sr2, new_freq=16000)
522
+ wav1 = resample1(wav1)
523
+ wav2 = resample2(wav2)
524
+
525
+ if use_gpu:
526
+ wav1 = wav1.cuda(device)
527
+ wav2 = wav2.cuda(device)
528
+ with torch.no_grad():
529
+ emb1 = model(wav1)
530
+ emb2 = model(wav2)
531
+
532
+ sim = F.cosine_similarity(emb1, emb2)[0].item()
533
+ # print(f"VSim score between two audios: {sim:.4f} (-1.0, 1.0).")
534
+ sim_list.append(sim)
535
+
536
+ return sim_list
537
+
538
+
539
+ # filter func for dirty data with many repetitions
540
+
541
+ def repetition_found(text, length = 2, tolerance = 10):
542
+ pattern_count = defaultdict(int)
543
+ for i in range(len(text) - length + 1):
544
+ pattern = text[i:i + length]
545
+ pattern_count[pattern] += 1
546
+ for pattern, count in pattern_count.items():
547
+ if count > tolerance:
548
+ return True
549
+ return False
550
+
551
+
552
+ # load model checkpoint for inference
553
+
554
+ def load_checkpoint(model, ckpt_path, device, use_ema = True):
555
+ from ema_pytorch import EMA
556
+
557
+ ckpt_type = ckpt_path.split(".")[-1]
558
+ if ckpt_type == "safetensors":
559
+ from safetensors.torch import load_file
560
+ checkpoint = load_file(ckpt_path, device=device)
561
+ else:
562
+ checkpoint = torch.load(ckpt_path, map_location=device)
563
+
564
+ if use_ema == True:
565
+ ema_model = EMA(model, include_online_model = False).to(device)
566
+ if ckpt_type == "safetensors":
567
+ ema_model.load_state_dict(checkpoint)
568
+ else:
569
+ ema_model.load_state_dict(checkpoint['ema_model_state_dict'])
570
+ ema_model.copy_params_from_ema_to_model()
571
+ else:
572
+ model.load_state_dict(checkpoint['model_state_dict'])
573
+
574
+ return model
requirements.txt ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ accelerate>=0.33.0
2
+ cached_path
3
+ click
4
+ datasets
5
+ einops>=0.8.0
6
+ einx>=0.3.0
7
+ ema_pytorch>=0.5.2
8
+ faster_whisper
9
+ funasr
10
+ gradio
11
+ jieba
12
+ jiwer
13
+ librosa
14
+ matplotlib
15
+ numpy==1.23.5
16
+ pydub
17
+ pypinyin
18
+ safetensors
19
+ soundfile
20
+ # torch>=2.0
21
+ # torchaudio>=2.3.0
22
+ torchdiffeq
23
+ tqdm>=4.65.0
24
+ transformers
25
+ vocos
26
+ wandb
27
+ x_transformers>=1.31.14
28
+ zhconv
29
+ zhon
scripts/count_max_epoch.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''ADAPTIVE BATCH SIZE'''
2
+ print('Adaptive batch size: using grouping batch sampler, frames_per_gpu fixed fed in')
3
+ print(' -> least padding, gather wavs with accumulated frames in a batch\n')
4
+
5
+ # data
6
+ total_hours = 95282
7
+ mel_hop_length = 256
8
+ mel_sampling_rate = 24000
9
+
10
+ # target
11
+ wanted_max_updates = 1000000
12
+
13
+ # train params
14
+ gpus = 8
15
+ frames_per_gpu = 38400 # 8 * 38400 = 307200
16
+ grad_accum = 1
17
+
18
+ # intermediate
19
+ mini_batch_frames = frames_per_gpu * grad_accum * gpus
20
+ mini_batch_hours = mini_batch_frames * mel_hop_length / mel_sampling_rate / 3600
21
+ updates_per_epoch = total_hours / mini_batch_hours
22
+ steps_per_epoch = updates_per_epoch * grad_accum
23
+
24
+ # result
25
+ epochs = wanted_max_updates / updates_per_epoch
26
+ print(f"epochs should be set to: {epochs:.0f} ({epochs/grad_accum:.1f} x gd_acum {grad_accum})")
27
+ print(f"progress_bar should show approx. 0/{updates_per_epoch:.0f} updates")
28
+ print(f" or approx. 0/{steps_per_epoch:.0f} steps")
29
+
30
+ # others
31
+ print(f"total {total_hours:.0f} hours")
32
+ print(f"mini-batch of {mini_batch_frames:.0f} frames, {mini_batch_hours:.2f} hours per mini-batch")
scripts/count_params_gflops.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys, os
2
+ sys.path.append(os.getcwd())
3
+
4
+ from model import M2_TTS, UNetT, DiT, MMDiT
5
+
6
+ import torch
7
+ import thop
8
+
9
+
10
+ ''' ~155M '''
11
+ # transformer = UNetT(dim = 768, depth = 20, heads = 12, ff_mult = 4)
12
+ # transformer = UNetT(dim = 768, depth = 20, heads = 12, ff_mult = 4, text_dim = 512, conv_layers = 4)
13
+ # transformer = DiT(dim = 768, depth = 18, heads = 12, ff_mult = 2)
14
+ # transformer = DiT(dim = 768, depth = 18, heads = 12, ff_mult = 2, text_dim = 512, conv_layers = 4)
15
+ # transformer = DiT(dim = 768, depth = 18, heads = 12, ff_mult = 2, text_dim = 512, conv_layers = 4, long_skip_connection = True)
16
+ # transformer = MMDiT(dim = 512, depth = 16, heads = 16, ff_mult = 2)
17
+
18
+ ''' ~335M '''
19
+ # FLOPs: 622.1 G, Params: 333.2 M
20
+ # transformer = UNetT(dim = 1024, depth = 24, heads = 16, ff_mult = 4)
21
+ # FLOPs: 363.4 G, Params: 335.8 M
22
+ transformer = DiT(dim = 1024, depth = 22, heads = 16, ff_mult = 2, text_dim = 512, conv_layers = 4)
23
+
24
+
25
+ model = M2_TTS(transformer=transformer)
26
+ target_sample_rate = 24000
27
+ n_mel_channels = 100
28
+ hop_length = 256
29
+ duration = 20
30
+ frame_length = int(duration * target_sample_rate / hop_length)
31
+ text_length = 150
32
+
33
+ flops, params = thop.profile(model, inputs=(torch.randn(1, frame_length, n_mel_channels), torch.zeros(1, text_length, dtype=torch.long)))
34
+ print(f"FLOPs: {flops / 1e9} G")
35
+ print(f"Params: {params / 1e6} M")
scripts/eval_infer_batch.py ADDED
@@ -0,0 +1,199 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys, os
2
+ sys.path.append(os.getcwd())
3
+
4
+ import time
5
+ import random
6
+ from tqdm import tqdm
7
+ import argparse
8
+
9
+ import torch
10
+ import torchaudio
11
+ from accelerate import Accelerator
12
+ from einops import rearrange
13
+ from vocos import Vocos
14
+
15
+ from model import CFM, UNetT, DiT
16
+ from model.utils import (
17
+ load_checkpoint,
18
+ get_tokenizer,
19
+ get_seedtts_testset_metainfo,
20
+ get_librispeech_test_clean_metainfo,
21
+ get_inference_prompt,
22
+ )
23
+
24
+ accelerator = Accelerator()
25
+ device = f"cuda:{accelerator.process_index}"
26
+
27
+
28
+ # --------------------- Dataset Settings -------------------- #
29
+
30
+ target_sample_rate = 24000
31
+ n_mel_channels = 100
32
+ hop_length = 256
33
+ target_rms = 0.1
34
+
35
+ tokenizer = "pinyin"
36
+
37
+
38
+ # ---------------------- infer setting ---------------------- #
39
+
40
+ parser = argparse.ArgumentParser(description="batch inference")
41
+
42
+ parser.add_argument('-s', '--seed', default=None, type=int)
43
+ parser.add_argument('-d', '--dataset', default="Emilia_ZH_EN")
44
+ parser.add_argument('-n', '--expname', required=True)
45
+ parser.add_argument('-c', '--ckptstep', default=1200000, type=int)
46
+
47
+ parser.add_argument('-nfe', '--nfestep', default=32, type=int)
48
+ parser.add_argument('-o', '--odemethod', default="euler")
49
+ parser.add_argument('-ss', '--swaysampling', default=-1, type=float)
50
+
51
+ parser.add_argument('-t', '--testset', required=True)
52
+
53
+ args = parser.parse_args()
54
+
55
+
56
+ seed = args.seed
57
+ dataset_name = args.dataset
58
+ exp_name = args.expname
59
+ ckpt_step = args.ckptstep
60
+ ckpt_path = f"ckpts/{exp_name}/model_{ckpt_step}.pt"
61
+
62
+ nfe_step = args.nfestep
63
+ ode_method = args.odemethod
64
+ sway_sampling_coef = args.swaysampling
65
+
66
+ testset = args.testset
67
+
68
+
69
+ infer_batch_size = 1 # max frames. 1 for ddp single inference (recommended)
70
+ cfg_strength = 2.
71
+ speed = 1.
72
+ use_truth_duration = False
73
+ no_ref_audio = False
74
+
75
+
76
+ if exp_name == "F5TTS_Base":
77
+ model_cls = DiT
78
+ model_cfg = dict(dim = 1024, depth = 22, heads = 16, ff_mult = 2, text_dim = 512, conv_layers = 4)
79
+
80
+ elif exp_name == "E2TTS_Base":
81
+ model_cls = UNetT
82
+ model_cfg = dict(dim = 1024, depth = 24, heads = 16, ff_mult = 4)
83
+
84
+
85
+ if testset == "ls_pc_test_clean":
86
+ metalst = "data/librispeech_pc_test_clean_cross_sentence.lst"
87
+ librispeech_test_clean_path = "<SOME_PATH>/LibriSpeech/test-clean" # test-clean path
88
+ metainfo = get_librispeech_test_clean_metainfo(metalst, librispeech_test_clean_path)
89
+
90
+ elif testset == "seedtts_test_zh":
91
+ metalst = "data/seedtts_testset/zh/meta.lst"
92
+ metainfo = get_seedtts_testset_metainfo(metalst)
93
+
94
+ elif testset == "seedtts_test_en":
95
+ metalst = "data/seedtts_testset/en/meta.lst"
96
+ metainfo = get_seedtts_testset_metainfo(metalst)
97
+
98
+
99
+ # path to save genereted wavs
100
+ if seed is None: seed = random.randint(-10000, 10000)
101
+ output_dir = f"results/{exp_name}_{ckpt_step}/{testset}/" \
102
+ f"seed{seed}_{ode_method}_nfe{nfe_step}" \
103
+ f"{f'_ss{sway_sampling_coef}' if sway_sampling_coef else ''}" \
104
+ f"_cfg{cfg_strength}_speed{speed}" \
105
+ f"{'_gt-dur' if use_truth_duration else ''}" \
106
+ f"{'_no-ref-audio' if no_ref_audio else ''}"
107
+
108
+
109
+ # -------------------------------------------------#
110
+
111
+ use_ema = True
112
+
113
+ prompts_all = get_inference_prompt(
114
+ metainfo,
115
+ speed = speed,
116
+ tokenizer = tokenizer,
117
+ target_sample_rate = target_sample_rate,
118
+ n_mel_channels = n_mel_channels,
119
+ hop_length = hop_length,
120
+ target_rms = target_rms,
121
+ use_truth_duration = use_truth_duration,
122
+ infer_batch_size = infer_batch_size,
123
+ )
124
+
125
+ # Vocoder model
126
+ local = False
127
+ if local:
128
+ vocos_local_path = "../checkpoints/charactr/vocos-mel-24khz"
129
+ vocos = Vocos.from_hparams(f"{vocos_local_path}/config.yaml")
130
+ state_dict = torch.load(f"{vocos_local_path}/pytorch_model.bin", map_location=device)
131
+ vocos.load_state_dict(state_dict)
132
+ vocos.eval()
133
+ else:
134
+ vocos = Vocos.from_pretrained("charactr/vocos-mel-24khz")
135
+
136
+ # Tokenizer
137
+ vocab_char_map, vocab_size = get_tokenizer(dataset_name, tokenizer)
138
+
139
+ # Model
140
+ model = CFM(
141
+ transformer = model_cls(
142
+ **model_cfg,
143
+ text_num_embeds = vocab_size,
144
+ mel_dim = n_mel_channels
145
+ ),
146
+ mel_spec_kwargs = dict(
147
+ target_sample_rate = target_sample_rate,
148
+ n_mel_channels = n_mel_channels,
149
+ hop_length = hop_length,
150
+ ),
151
+ odeint_kwargs = dict(
152
+ method = ode_method,
153
+ ),
154
+ vocab_char_map = vocab_char_map,
155
+ ).to(device)
156
+
157
+ model = load_checkpoint(model, ckpt_path, device, use_ema = use_ema)
158
+
159
+ if not os.path.exists(output_dir) and accelerator.is_main_process:
160
+ os.makedirs(output_dir)
161
+
162
+ # start batch inference
163
+ accelerator.wait_for_everyone()
164
+ start = time.time()
165
+
166
+ with accelerator.split_between_processes(prompts_all) as prompts:
167
+
168
+ for prompt in tqdm(prompts, disable=not accelerator.is_local_main_process):
169
+ utts, ref_rms_list, ref_mels, ref_mel_lens, total_mel_lens, final_text_list = prompt
170
+ ref_mels = ref_mels.to(device)
171
+ ref_mel_lens = torch.tensor(ref_mel_lens, dtype = torch.long).to(device)
172
+ total_mel_lens = torch.tensor(total_mel_lens, dtype = torch.long).to(device)
173
+
174
+ # Inference
175
+ with torch.inference_mode():
176
+ generated, _ = model.sample(
177
+ cond = ref_mels,
178
+ text = final_text_list,
179
+ duration = total_mel_lens,
180
+ lens = ref_mel_lens,
181
+ steps = nfe_step,
182
+ cfg_strength = cfg_strength,
183
+ sway_sampling_coef = sway_sampling_coef,
184
+ no_ref_audio = no_ref_audio,
185
+ seed = seed,
186
+ )
187
+ # Final result
188
+ for i, gen in enumerate(generated):
189
+ gen = gen[ref_mel_lens[i]:total_mel_lens[i], :].unsqueeze(0)
190
+ gen_mel_spec = rearrange(gen, '1 n d -> 1 d n')
191
+ generated_wave = vocos.decode(gen_mel_spec.cpu())
192
+ if ref_rms_list[i] < target_rms:
193
+ generated_wave = generated_wave * ref_rms_list[i] / target_rms
194
+ torchaudio.save(f"{output_dir}/{utts[i]}.wav", generated_wave, target_sample_rate)
195
+
196
+ accelerator.wait_for_everyone()
197
+ if accelerator.is_main_process:
198
+ timediff = time.time() - start
199
+ print(f"Done batch inference in {timediff / 60 :.2f} minutes.")
scripts/eval_infer_batch.sh ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ # e.g. F5-TTS, 16 NFE
4
+ accelerate launch scripts/eval_infer_batch.py -n "F5TTS_Base" -t "seedtts_test_zh" -nfe 16
5
+ accelerate launch scripts/eval_infer_batch.py -n "F5TTS_Base" -t "seedtts_test_en" -nfe 16
6
+ accelerate launch scripts/eval_infer_batch.py -n "F5TTS_Base" -t "ls_pc_test_clean" -nfe 16
7
+
8
+ # e.g. Vanilla E2 TTS, 32 NFE
9
+ accelerate launch scripts/eval_infer_batch.py -n "E2TTS_Base" -t "seedtts_test_zh" -o "midpoint" -ss 0
10
+ accelerate launch scripts/eval_infer_batch.py -n "E2TTS_Base" -t "seedtts_test_en" -o "midpoint" -ss 0
11
+ accelerate launch scripts/eval_infer_batch.py -n "E2TTS_Base" -t "ls_pc_test_clean" -o "midpoint" -ss 0
12
+
13
+ # etc.
scripts/eval_librispeech_test_clean.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Evaluate with Librispeech test-clean, ~3s prompt to generate 4-10s audio (the way of valle/voicebox evaluation)
2
+
3
+ import sys, os
4
+ sys.path.append(os.getcwd())
5
+
6
+ import multiprocessing as mp
7
+ import numpy as np
8
+
9
+ from model.utils import (
10
+ get_librispeech_test,
11
+ run_asr_wer,
12
+ run_sim,
13
+ )
14
+
15
+
16
+ eval_task = "wer" # sim | wer
17
+ lang = "en"
18
+ metalst = "data/librispeech_pc_test_clean_cross_sentence.lst"
19
+ librispeech_test_clean_path = "<SOME_PATH>/LibriSpeech/test-clean" # test-clean path
20
+ gen_wav_dir = "PATH_TO_GENERATED" # generated wavs
21
+
22
+ gpus = [0,1,2,3,4,5,6,7]
23
+ test_set = get_librispeech_test(metalst, gen_wav_dir, gpus, librispeech_test_clean_path)
24
+
25
+ ## In LibriSpeech, some speakers utilized varying voice characteristics for different characters in the book,
26
+ ## leading to a low similarity for the ground truth in some cases.
27
+ # test_set = get_librispeech_test(metalst, gen_wav_dir, gpus, librispeech_test_clean_path, eval_ground_truth = True) # eval ground truth
28
+
29
+ local = False
30
+ if local: # use local custom checkpoint dir
31
+ asr_ckpt_dir = "../checkpoints/Systran/faster-whisper-large-v3"
32
+ else:
33
+ asr_ckpt_dir = "" # auto download to cache dir
34
+
35
+ wavlm_ckpt_dir = "../checkpoints/UniSpeech/wavlm_large_finetune.pth"
36
+
37
+
38
+ # --------------------------- WER ---------------------------
39
+
40
+ if eval_task == "wer":
41
+ wers = []
42
+
43
+ with mp.Pool(processes=len(gpus)) as pool:
44
+ args = [(rank, lang, sub_test_set, asr_ckpt_dir) for (rank, sub_test_set) in test_set]
45
+ results = pool.map(run_asr_wer, args)
46
+ for wers_ in results:
47
+ wers.extend(wers_)
48
+
49
+ wer = round(np.mean(wers)*100, 3)
50
+ print(f"\nTotal {len(wers)} samples")
51
+ print(f"WER : {wer}%")
52
+
53
+
54
+ # --------------------------- SIM ---------------------------
55
+
56
+ if eval_task == "sim":
57
+ sim_list = []
58
+
59
+ with mp.Pool(processes=len(gpus)) as pool:
60
+ args = [(rank, sub_test_set, wavlm_ckpt_dir) for (rank, sub_test_set) in test_set]
61
+ results = pool.map(run_sim, args)
62
+ for sim_ in results:
63
+ sim_list.extend(sim_)
64
+
65
+ sim = round(sum(sim_list)/len(sim_list), 3)
66
+ print(f"\nTotal {len(sim_list)} samples")
67
+ print(f"SIM : {sim}")
scripts/eval_seedtts_testset.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Evaluate with Seed-TTS testset
2
+
3
+ import sys, os
4
+ sys.path.append(os.getcwd())
5
+
6
+ import multiprocessing as mp
7
+ import numpy as np
8
+
9
+ from model.utils import (
10
+ get_seed_tts_test,
11
+ run_asr_wer,
12
+ run_sim,
13
+ )
14
+
15
+
16
+ eval_task = "wer" # sim | wer
17
+ lang = "zh" # zh | en
18
+ metalst = f"data/seedtts_testset/{lang}/meta.lst" # seed-tts testset
19
+ # gen_wav_dir = f"data/seedtts_testset/{lang}/wavs" # ground truth wavs
20
+ gen_wav_dir = f"PATH_TO_GENERATED" # generated wavs
21
+
22
+
23
+ # NOTE. paraformer-zh result will be slightly different according to the number of gpus, cuz batchsize is different
24
+ # zh 1.254 seems a result of 4 workers wer_seed_tts
25
+ gpus = [0,1,2,3,4,5,6,7]
26
+ test_set = get_seed_tts_test(metalst, gen_wav_dir, gpus)
27
+
28
+ local = False
29
+ if local: # use local custom checkpoint dir
30
+ if lang == "zh":
31
+ asr_ckpt_dir = "../checkpoints/funasr" # paraformer-zh dir under funasr
32
+ elif lang == "en":
33
+ asr_ckpt_dir = "../checkpoints/Systran/faster-whisper-large-v3"
34
+ else:
35
+ asr_ckpt_dir = "" # auto download to cache dir
36
+
37
+ wavlm_ckpt_dir = "../checkpoints/UniSpeech/wavlm_large_finetune.pth"
38
+
39
+
40
+ # --------------------------- WER ---------------------------
41
+
42
+ if eval_task == "wer":
43
+ wers = []
44
+
45
+ with mp.Pool(processes=len(gpus)) as pool:
46
+ args = [(rank, lang, sub_test_set, asr_ckpt_dir) for (rank, sub_test_set) in test_set]
47
+ results = pool.map(run_asr_wer, args)
48
+ for wers_ in results:
49
+ wers.extend(wers_)
50
+
51
+ wer = round(np.mean(wers)*100, 3)
52
+ print(f"\nTotal {len(wers)} samples")
53
+ print(f"WER : {wer}%")
54
+
55
+
56
+ # --------------------------- SIM ---------------------------
57
+
58
+ if eval_task == "sim":
59
+ sim_list = []
60
+
61
+ with mp.Pool(processes=len(gpus)) as pool:
62
+ args = [(rank, sub_test_set, wavlm_ckpt_dir) for (rank, sub_test_set) in test_set]
63
+ results = pool.map(run_sim, args)
64
+ for sim_ in results:
65
+ sim_list.extend(sim_)
66
+
67
+ sim = round(sum(sim_list)/len(sim_list), 3)
68
+ print(f"\nTotal {len(sim_list)} samples")
69
+ print(f"SIM : {sim}")
scripts/prepare_emilia.py ADDED
@@ -0,0 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Emilia Dataset: https://huggingface.co/datasets/amphion/Emilia-Dataset/tree/fc71e07
2
+ # if use updated new version, i.e. WebDataset, feel free to modify / draft your own script
3
+
4
+ # generate audio text map for Emilia ZH & EN
5
+ # evaluate for vocab size
6
+
7
+ import sys, os
8
+ sys.path.append(os.getcwd())
9
+
10
+ from pathlib import Path
11
+ import json
12
+ from tqdm import tqdm
13
+ from concurrent.futures import ProcessPoolExecutor
14
+
15
+ from datasets import Dataset
16
+ from datasets.arrow_writer import ArrowWriter
17
+
18
+ from model.utils import (
19
+ repetition_found,
20
+ convert_char_to_pinyin,
21
+ )
22
+
23
+
24
+ out_zh = {"ZH_B00041_S06226", "ZH_B00042_S09204", "ZH_B00065_S09430", "ZH_B00065_S09431", "ZH_B00066_S09327", "ZH_B00066_S09328"}
25
+ zh_filters = ["い", "て"]
26
+ # seems synthesized audios, or heavily code-switched
27
+ out_en = {
28
+ "EN_B00013_S00913", "EN_B00042_S00120", "EN_B00055_S04111", "EN_B00061_S00693", "EN_B00061_S01494", "EN_B00061_S03375",
29
+
30
+ "EN_B00059_S00092", "EN_B00111_S04300", "EN_B00100_S03759", "EN_B00087_S03811", "EN_B00059_S00950", "EN_B00089_S00946", "EN_B00078_S05127", "EN_B00070_S04089", "EN_B00074_S09659", "EN_B00061_S06983", "EN_B00061_S07060", "EN_B00059_S08397", "EN_B00082_S06192", "EN_B00091_S01238", "EN_B00089_S07349", "EN_B00070_S04343", "EN_B00061_S02400", "EN_B00076_S01262", "EN_B00068_S06467", "EN_B00076_S02943", "EN_B00064_S05954", "EN_B00061_S05386", "EN_B00066_S06544", "EN_B00076_S06944", "EN_B00072_S08620", "EN_B00076_S07135", "EN_B00076_S09127", "EN_B00065_S00497", "EN_B00059_S06227", "EN_B00063_S02859", "EN_B00075_S01547", "EN_B00061_S08286", "EN_B00079_S02901", "EN_B00092_S03643", "EN_B00096_S08653", "EN_B00063_S04297", "EN_B00063_S04614", "EN_B00079_S04698", "EN_B00104_S01666", "EN_B00061_S09504", "EN_B00061_S09694", "EN_B00065_S05444", "EN_B00063_S06860", "EN_B00065_S05725", "EN_B00069_S07628", "EN_B00083_S03875", "EN_B00071_S07665", "EN_B00071_S07665", "EN_B00062_S04187", "EN_B00065_S09873", "EN_B00065_S09922", "EN_B00084_S02463", "EN_B00067_S05066", "EN_B00106_S08060", "EN_B00073_S06399", "EN_B00073_S09236", "EN_B00087_S00432", "EN_B00085_S05618", "EN_B00064_S01262", "EN_B00072_S01739", "EN_B00059_S03913", "EN_B00069_S04036", "EN_B00067_S05623", "EN_B00060_S05389", "EN_B00060_S07290", "EN_B00062_S08995",
31
+ }
32
+ en_filters = ["ا", "い", "て"]
33
+
34
+
35
+ def deal_with_audio_dir(audio_dir):
36
+ audio_jsonl = audio_dir.with_suffix(".jsonl")
37
+ sub_result, durations = [], []
38
+ vocab_set = set()
39
+ bad_case_zh = 0
40
+ bad_case_en = 0
41
+ with open(audio_jsonl, "r") as f:
42
+ lines = f.readlines()
43
+ for line in tqdm(lines, desc=f"{audio_jsonl.stem}"):
44
+ obj = json.loads(line)
45
+ text = obj["text"]
46
+ if obj['language'] == "zh":
47
+ if obj["wav"].split("/")[1] in out_zh or any(f in text for f in zh_filters) or repetition_found(text):
48
+ bad_case_zh += 1
49
+ continue
50
+ else:
51
+ text = text.translate(str.maketrans({',': ',', '!': '!', '?': '?'})) # not "。" cuz much code-switched
52
+ if obj['language'] == "en":
53
+ if obj["wav"].split("/")[1] in out_en or any(f in text for f in en_filters) or repetition_found(text, length=4):
54
+ bad_case_en += 1
55
+ continue
56
+ if tokenizer == "pinyin":
57
+ text = convert_char_to_pinyin([text], polyphone = polyphone)[0]
58
+ duration = obj["duration"]
59
+ sub_result.append({"audio_path": str(audio_dir.parent / obj["wav"]), "text": text, "duration": duration})
60
+ durations.append(duration)
61
+ vocab_set.update(list(text))
62
+ return sub_result, durations, vocab_set, bad_case_zh, bad_case_en
63
+
64
+
65
+ def main():
66
+ assert tokenizer in ["pinyin", "char"]
67
+ result = []
68
+ duration_list = []
69
+ text_vocab_set = set()
70
+ total_bad_case_zh = 0
71
+ total_bad_case_en = 0
72
+
73
+ # process raw data
74
+ executor = ProcessPoolExecutor(max_workers=max_workers)
75
+ futures = []
76
+ for lang in langs:
77
+ dataset_path = Path(os.path.join(dataset_dir, lang))
78
+ [
79
+ futures.append(executor.submit(deal_with_audio_dir, audio_dir))
80
+ for audio_dir in dataset_path.iterdir()
81
+ if audio_dir.is_dir()
82
+ ]
83
+ for futures in tqdm(futures, total=len(futures)):
84
+ sub_result, durations, vocab_set, bad_case_zh, bad_case_en = futures.result()
85
+ result.extend(sub_result)
86
+ duration_list.extend(durations)
87
+ text_vocab_set.update(vocab_set)
88
+ total_bad_case_zh += bad_case_zh
89
+ total_bad_case_en += bad_case_en
90
+ executor.shutdown()
91
+
92
+ # save preprocessed dataset to disk
93
+ if not os.path.exists(f"data/{dataset_name}"):
94
+ os.makedirs(f"data/{dataset_name}")
95
+ print(f"\nSaving to data/{dataset_name} ...")
96
+ # dataset = Dataset.from_dict({"audio_path": audio_path_list, "text": text_list, "duration": duration_list}) # oom
97
+ # dataset.save_to_disk(f"data/{dataset_name}/raw", max_shard_size="2GB")
98
+ with ArrowWriter(path=f"data/{dataset_name}/raw.arrow") as writer:
99
+ for line in tqdm(result, desc=f"Writing to raw.arrow ..."):
100
+ writer.write(line)
101
+
102
+ # dup a json separately saving duration in case for DynamicBatchSampler ease
103
+ with open(f"data/{dataset_name}/duration.json", 'w', encoding='utf-8') as f:
104
+ json.dump({"duration": duration_list}, f, ensure_ascii=False)
105
+
106
+ # vocab map, i.e. tokenizer
107
+ # add alphabets and symbols (optional, if plan to ft on de/fr etc.)
108
+ # if tokenizer == "pinyin":
109
+ # text_vocab_set.update([chr(i) for i in range(32, 127)] + [chr(i) for i in range(192, 256)])
110
+ with open(f"data/{dataset_name}/vocab.txt", "w") as f:
111
+ for vocab in sorted(text_vocab_set):
112
+ f.write(vocab + "\n")
113
+
114
+ print(f"\nFor {dataset_name}, sample count: {len(result)}")
115
+ print(f"For {dataset_name}, vocab size is: {len(text_vocab_set)}")
116
+ print(f"For {dataset_name}, total {sum(duration_list)/3600:.2f} hours")
117
+ if "ZH" in langs: print(f"Bad zh transcription case: {total_bad_case_zh}")
118
+ if "EN" in langs: print(f"Bad en transcription case: {total_bad_case_en}\n")
119
+
120
+
121
+ if __name__ == "__main__":
122
+
123
+ max_workers = 32
124
+
125
+ tokenizer = "pinyin" # "pinyin" | "char"
126
+ polyphone = True
127
+
128
+ langs = ["ZH", "EN"]
129
+ dataset_dir = "<SOME_PATH>/Emilia_Dataset/raw"
130
+ dataset_name = f"Emilia_{'_'.join(langs)}_{tokenizer}"
131
+ print(f"\nPrepare for {dataset_name}\n")
132
+
133
+ main()
134
+
135
+ # Emilia ZH & EN
136
+ # samples count 37837916 (after removal)
137
+ # pinyin vocab size 2543 (polyphone)
138
+ # total duration 95281.87 (hours)
139
+ # bad zh asr cnt 230435 (samples)
140
+ # bad eh asr cnt 37217 (samples)
141
+
142
+ # vocab size may be slightly different due to jieba tokenizer and pypinyin (e.g. way of polyphoneme)
143
+ # please be careful if using pretrained model, make sure the vocab.txt is same
scripts/prepare_wenetspeech4tts.py ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # generate audio text map for WenetSpeech4TTS
2
+ # evaluate for vocab size
3
+
4
+ import sys, os
5
+ sys.path.append(os.getcwd())
6
+
7
+ import json
8
+ from tqdm import tqdm
9
+ from concurrent.futures import ProcessPoolExecutor
10
+
11
+ import torchaudio
12
+ from datasets import Dataset
13
+
14
+ from model.utils import convert_char_to_pinyin
15
+
16
+
17
+ def deal_with_sub_path_files(dataset_path, sub_path):
18
+ print(f"Dealing with: {sub_path}")
19
+
20
+ text_dir = os.path.join(dataset_path, sub_path, "txts")
21
+ audio_dir = os.path.join(dataset_path, sub_path, "wavs")
22
+ text_files = os.listdir(text_dir)
23
+
24
+ audio_paths, texts, durations = [], [], []
25
+ for text_file in tqdm(text_files):
26
+ with open(os.path.join(text_dir, text_file), 'r', encoding='utf-8') as file:
27
+ first_line = file.readline().split("\t")
28
+ audio_nm = first_line[0]
29
+ audio_path = os.path.join(audio_dir, audio_nm + ".wav")
30
+ text = first_line[1].strip()
31
+
32
+ audio_paths.append(audio_path)
33
+
34
+ if tokenizer == "pinyin":
35
+ texts.extend(convert_char_to_pinyin([text], polyphone = polyphone))
36
+ elif tokenizer == "char":
37
+ texts.append(text)
38
+
39
+ audio, sample_rate = torchaudio.load(audio_path)
40
+ durations.append(audio.shape[-1] / sample_rate)
41
+
42
+ return audio_paths, texts, durations
43
+
44
+
45
+ def main():
46
+ assert tokenizer in ["pinyin", "char"]
47
+
48
+ audio_path_list, text_list, duration_list = [], [], []
49
+
50
+ executor = ProcessPoolExecutor(max_workers=max_workers)
51
+ futures = []
52
+ for dataset_path in dataset_paths:
53
+ sub_items = os.listdir(dataset_path)
54
+ sub_paths = [item for item in sub_items if os.path.isdir(os.path.join(dataset_path, item))]
55
+ for sub_path in sub_paths:
56
+ futures.append(executor.submit(deal_with_sub_path_files, dataset_path, sub_path))
57
+ for future in tqdm(futures, total=len(futures)):
58
+ audio_paths, texts, durations = future.result()
59
+ audio_path_list.extend(audio_paths)
60
+ text_list.extend(texts)
61
+ duration_list.extend(durations)
62
+ executor.shutdown()
63
+
64
+ if not os.path.exists("data"):
65
+ os.makedirs("data")
66
+
67
+ print(f"\nSaving to data/{dataset_name}_{tokenizer} ...")
68
+ dataset = Dataset.from_dict({"audio_path": audio_path_list, "text": text_list, "duration": duration_list})
69
+ dataset.save_to_disk(f"data/{dataset_name}_{tokenizer}/raw", max_shard_size="2GB") # arrow format
70
+
71
+ with open(f"data/{dataset_name}_{tokenizer}/duration.json", 'w', encoding='utf-8') as f:
72
+ json.dump({"duration": duration_list}, f, ensure_ascii=False) # dup a json separately saving duration in case for DynamicBatchSampler ease
73
+
74
+ print("\nEvaluating vocab size (all characters and symbols / all phonemes) ...")
75
+ text_vocab_set = set()
76
+ for text in tqdm(text_list):
77
+ text_vocab_set.update(list(text))
78
+
79
+ # add alphabets and symbols (optional, if plan to ft on de/fr etc.)
80
+ if tokenizer == "pinyin":
81
+ text_vocab_set.update([chr(i) for i in range(32, 127)] + [chr(i) for i in range(192, 256)])
82
+
83
+ with open(f"data/{dataset_name}_{tokenizer}/vocab.txt", "w") as f:
84
+ for vocab in sorted(text_vocab_set):
85
+ f.write(vocab + "\n")
86
+ print(f"\nFor {dataset_name}, sample count: {len(text_list)}")
87
+ print(f"For {dataset_name}, vocab size is: {len(text_vocab_set)}\n")
88
+
89
+
90
+ if __name__ == "__main__":
91
+
92
+ max_workers = 32
93
+
94
+ tokenizer = "pinyin" # "pinyin" | "char"
95
+ polyphone = True
96
+ dataset_choice = 1 # 1: Premium, 2: Standard, 3: Basic
97
+
98
+ dataset_name = ["WenetSpeech4TTS_Premium", "WenetSpeech4TTS_Standard", "WenetSpeech4TTS_Basic"][dataset_choice-1]
99
+ dataset_paths = [
100
+ "<SOME_PATH>/WenetSpeech4TTS/Basic",
101
+ "<SOME_PATH>/WenetSpeech4TTS/Standard",
102
+ "<SOME_PATH>/WenetSpeech4TTS/Premium",
103
+ ][-dataset_choice:]
104
+ print(f"\nChoose Dataset: {dataset_name}\n")
105
+
106
+ main()
107
+
108
+ # Results (if adding alphabets with accents and symbols):
109
+ # WenetSpeech4TTS Basic Standard Premium
110
+ # samples count 3932473 1941220 407494
111
+ # pinyin vocab size 1349 1348 1344 (no polyphone)
112
+ # - - 1459 (polyphone)
113
+ # char vocab size 5264 5219 5042
114
+
115
+ # vocab size may be slightly different due to jieba tokenizer and pypinyin (e.g. way of polyphoneme)
116
+ # please be careful if using pretrained model, make sure the vocab.txt is same
speech_edit.py ADDED
@@ -0,0 +1,182 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import torch
4
+ import torch.nn.functional as F
5
+ import torchaudio
6
+ from einops import rearrange
7
+ from vocos import Vocos
8
+
9
+ from model import CFM, UNetT, DiT, MMDiT
10
+ from model.utils import (
11
+ load_checkpoint,
12
+ get_tokenizer,
13
+ convert_char_to_pinyin,
14
+ save_spectrogram,
15
+ )
16
+
17
+ device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
18
+
19
+
20
+ # --------------------- Dataset Settings -------------------- #
21
+
22
+ target_sample_rate = 24000
23
+ n_mel_channels = 100
24
+ hop_length = 256
25
+ target_rms = 0.1
26
+
27
+ tokenizer = "pinyin"
28
+ dataset_name = "Emilia_ZH_EN"
29
+
30
+
31
+ # ---------------------- infer setting ---------------------- #
32
+
33
+ seed = None # int | None
34
+
35
+ exp_name = "F5TTS_Base" # F5TTS_Base | E2TTS_Base
36
+ ckpt_step = 1200000
37
+
38
+ nfe_step = 32 # 16, 32
39
+ cfg_strength = 2.
40
+ ode_method = 'euler' # euler | midpoint
41
+ sway_sampling_coef = -1.
42
+ speed = 1.
43
+
44
+ if exp_name == "F5TTS_Base":
45
+ model_cls = DiT
46
+ model_cfg = dict(dim = 1024, depth = 22, heads = 16, ff_mult = 2, text_dim = 512, conv_layers = 4)
47
+
48
+ elif exp_name == "E2TTS_Base":
49
+ model_cls = UNetT
50
+ model_cfg = dict(dim = 1024, depth = 24, heads = 16, ff_mult = 4)
51
+
52
+ ckpt_path = f"ckpts/{exp_name}/model_{ckpt_step}.pt"
53
+ output_dir = "tests"
54
+
55
+ # [leverage https://github.com/MahmoudAshraf97/ctc-forced-aligner to get char level alignment]
56
+ # pip install git+https://github.com/MahmoudAshraf97/ctc-forced-aligner.git
57
+ # [write the origin_text into a file, e.g. tests/test_edit.txt]
58
+ # ctc-forced-aligner --audio_path "tests/ref_audio/test_en_1_ref_short.wav" --text_path "tests/test_edit.txt" --language "zho" --romanize --split_size "char"
59
+ # [result will be saved at same path of audio file]
60
+ # [--language "zho" for Chinese, "eng" for English]
61
+ # [if local ckpt, set --alignment_model "../checkpoints/mms-300m-1130-forced-aligner"]
62
+
63
+ audio_to_edit = "tests/ref_audio/test_en_1_ref_short.wav"
64
+ origin_text = "Some call me nature, others call me mother nature."
65
+ target_text = "Some call me optimist, others call me realist."
66
+ parts_to_edit = [[1.42, 2.44], [4.04, 4.9], ] # stard_ends of "nature" & "mother nature", in seconds
67
+ fix_duration = [1.2, 1, ] # fix duration for "optimist" & "realist", in seconds
68
+
69
+ # audio_to_edit = "tests/ref_audio/test_zh_1_ref_short.wav"
70
+ # origin_text = "对,这就是我,万人敬仰的太乙真人。"
71
+ # target_text = "对,那就是你,万人敬仰的太白金星。"
72
+ # parts_to_edit = [[0.84, 1.4], [1.92, 2.4], [4.26, 6.26], ]
73
+ # fix_duration = None # use origin text duration
74
+
75
+
76
+ # -------------------------------------------------#
77
+
78
+ use_ema = True
79
+
80
+ if not os.path.exists(output_dir):
81
+ os.makedirs(output_dir)
82
+
83
+ # Vocoder model
84
+ local = False
85
+ if local:
86
+ vocos_local_path = "../checkpoints/charactr/vocos-mel-24khz"
87
+ vocos = Vocos.from_hparams(f"{vocos_local_path}/config.yaml")
88
+ state_dict = torch.load(f"{vocos_local_path}/pytorch_model.bin", map_location=device)
89
+ vocos.load_state_dict(state_dict)
90
+ vocos.eval()
91
+ else:
92
+ vocos = Vocos.from_pretrained("charactr/vocos-mel-24khz")
93
+
94
+ # Tokenizer
95
+ vocab_char_map, vocab_size = get_tokenizer(dataset_name, tokenizer)
96
+
97
+ # Model
98
+ model = CFM(
99
+ transformer = model_cls(
100
+ **model_cfg,
101
+ text_num_embeds = vocab_size,
102
+ mel_dim = n_mel_channels
103
+ ),
104
+ mel_spec_kwargs = dict(
105
+ target_sample_rate = target_sample_rate,
106
+ n_mel_channels = n_mel_channels,
107
+ hop_length = hop_length,
108
+ ),
109
+ odeint_kwargs = dict(
110
+ method = ode_method,
111
+ ),
112
+ vocab_char_map = vocab_char_map,
113
+ ).to(device)
114
+
115
+ model = load_checkpoint(model, ckpt_path, device, use_ema = use_ema)
116
+
117
+ # Audio
118
+ audio, sr = torchaudio.load(audio_to_edit)
119
+ if audio.shape[0] > 1:
120
+ audio = torch.mean(audio, dim=0, keepdim=True)
121
+ rms = torch.sqrt(torch.mean(torch.square(audio)))
122
+ if rms < target_rms:
123
+ audio = audio * target_rms / rms
124
+ if sr != target_sample_rate:
125
+ resampler = torchaudio.transforms.Resample(sr, target_sample_rate)
126
+ audio = resampler(audio)
127
+ offset = 0
128
+ audio_ = torch.zeros(1, 0)
129
+ edit_mask = torch.zeros(1, 0, dtype=torch.bool)
130
+ for part in parts_to_edit:
131
+ start, end = part
132
+ part_dur = end - start if fix_duration is None else fix_duration.pop(0)
133
+ part_dur = part_dur * target_sample_rate
134
+ start = start * target_sample_rate
135
+ audio_ = torch.cat((audio_, audio[:, round(offset):round(start)], torch.zeros(1, round(part_dur))), dim = -1)
136
+ edit_mask = torch.cat((edit_mask,
137
+ torch.ones(1, round((start - offset) / hop_length), dtype = torch.bool),
138
+ torch.zeros(1, round(part_dur / hop_length), dtype = torch.bool)
139
+ ), dim = -1)
140
+ offset = end * target_sample_rate
141
+ # audio = torch.cat((audio_, audio[:, round(offset):]), dim = -1)
142
+ edit_mask = F.pad(edit_mask, (0, audio.shape[-1] // hop_length - edit_mask.shape[-1] + 1), value = True)
143
+ audio = audio.to(device)
144
+ edit_mask = edit_mask.to(device)
145
+
146
+ # Text
147
+ text_list = [target_text]
148
+ if tokenizer == "pinyin":
149
+ final_text_list = convert_char_to_pinyin(text_list)
150
+ else:
151
+ final_text_list = [text_list]
152
+ print(f"text : {text_list}")
153
+ print(f"pinyin: {final_text_list}")
154
+
155
+ # Duration
156
+ ref_audio_len = 0
157
+ duration = audio.shape[-1] // hop_length
158
+
159
+ # Inference
160
+ with torch.inference_mode():
161
+ generated, trajectory = model.sample(
162
+ cond = audio,
163
+ text = final_text_list,
164
+ duration = duration,
165
+ steps = nfe_step,
166
+ cfg_strength = cfg_strength,
167
+ sway_sampling_coef = sway_sampling_coef,
168
+ seed = seed,
169
+ edit_mask = edit_mask,
170
+ )
171
+ print(f"Generated mel: {generated.shape}")
172
+
173
+ # Final result
174
+ generated = generated[:, ref_audio_len:, :]
175
+ generated_mel_spec = rearrange(generated, '1 n d -> 1 d n')
176
+ generated_wave = vocos.decode(generated_mel_spec.cpu())
177
+ if rms < target_rms:
178
+ generated_wave = generated_wave * rms / target_rms
179
+
180
+ save_spectrogram(generated_mel_spec[0].cpu().numpy(), f"{output_dir}/test_single_edit.png")
181
+ torchaudio.save(f"{output_dir}/test_single_edit.wav", generated_wave, target_sample_rate)
182
+ print(f"Generated wav: {generated_wave.shape}")
train.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from model import CFM, UNetT, DiT, MMDiT, Trainer
2
+ from model.utils import get_tokenizer
3
+ from model.dataset import load_dataset
4
+
5
+
6
+ # -------------------------- Dataset Settings --------------------------- #
7
+
8
+ target_sample_rate = 24000
9
+ n_mel_channels = 100
10
+ hop_length = 256
11
+
12
+ tokenizer = "pinyin"
13
+ dataset_name = "Emilia_ZH_EN"
14
+
15
+
16
+ # -------------------------- Training Settings -------------------------- #
17
+
18
+ exp_name = "F5TTS_Base" # F5TTS_Base | E2TTS_Base
19
+
20
+ learning_rate = 7.5e-5
21
+
22
+ batch_size_per_gpu = 38400 # 8 GPUs, 8 * 38400 = 307200
23
+ batch_size_type = "frame" # "frame" or "sample"
24
+ max_samples = 64 # max sequences per batch if use frame-wise batch_size. we set 32 for small models, 64 for base models
25
+ grad_accumulation_steps = 1 # note: updates = steps / grad_accumulation_steps
26
+ max_grad_norm = 1.
27
+
28
+ epochs = 11 # use linear decay, thus epochs control the slope
29
+ num_warmup_updates = 20000 # warmup steps
30
+ save_per_updates = 50000 # save checkpoint per steps
31
+ last_per_steps = 5000 # save last checkpoint per steps
32
+
33
+ # model params
34
+ if exp_name == "F5TTS_Base":
35
+ wandb_resume_id = None
36
+ model_cls = DiT
37
+ model_cfg = dict(dim = 1024, depth = 22, heads = 16, ff_mult = 2, text_dim = 512, conv_layers = 4)
38
+ elif exp_name == "E2TTS_Base":
39
+ wandb_resume_id = None
40
+ model_cls = UNetT
41
+ model_cfg = dict(dim = 1024, depth = 24, heads = 16, ff_mult = 4)
42
+
43
+
44
+ # ----------------------------------------------------------------------- #
45
+
46
+ def main():
47
+
48
+ vocab_char_map, vocab_size = get_tokenizer(dataset_name, tokenizer)
49
+
50
+ mel_spec_kwargs = dict(
51
+ target_sample_rate = target_sample_rate,
52
+ n_mel_channels = n_mel_channels,
53
+ hop_length = hop_length,
54
+ )
55
+
56
+ e2tts = CFM(
57
+ transformer = model_cls(
58
+ **model_cfg,
59
+ text_num_embeds = vocab_size,
60
+ mel_dim = n_mel_channels
61
+ ),
62
+ mel_spec_kwargs = mel_spec_kwargs,
63
+ vocab_char_map = vocab_char_map,
64
+ )
65
+
66
+ trainer = Trainer(
67
+ e2tts,
68
+ epochs,
69
+ learning_rate,
70
+ num_warmup_updates = num_warmup_updates,
71
+ save_per_updates = save_per_updates,
72
+ checkpoint_path = f'ckpts/{exp_name}',
73
+ batch_size = batch_size_per_gpu,
74
+ batch_size_type = batch_size_type,
75
+ max_samples = max_samples,
76
+ grad_accumulation_steps = grad_accumulation_steps,
77
+ max_grad_norm = max_grad_norm,
78
+ wandb_project = "CFM-TTS",
79
+ wandb_run_name = exp_name,
80
+ wandb_resume_id = wandb_resume_id,
81
+ last_per_steps = last_per_steps,
82
+ )
83
+
84
+ train_dataset = load_dataset(dataset_name, tokenizer, mel_spec_kwargs=mel_spec_kwargs)
85
+ trainer.train(train_dataset,
86
+ resumable_with_seed = 666 # seed for shuffling dataset
87
+ )
88
+
89
+
90
+ if __name__ == '__main__':
91
+ main()