Spaces:
Sleeping
Sleeping
ITS-C4SF733\Administrator
commited on
Commit
·
324bf29
1
Parent(s):
00e6160
all resource
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitignore +12 -0
- .ipynb_checkpoints/infer-checkpoint.ipynb +6 -0
- .ipynb_checkpoints/test-checkpoint.py +71 -0
- .ipynb_checkpoints/train-checkpoint.py +1068 -0
- LICENSE +21 -0
- customs/.ipynb_checkpoints/make_custom_dataset-checkpoint.py +83 -0
- customs/make_custom_dataset.py +83 -0
- customs/ph.txt +0 -0
- data/.ipynb_checkpoints/dataset-checkpoint.py +251 -0
- data/__init__.py +3 -0
- data/collation.py +120 -0
- data/datamodule.py +419 -0
- data/dataset.py +251 -0
- data/fbank.py +212 -0
- data/input_strategies.py +159 -0
- data/tokenizer.py +376 -0
- descriptions.py +31 -0
- examples.py +24 -0
- exp/valle_dev/log/log-train-2023-11-01-00-19-48 +117 -0
- exp/valle_dev/log/log-train-2023-11-01-01-01-00 +172 -0
- exp/valle_dev/tensorboard/events.out.tfevents.1698769188.vallex1-4110961-iaas.58414.0 +3 -0
- exp/valle_dev/tensorboard/events.out.tfevents.1698771660.vallex1-4110961-iaas.58697.0 +3 -0
- images/vallex_framework.jpg +0 -0
- infer.ipynb +0 -0
- launch-ui.py +432 -0
- macros.py +51 -0
- makedata.ipynb +0 -0
- model-card.md +33 -0
- models/__init__.py +136 -0
- models/macros.py +11 -0
- models/transformer.py +394 -0
- models/vallex.py +1353 -0
- models/visualizer.py +106 -0
- modules/__init__.py +0 -0
- modules/activation.py +612 -0
- modules/embedding.py +97 -0
- modules/optim.py +1105 -0
- modules/scaling.py +1401 -0
- modules/scheduler.py +78 -0
- modules/transformer.py +683 -0
- nltk_data/tokenizers/punkt/.DS_Store +0 -0
- nltk_data/tokenizers/punkt/PY3/README +98 -0
- nltk_data/tokenizers/punkt/PY3/czech.pickle +3 -0
- nltk_data/tokenizers/punkt/PY3/danish.pickle +3 -0
- nltk_data/tokenizers/punkt/PY3/dutch.pickle +3 -0
- nltk_data/tokenizers/punkt/PY3/english.pickle +3 -0
- nltk_data/tokenizers/punkt/PY3/estonian.pickle +3 -0
- nltk_data/tokenizers/punkt/PY3/finnish.pickle +3 -0
- nltk_data/tokenizers/punkt/PY3/french.pickle +3 -0
- nltk_data/tokenizers/punkt/PY3/german.pickle +3 -0
.gitignore
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
.git/objects/pack/pack-c489e96424f1ee71e0c94eaadab46ec057581843.pack
|
2 |
+
exp/valle_dev/best-train-loss.pt
|
3 |
+
exp/valle_dev/checkpoint-10000.pt
|
4 |
+
exp/valle_dev/epoch-1.pt
|
5 |
+
exp/valle_dev/best-valid-loss.pt
|
6 |
+
checkpoints/vallex-checkpoint_modified.pt
|
7 |
+
venv
|
8 |
+
.idea
|
9 |
+
__pycache__
|
10 |
+
checkpoints_backup/vallex-checkpoint.pt
|
11 |
+
checkpoints/vallex-checkpoint.pt
|
12 |
+
whisper/medium.pt
|
.ipynb_checkpoints/infer-checkpoint.ipynb
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [],
|
3 |
+
"metadata": {},
|
4 |
+
"nbformat": 4,
|
5 |
+
"nbformat_minor": 5
|
6 |
+
}
|
.ipynb_checkpoints/test-checkpoint.py
ADDED
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import torch
|
3 |
+
import logging
|
4 |
+
from data.dataset import create_dataloader
|
5 |
+
from macros import *
|
6 |
+
from data.tokenizer import (
|
7 |
+
AudioTokenizer,
|
8 |
+
tokenize_audio,
|
9 |
+
)
|
10 |
+
from data.collation import get_text_token_collater
|
11 |
+
from models.vallex import VALLE
|
12 |
+
if torch.cuda.is_available():
|
13 |
+
device = torch.device("cuda", 0)
|
14 |
+
from vocos import Vocos
|
15 |
+
from pathlib import Path
|
16 |
+
import platform
|
17 |
+
import pathlib
|
18 |
+
|
19 |
+
plt = platform.system()
|
20 |
+
print("Operating System:", plt)
|
21 |
+
|
22 |
+
if plt == 'Linux':
|
23 |
+
pathlib.WindowsPath = pathlib.PosixPath
|
24 |
+
|
25 |
+
def get_model(device):
|
26 |
+
url = 'https://huggingface.co/Plachta/VALL-E-X/resolve/main/vallex-checkpoint.pt'
|
27 |
+
|
28 |
+
checkpoints_dir = "./checkpoints"
|
29 |
+
|
30 |
+
model_checkpoint_name = "vallex-checkpoint_modified.pt"
|
31 |
+
if not os.path.exists(checkpoints_dir): os.mkdir(checkpoints_dir)
|
32 |
+
if not os.path.exists(os.path.join(checkpoints_dir, model_checkpoint_name)):
|
33 |
+
import wget
|
34 |
+
print("3")
|
35 |
+
try:
|
36 |
+
logging.info(
|
37 |
+
"Downloading model from https://huggingface.co/Plachta/VALL-E-X/resolve/main/vallex-checkpoint.pt ...")
|
38 |
+
# download from https://huggingface.co/Plachta/VALL-E-X/resolve/main/vallex-checkpoint.pt to ./checkpoints/vallex-checkpoint.pt
|
39 |
+
wget.download("https://huggingface.co/Plachta/VALL-E-X/resolve/main/vallex-checkpoint.pt",
|
40 |
+
out="./checkpoints/vallex-checkpoint.pt", bar=wget.bar_adaptive)
|
41 |
+
except Exception as e:
|
42 |
+
logging.info(e)
|
43 |
+
raise Exception(
|
44 |
+
"\n Model weights download failed, please go to 'https://huggingface.co/Plachta/VALL-E-X/resolve/main/vallex-checkpoint.pt'"
|
45 |
+
"\n manually download model weights and put it to {} .".format(os.getcwd() + "\checkpoints"))
|
46 |
+
# VALL-E
|
47 |
+
model = VALLE(
|
48 |
+
N_DIM,
|
49 |
+
NUM_HEAD,
|
50 |
+
NUM_LAYERS,
|
51 |
+
norm_first=True,
|
52 |
+
add_prenet=False,
|
53 |
+
prefix_mode=PREFIX_MODE,
|
54 |
+
share_embedding=True,
|
55 |
+
nar_scale_factor=1.0,
|
56 |
+
prepend_bos=True,
|
57 |
+
num_quantizers=NUM_QUANTIZERS,
|
58 |
+
).to(device)
|
59 |
+
checkpoint_path = Path(checkpoints_dir) / model_checkpoint_name
|
60 |
+
checkpoint = torch.load(checkpoint_path, map_location='cpu')
|
61 |
+
missing_keys, unexpected_keys = model.load_state_dict(
|
62 |
+
checkpoint["model"], strict=True
|
63 |
+
)
|
64 |
+
assert not missing_keys
|
65 |
+
|
66 |
+
# Encodec
|
67 |
+
codec = AudioTokenizer(device)
|
68 |
+
|
69 |
+
vocos = Vocos.from_pretrained('charactr/vocos-encodec-24khz').to(device)
|
70 |
+
|
71 |
+
return model, codec, vocos
|
.ipynb_checkpoints/train-checkpoint.py
ADDED
@@ -0,0 +1,1068 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
# Copyright 2021-2022 Xiaomi Corp. (authors: Fangjun Kuang,
|
3 |
+
# Wei Kang,
|
4 |
+
# Mingshuang Luo)
|
5 |
+
# Copyright 2023 (authors: Feiteng Li)
|
6 |
+
#
|
7 |
+
# See ../../../../LICENSE for clarification regarding multiple authors
|
8 |
+
#
|
9 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
10 |
+
# you may not use this file except in compliance with the License.
|
11 |
+
# You may obtain a copy of the License at
|
12 |
+
#
|
13 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
14 |
+
#
|
15 |
+
# Unless required by applicable law or agreed to in writing, software
|
16 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
17 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
18 |
+
# See the License for the specific language governing permissions and
|
19 |
+
# limitations under the License.
|
20 |
+
"""
|
21 |
+
Usage:
|
22 |
+
python3 bin/trainer.py \
|
23 |
+
--decoder-dim 1024 --nhead 16 --num-decoder-layers 12 \
|
24 |
+
--max-duration 40 --model-name valle \
|
25 |
+
--exp-dir exp/valle
|
26 |
+
--dtype "bfloat16" \
|
27 |
+
"""
|
28 |
+
import warnings
|
29 |
+
warnings.filterwarnings("ignore")
|
30 |
+
import argparse
|
31 |
+
import copy
|
32 |
+
import logging
|
33 |
+
import os
|
34 |
+
|
35 |
+
os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION"] = "python"
|
36 |
+
|
37 |
+
import random
|
38 |
+
import warnings
|
39 |
+
from pathlib import Path
|
40 |
+
from shutil import copyfile
|
41 |
+
from typing import Any, Dict, Optional, Tuple, Union
|
42 |
+
|
43 |
+
import torch
|
44 |
+
import torch.multiprocessing as mp
|
45 |
+
import torch.nn as nn
|
46 |
+
from torch import Tensor
|
47 |
+
from torch.cuda.amp import GradScaler
|
48 |
+
from torch.nn.parallel import DistributedDataParallel as DDP
|
49 |
+
from torch.utils.tensorboard import SummaryWriter
|
50 |
+
from train_utils.utils import *
|
51 |
+
from train_utils.icefall.utils import *
|
52 |
+
from train_utils.lhotse.utils import *
|
53 |
+
from test import get_model
|
54 |
+
from customs.make_custom_dataset import create_dataset
|
55 |
+
|
56 |
+
LRSchedulerType = torch.optim.lr_scheduler._LRScheduler
|
57 |
+
|
58 |
+
|
59 |
+
def set_batch_count(model: Union[nn.Module, DDP], batch_count: float) -> None:
|
60 |
+
if isinstance(model, DDP):
|
61 |
+
# get underlying nn.Module
|
62 |
+
model = model.module
|
63 |
+
|
64 |
+
for module in model.modules():
|
65 |
+
if hasattr(module, "batch_count"):
|
66 |
+
module.batch_count = batch_count
|
67 |
+
|
68 |
+
|
69 |
+
def get_parser():
|
70 |
+
parser = argparse.ArgumentParser(
|
71 |
+
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
72 |
+
)
|
73 |
+
|
74 |
+
parser.add_argument(
|
75 |
+
"--world-size",
|
76 |
+
type=int,
|
77 |
+
default=1,
|
78 |
+
help="Number of GPUs for DDP training.",
|
79 |
+
)
|
80 |
+
|
81 |
+
parser.add_argument(
|
82 |
+
"--master-port",
|
83 |
+
type=int,
|
84 |
+
default=12354,
|
85 |
+
help="Master port to use for DDP training.",
|
86 |
+
)
|
87 |
+
|
88 |
+
parser.add_argument(
|
89 |
+
"--tensorboard",
|
90 |
+
type=str2bool,
|
91 |
+
default=True,
|
92 |
+
help="Should various information be logged in tensorboard.",
|
93 |
+
)
|
94 |
+
|
95 |
+
parser.add_argument(
|
96 |
+
"--num-epochs",
|
97 |
+
type=int,
|
98 |
+
default=20,
|
99 |
+
help="Number of epochs to train.",
|
100 |
+
)
|
101 |
+
|
102 |
+
parser.add_argument(
|
103 |
+
"--start-epoch",
|
104 |
+
type=int,
|
105 |
+
default=1,
|
106 |
+
help="""Resume training from this epoch. It should be positive.
|
107 |
+
If larger than 1, it will load checkpoint from
|
108 |
+
exp-dir/epoch-{start_epoch-1}.pt
|
109 |
+
""",
|
110 |
+
)
|
111 |
+
|
112 |
+
parser.add_argument(
|
113 |
+
"--start-batch",
|
114 |
+
type=int,
|
115 |
+
default=0,
|
116 |
+
help="""If positive, --start-epoch is ignored and
|
117 |
+
it loads the checkpoint from exp-dir/checkpoint-{start_batch}.pt
|
118 |
+
""",
|
119 |
+
)
|
120 |
+
|
121 |
+
parser.add_argument(
|
122 |
+
"--exp-dir",
|
123 |
+
type=str,
|
124 |
+
default="exp/valle_dev",
|
125 |
+
help="""The experiment dir.
|
126 |
+
It specifies the directory where all training related
|
127 |
+
files, e.g., checkpoints, log, etc, are saved
|
128 |
+
""",
|
129 |
+
)
|
130 |
+
|
131 |
+
parser.add_argument(
|
132 |
+
"--optimizer-name",
|
133 |
+
type=str,
|
134 |
+
default="ScaledAdam",
|
135 |
+
help="The optimizer.",
|
136 |
+
)
|
137 |
+
parser.add_argument(
|
138 |
+
"--scheduler-name",
|
139 |
+
type=str,
|
140 |
+
default="Eden",
|
141 |
+
help="The scheduler.",
|
142 |
+
)
|
143 |
+
parser.add_argument(
|
144 |
+
"--base-lr", type=float, default=0.005, help="The base learning rate."
|
145 |
+
)
|
146 |
+
parser.add_argument(
|
147 |
+
"--warmup-steps",
|
148 |
+
type=int,
|
149 |
+
default=200,
|
150 |
+
help="""Number of steps that affects how rapidly the learning rate
|
151 |
+
decreases. We suggest not to change this.""",
|
152 |
+
)
|
153 |
+
|
154 |
+
parser.add_argument(
|
155 |
+
"--seed",
|
156 |
+
type=int,
|
157 |
+
default=42,
|
158 |
+
help="The seed for random generators intended for reproducibility",
|
159 |
+
)
|
160 |
+
|
161 |
+
parser.add_argument(
|
162 |
+
"--inf-check",
|
163 |
+
type=str2bool,
|
164 |
+
default=False,
|
165 |
+
help="Add hooks to check for infinite module outputs and gradients.",
|
166 |
+
)
|
167 |
+
|
168 |
+
parser.add_argument(
|
169 |
+
"--save-every-n",
|
170 |
+
type=int,
|
171 |
+
default=10000,
|
172 |
+
# default=100,
|
173 |
+
help="""Save checkpoint after processing this number of batches"
|
174 |
+
periodically. We save checkpoint to exp-dir/ whenever
|
175 |
+
params.batch_idx_train %% save_every_n == 0. The checkpoint filename
|
176 |
+
has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt'
|
177 |
+
Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the
|
178 |
+
end of each epoch where `xxx` is the epoch number counting from 0.
|
179 |
+
""",
|
180 |
+
)
|
181 |
+
parser.add_argument(
|
182 |
+
"--valid-interval",
|
183 |
+
type=int,
|
184 |
+
default=10000,
|
185 |
+
help="""Run validation if batch_idx %% valid_interval is 0.""",
|
186 |
+
)
|
187 |
+
|
188 |
+
parser.add_argument(
|
189 |
+
"--keep-last-k",
|
190 |
+
type=int,
|
191 |
+
default=20,
|
192 |
+
help="""Only keep this number of checkpoints on disk.
|
193 |
+
For instance, if it is 3, there are only 3 checkpoints
|
194 |
+
in the exp-dir with filenames `checkpoint-xxx.pt`.
|
195 |
+
It does not affect checkpoints with name `epoch-xxx.pt`.
|
196 |
+
""",
|
197 |
+
)
|
198 |
+
|
199 |
+
parser.add_argument(
|
200 |
+
"--average-period",
|
201 |
+
type=int,
|
202 |
+
default=0,
|
203 |
+
help="""Update the averaged model, namely `model_avg`, after processing
|
204 |
+
this number of batches. `model_avg` is a separate version of model,
|
205 |
+
in which each floating-point parameter is the average of all the
|
206 |
+
parameters from the start of training. Each time we take the average,
|
207 |
+
we do: `model_avg = model * (average_period / batch_idx_train) +
|
208 |
+
model_avg * ((batch_idx_train - average_period) / batch_idx_train)`.
|
209 |
+
""",
|
210 |
+
)
|
211 |
+
|
212 |
+
parser.add_argument(
|
213 |
+
"--accumulate-grad-steps",
|
214 |
+
type=int,
|
215 |
+
default=1,
|
216 |
+
help="""update gradient when batch_idx_train %% accumulate_grad_steps == 0.
|
217 |
+
""",
|
218 |
+
)
|
219 |
+
|
220 |
+
parser.add_argument(
|
221 |
+
"--dtype",
|
222 |
+
type=str,
|
223 |
+
default="float16",
|
224 |
+
help="Training dtype: float32 bfloat16 float16.",
|
225 |
+
)
|
226 |
+
|
227 |
+
parser.add_argument(
|
228 |
+
"--filter-min-duration",
|
229 |
+
type=float,
|
230 |
+
default=0.0,
|
231 |
+
help="Keep only utterances with duration > this.",
|
232 |
+
)
|
233 |
+
parser.add_argument(
|
234 |
+
"--filter-max-duration",
|
235 |
+
type=float,
|
236 |
+
default=20.0,
|
237 |
+
help="Keep only utterances with duration < this.",
|
238 |
+
)
|
239 |
+
|
240 |
+
parser.add_argument(
|
241 |
+
"--train-stage",
|
242 |
+
type=int,
|
243 |
+
default=0,
|
244 |
+
help="""0: train all modules, For VALL-E, support 1: AR Decoder 2: NAR Decoder(s)
|
245 |
+
""",
|
246 |
+
)
|
247 |
+
|
248 |
+
parser.add_argument(
|
249 |
+
"--visualize",
|
250 |
+
type=str2bool,
|
251 |
+
default=False,
|
252 |
+
help="visualize model results in eval step.",
|
253 |
+
)
|
254 |
+
|
255 |
+
parser.add_argument(
|
256 |
+
"--oom-check",
|
257 |
+
type=str2bool,
|
258 |
+
default=True,
|
259 |
+
help="perform OOM check on dataloader batches before starting training.",
|
260 |
+
)
|
261 |
+
|
262 |
+
parser.add_argument(
|
263 |
+
"--train_dir",
|
264 |
+
default='/home/ubuntu/VALL-E-X/JS_Dataset/JS_Dataset/train_tune'
|
265 |
+
)
|
266 |
+
|
267 |
+
parser.add_argument(
|
268 |
+
"--valid_dir",
|
269 |
+
default='/home/ubuntu/VALL-E-X/JS_Dataset/JS_Dataset/valid_tune'
|
270 |
+
)
|
271 |
+
|
272 |
+
add_model_arguments(parser)
|
273 |
+
|
274 |
+
return parser
|
275 |
+
|
276 |
+
|
277 |
+
def get_params() -> AttributeDict:
|
278 |
+
"""Return a dict containing training parameters.
|
279 |
+
|
280 |
+
All training related parameters that are not passed from the commandline
|
281 |
+
are saved in the variable `params`.
|
282 |
+
|
283 |
+
Commandline options are merged into `params` after they are parsed, so
|
284 |
+
you can also access them via `params`.
|
285 |
+
|
286 |
+
Explanation of options saved in `params`:
|
287 |
+
|
288 |
+
- best_train_loss: Best training loss so far. It is used to select
|
289 |
+
the model that has the lowest training loss. It is
|
290 |
+
updated during the training.
|
291 |
+
|
292 |
+
- best_valid_loss: Best validation loss so far. It is used to select
|
293 |
+
the model that has the lowest validation loss. It is
|
294 |
+
updated during the training.
|
295 |
+
|
296 |
+
- best_train_epoch: It is the epoch that has the best training loss.
|
297 |
+
|
298 |
+
- best_valid_epoch: It is the epoch that has the best validation loss.
|
299 |
+
|
300 |
+
- batch_idx_train: Used to writing statistics to tensorboard. It
|
301 |
+
contains number of batches trained so far across
|
302 |
+
epochs.
|
303 |
+
|
304 |
+
- log_interval: Print training loss if batch_idx % log_interval` is 0
|
305 |
+
|
306 |
+
- reset_interval: Reset statistics if batch_idx % reset_interval is 0
|
307 |
+
|
308 |
+
- valid_interval: Run validation if batch_idx % valid_interval is 0
|
309 |
+
"""
|
310 |
+
params = AttributeDict(
|
311 |
+
{
|
312 |
+
"best_train_loss": float("inf"),
|
313 |
+
"best_valid_loss": float("inf"),
|
314 |
+
"best_train_epoch": -1,
|
315 |
+
"best_valid_epoch": -1,
|
316 |
+
"batch_idx_train": 0,
|
317 |
+
"log_interval": 100, # 10: debug 100: train
|
318 |
+
"reset_interval": 200,
|
319 |
+
"valid_interval": 10000,
|
320 |
+
}
|
321 |
+
)
|
322 |
+
|
323 |
+
return params
|
324 |
+
|
325 |
+
|
326 |
+
def load_checkpoint_if_available(
|
327 |
+
params: AttributeDict,
|
328 |
+
model: nn.Module,
|
329 |
+
model_avg: nn.Module = None,
|
330 |
+
optimizer: Optional[torch.optim.Optimizer] = None,
|
331 |
+
scheduler: Optional[LRSchedulerType] = None,
|
332 |
+
) -> Optional[Dict[str, Any]]:
|
333 |
+
"""Load checkpoint from file.
|
334 |
+
|
335 |
+
If params.start_batch is positive, it will load the checkpoint from
|
336 |
+
`params.exp_dir/checkpoint-{params.start_batch}.pt`. Otherwise, if
|
337 |
+
params.start_epoch is larger than 1, it will load the checkpoint from
|
338 |
+
`params.start_epoch - 1`.
|
339 |
+
|
340 |
+
Apart from loading state dict for `model` and `optimizer` it also updates
|
341 |
+
`best_train_epoch`, `best_train_loss`, `best_valid_epoch`,
|
342 |
+
and `best_valid_loss` in `params`.
|
343 |
+
|
344 |
+
Args:
|
345 |
+
params:
|
346 |
+
The return value of :func:`get_params`.
|
347 |
+
model:
|
348 |
+
The training model.
|
349 |
+
model_avg:
|
350 |
+
The stored model averaged from the start of training.
|
351 |
+
optimizer:
|
352 |
+
The optimizer that we are using.
|
353 |
+
scheduler:
|
354 |
+
The scheduler that we are using.
|
355 |
+
Returns:
|
356 |
+
Return a dict containing previously saved training info.
|
357 |
+
"""
|
358 |
+
if params.start_batch > 0:
|
359 |
+
filename = params.exp_dir / f"checkpoint-{params.start_batch}.pt"
|
360 |
+
elif params.start_epoch > 1:
|
361 |
+
filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt"
|
362 |
+
else:
|
363 |
+
return None
|
364 |
+
|
365 |
+
assert filename.is_file(), f"{filename} does not exist!"
|
366 |
+
|
367 |
+
if isinstance(model, DDP):
|
368 |
+
raise ValueError("load_checkpoint before DDP")
|
369 |
+
|
370 |
+
saved_params = load_checkpoint(
|
371 |
+
filename,
|
372 |
+
model=model,
|
373 |
+
model_avg=model_avg,
|
374 |
+
optimizer=optimizer,
|
375 |
+
scheduler=scheduler,
|
376 |
+
)
|
377 |
+
|
378 |
+
saved_stage = saved_params.get("train_stage", 0)
|
379 |
+
if params.train_stage != saved_stage:
|
380 |
+
# switch training stage
|
381 |
+
if params.train_stage and saved_stage: # switch between 1 and 2
|
382 |
+
params.start_epoch = 1
|
383 |
+
params.start_batch = 0
|
384 |
+
else:
|
385 |
+
# switch between 0 and 1/2
|
386 |
+
assert params.num_epochs >= params.start_epoch
|
387 |
+
params.batch_idx_train = saved_params["batch_idx_train"]
|
388 |
+
|
389 |
+
for key in ["optimizer", "grad_scaler", "sampler"]:
|
390 |
+
if key in saved_params:
|
391 |
+
saved_params.pop(key)
|
392 |
+
|
393 |
+
# when base on stage 0, we keep scheduler
|
394 |
+
if saved_stage != 0:
|
395 |
+
for key in ["scheduler"]:
|
396 |
+
if key in saved_params:
|
397 |
+
saved_params.pop(key)
|
398 |
+
|
399 |
+
best_train_filename = params.exp_dir / "best-train-loss.pt"
|
400 |
+
if best_train_filename.is_file():
|
401 |
+
copyfile(
|
402 |
+
src=best_train_filename,
|
403 |
+
dst=params.exp_dir / f"best-train-loss-stage{saved_stage}.pt",
|
404 |
+
)
|
405 |
+
|
406 |
+
best_valid_filename = params.exp_dir / "best-valid-loss.pt"
|
407 |
+
if best_valid_filename.is_file():
|
408 |
+
copyfile(
|
409 |
+
src=best_valid_filename,
|
410 |
+
dst=params.exp_dir / f"best-valid-loss-stage{saved_stage}.pt",
|
411 |
+
)
|
412 |
+
else:
|
413 |
+
|
414 |
+
keys = [
|
415 |
+
"best_train_epoch",
|
416 |
+
"best_valid_epoch",
|
417 |
+
"batch_idx_train",
|
418 |
+
"best_train_loss",
|
419 |
+
"best_valid_loss",
|
420 |
+
]
|
421 |
+
for k in keys:
|
422 |
+
params[k] = saved_params[k]
|
423 |
+
|
424 |
+
if params.start_batch > 0:
|
425 |
+
if "cur_epoch" in saved_params:
|
426 |
+
params["start_epoch"] = saved_params["cur_epoch"]
|
427 |
+
|
428 |
+
return saved_params
|
429 |
+
|
430 |
+
|
431 |
+
def save_checkpoint(
|
432 |
+
params: AttributeDict,
|
433 |
+
model: Union[nn.Module, DDP],
|
434 |
+
model_avg: Optional[nn.Module] = None,
|
435 |
+
optimizer: Optional[torch.optim.Optimizer] = None,
|
436 |
+
scheduler: Optional[LRSchedulerType] = None,
|
437 |
+
sampler = None,
|
438 |
+
scaler: Optional[GradScaler] = None,
|
439 |
+
rank: int = 0,
|
440 |
+
) -> None:
|
441 |
+
"""Save model, optimizer, scheduler and training stats to file.
|
442 |
+
|
443 |
+
Args:
|
444 |
+
params:
|
445 |
+
It is returned by :func:`get_params`.
|
446 |
+
model:
|
447 |
+
The training model.
|
448 |
+
model_avg:
|
449 |
+
The stored model averaged from the start of training.
|
450 |
+
optimizer:
|
451 |
+
The optimizer used in the training.
|
452 |
+
sampler:
|
453 |
+
The sampler for the training dataset.
|
454 |
+
scaler:
|
455 |
+
The scaler used for mix precision training.
|
456 |
+
"""
|
457 |
+
if rank != 0:
|
458 |
+
return
|
459 |
+
filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt"
|
460 |
+
save_checkpoint_impl(
|
461 |
+
filename=filename,
|
462 |
+
model=model,
|
463 |
+
model_avg=model_avg,
|
464 |
+
params=params,
|
465 |
+
optimizer=optimizer,
|
466 |
+
scheduler=scheduler,
|
467 |
+
sampler=sampler,
|
468 |
+
scaler=scaler,
|
469 |
+
rank=rank,
|
470 |
+
)
|
471 |
+
|
472 |
+
if params.best_train_epoch == params.cur_epoch:
|
473 |
+
best_train_filename = params.exp_dir / "best-train-loss.pt"
|
474 |
+
copyfile(src=filename, dst=best_train_filename)
|
475 |
+
|
476 |
+
if params.best_valid_epoch == params.cur_epoch:
|
477 |
+
best_valid_filename = params.exp_dir / "best-valid-loss.pt"
|
478 |
+
copyfile(src=filename, dst=best_valid_filename)
|
479 |
+
|
480 |
+
|
481 |
+
def compute_loss(
|
482 |
+
params: AttributeDict,
|
483 |
+
model: Union[nn.Module, DDP],
|
484 |
+
batch: dict,
|
485 |
+
is_training: bool,
|
486 |
+
) -> Tuple[Tensor, MetricsTracker]:
|
487 |
+
"""
|
488 |
+
Compute transducer loss given the model and its inputs.
|
489 |
+
|
490 |
+
Args:
|
491 |
+
params:
|
492 |
+
Parameters for training. See :func:`get_params`.
|
493 |
+
model:
|
494 |
+
The model for training. It is an instance of Zipformer in our case.
|
495 |
+
batch:
|
496 |
+
A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()`
|
497 |
+
for the content in it.
|
498 |
+
is_training:
|
499 |
+
True for training. False for validation. When it is True, this
|
500 |
+
function enables autograd during computation; when it is False, it
|
501 |
+
disables autograd.
|
502 |
+
warmup: a floating point value which increases throughout training;
|
503 |
+
values >= 1.0 are fully warmed up and have all modules present.
|
504 |
+
"""
|
505 |
+
device = (
|
506 |
+
model.device
|
507 |
+
if isinstance(model, DDP)
|
508 |
+
else next(model.parameters()).device
|
509 |
+
)
|
510 |
+
# at entry, TextTokens is (N, P)
|
511 |
+
text_tokens = batch["text_tokens"].to(device)
|
512 |
+
text_tokens_lens = batch["text_tokens_lens"].to(device)
|
513 |
+
assert text_tokens.ndim == 2
|
514 |
+
|
515 |
+
audio_features = batch["audio_features"].to(device)
|
516 |
+
audio_features_lens = batch["audio_features_lens"].to(device)
|
517 |
+
assert audio_features.ndim == 3
|
518 |
+
|
519 |
+
with torch.set_grad_enabled(is_training):
|
520 |
+
predicts, loss, metrics = model(
|
521 |
+
x=text_tokens,
|
522 |
+
x_lens=text_tokens_lens,
|
523 |
+
y=audio_features,
|
524 |
+
y_lens=audio_features_lens,
|
525 |
+
train_stage=params.train_stage,
|
526 |
+
)
|
527 |
+
|
528 |
+
assert loss.requires_grad == is_training
|
529 |
+
|
530 |
+
info = MetricsTracker()
|
531 |
+
with warnings.catch_warnings():
|
532 |
+
warnings.simplefilter("ignore")
|
533 |
+
info["frames"] = (audio_features_lens).sum().item()
|
534 |
+
info["utterances"] = text_tokens.size(0)
|
535 |
+
|
536 |
+
# Note: We use reduction=sum while computing the loss.
|
537 |
+
info["loss"] = loss.detach().cpu().item()
|
538 |
+
for metric in metrics:
|
539 |
+
info[metric] = metrics[metric].detach().cpu().item()
|
540 |
+
del metrics
|
541 |
+
|
542 |
+
return predicts, loss, info
|
543 |
+
|
544 |
+
|
545 |
+
def compute_validation_loss(
|
546 |
+
params: AttributeDict,
|
547 |
+
model: Union[nn.Module, DDP],
|
548 |
+
valid_dl: torch.utils.data.DataLoader,
|
549 |
+
world_size: int = 1,
|
550 |
+
) -> MetricsTracker:
|
551 |
+
"""Run the validation process."""
|
552 |
+
tot_loss = MetricsTracker()
|
553 |
+
|
554 |
+
for batch_idx, batch in enumerate(valid_dl):
|
555 |
+
predicts, loss, loss_info = compute_loss(
|
556 |
+
params=params,
|
557 |
+
model=model,
|
558 |
+
batch=batch,
|
559 |
+
is_training=False,
|
560 |
+
)
|
561 |
+
assert loss.requires_grad is False
|
562 |
+
tot_loss = tot_loss + loss_info
|
563 |
+
if world_size > 1:
|
564 |
+
tot_loss.reduce(loss.device)
|
565 |
+
loss_value = tot_loss["loss"] / tot_loss["frames"]
|
566 |
+
if loss_value < params.best_valid_loss:
|
567 |
+
params.best_valid_epoch = params.cur_epoch
|
568 |
+
params.best_valid_loss = loss_value
|
569 |
+
|
570 |
+
if params.visualize:
|
571 |
+
output_dir = Path(
|
572 |
+
f"{params.exp_dir}/eval/step-{params.batch_idx_train:06d}"
|
573 |
+
)
|
574 |
+
output_dir.mkdir(parents=True, exist_ok=True)
|
575 |
+
if isinstance(model, DDP):
|
576 |
+
model.module.visualize(predicts, batch, output_dir=output_dir)
|
577 |
+
else:
|
578 |
+
model.visualize(predicts, batch, output_dir=output_dir)
|
579 |
+
|
580 |
+
return tot_loss
|
581 |
+
|
582 |
+
|
583 |
+
def train_one_epoch(
|
584 |
+
params: AttributeDict,
|
585 |
+
model: Union[nn.Module, DDP],
|
586 |
+
optimizer: torch.optim.Optimizer,
|
587 |
+
scheduler: LRSchedulerType,
|
588 |
+
train_dl: torch.utils.data.DataLoader,
|
589 |
+
valid_dl: torch.utils.data.DataLoader,
|
590 |
+
rng: random.Random,
|
591 |
+
scaler: GradScaler,
|
592 |
+
model_avg: Optional[nn.Module] = None,
|
593 |
+
tb_writer: Optional[SummaryWriter] = None,
|
594 |
+
world_size: int = 1,
|
595 |
+
rank: int = 0,
|
596 |
+
) -> None:
|
597 |
+
"""Train the model for one epoch.
|
598 |
+
|
599 |
+
The training loss from the mean of all frames is saved in
|
600 |
+
`params.train_loss`. It runs the validation process every
|
601 |
+
`params.valid_interval` batches.
|
602 |
+
|
603 |
+
Args:
|
604 |
+
params:
|
605 |
+
It is returned by :func:`get_params`.
|
606 |
+
model:
|
607 |
+
The model for training.
|
608 |
+
optimizer:
|
609 |
+
The optimizer we are using.
|
610 |
+
scheduler:
|
611 |
+
The learning rate scheduler, we call step() every step.
|
612 |
+
train_dl:
|
613 |
+
Dataloader for the training dataset.
|
614 |
+
valid_dl:
|
615 |
+
Dataloader for the validation dataset.
|
616 |
+
rng:
|
617 |
+
Random for selecting.
|
618 |
+
scaler:
|
619 |
+
The scaler used for mix precision training.
|
620 |
+
model_avg:
|
621 |
+
The stored model averaged from the start of training.
|
622 |
+
tb_writer:
|
623 |
+
Writer to write log messages to tensorboard.
|
624 |
+
world_size:
|
625 |
+
Number of nodes in DDP training. If it is 1, DDP is disabled.
|
626 |
+
rank:
|
627 |
+
The rank of the node in DDP training. If no DDP is used, it should
|
628 |
+
be set to 0.
|
629 |
+
"""
|
630 |
+
model.train()
|
631 |
+
tot_loss = MetricsTracker()
|
632 |
+
iter_dl = iter(train_dl)
|
633 |
+
|
634 |
+
dtype, enabled = torch.float32, False
|
635 |
+
if params.dtype in ["bfloat16", "bf16"]:
|
636 |
+
dtype, enabled = torch.bfloat16, True
|
637 |
+
elif params.dtype in ["float16", "fp16"]:
|
638 |
+
dtype, enabled = torch.float16, True
|
639 |
+
|
640 |
+
batch_idx = 0
|
641 |
+
accumulation_steps = 5 # 设置梯度累积步数
|
642 |
+
grad_accumulation_count = 0 # 用于跟踪梯度累积的计数器
|
643 |
+
|
644 |
+
while True:
|
645 |
+
try:
|
646 |
+
batch = next(iter_dl)
|
647 |
+
except StopIteration:
|
648 |
+
logging.info("Reaches end of dataloader.")
|
649 |
+
break
|
650 |
+
|
651 |
+
batch_idx += 1
|
652 |
+
params.batch_idx_train += 1
|
653 |
+
batch_size = len(batch["text"])
|
654 |
+
|
655 |
+
try:
|
656 |
+
with torch.cuda.amp.autocast(dtype=dtype, enabled=enabled):
|
657 |
+
_, loss, loss_info = compute_loss(
|
658 |
+
params=params,
|
659 |
+
model=model,
|
660 |
+
batch=batch,
|
661 |
+
is_training=True,
|
662 |
+
)
|
663 |
+
|
664 |
+
# summary stats
|
665 |
+
tot_loss = (
|
666 |
+
tot_loss * (1 - 1 / params.reset_interval)
|
667 |
+
) + loss_info * (1 / params.reset_interval)
|
668 |
+
|
669 |
+
# 梯度累积
|
670 |
+
scaler.scale(loss / accumulation_steps).backward()
|
671 |
+
grad_accumulation_count += 1
|
672 |
+
|
673 |
+
if grad_accumulation_count % accumulation_steps == 0 or params.batch_idx_train >= params.accumulate_grad_steps:
|
674 |
+
if (
|
675 |
+
params.batch_idx_train % params.accumulate_grad_steps
|
676 |
+
== 0
|
677 |
+
):
|
678 |
+
if params.optimizer_name not in ["ScaledAdam", "Eve"]:
|
679 |
+
# Unscales the gradients of optimizer's assigned params in-place
|
680 |
+
scaler.unscale_(optimizer)
|
681 |
+
# Since the gradients of optimizer's assigned params are unscaled, clips as usual:
|
682 |
+
torch.nn.utils.clip_grad_norm_(
|
683 |
+
model.parameters(), 1.0
|
684 |
+
)
|
685 |
+
|
686 |
+
scaler.step(optimizer)
|
687 |
+
scaler.update()
|
688 |
+
optimizer.zero_grad()
|
689 |
+
grad_accumulation_count = 0 # 重置梯度累积计数器
|
690 |
+
|
691 |
+
for k in range(params.accumulate_grad_steps):
|
692 |
+
if isinstance(scheduler, Eden):
|
693 |
+
scheduler.step_batch(params.batch_idx_train)
|
694 |
+
else:
|
695 |
+
scheduler.step()
|
696 |
+
|
697 |
+
set_batch_count(model, params.batch_idx_train)
|
698 |
+
except: # noqa
|
699 |
+
display_and_save_batch(batch, params=params)
|
700 |
+
raise
|
701 |
+
|
702 |
+
if params.average_period > 0:
|
703 |
+
if (
|
704 |
+
params.batch_idx_train > 0
|
705 |
+
and params.batch_idx_train % params.average_period == 0
|
706 |
+
):
|
707 |
+
# Perform Operation in rank 0
|
708 |
+
if rank == 0:
|
709 |
+
update_averaged_model(
|
710 |
+
params=params,
|
711 |
+
model_cur=model,
|
712 |
+
model_avg=model_avg,
|
713 |
+
)
|
714 |
+
|
715 |
+
if (
|
716 |
+
params.batch_idx_train > 0
|
717 |
+
and params.batch_idx_train % params.save_every_n == 0
|
718 |
+
):
|
719 |
+
# Perform Operation in rank 0
|
720 |
+
if rank == 0:
|
721 |
+
save_checkpoint_with_global_batch_idx(
|
722 |
+
out_dir=params.exp_dir,
|
723 |
+
global_batch_idx=params.batch_idx_train,
|
724 |
+
model=model,
|
725 |
+
model_avg=model_avg,
|
726 |
+
params=params,
|
727 |
+
optimizer=optimizer,
|
728 |
+
scheduler=scheduler,
|
729 |
+
sampler=None,
|
730 |
+
scaler=scaler,
|
731 |
+
rank=rank,
|
732 |
+
)
|
733 |
+
remove_checkpoints(
|
734 |
+
out_dir=params.exp_dir,
|
735 |
+
topk=params.keep_last_k,
|
736 |
+
# rank=rank,
|
737 |
+
)
|
738 |
+
|
739 |
+
if batch_idx % 100 == 0 and params.dtype in ["float16", "fp16"]:
|
740 |
+
# If the grad scale was less than 1, try increasing it. The _growth_interval
|
741 |
+
# of the grad scaler is configurable, but we can't configure it to have different
|
742 |
+
# behavior depending on the current grad scale.
|
743 |
+
cur_grad_scale = scaler._scale.item()
|
744 |
+
if cur_grad_scale < 1.0 or (
|
745 |
+
cur_grad_scale < 8.0 and batch_idx % 400 == 0
|
746 |
+
):
|
747 |
+
scaler.update(cur_grad_scale * 2.0)
|
748 |
+
|
749 |
+
if cur_grad_scale < 0.01:
|
750 |
+
logging.warning(f"Grad scale is small: {cur_grad_scale}")
|
751 |
+
if cur_grad_scale < 1.0e-05:
|
752 |
+
raise RuntimeError(
|
753 |
+
f"grad_scale is too small, exiting: {cur_grad_scale}"
|
754 |
+
)
|
755 |
+
|
756 |
+
if batch_idx % params.log_interval == 0:
|
757 |
+
cur_lr = scheduler.get_last_lr()[0]
|
758 |
+
cur_grad_scale = (
|
759 |
+
scaler._scale.item()
|
760 |
+
if params.dtype in ["float16", "fp16"]
|
761 |
+
else 1.0
|
762 |
+
)
|
763 |
+
|
764 |
+
logging.info(
|
765 |
+
f"Epoch {params.cur_epoch}, "
|
766 |
+
f"batch {batch_idx}, train_loss[{loss_info}], "
|
767 |
+
f"tot_loss[{tot_loss}], "
|
768 |
+
f"batch size: {batch_size}, "
|
769 |
+
f"lr: {cur_lr:.2e}"
|
770 |
+
+ (
|
771 |
+
f", grad_scale: {cur_grad_scale}"
|
772 |
+
if params.dtype in ["float16", "fp16"]
|
773 |
+
else ""
|
774 |
+
)
|
775 |
+
)
|
776 |
+
|
777 |
+
if tb_writer is not None:
|
778 |
+
tb_writer.add_scalar(
|
779 |
+
"train/learning_rate", cur_lr, params.batch_idx_train
|
780 |
+
)
|
781 |
+
loss_info.write_summary(
|
782 |
+
tb_writer,
|
783 |
+
"train/current_",
|
784 |
+
params.batch_idx_train,
|
785 |
+
)
|
786 |
+
tot_loss.write_summary(
|
787 |
+
tb_writer, "train/tot_", params.batch_idx_train
|
788 |
+
)
|
789 |
+
tot_loss.write_summary(
|
790 |
+
tb_writer, "train/tot_", params.batch_idx_train
|
791 |
+
)
|
792 |
+
if params.dtype in ["float16", "fp16"]:
|
793 |
+
tb_writer.add_scalar(
|
794 |
+
"train/grad_scale",
|
795 |
+
cur_grad_scale,
|
796 |
+
params.batch_idx_train,
|
797 |
+
)
|
798 |
+
|
799 |
+
if params.batch_idx_train % params.valid_interval == 0:
|
800 |
+
# Calculate validation loss in Rank 0
|
801 |
+
model.eval()
|
802 |
+
logging.info("Computing validation loss")
|
803 |
+
with torch.cuda.amp.autocast(dtype=dtype):
|
804 |
+
valid_info = compute_validation_loss(
|
805 |
+
params=params,
|
806 |
+
model=model,
|
807 |
+
valid_dl=valid_dl,
|
808 |
+
world_size=world_size,
|
809 |
+
)
|
810 |
+
logging.info(
|
811 |
+
f"Epoch {params.cur_epoch}, validation: {valid_info}"
|
812 |
+
)
|
813 |
+
logging.info(
|
814 |
+
f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB"
|
815 |
+
)
|
816 |
+
|
817 |
+
if tb_writer is not None:
|
818 |
+
valid_info.write_summary(
|
819 |
+
tb_writer, "train/valid_", params.batch_idx_train
|
820 |
+
)
|
821 |
+
|
822 |
+
model.train()
|
823 |
+
|
824 |
+
loss_value = tot_loss["loss"] / tot_loss["frames"]
|
825 |
+
params.train_loss = loss_value
|
826 |
+
if params.train_loss < params.best_train_loss:
|
827 |
+
params.best_train_epoch = params.cur_epoch
|
828 |
+
params.best_train_loss = params.train_loss
|
829 |
+
|
830 |
+
def run(rank, world_size, args):
|
831 |
+
"""
|
832 |
+
Args:
|
833 |
+
rank:
|
834 |
+
It is a value between 0 and `world_size-1`, which is
|
835 |
+
passed automatically by `mp.spawn()` in :func:`main`.
|
836 |
+
The node with rank 0 is responsible for saving checkpoint.
|
837 |
+
world_size:
|
838 |
+
Number of GPUs for DDP training.
|
839 |
+
args:
|
840 |
+
The return value of get_parser().parse_args()
|
841 |
+
"""
|
842 |
+
params = get_params()
|
843 |
+
params.update(vars(args))
|
844 |
+
|
845 |
+
fix_random_seed(params.seed)
|
846 |
+
rng = random.Random(params.seed)
|
847 |
+
if world_size > 1:
|
848 |
+
setup_dist(rank, world_size, params.master_port)
|
849 |
+
|
850 |
+
setup_logger(f"{params.exp_dir}/log/log-train")
|
851 |
+
logging.info("Training started")
|
852 |
+
|
853 |
+
if args.tensorboard and rank == 0:
|
854 |
+
if params.train_stage:
|
855 |
+
tb_writer = SummaryWriter(
|
856 |
+
log_dir=f"{params.exp_dir}/tensorboard_stage{params.train_stage}"
|
857 |
+
)
|
858 |
+
else:
|
859 |
+
tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard")
|
860 |
+
else:
|
861 |
+
tb_writer = None
|
862 |
+
|
863 |
+
device = torch.device("cpu")
|
864 |
+
if torch.cuda.is_available():
|
865 |
+
device = torch.device("cuda", rank)
|
866 |
+
# https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
|
867 |
+
torch.backends.cudnn.allow_tf32 = True
|
868 |
+
torch.backends.cuda.matmul.allow_tf32 = True
|
869 |
+
|
870 |
+
logging.info(f"Device: {device}")
|
871 |
+
logging.info(params)
|
872 |
+
|
873 |
+
logging.info("About to create model")
|
874 |
+
model, codec, vocos = get_model(device)
|
875 |
+
|
876 |
+
num_param = sum([p.numel() for p in model.parameters()])
|
877 |
+
logging.info(f"Number of model parameters: {num_param}")
|
878 |
+
|
879 |
+
assert params.save_every_n >= params.average_period
|
880 |
+
model_avg: Optional[nn.Module] = None
|
881 |
+
if rank == 0 and params.average_period > 0:
|
882 |
+
# model_avg is only used with rank 0
|
883 |
+
model_avg = copy.deepcopy(model).to(torch.float64)
|
884 |
+
|
885 |
+
assert params.start_epoch > 0, params.start_epoch
|
886 |
+
checkpoints = load_checkpoint_if_available(
|
887 |
+
params=params, model=model, model_avg=model_avg
|
888 |
+
)
|
889 |
+
|
890 |
+
model.to(device)
|
891 |
+
if world_size > 1:
|
892 |
+
logging.info("Using DDP")
|
893 |
+
model = DDP(model, device_ids=[rank], find_unused_parameters=True)
|
894 |
+
|
895 |
+
if params.train_stage:
|
896 |
+
_model = model.module if isinstance(model, DDP) else model
|
897 |
+
model_parameters = _model.stage_parameters(params.train_stage)
|
898 |
+
else:
|
899 |
+
model_parameters = model.parameters()
|
900 |
+
|
901 |
+
if params.optimizer_name == "ScaledAdam":
|
902 |
+
parameters_names = []
|
903 |
+
if params.train_stage: # != 0
|
904 |
+
_model = model.module if isinstance(model, DDP) else model
|
905 |
+
parameters_names.append(
|
906 |
+
[
|
907 |
+
name_param_pair[0]
|
908 |
+
for name_param_pair in _model.stage_named_parameters(
|
909 |
+
params.train_stage
|
910 |
+
)
|
911 |
+
]
|
912 |
+
)
|
913 |
+
else:
|
914 |
+
parameters_names.append(
|
915 |
+
[
|
916 |
+
name_param_pair[0]
|
917 |
+
for name_param_pair in model.named_parameters()
|
918 |
+
]
|
919 |
+
)
|
920 |
+
|
921 |
+
optimizer = ScaledAdam(
|
922 |
+
model_parameters,
|
923 |
+
lr=params.base_lr,
|
924 |
+
betas=(0.9, 0.95),
|
925 |
+
clipping_scale=2.0,
|
926 |
+
parameters_names=parameters_names,
|
927 |
+
show_dominant_parameters=False,
|
928 |
+
clipping_update_period=1000,
|
929 |
+
)
|
930 |
+
elif params.optimizer_name == "Eve":
|
931 |
+
optimizer = Eve(
|
932 |
+
model_parameters,
|
933 |
+
lr=params.base_lr,
|
934 |
+
betas=(0.9, 0.98),
|
935 |
+
target_rms=0.1,
|
936 |
+
)
|
937 |
+
elif params.optimizer_name == "AdamW":
|
938 |
+
optimizer = torch.optim.AdamW(
|
939 |
+
model_parameters,
|
940 |
+
lr=params.base_lr,
|
941 |
+
betas=(0.9, 0.95),
|
942 |
+
weight_decay=1e-2,
|
943 |
+
eps=1e-8,
|
944 |
+
)
|
945 |
+
elif params.optimizer_name == "Adam":
|
946 |
+
optimizer = torch.optim.Adam(
|
947 |
+
model_parameters,
|
948 |
+
lr=params.base_lr,
|
949 |
+
betas=(0.9, 0.95),
|
950 |
+
eps=1e-8,
|
951 |
+
)
|
952 |
+
else:
|
953 |
+
raise NotImplementedError()
|
954 |
+
|
955 |
+
scheduler = get_scheduler(params, optimizer)
|
956 |
+
optimizer.zero_grad()
|
957 |
+
|
958 |
+
if checkpoints and "optimizer" in checkpoints:
|
959 |
+
logging.info("Loading optimizer state dict")
|
960 |
+
optimizer.load_state_dict(checkpoints["optimizer"])
|
961 |
+
|
962 |
+
if (
|
963 |
+
checkpoints
|
964 |
+
and "scheduler" in checkpoints
|
965 |
+
and checkpoints["scheduler"] is not None
|
966 |
+
):
|
967 |
+
logging.info("Loading scheduler state dict")
|
968 |
+
scheduler.load_state_dict(checkpoints["scheduler"])
|
969 |
+
|
970 |
+
if params.inf_check:
|
971 |
+
register_inf_check_hooks(model)
|
972 |
+
|
973 |
+
if params.start_batch > 0 and checkpoints and "sampler" in checkpoints:
|
974 |
+
sampler_state_dict = checkpoints["sampler"]
|
975 |
+
else:
|
976 |
+
sampler_state_dict = None
|
977 |
+
|
978 |
+
train_dl = create_dataset(params.train_dir, dataloader_process_only=False)
|
979 |
+
valid_dl = create_dataset(params.valid_dir, dataloader_process_only=False)
|
980 |
+
|
981 |
+
scaler = GradScaler(
|
982 |
+
enabled=(params.dtype in ["fp16", "float16"]), init_scale=1.0
|
983 |
+
)
|
984 |
+
if checkpoints and "grad_scaler" in checkpoints:
|
985 |
+
logging.info("Loading grad scaler state dict")
|
986 |
+
scaler.load_state_dict(checkpoints["grad_scaler"])
|
987 |
+
|
988 |
+
for epoch in range(params.start_epoch, params.num_epochs + 1):
|
989 |
+
if isinstance(scheduler, Eden):
|
990 |
+
scheduler.step_epoch(epoch - 1)
|
991 |
+
|
992 |
+
fix_random_seed(params.seed + epoch - 1)
|
993 |
+
train_dl.batch_sampler.set_epoch(epoch - 1)
|
994 |
+
|
995 |
+
if tb_writer is not None:
|
996 |
+
tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train)
|
997 |
+
|
998 |
+
params.cur_epoch = epoch
|
999 |
+
|
1000 |
+
train_one_epoch(
|
1001 |
+
params=params,
|
1002 |
+
model=model,
|
1003 |
+
model_avg=model_avg,
|
1004 |
+
optimizer=optimizer,
|
1005 |
+
scheduler=scheduler,
|
1006 |
+
train_dl=train_dl,
|
1007 |
+
valid_dl=valid_dl,
|
1008 |
+
rng=rng,
|
1009 |
+
scaler=scaler,
|
1010 |
+
tb_writer=tb_writer,
|
1011 |
+
world_size=world_size,
|
1012 |
+
rank=rank,
|
1013 |
+
)
|
1014 |
+
|
1015 |
+
save_checkpoint(
|
1016 |
+
params=params,
|
1017 |
+
model=model,
|
1018 |
+
model_avg=model_avg,
|
1019 |
+
optimizer=optimizer,
|
1020 |
+
scheduler=scheduler,
|
1021 |
+
sampler=None,
|
1022 |
+
scaler=scaler,
|
1023 |
+
rank=rank,
|
1024 |
+
)
|
1025 |
+
|
1026 |
+
logging.info("Done!")
|
1027 |
+
|
1028 |
+
if world_size > 1:
|
1029 |
+
torch.distributed.barrier()
|
1030 |
+
cleanup_dist()
|
1031 |
+
|
1032 |
+
|
1033 |
+
def display_and_save_batch(
|
1034 |
+
batch: dict,
|
1035 |
+
params: AttributeDict,
|
1036 |
+
) -> None:
|
1037 |
+
"""Display the batch statistics and save the batch into disk.
|
1038 |
+
|
1039 |
+
Args:
|
1040 |
+
batch:
|
1041 |
+
A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()`
|
1042 |
+
for the content in it.
|
1043 |
+
params:
|
1044 |
+
Parameters for training. See :func:`get_params`.
|
1045 |
+
"""
|
1046 |
+
|
1047 |
+
filename = f"{params.exp_dir}/batch-{uuid4()}.pt"
|
1048 |
+
logging.info(f"Saving batch to {filename}")
|
1049 |
+
torch.save(batch, filename)
|
1050 |
+
|
1051 |
+
def main():
|
1052 |
+
parser = get_parser()
|
1053 |
+
args = parser.parse_args()
|
1054 |
+
args.exp_dir = Path(args.exp_dir)
|
1055 |
+
|
1056 |
+
world_size = args.world_size
|
1057 |
+
assert world_size >= 1
|
1058 |
+
if world_size > 1:
|
1059 |
+
mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True)
|
1060 |
+
else:
|
1061 |
+
run(rank=0, world_size=1, args=args)
|
1062 |
+
|
1063 |
+
|
1064 |
+
torch.set_num_threads(1)
|
1065 |
+
torch.set_num_interop_threads(1)
|
1066 |
+
|
1067 |
+
if __name__ == "__main__":
|
1068 |
+
main()
|
LICENSE
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
MIT License
|
2 |
+
|
3 |
+
Copyright (c) 2023 Songting
|
4 |
+
|
5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
6 |
+
of this software and associated documentation files (the "Software"), to deal
|
7 |
+
in the Software without restriction, including without limitation the rights
|
8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
9 |
+
copies of the Software, and to permit persons to whom the Software is
|
10 |
+
furnished to do so, subject to the following conditions:
|
11 |
+
|
12 |
+
The above copyright notice and this permission notice shall be included in all
|
13 |
+
copies or substantial portions of the Software.
|
14 |
+
|
15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
21 |
+
SOFTWARE.
|
customs/.ipynb_checkpoints/make_custom_dataset-checkpoint.py
ADDED
@@ -0,0 +1,83 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import h5py
|
2 |
+
import glob
|
3 |
+
import torch
|
4 |
+
import numpy as np
|
5 |
+
import os
|
6 |
+
import torchaudio
|
7 |
+
import soundfile as sf
|
8 |
+
from utils.g2p.symbols import symbols
|
9 |
+
from utils.g2p import PhonemeBpeTokenizer
|
10 |
+
from utils.prompt_making import make_prompt, make_transcript
|
11 |
+
from data.collation import get_text_token_collater
|
12 |
+
from data.dataset import create_dataloader
|
13 |
+
|
14 |
+
# Mappings from symbol to numeric ID and vice versa:
|
15 |
+
_symbol_to_id = {s: i for i, s in enumerate(symbols)}
|
16 |
+
_id_to_symbol = {i: s for i, s in enumerate(symbols)}
|
17 |
+
from data.tokenizer import (
|
18 |
+
AudioTokenizer,
|
19 |
+
tokenize_audio,
|
20 |
+
)
|
21 |
+
|
22 |
+
tokenizer_path = "./utils/g2p/bpe_175.json"
|
23 |
+
tokenizer = PhonemeBpeTokenizer(tokenizer_path)
|
24 |
+
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
25 |
+
|
26 |
+
def make_prompts(name, audio_prompt_path, transcript=None):
|
27 |
+
text_tokenizer = PhonemeBpeTokenizer(tokenizer_path="./utils/g2p/bpe_175.json")
|
28 |
+
text_collater = get_text_token_collater()
|
29 |
+
codec = AudioTokenizer(device)
|
30 |
+
wav_pr, sr = torchaudio.load(audio_prompt_path)
|
31 |
+
# check length
|
32 |
+
if wav_pr.size(-1) / sr > 15:
|
33 |
+
raise ValueError(f"Prompt too long, expect length below 15 seconds, got {wav_pr / sr} seconds.")
|
34 |
+
if wav_pr.size(0) == 2:
|
35 |
+
wav_pr = wav_pr.mean(0, keepdim=True)
|
36 |
+
text_pr, lang_pr = make_transcript(name, wav_pr, sr, transcript)
|
37 |
+
|
38 |
+
# tokenize audio
|
39 |
+
encoded_frames = tokenize_audio(codec, (wav_pr, sr))
|
40 |
+
audio_tokens = encoded_frames[0][0].transpose(2, 1).cpu().numpy()
|
41 |
+
|
42 |
+
# tokenize text
|
43 |
+
phonemes, langs = text_tokenizer.tokenize(text=f"{text_pr}".strip())
|
44 |
+
text_tokens, enroll_x_lens = text_collater(
|
45 |
+
[
|
46 |
+
phonemes
|
47 |
+
]
|
48 |
+
)
|
49 |
+
|
50 |
+
return audio_tokens, text_tokens, langs, text_pr
|
51 |
+
|
52 |
+
def create_dataset(data_dir, dataloader_process_only):
|
53 |
+
if dataloader_process_only:
|
54 |
+
h5_output_path=f"{data_dir}/audio_sum.hdf5"
|
55 |
+
ann_output_path=f"{data_dir}/audio_ann_sum.txt"
|
56 |
+
#audio_folder = os.path.join(data_dir, 'audio')
|
57 |
+
audio_paths = glob.glob(f"{data_dir}/*.wav") # Change this to match your audio file extension
|
58 |
+
|
59 |
+
# Create or open an HDF5 file
|
60 |
+
with h5py.File(h5_output_path, 'w') as h5_file:
|
61 |
+
# Loop through each audio and text file, assuming they have the same stem
|
62 |
+
for audio_path in audio_paths:
|
63 |
+
stem = os.path.splitext(os.path.basename(audio_path))[0]
|
64 |
+
audio_tokens, text_tokens, langs, text = make_prompts(name=stem, audio_prompt_path=audio_path)
|
65 |
+
|
66 |
+
text_tokens = text_tokens.squeeze(0)
|
67 |
+
# Create a group for each stem
|
68 |
+
grp = h5_file.create_group(stem)
|
69 |
+
# Add audio and text tokens as datasets to the group
|
70 |
+
grp.create_dataset('audio', data=audio_tokens)
|
71 |
+
#grp.create_dataset('text', data=text_tokens)
|
72 |
+
|
73 |
+
with open(ann_output_path, 'a', encoding='utf-8') as ann_file:
|
74 |
+
try:
|
75 |
+
audio, sample_rate = sf.read(audio_path)
|
76 |
+
duration = len(audio) / sample_rate
|
77 |
+
ann_file.write(f'{stem}|{duration}|{langs[0]}|{text}\n') # 改行を追加
|
78 |
+
print(f"Successfully wrote to {ann_output_path}")
|
79 |
+
except Exception as e:
|
80 |
+
print(f"An error occurred: {e}")
|
81 |
+
else:
|
82 |
+
dataloader = create_dataloader(data_dir=data_dir, max_duration=20)
|
83 |
+
return dataloader
|
customs/make_custom_dataset.py
ADDED
@@ -0,0 +1,83 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import h5py
|
2 |
+
import glob
|
3 |
+
import torch
|
4 |
+
import numpy as np
|
5 |
+
import os
|
6 |
+
import torchaudio
|
7 |
+
import soundfile as sf
|
8 |
+
from utils.g2p.symbols import symbols
|
9 |
+
from utils.g2p import PhonemeBpeTokenizer
|
10 |
+
from utils.prompt_making import make_prompt, make_transcript
|
11 |
+
from data.collation import get_text_token_collater
|
12 |
+
from data.dataset import create_dataloader
|
13 |
+
|
14 |
+
# Mappings from symbol to numeric ID and vice versa:
|
15 |
+
_symbol_to_id = {s: i for i, s in enumerate(symbols)}
|
16 |
+
_id_to_symbol = {i: s for i, s in enumerate(symbols)}
|
17 |
+
from data.tokenizer import (
|
18 |
+
AudioTokenizer,
|
19 |
+
tokenize_audio,
|
20 |
+
)
|
21 |
+
|
22 |
+
tokenizer_path = "./utils/g2p/bpe_175.json"
|
23 |
+
tokenizer = PhonemeBpeTokenizer(tokenizer_path)
|
24 |
+
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
25 |
+
|
26 |
+
def make_prompts(name, audio_prompt_path, transcript=None):
|
27 |
+
text_tokenizer = PhonemeBpeTokenizer(tokenizer_path="./utils/g2p/bpe_175.json")
|
28 |
+
text_collater = get_text_token_collater()
|
29 |
+
codec = AudioTokenizer(device)
|
30 |
+
wav_pr, sr = torchaudio.load(audio_prompt_path)
|
31 |
+
# check length
|
32 |
+
if wav_pr.size(-1) / sr > 15:
|
33 |
+
raise ValueError(f"Prompt too long, expect length below 15 seconds, got {wav_pr / sr} seconds.")
|
34 |
+
if wav_pr.size(0) == 2:
|
35 |
+
wav_pr = wav_pr.mean(0, keepdim=True)
|
36 |
+
text_pr, lang_pr = make_transcript(name, wav_pr, sr, transcript)
|
37 |
+
|
38 |
+
# tokenize audio
|
39 |
+
encoded_frames = tokenize_audio(codec, (wav_pr, sr))
|
40 |
+
audio_tokens = encoded_frames[0][0].transpose(2, 1).cpu().numpy()
|
41 |
+
|
42 |
+
# tokenize text
|
43 |
+
phonemes, langs = text_tokenizer.tokenize(text=f"{text_pr}".strip())
|
44 |
+
text_tokens, enroll_x_lens = text_collater(
|
45 |
+
[
|
46 |
+
phonemes
|
47 |
+
]
|
48 |
+
)
|
49 |
+
|
50 |
+
return audio_tokens, text_tokens, langs, text_pr
|
51 |
+
|
52 |
+
def create_dataset(data_dir, dataloader_process_only):
|
53 |
+
if dataloader_process_only:
|
54 |
+
h5_output_path=f"{data_dir}/audio_sum.hdf5"
|
55 |
+
ann_output_path=f"{data_dir}/audio_ann_sum.txt"
|
56 |
+
#audio_folder = os.path.join(data_dir, 'audio')
|
57 |
+
audio_paths = glob.glob(f"{data_dir}/*.wav") # Change this to match your audio file extension
|
58 |
+
|
59 |
+
# Create or open an HDF5 file
|
60 |
+
with h5py.File(h5_output_path, 'w') as h5_file:
|
61 |
+
# Loop through each audio and text file, assuming they have the same stem
|
62 |
+
for audio_path in audio_paths:
|
63 |
+
stem = os.path.splitext(os.path.basename(audio_path))[0]
|
64 |
+
audio_tokens, text_tokens, langs, text = make_prompts(name=stem, audio_prompt_path=audio_path)
|
65 |
+
|
66 |
+
text_tokens = text_tokens.squeeze(0)
|
67 |
+
# Create a group for each stem
|
68 |
+
grp = h5_file.create_group(stem)
|
69 |
+
# Add audio and text tokens as datasets to the group
|
70 |
+
grp.create_dataset('audio', data=audio_tokens)
|
71 |
+
#grp.create_dataset('text', data=text_tokens)
|
72 |
+
|
73 |
+
with open(ann_output_path, 'a', encoding='utf-8') as ann_file:
|
74 |
+
try:
|
75 |
+
audio, sample_rate = sf.read(audio_path)
|
76 |
+
duration = len(audio) / sample_rate
|
77 |
+
ann_file.write(f'{stem}|{duration}|{langs[0]}|{text}\n') # 改行を追加
|
78 |
+
print(f"Successfully wrote to {ann_output_path}")
|
79 |
+
except Exception as e:
|
80 |
+
print(f"An error occurred: {e}")
|
81 |
+
else:
|
82 |
+
dataloader = create_dataloader(data_dir=data_dir, max_duration=20)
|
83 |
+
return dataloader
|
customs/ph.txt
ADDED
File without changes
|
data/.ipynb_checkpoints/dataset-checkpoint.py
ADDED
@@ -0,0 +1,251 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023 (authors: Feiteng Li)
|
2 |
+
#
|
3 |
+
# See ../../../../LICENSE for clarification regarding multiple authors
|
4 |
+
#
|
5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
6 |
+
# you may not use this file except in compliance with the License.
|
7 |
+
# You may obtain a copy of the License at
|
8 |
+
#
|
9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
10 |
+
#
|
11 |
+
# Unless required by applicable law or agreed to in writing, software
|
12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
14 |
+
# See the License for the specific language governing permissions and
|
15 |
+
# limitations under the License.
|
16 |
+
|
17 |
+
"""
|
18 |
+
modified from lhoste.dataset.speech_synthesis.py
|
19 |
+
"""
|
20 |
+
|
21 |
+
import torch
|
22 |
+
import math
|
23 |
+
import h5py
|
24 |
+
from tokenizers import Tokenizer
|
25 |
+
from typing import Union, List
|
26 |
+
import numpy as np
|
27 |
+
from tqdm import tqdm
|
28 |
+
from utils.g2p import PhonemeBpeTokenizer
|
29 |
+
from data.collation import get_text_token_collater
|
30 |
+
text_collater = get_text_token_collater()
|
31 |
+
|
32 |
+
_pad = '_'
|
33 |
+
_punctuation = ',.!?-~…'
|
34 |
+
_letters = 'NQabdefghijklmnopstuvwxyzɑæʃʑçɯɪɔɛɹðəɫɥɸʊɾʒθβŋɦ⁼ʰ`^#*=ˈˌ→↓↑ '
|
35 |
+
symbols = [_pad] + list(_punctuation) + list(_letters)
|
36 |
+
|
37 |
+
language_dict = {
|
38 |
+
'en': 0,
|
39 |
+
'zh': 1,
|
40 |
+
'ja': 2,
|
41 |
+
'vi': 3,
|
42 |
+
}
|
43 |
+
def seq2phone(tokens: Union[List, np.ndarray]):
|
44 |
+
"""
|
45 |
+
Convert tokenized phoneme ID sequence back to phoneme string
|
46 |
+
:param tokens: phoneme tokens
|
47 |
+
:return: recovered phoneme sequence
|
48 |
+
"""
|
49 |
+
phones = "".join([symbols[i] for i in tokens])
|
50 |
+
return phones
|
51 |
+
|
52 |
+
class DynamicBatchSampler(torch.utils.data.Sampler):
|
53 |
+
def __init__(self, sampler, num_tokens_fn, num_buckets=100, min_size=0, max_size=1000,
|
54 |
+
max_tokens=None, max_sentences=None, drop_last=False):
|
55 |
+
"""
|
56 |
+
|
57 |
+
:param sampler:
|
58 |
+
:param num_tokens_fn: 根据idx返回样本的长度的函数
|
59 |
+
:param num_buckets: 利用桶原理将相似长度的样本放在一个batchsize中,桶的数量
|
60 |
+
:param min_size: 最小长度的样本, 小于这个值的样本会被过滤掉。 依据这个值来创建样桶
|
61 |
+
:param max_size: 最大长度的样本
|
62 |
+
:param max_sentences: batch_size, 但是这里可以通过max_sentences 和 max_tokens 共同控制最终的大小
|
63 |
+
"""
|
64 |
+
super(DynamicBatchSampler, self).__init__(sampler)
|
65 |
+
self.sampler = sampler
|
66 |
+
self.num_tokens_fn = num_tokens_fn
|
67 |
+
self.num_buckets = num_buckets
|
68 |
+
|
69 |
+
self.min_size = min_size
|
70 |
+
self.max_size = max_size
|
71 |
+
|
72 |
+
assert max_size <= max_tokens, "max_size should be smaller than max tokens"
|
73 |
+
assert max_tokens is not None or max_sentences is not None, \
|
74 |
+
"max_tokens and max_sentences should not be null at the same time, please specify one parameter at least"
|
75 |
+
self.max_tokens = max_tokens if max_tokens is not None else float('Inf')
|
76 |
+
self.max_sentences = max_sentences if max_sentences is not None else float('Inf')
|
77 |
+
self.drop_last = drop_last
|
78 |
+
|
79 |
+
def set_epoch(self, epoch):
|
80 |
+
self.sampler.set_epoch(epoch)
|
81 |
+
def is_batch_full(self, num_tokens, batch):
|
82 |
+
if len(batch) == 0:
|
83 |
+
return False
|
84 |
+
if len(batch) == self.max_sentences:
|
85 |
+
return True
|
86 |
+
if num_tokens > self.max_tokens:
|
87 |
+
return True
|
88 |
+
return False
|
89 |
+
|
90 |
+
def __iter__(self):
|
91 |
+
buckets = [[] for _ in range(self.num_buckets)]
|
92 |
+
sample_len = [0] * self.num_buckets
|
93 |
+
|
94 |
+
for idx in self.sampler:
|
95 |
+
idx_length = self.num_tokens_fn(idx)
|
96 |
+
if not (self.min_size <= idx_length <= self.max_size):
|
97 |
+
print("sentence at index {} of size {} exceeds max_tokens, the sentence is ignored".format(idx, idx_length))
|
98 |
+
continue
|
99 |
+
|
100 |
+
index_buckets = math.floor((idx_length - self.min_size) / (self.max_size - self.min_size + 1)
|
101 |
+
* self.num_buckets)
|
102 |
+
sample_len[index_buckets] = max(sample_len[index_buckets], idx_length)
|
103 |
+
|
104 |
+
num_tokens = (len(buckets[index_buckets]) + 1) * sample_len[index_buckets]
|
105 |
+
if self.is_batch_full(num_tokens, buckets[index_buckets]):
|
106 |
+
# yield this batch
|
107 |
+
yield buckets[index_buckets]
|
108 |
+
buckets[index_buckets] = []
|
109 |
+
sample_len[index_buckets] = 0
|
110 |
+
|
111 |
+
buckets[index_buckets].append(idx)
|
112 |
+
|
113 |
+
# process left-over
|
114 |
+
leftover_batch = []
|
115 |
+
leftover_sample_len = 0
|
116 |
+
leftover = [idx for bucket in buckets for idx in bucket]
|
117 |
+
for idx in leftover:
|
118 |
+
idx_length = self.num_tokens_fn(idx)
|
119 |
+
leftover_sample_len = max(leftover_sample_len, idx_length)
|
120 |
+
num_tokens = (len(leftover_batch) + 1) * leftover_sample_len
|
121 |
+
if self.is_batch_full(num_tokens, leftover_batch):
|
122 |
+
yield leftover_batch
|
123 |
+
leftover_batch = []
|
124 |
+
leftover_sample_len = 0
|
125 |
+
leftover_batch.append(idx)
|
126 |
+
|
127 |
+
if len(leftover_batch) > 0 and not self.drop_last:
|
128 |
+
yield leftover_batch
|
129 |
+
|
130 |
+
def __len__(self):
|
131 |
+
# we do not know the exactly batch size, so do not call len(dataloader)
|
132 |
+
pass
|
133 |
+
|
134 |
+
|
135 |
+
class AudioDataset(torch.utils.data.Dataset):
|
136 |
+
def __init__(self, h5_path, ann_path, tokenizer_path):
|
137 |
+
self.h5_path = h5_path
|
138 |
+
with open(ann_path, 'r', encoding='utf-8') as f:
|
139 |
+
lines = f.readlines()
|
140 |
+
ls = [l.split("|") for l in lines]
|
141 |
+
ls_T = list(zip(*ls))
|
142 |
+
#del ls_T[-1]
|
143 |
+
self.h5_paths, self.durations, self.langs, self.texts = \
|
144 |
+
list(ls_T[0]), list(ls_T[1]), list(ls_T[2]), list(ls_T[3])
|
145 |
+
self.durations = [float(dur) for dur in self.durations]
|
146 |
+
self.tokenizer = PhonemeBpeTokenizer(tokenizer_path)
|
147 |
+
self._archive = None
|
148 |
+
|
149 |
+
def __len__(self):
|
150 |
+
return len(self.h5_paths)
|
151 |
+
|
152 |
+
def get_dur(self, idx):
|
153 |
+
return self.durations[idx]
|
154 |
+
|
155 |
+
@property
|
156 |
+
def archive(self):
|
157 |
+
if self._archive is None: # lazy loading here!
|
158 |
+
self._archive = h5py.File(self.h5_path, "r")
|
159 |
+
return self._archive
|
160 |
+
def __getitem__(self, idx):
|
161 |
+
archive = self.archive
|
162 |
+
h5_path = self.h5_paths[idx]
|
163 |
+
sub = archive[h5_path]
|
164 |
+
audio_tokens = sub['audio'][()]
|
165 |
+
#phone_tokens = sub['text'][()]
|
166 |
+
dur = self.durations[idx]
|
167 |
+
lang = self.langs[idx]
|
168 |
+
text = self.texts[idx]
|
169 |
+
# tokenization should be done within dataloader
|
170 |
+
#phones = seq2phone(phone_tokens)
|
171 |
+
#phones = phones.replace(" ", "_")
|
172 |
+
phonemes, langs = self.tokenizer.tokenize(text=f"{text}".strip())
|
173 |
+
cptpho_tokens, enroll_x_lens = text_collater([phonemes])
|
174 |
+
cptpho_tokens = cptpho_tokens.squeeze(0)
|
175 |
+
text_token_lens = enroll_x_lens[0]
|
176 |
+
'''
|
177 |
+
if not len(phones):
|
178 |
+
cptpho_tokens = self.tokenizer.encode(text).ids
|
179 |
+
else:
|
180 |
+
cptpho_tokens = self.tokenizer.encode(phones).ids
|
181 |
+
'''
|
182 |
+
assert len(cptpho_tokens)
|
183 |
+
return {
|
184 |
+
'utt_id': h5_path,
|
185 |
+
'text': text,
|
186 |
+
'audio': None,
|
187 |
+
'audio_lens': None,
|
188 |
+
'audio_features': audio_tokens,
|
189 |
+
'audio_features_lens': audio_tokens.shape[1],
|
190 |
+
'text_tokens': np.array(cptpho_tokens),
|
191 |
+
'text_tokens_lens': text_token_lens,
|
192 |
+
'language': language_dict[lang],
|
193 |
+
}
|
194 |
+
|
195 |
+
def collate(batch):
|
196 |
+
utt_id_s = [b['utt_id'] for b in batch]
|
197 |
+
text_s = [b['text'] for b in batch]
|
198 |
+
|
199 |
+
audio_s = [b['audio'] for b in batch]
|
200 |
+
audio_lens_s = [b['audio_lens'] for b in batch]
|
201 |
+
|
202 |
+
audio_features_lens_s = [b['audio_features_lens'] for b in batch]
|
203 |
+
# create an empty tensor with maximum audio feature length
|
204 |
+
audio_features_s = torch.zeros([len(batch), max(audio_features_lens_s), 8], dtype=torch.int64) - 1 # audio pad with -1
|
205 |
+
|
206 |
+
text_tokens_lens_s = [b['text_tokens_lens'] for b in batch]
|
207 |
+
# create an empty tensor with maximum text tokens length
|
208 |
+
text_tokens_s = torch.zeros([len(batch), max(text_tokens_lens_s)], dtype=torch.int64) + 3 # [PAD] token id 3
|
209 |
+
|
210 |
+
language_s = [b['language'] for b in batch]
|
211 |
+
|
212 |
+
for i, b in enumerate(batch):
|
213 |
+
audio_features = b['audio_features']
|
214 |
+
audio_features_lens = b['audio_features_lens']
|
215 |
+
audio_features_s[i, :audio_features_lens, :] = torch.LongTensor(audio_features)
|
216 |
+
|
217 |
+
text_tokens = b['text_tokens']
|
218 |
+
text_tokens_lens = b['text_tokens_lens']
|
219 |
+
text_tokens_s[i, :text_tokens_lens] = torch.LongTensor(text_tokens)
|
220 |
+
|
221 |
+
batch = {
|
222 |
+
'utt_id': utt_id_s,
|
223 |
+
'text': text_s,
|
224 |
+
'audio': audio_s,
|
225 |
+
'audio_lens': audio_lens_s,
|
226 |
+
'audio_features': audio_features_s,
|
227 |
+
'audio_features_lens': torch.LongTensor(np.array(audio_features_lens_s)),
|
228 |
+
'text_tokens': text_tokens_s,
|
229 |
+
'text_tokens_lens': torch.LongTensor(np.array(text_tokens_lens_s)),
|
230 |
+
'languages': torch.LongTensor(np.array(language_s)),
|
231 |
+
}
|
232 |
+
return batch
|
233 |
+
|
234 |
+
def create_dataloader(data_dir="/root/valle/egs/mix", n_gpus=1, rank=0, num_workers=0, num_buckets=10, max_duration=120):
|
235 |
+
train_dataset = AudioDataset(h5_path=f"{data_dir}/audio_sum.hdf5",
|
236 |
+
ann_path=f"{data_dir}/audio_ann_sum.txt",
|
237 |
+
tokenizer_path=f"{data_dir}/bpe_175.json")
|
238 |
+
ran_sampler = torch.utils.data.distributed.DistributedSampler(
|
239 |
+
train_dataset,
|
240 |
+
num_replicas=n_gpus,
|
241 |
+
rank=rank,
|
242 |
+
shuffle=True,
|
243 |
+
)
|
244 |
+
dynamic_sampler = DynamicBatchSampler(ran_sampler, train_dataset.get_dur, num_buckets=num_buckets, max_size=20,
|
245 |
+
max_tokens=max_duration)
|
246 |
+
|
247 |
+
|
248 |
+
train_loader = torch.utils.data.DataLoader(train_dataset, num_workers=num_workers, collate_fn=collate,
|
249 |
+
batch_sampler=dynamic_sampler)
|
250 |
+
|
251 |
+
return train_loader
|
data/__init__.py
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
# from .datamodule import *
|
2 |
+
# from .tokenizer import *
|
3 |
+
from .collation import *
|
data/collation.py
ADDED
@@ -0,0 +1,120 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from pathlib import Path
|
2 |
+
from typing import List, Tuple
|
3 |
+
|
4 |
+
import numpy as np
|
5 |
+
import torch
|
6 |
+
|
7 |
+
from utils import SymbolTable
|
8 |
+
|
9 |
+
|
10 |
+
class TextTokenCollater:
|
11 |
+
"""Collate list of text tokens
|
12 |
+
|
13 |
+
Map sentences to integers. Sentences are padded to equal length.
|
14 |
+
Beginning and end-of-sequence symbols can be added.
|
15 |
+
|
16 |
+
Example:
|
17 |
+
>>> token_collater = TextTokenCollater(text_tokens)
|
18 |
+
>>> tokens_batch, tokens_lens = token_collater(text)
|
19 |
+
|
20 |
+
Returns:
|
21 |
+
tokens_batch: IntTensor of shape (B, L)
|
22 |
+
B: batch dimension, number of input sentences
|
23 |
+
L: length of the longest sentence
|
24 |
+
tokens_lens: IntTensor of shape (B,)
|
25 |
+
Length of each sentence after adding <eos> and <bos>
|
26 |
+
but before padding.
|
27 |
+
"""
|
28 |
+
|
29 |
+
def __init__(
|
30 |
+
self,
|
31 |
+
text_tokens: List[str],
|
32 |
+
add_eos: bool = True,
|
33 |
+
add_bos: bool = True,
|
34 |
+
pad_symbol: str = "<pad>",
|
35 |
+
bos_symbol: str = "<bos>",
|
36 |
+
eos_symbol: str = "<eos>",
|
37 |
+
):
|
38 |
+
self.pad_symbol = pad_symbol
|
39 |
+
|
40 |
+
self.add_eos = add_eos
|
41 |
+
self.add_bos = add_bos
|
42 |
+
|
43 |
+
self.bos_symbol = bos_symbol
|
44 |
+
self.eos_symbol = eos_symbol
|
45 |
+
|
46 |
+
unique_tokens = (
|
47 |
+
[pad_symbol]
|
48 |
+
+ ([bos_symbol] if add_bos else [])
|
49 |
+
+ ([eos_symbol] if add_eos else [])
|
50 |
+
+ sorted(text_tokens)
|
51 |
+
)
|
52 |
+
|
53 |
+
self.token2idx = {token: idx for idx, token in enumerate(unique_tokens)}
|
54 |
+
self.idx2token = [token for token in unique_tokens]
|
55 |
+
|
56 |
+
def index(
|
57 |
+
self, tokens_list: List[str]
|
58 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
59 |
+
seqs, seq_lens = [], []
|
60 |
+
for tokens in tokens_list:
|
61 |
+
assert (
|
62 |
+
all([True if s in self.token2idx else False for s in tokens])
|
63 |
+
is True
|
64 |
+
)
|
65 |
+
seq = (
|
66 |
+
([self.bos_symbol] if self.add_bos else [])
|
67 |
+
+ list(tokens)
|
68 |
+
+ ([self.eos_symbol] if self.add_eos else [])
|
69 |
+
)
|
70 |
+
seqs.append(seq)
|
71 |
+
seq_lens.append(len(seq))
|
72 |
+
|
73 |
+
max_len = max(seq_lens)
|
74 |
+
for k, (seq, seq_len) in enumerate(zip(seqs, seq_lens)):
|
75 |
+
seq.extend([self.pad_symbol] * (max_len - seq_len))
|
76 |
+
|
77 |
+
tokens = torch.from_numpy(
|
78 |
+
np.array(
|
79 |
+
[[self.token2idx[token] for token in seq] for seq in seqs],
|
80 |
+
dtype=np.int64,
|
81 |
+
)
|
82 |
+
)
|
83 |
+
tokens_lens = torch.IntTensor(seq_lens)
|
84 |
+
|
85 |
+
return tokens, tokens_lens
|
86 |
+
|
87 |
+
def __call__(self, texts: List[str]) -> Tuple[torch.Tensor, torch.Tensor]:
|
88 |
+
tokens_seqs = [[p for p in text] for text in texts]
|
89 |
+
max_len = len(max(tokens_seqs, key=len))
|
90 |
+
|
91 |
+
seqs = [
|
92 |
+
([self.bos_symbol] if self.add_bos else [])
|
93 |
+
+ list(seq)
|
94 |
+
+ ([self.eos_symbol] if self.add_eos else [])
|
95 |
+
+ [self.pad_symbol] * (max_len - len(seq))
|
96 |
+
for seq in tokens_seqs
|
97 |
+
]
|
98 |
+
|
99 |
+
tokens_batch = torch.from_numpy(
|
100 |
+
np.array(
|
101 |
+
[seq for seq in seqs],
|
102 |
+
dtype=np.int64,
|
103 |
+
)
|
104 |
+
)
|
105 |
+
|
106 |
+
tokens_lens = torch.IntTensor(
|
107 |
+
[
|
108 |
+
len(seq) + int(self.add_eos) + int(self.add_bos)
|
109 |
+
for seq in tokens_seqs
|
110 |
+
]
|
111 |
+
)
|
112 |
+
|
113 |
+
return tokens_batch, tokens_lens
|
114 |
+
|
115 |
+
|
116 |
+
def get_text_token_collater() -> TextTokenCollater:
|
117 |
+
collater = TextTokenCollater(
|
118 |
+
['0'], add_bos=False, add_eos=False
|
119 |
+
)
|
120 |
+
return collater
|
data/datamodule.py
ADDED
@@ -0,0 +1,419 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023 (authors: Feiteng Li)
|
2 |
+
#
|
3 |
+
# See ../../../../LICENSE for clarification regarding multiple authors
|
4 |
+
#
|
5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
6 |
+
# you may not use this file except in compliance with the License.
|
7 |
+
# You may obtain a copy of the License at
|
8 |
+
#
|
9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
10 |
+
#
|
11 |
+
# Unless required by applicable law or agreed to in writing, software
|
12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
14 |
+
# See the License for the specific language governing permissions and
|
15 |
+
# limitations under the License.
|
16 |
+
|
17 |
+
|
18 |
+
import argparse
|
19 |
+
import inspect
|
20 |
+
import logging
|
21 |
+
from functools import lru_cache
|
22 |
+
from pathlib import Path
|
23 |
+
from typing import Any, Dict, Optional
|
24 |
+
|
25 |
+
import torch
|
26 |
+
# from icefall.utils import str2bool
|
27 |
+
# from lhotse import CutSet, load_manifest_lazy
|
28 |
+
# from lhotse.dataset import (
|
29 |
+
# CutConcatenate,
|
30 |
+
# DynamicBucketingSampler,
|
31 |
+
# PrecomputedFeatures,
|
32 |
+
# SingleCutSampler,
|
33 |
+
# SpecAugment,
|
34 |
+
# )
|
35 |
+
# from lhotse.dataset.input_strategies import OnTheFlyFeatures
|
36 |
+
# from lhotse.utils import fix_random_seed
|
37 |
+
from torch.utils.data import DataLoader
|
38 |
+
|
39 |
+
from data.collation import get_text_token_collater
|
40 |
+
# from data.dataset import SpeechSynthesisDataset
|
41 |
+
from data.fbank import get_fbank_extractor
|
42 |
+
from data.input_strategies import PromptedPrecomputedFeatures
|
43 |
+
|
44 |
+
# PrecomputedFeatures = PrecomputedFeatures
|
45 |
+
|
46 |
+
|
47 |
+
class _SeedWorkers:
|
48 |
+
def __init__(self, seed: int):
|
49 |
+
self.seed = seed
|
50 |
+
|
51 |
+
def __call__(self, worker_id: int):
|
52 |
+
fix_random_seed(self.seed + worker_id)
|
53 |
+
|
54 |
+
|
55 |
+
def _get_input_strategy(input_strategy, dataset, cuts):
|
56 |
+
if input_strategy == "PromptedPrecomputedFeatures":
|
57 |
+
return PromptedPrecomputedFeatures(dataset, cuts)
|
58 |
+
|
59 |
+
return eval(input_strategy)()
|
60 |
+
|
61 |
+
|
62 |
+
class TtsDataModule:
|
63 |
+
"""
|
64 |
+
DataModule for VALL-E TTS experiments.
|
65 |
+
It assumes there is always one train and valid dataloader.
|
66 |
+
|
67 |
+
It contains all the common data pipeline modules used in TTS
|
68 |
+
experiments, e.g.:
|
69 |
+
- dynamic batch size,
|
70 |
+
- bucketing samplers,
|
71 |
+
- cut concatenation[not used & tested yet],
|
72 |
+
- augmentation[not used & tested yet],
|
73 |
+
- on-the-fly feature extraction[not used & tested yet]
|
74 |
+
|
75 |
+
This class should be derived for specific corpora used in TTS tasks.
|
76 |
+
"""
|
77 |
+
|
78 |
+
def __init__(self, args: argparse.Namespace):
|
79 |
+
self.args = args
|
80 |
+
|
81 |
+
@classmethod
|
82 |
+
def add_arguments(cls, parser: argparse.ArgumentParser):
|
83 |
+
group = parser.add_argument_group(
|
84 |
+
title="TTS data related options",
|
85 |
+
description="These options are used for the preparation of "
|
86 |
+
"PyTorch DataLoaders from Lhotse CutSet's -- they control the "
|
87 |
+
"effective batch sizes, sampling strategies, applied data "
|
88 |
+
"augmentations, etc.",
|
89 |
+
)
|
90 |
+
group.add_argument(
|
91 |
+
"--manifest-dir",
|
92 |
+
type=Path,
|
93 |
+
default=Path("data/tokenized"),
|
94 |
+
help="Path to directory with train/valid/test cuts.",
|
95 |
+
)
|
96 |
+
group.add_argument(
|
97 |
+
"--max-duration",
|
98 |
+
type=int,
|
99 |
+
default=40.0,
|
100 |
+
help="Maximum pooled recordings duration (seconds) in a "
|
101 |
+
"single batch. You can reduce it if it causes CUDA OOM.",
|
102 |
+
)
|
103 |
+
group.add_argument(
|
104 |
+
"--bucketing-sampler",
|
105 |
+
type=str2bool,
|
106 |
+
default=True,
|
107 |
+
help="When enabled, the batches will come from buckets of "
|
108 |
+
"similar duration (saves padding frames).",
|
109 |
+
)
|
110 |
+
group.add_argument(
|
111 |
+
"--num-buckets",
|
112 |
+
type=int,
|
113 |
+
default=10,
|
114 |
+
help="The number of buckets for the DynamicBucketingSampler"
|
115 |
+
"(you might want to increase it for larger datasets).",
|
116 |
+
)
|
117 |
+
group.add_argument(
|
118 |
+
"--concatenate-cuts",
|
119 |
+
type=str2bool,
|
120 |
+
default=False,
|
121 |
+
help="When enabled, utterances (cuts) will be concatenated "
|
122 |
+
"to minimize the amount of padding.",
|
123 |
+
)
|
124 |
+
group.add_argument(
|
125 |
+
"--duration-factor",
|
126 |
+
type=float,
|
127 |
+
default=1.0,
|
128 |
+
help="Determines the maximum duration of a concatenated cut "
|
129 |
+
"relative to the duration of the longest cut in a batch.",
|
130 |
+
)
|
131 |
+
group.add_argument(
|
132 |
+
"--gap",
|
133 |
+
type=float,
|
134 |
+
default=0.1,
|
135 |
+
help="The amount of padding (in seconds) inserted between "
|
136 |
+
"concatenated cuts. This padding is filled with noise when "
|
137 |
+
"noise augmentation is used.",
|
138 |
+
)
|
139 |
+
group.add_argument(
|
140 |
+
"--on-the-fly-feats",
|
141 |
+
type=str2bool,
|
142 |
+
default=False,
|
143 |
+
help="When enabled, use on-the-fly cut mixing and feature "
|
144 |
+
"extraction. Will drop existing precomputed feature manifests "
|
145 |
+
"if available.",
|
146 |
+
)
|
147 |
+
group.add_argument(
|
148 |
+
"--shuffle",
|
149 |
+
type=str2bool,
|
150 |
+
default=True,
|
151 |
+
help="When enabled (=default), the examples will be "
|
152 |
+
"shuffled for each epoch.",
|
153 |
+
)
|
154 |
+
group.add_argument(
|
155 |
+
"--drop-last",
|
156 |
+
type=str2bool,
|
157 |
+
default=False,
|
158 |
+
help="Whether to drop last batch. Used by sampler.",
|
159 |
+
)
|
160 |
+
group.add_argument(
|
161 |
+
"--return-cuts",
|
162 |
+
type=str2bool,
|
163 |
+
default=True,
|
164 |
+
help="When enabled, each batch will have the "
|
165 |
+
"field: batch['supervisions']['cut'] with the cuts that "
|
166 |
+
"were used to construct it.",
|
167 |
+
)
|
168 |
+
|
169 |
+
group.add_argument(
|
170 |
+
"--num-workers",
|
171 |
+
type=int,
|
172 |
+
default=8,
|
173 |
+
help="The number of training dataloader workers that "
|
174 |
+
"collect the batches.",
|
175 |
+
)
|
176 |
+
|
177 |
+
group.add_argument(
|
178 |
+
"--enable-spec-aug",
|
179 |
+
type=str2bool,
|
180 |
+
default=False,
|
181 |
+
help="When enabled, use SpecAugment for training dataset.",
|
182 |
+
)
|
183 |
+
|
184 |
+
group.add_argument(
|
185 |
+
"--spec-aug-time-warp-factor",
|
186 |
+
type=int,
|
187 |
+
default=80,
|
188 |
+
help="Used only when --enable-spec-aug is True. "
|
189 |
+
"It specifies the factor for time warping in SpecAugment. "
|
190 |
+
"Larger values mean more warping. "
|
191 |
+
"A value less than 1 means to disable time warp.",
|
192 |
+
)
|
193 |
+
|
194 |
+
group.add_argument(
|
195 |
+
"--input-strategy",
|
196 |
+
type=str,
|
197 |
+
default="PrecomputedFeatures",
|
198 |
+
help="AudioSamples or PrecomputedFeatures or PromptedPrecomputedFeatures",
|
199 |
+
)
|
200 |
+
|
201 |
+
group.add_argument(
|
202 |
+
"--dataset",
|
203 |
+
type=str,
|
204 |
+
default="ljspeech",
|
205 |
+
help="--input-strategy PromptedPrecomputedFeatures needs dataset name to prepare prompts.",
|
206 |
+
)
|
207 |
+
|
208 |
+
parser.add_argument(
|
209 |
+
"--text-tokens",
|
210 |
+
type=str,
|
211 |
+
default="data/tokenized/unique_text_tokens.k2symbols",
|
212 |
+
help="Path to the unique text tokens file",
|
213 |
+
)
|
214 |
+
|
215 |
+
parser.add_argument(
|
216 |
+
"--sampling-rate",
|
217 |
+
type=int,
|
218 |
+
default=24000,
|
219 |
+
help="""Audio sampling rate.""",
|
220 |
+
)
|
221 |
+
|
222 |
+
def train_dataloaders(
|
223 |
+
self,
|
224 |
+
cuts_train: CutSet,
|
225 |
+
sampler_state_dict: Optional[Dict[str, Any]] = None,
|
226 |
+
) -> DataLoader:
|
227 |
+
"""
|
228 |
+
Args:
|
229 |
+
cuts_train:
|
230 |
+
CutSet for training.
|
231 |
+
sampler_state_dict:
|
232 |
+
The state dict for the training sampler.
|
233 |
+
"""
|
234 |
+
transforms = []
|
235 |
+
|
236 |
+
if self.args.concatenate_cuts:
|
237 |
+
logging.info(
|
238 |
+
f"Using cut concatenation with duration factor "
|
239 |
+
f"{self.args.duration_factor} and gap {self.args.gap}."
|
240 |
+
)
|
241 |
+
# Cut concatenation should be the first transform in the list,
|
242 |
+
# so that if we e.g. mix noise in, it will fill the gaps between
|
243 |
+
# different utterances.
|
244 |
+
transforms = [
|
245 |
+
CutConcatenate(
|
246 |
+
duration_factor=self.args.duration_factor, gap=self.args.gap
|
247 |
+
)
|
248 |
+
] + transforms
|
249 |
+
|
250 |
+
input_transforms = []
|
251 |
+
if self.args.enable_spec_aug:
|
252 |
+
logging.info("Enable SpecAugment")
|
253 |
+
logging.info(
|
254 |
+
f"Time warp factor: {self.args.spec_aug_time_warp_factor}"
|
255 |
+
)
|
256 |
+
# Set the value of num_frame_masks according to Lhotse's version.
|
257 |
+
# In different Lhotse's versions, the default of num_frame_masks is
|
258 |
+
# different.
|
259 |
+
num_frame_masks = 10
|
260 |
+
num_frame_masks_parameter = inspect.signature(
|
261 |
+
SpecAugment.__init__
|
262 |
+
).parameters["num_frame_masks"]
|
263 |
+
if num_frame_masks_parameter.default == 1:
|
264 |
+
num_frame_masks = 2
|
265 |
+
logging.info(f"Num frame mask: {num_frame_masks}")
|
266 |
+
input_transforms.append(
|
267 |
+
SpecAugment(
|
268 |
+
time_warp_factor=self.args.spec_aug_time_warp_factor,
|
269 |
+
num_frame_masks=num_frame_masks,
|
270 |
+
features_mask_size=27,
|
271 |
+
num_feature_masks=2,
|
272 |
+
frames_mask_size=100,
|
273 |
+
)
|
274 |
+
)
|
275 |
+
else:
|
276 |
+
logging.info("Disable SpecAugment")
|
277 |
+
|
278 |
+
logging.info("About to create train dataset")
|
279 |
+
if self.args.on_the_fly_feats:
|
280 |
+
# NOTE: the PerturbSpeed transform should be added only if we
|
281 |
+
# remove it from data prep stage.
|
282 |
+
# Add on-the-fly speed perturbation; since originally it would
|
283 |
+
# have increased epoch size by 3, we will apply prob 2/3 and use
|
284 |
+
# 3x more epochs.
|
285 |
+
# Speed perturbation probably should come first before
|
286 |
+
# concatenation, but in principle the transforms order doesn't have
|
287 |
+
# to be strict (e.g. could be randomized)
|
288 |
+
# transforms = [PerturbSpeed(factors=[0.9, 1.1], p=2/3)] + transforms # noqa
|
289 |
+
# Drop feats to be on the safe side.
|
290 |
+
train = SpeechSynthesisDataset(
|
291 |
+
get_text_token_collater(self.args.text_tokens),
|
292 |
+
cut_transforms=transforms,
|
293 |
+
feature_input_strategy=OnTheFlyFeatures(get_fbank_extractor()),
|
294 |
+
feature_transforms=input_transforms,
|
295 |
+
)
|
296 |
+
else:
|
297 |
+
train = SpeechSynthesisDataset(
|
298 |
+
get_text_token_collater(self.args.text_tokens),
|
299 |
+
feature_input_strategy=_get_input_strategy(
|
300 |
+
self.args.input_strategy, self.args.dataset, cuts_train
|
301 |
+
),
|
302 |
+
cut_transforms=transforms,
|
303 |
+
feature_transforms=input_transforms,
|
304 |
+
)
|
305 |
+
|
306 |
+
if self.args.bucketing_sampler:
|
307 |
+
logging.info("Using DynamicBucketingSampler")
|
308 |
+
train_sampler = DynamicBucketingSampler(
|
309 |
+
cuts_train,
|
310 |
+
max_duration=self.args.max_duration,
|
311 |
+
shuffle=self.args.shuffle,
|
312 |
+
num_buckets=self.args.num_buckets,
|
313 |
+
drop_last=self.args.drop_last,
|
314 |
+
)
|
315 |
+
else:
|
316 |
+
logging.info(
|
317 |
+
"Using SingleCutSampler and sort by duraton(ascending=True)."
|
318 |
+
)
|
319 |
+
cuts_train = cuts_train.to_eager().sort_by_duration(ascending=True)
|
320 |
+
train_sampler = SingleCutSampler(
|
321 |
+
cuts_train,
|
322 |
+
max_duration=self.args.max_duration,
|
323 |
+
shuffle=self.args.shuffle,
|
324 |
+
)
|
325 |
+
logging.info("About to create train dataloader")
|
326 |
+
|
327 |
+
if sampler_state_dict is not None:
|
328 |
+
logging.info("Loading sampler state dict")
|
329 |
+
train_sampler.load_state_dict(sampler_state_dict)
|
330 |
+
|
331 |
+
# 'seed' is derived from the current random state, which will have
|
332 |
+
# previously been set in the main process.
|
333 |
+
seed = torch.randint(0, 100000, ()).item()
|
334 |
+
worker_init_fn = _SeedWorkers(seed)
|
335 |
+
|
336 |
+
train_dl = DataLoader(
|
337 |
+
train,
|
338 |
+
sampler=train_sampler,
|
339 |
+
batch_size=None,
|
340 |
+
num_workers=self.args.num_workers,
|
341 |
+
persistent_workers=False,
|
342 |
+
worker_init_fn=worker_init_fn,
|
343 |
+
)
|
344 |
+
|
345 |
+
return train_dl
|
346 |
+
|
347 |
+
def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader:
|
348 |
+
logging.info("About to create dev dataset")
|
349 |
+
if self.args.on_the_fly_feats:
|
350 |
+
validate = SpeechSynthesisDataset(
|
351 |
+
get_text_token_collater(self.args.text_tokens),
|
352 |
+
feature_input_strategy=OnTheFlyFeatures(get_fbank_extractor()),
|
353 |
+
cut_transforms=[],
|
354 |
+
)
|
355 |
+
else:
|
356 |
+
validate = SpeechSynthesisDataset(
|
357 |
+
get_text_token_collater(self.args.text_tokens),
|
358 |
+
feature_input_strategy=_get_input_strategy(
|
359 |
+
self.args.input_strategy, self.args.dataset, cuts_valid
|
360 |
+
),
|
361 |
+
cut_transforms=[],
|
362 |
+
)
|
363 |
+
valid_sampler = DynamicBucketingSampler(
|
364 |
+
cuts_valid,
|
365 |
+
max_duration=self.args.max_duration,
|
366 |
+
shuffle=False,
|
367 |
+
)
|
368 |
+
logging.info("About to create dev dataloader")
|
369 |
+
valid_dl = DataLoader(
|
370 |
+
validate,
|
371 |
+
sampler=valid_sampler,
|
372 |
+
batch_size=None,
|
373 |
+
num_workers=4,
|
374 |
+
persistent_workers=False,
|
375 |
+
)
|
376 |
+
|
377 |
+
return valid_dl
|
378 |
+
|
379 |
+
def test_dataloaders(self, cuts: CutSet) -> DataLoader:
|
380 |
+
logging.debug("About to create test dataset")
|
381 |
+
test = SpeechSynthesisDataset(
|
382 |
+
get_text_token_collater(self.args.text_tokens),
|
383 |
+
feature_input_strategy=OnTheFlyFeatures(get_fbank_extractor())
|
384 |
+
if self.args.on_the_fly_feats
|
385 |
+
else _get_input_strategy(
|
386 |
+
self.args.input_strategy, self.args.dataset, cuts
|
387 |
+
),
|
388 |
+
cut_transforms=[],
|
389 |
+
)
|
390 |
+
sampler = DynamicBucketingSampler(
|
391 |
+
cuts,
|
392 |
+
max_duration=self.args.max_duration,
|
393 |
+
shuffle=False,
|
394 |
+
)
|
395 |
+
logging.debug("About to create test dataloader")
|
396 |
+
test_dl = DataLoader(
|
397 |
+
test,
|
398 |
+
batch_size=None,
|
399 |
+
sampler=sampler,
|
400 |
+
num_workers=self.args.num_workers,
|
401 |
+
)
|
402 |
+
return test_dl
|
403 |
+
|
404 |
+
@lru_cache()
|
405 |
+
def train_cuts(self) -> CutSet:
|
406 |
+
logging.info("About to get train cuts")
|
407 |
+
return load_manifest_lazy(
|
408 |
+
self.args.manifest_dir / "cuts_train.jsonl.gz"
|
409 |
+
)
|
410 |
+
|
411 |
+
@lru_cache()
|
412 |
+
def dev_cuts(self) -> CutSet:
|
413 |
+
logging.info("About to get dev cuts")
|
414 |
+
return load_manifest_lazy(self.args.manifest_dir / "cuts_dev.jsonl.gz")
|
415 |
+
|
416 |
+
@lru_cache()
|
417 |
+
def test_cuts(self) -> CutSet:
|
418 |
+
logging.info("About to get test cuts")
|
419 |
+
return load_manifest_lazy(self.args.manifest_dir / "cuts_test.jsonl.gz")
|
data/dataset.py
ADDED
@@ -0,0 +1,251 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023 (authors: Feiteng Li)
|
2 |
+
#
|
3 |
+
# See ../../../../LICENSE for clarification regarding multiple authors
|
4 |
+
#
|
5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
6 |
+
# you may not use this file except in compliance with the License.
|
7 |
+
# You may obtain a copy of the License at
|
8 |
+
#
|
9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
10 |
+
#
|
11 |
+
# Unless required by applicable law or agreed to in writing, software
|
12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
14 |
+
# See the License for the specific language governing permissions and
|
15 |
+
# limitations under the License.
|
16 |
+
|
17 |
+
"""
|
18 |
+
modified from lhoste.dataset.speech_synthesis.py
|
19 |
+
"""
|
20 |
+
|
21 |
+
import torch
|
22 |
+
import math
|
23 |
+
import h5py
|
24 |
+
from tokenizers import Tokenizer
|
25 |
+
from typing import Union, List
|
26 |
+
import numpy as np
|
27 |
+
from tqdm import tqdm
|
28 |
+
from utils.g2p import PhonemeBpeTokenizer
|
29 |
+
from data.collation import get_text_token_collater
|
30 |
+
text_collater = get_text_token_collater()
|
31 |
+
|
32 |
+
_pad = '_'
|
33 |
+
_punctuation = ',.!?-~…'
|
34 |
+
_letters = 'NQabdefghijklmnopstuvwxyzɑæʃʑçɯɪɔɛɹðəɫɥɸʊɾʒθβŋɦ⁼ʰ`^#*=ˈˌ→↓↑ '
|
35 |
+
symbols = [_pad] + list(_punctuation) + list(_letters)
|
36 |
+
|
37 |
+
language_dict = {
|
38 |
+
'en': 0,
|
39 |
+
'zh': 1,
|
40 |
+
'ja': 2,
|
41 |
+
'vi': 3,
|
42 |
+
}
|
43 |
+
def seq2phone(tokens: Union[List, np.ndarray]):
|
44 |
+
"""
|
45 |
+
Convert tokenized phoneme ID sequence back to phoneme string
|
46 |
+
:param tokens: phoneme tokens
|
47 |
+
:return: recovered phoneme sequence
|
48 |
+
"""
|
49 |
+
phones = "".join([symbols[i] for i in tokens])
|
50 |
+
return phones
|
51 |
+
|
52 |
+
class DynamicBatchSampler(torch.utils.data.Sampler):
|
53 |
+
def __init__(self, sampler, num_tokens_fn, num_buckets=100, min_size=0, max_size=1000,
|
54 |
+
max_tokens=None, max_sentences=None, drop_last=False):
|
55 |
+
"""
|
56 |
+
|
57 |
+
:param sampler:
|
58 |
+
:param num_tokens_fn: 根据idx返回样本的长度的函数
|
59 |
+
:param num_buckets: 利用桶原理将相似长度的样本放在一个batchsize中,桶的数量
|
60 |
+
:param min_size: 最小长度的样本, 小于这个值的样本会被过滤掉。 依据这个值来创建样桶
|
61 |
+
:param max_size: 最大长度的样本
|
62 |
+
:param max_sentences: batch_size, 但是这里可以通过max_sentences 和 max_tokens 共同控制最终的大小
|
63 |
+
"""
|
64 |
+
super(DynamicBatchSampler, self).__init__(sampler)
|
65 |
+
self.sampler = sampler
|
66 |
+
self.num_tokens_fn = num_tokens_fn
|
67 |
+
self.num_buckets = num_buckets
|
68 |
+
|
69 |
+
self.min_size = min_size
|
70 |
+
self.max_size = max_size
|
71 |
+
|
72 |
+
assert max_size <= max_tokens, "max_size should be smaller than max tokens"
|
73 |
+
assert max_tokens is not None or max_sentences is not None, \
|
74 |
+
"max_tokens and max_sentences should not be null at the same time, please specify one parameter at least"
|
75 |
+
self.max_tokens = max_tokens if max_tokens is not None else float('Inf')
|
76 |
+
self.max_sentences = max_sentences if max_sentences is not None else float('Inf')
|
77 |
+
self.drop_last = drop_last
|
78 |
+
|
79 |
+
def set_epoch(self, epoch):
|
80 |
+
self.sampler.set_epoch(epoch)
|
81 |
+
def is_batch_full(self, num_tokens, batch):
|
82 |
+
if len(batch) == 0:
|
83 |
+
return False
|
84 |
+
if len(batch) == self.max_sentences:
|
85 |
+
return True
|
86 |
+
if num_tokens > self.max_tokens:
|
87 |
+
return True
|
88 |
+
return False
|
89 |
+
|
90 |
+
def __iter__(self):
|
91 |
+
buckets = [[] for _ in range(self.num_buckets)]
|
92 |
+
sample_len = [0] * self.num_buckets
|
93 |
+
|
94 |
+
for idx in self.sampler:
|
95 |
+
idx_length = self.num_tokens_fn(idx)
|
96 |
+
if not (self.min_size <= idx_length <= self.max_size):
|
97 |
+
print("sentence at index {} of size {} exceeds max_tokens, the sentence is ignored".format(idx, idx_length))
|
98 |
+
continue
|
99 |
+
|
100 |
+
index_buckets = math.floor((idx_length - self.min_size) / (self.max_size - self.min_size + 1)
|
101 |
+
* self.num_buckets)
|
102 |
+
sample_len[index_buckets] = max(sample_len[index_buckets], idx_length)
|
103 |
+
|
104 |
+
num_tokens = (len(buckets[index_buckets]) + 1) * sample_len[index_buckets]
|
105 |
+
if self.is_batch_full(num_tokens, buckets[index_buckets]):
|
106 |
+
# yield this batch
|
107 |
+
yield buckets[index_buckets]
|
108 |
+
buckets[index_buckets] = []
|
109 |
+
sample_len[index_buckets] = 0
|
110 |
+
|
111 |
+
buckets[index_buckets].append(idx)
|
112 |
+
|
113 |
+
# process left-over
|
114 |
+
leftover_batch = []
|
115 |
+
leftover_sample_len = 0
|
116 |
+
leftover = [idx for bucket in buckets for idx in bucket]
|
117 |
+
for idx in leftover:
|
118 |
+
idx_length = self.num_tokens_fn(idx)
|
119 |
+
leftover_sample_len = max(leftover_sample_len, idx_length)
|
120 |
+
num_tokens = (len(leftover_batch) + 1) * leftover_sample_len
|
121 |
+
if self.is_batch_full(num_tokens, leftover_batch):
|
122 |
+
yield leftover_batch
|
123 |
+
leftover_batch = []
|
124 |
+
leftover_sample_len = 0
|
125 |
+
leftover_batch.append(idx)
|
126 |
+
|
127 |
+
if len(leftover_batch) > 0 and not self.drop_last:
|
128 |
+
yield leftover_batch
|
129 |
+
|
130 |
+
def __len__(self):
|
131 |
+
# we do not know the exactly batch size, so do not call len(dataloader)
|
132 |
+
pass
|
133 |
+
|
134 |
+
|
135 |
+
class AudioDataset(torch.utils.data.Dataset):
|
136 |
+
def __init__(self, h5_path, ann_path, tokenizer_path):
|
137 |
+
self.h5_path = h5_path
|
138 |
+
with open(ann_path, 'r', encoding='utf-8') as f:
|
139 |
+
lines = f.readlines()
|
140 |
+
ls = [l.split("|") for l in lines]
|
141 |
+
ls_T = list(zip(*ls))
|
142 |
+
#del ls_T[-1]
|
143 |
+
self.h5_paths, self.durations, self.langs, self.texts = \
|
144 |
+
list(ls_T[0]), list(ls_T[1]), list(ls_T[2]), list(ls_T[3])
|
145 |
+
self.durations = [float(dur) for dur in self.durations]
|
146 |
+
self.tokenizer = PhonemeBpeTokenizer(tokenizer_path)
|
147 |
+
self._archive = None
|
148 |
+
|
149 |
+
def __len__(self):
|
150 |
+
return len(self.h5_paths)
|
151 |
+
|
152 |
+
def get_dur(self, idx):
|
153 |
+
return self.durations[idx]
|
154 |
+
|
155 |
+
@property
|
156 |
+
def archive(self):
|
157 |
+
if self._archive is None: # lazy loading here!
|
158 |
+
self._archive = h5py.File(self.h5_path, "r")
|
159 |
+
return self._archive
|
160 |
+
def __getitem__(self, idx):
|
161 |
+
archive = self.archive
|
162 |
+
h5_path = self.h5_paths[idx]
|
163 |
+
sub = archive[h5_path]
|
164 |
+
audio_tokens = sub['audio'][()]
|
165 |
+
#phone_tokens = sub['text'][()]
|
166 |
+
dur = self.durations[idx]
|
167 |
+
lang = self.langs[idx]
|
168 |
+
text = self.texts[idx]
|
169 |
+
# tokenization should be done within dataloader
|
170 |
+
#phones = seq2phone(phone_tokens)
|
171 |
+
#phones = phones.replace(" ", "_")
|
172 |
+
phonemes, langs = self.tokenizer.tokenize(text=f"{text}".strip())
|
173 |
+
cptpho_tokens, enroll_x_lens = text_collater([phonemes])
|
174 |
+
cptpho_tokens = cptpho_tokens.squeeze(0)
|
175 |
+
text_token_lens = enroll_x_lens[0]
|
176 |
+
'''
|
177 |
+
if not len(phones):
|
178 |
+
cptpho_tokens = self.tokenizer.encode(text).ids
|
179 |
+
else:
|
180 |
+
cptpho_tokens = self.tokenizer.encode(phones).ids
|
181 |
+
'''
|
182 |
+
assert len(cptpho_tokens)
|
183 |
+
return {
|
184 |
+
'utt_id': h5_path,
|
185 |
+
'text': text,
|
186 |
+
'audio': None,
|
187 |
+
'audio_lens': None,
|
188 |
+
'audio_features': audio_tokens,
|
189 |
+
'audio_features_lens': audio_tokens.shape[1],
|
190 |
+
'text_tokens': np.array(cptpho_tokens),
|
191 |
+
'text_tokens_lens': text_token_lens,
|
192 |
+
'language': language_dict[lang],
|
193 |
+
}
|
194 |
+
|
195 |
+
def collate(batch):
|
196 |
+
utt_id_s = [b['utt_id'] for b in batch]
|
197 |
+
text_s = [b['text'] for b in batch]
|
198 |
+
|
199 |
+
audio_s = [b['audio'] for b in batch]
|
200 |
+
audio_lens_s = [b['audio_lens'] for b in batch]
|
201 |
+
|
202 |
+
audio_features_lens_s = [b['audio_features_lens'] for b in batch]
|
203 |
+
# create an empty tensor with maximum audio feature length
|
204 |
+
audio_features_s = torch.zeros([len(batch), max(audio_features_lens_s), 8], dtype=torch.int64) - 1 # audio pad with -1
|
205 |
+
|
206 |
+
text_tokens_lens_s = [b['text_tokens_lens'] for b in batch]
|
207 |
+
# create an empty tensor with maximum text tokens length
|
208 |
+
text_tokens_s = torch.zeros([len(batch), max(text_tokens_lens_s)], dtype=torch.int64) + 3 # [PAD] token id 3
|
209 |
+
|
210 |
+
language_s = [b['language'] for b in batch]
|
211 |
+
|
212 |
+
for i, b in enumerate(batch):
|
213 |
+
audio_features = b['audio_features']
|
214 |
+
audio_features_lens = b['audio_features_lens']
|
215 |
+
audio_features_s[i, :audio_features_lens, :] = torch.LongTensor(audio_features)
|
216 |
+
|
217 |
+
text_tokens = b['text_tokens']
|
218 |
+
text_tokens_lens = b['text_tokens_lens']
|
219 |
+
text_tokens_s[i, :text_tokens_lens] = torch.LongTensor(text_tokens)
|
220 |
+
|
221 |
+
batch = {
|
222 |
+
'utt_id': utt_id_s,
|
223 |
+
'text': text_s,
|
224 |
+
'audio': audio_s,
|
225 |
+
'audio_lens': audio_lens_s,
|
226 |
+
'audio_features': audio_features_s,
|
227 |
+
'audio_features_lens': torch.LongTensor(np.array(audio_features_lens_s)),
|
228 |
+
'text_tokens': text_tokens_s,
|
229 |
+
'text_tokens_lens': torch.LongTensor(np.array(text_tokens_lens_s)),
|
230 |
+
'languages': torch.LongTensor(np.array(language_s)),
|
231 |
+
}
|
232 |
+
return batch
|
233 |
+
|
234 |
+
def create_dataloader(data_dir="/root/valle/egs/mix", n_gpus=1, rank=0, num_workers=0, num_buckets=10, max_duration=120):
|
235 |
+
train_dataset = AudioDataset(h5_path=f"{data_dir}/audio_sum.hdf5",
|
236 |
+
ann_path=f"{data_dir}/audio_ann_sum.txt",
|
237 |
+
tokenizer_path=f"{data_dir}/bpe_175.json")
|
238 |
+
ran_sampler = torch.utils.data.distributed.DistributedSampler(
|
239 |
+
train_dataset,
|
240 |
+
num_replicas=n_gpus,
|
241 |
+
rank=rank,
|
242 |
+
shuffle=True,
|
243 |
+
)
|
244 |
+
dynamic_sampler = DynamicBatchSampler(ran_sampler, train_dataset.get_dur, num_buckets=num_buckets, max_size=20,
|
245 |
+
max_tokens=max_duration)
|
246 |
+
|
247 |
+
|
248 |
+
train_loader = torch.utils.data.DataLoader(train_dataset, num_workers=num_workers, collate_fn=collate,
|
249 |
+
batch_sampler=dynamic_sampler)
|
250 |
+
|
251 |
+
return train_loader
|
data/fbank.py
ADDED
@@ -0,0 +1,212 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023 (authors: Feiteng Li)
|
2 |
+
#
|
3 |
+
# See ../../../../LICENSE for clarification regarding multiple authors
|
4 |
+
#
|
5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
6 |
+
# you may not use this file except in compliance with the License.
|
7 |
+
# You may obtain a copy of the License at
|
8 |
+
#
|
9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
10 |
+
#
|
11 |
+
# Unless required by applicable law or agreed to in writing, software
|
12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
14 |
+
# See the License for the specific language governing permissions and
|
15 |
+
# limitations under the License.
|
16 |
+
|
17 |
+
|
18 |
+
from dataclasses import asdict, dataclass
|
19 |
+
from typing import Any, Dict, Optional, Union
|
20 |
+
|
21 |
+
import numpy as np
|
22 |
+
import torch
|
23 |
+
# from lhotse.features.base import FeatureExtractor
|
24 |
+
# from lhotse.utils import EPSILON, Seconds, compute_num_frames
|
25 |
+
from librosa.filters import mel as librosa_mel_fn
|
26 |
+
|
27 |
+
|
28 |
+
@dataclass
|
29 |
+
class BigVGANFbankConfig:
|
30 |
+
# Spectogram-related part
|
31 |
+
# Note that frame_length and frame_shift will be converted to milliseconds before torchaudio/Kaldi sees them
|
32 |
+
frame_length: Seconds = 1024 / 24000.0
|
33 |
+
frame_shift: Seconds = 256 / 24000.0
|
34 |
+
remove_dc_offset: bool = True
|
35 |
+
round_to_power_of_two: bool = True
|
36 |
+
|
37 |
+
# Fbank-related part
|
38 |
+
low_freq: float = 0.0
|
39 |
+
high_freq: float = 12000.0
|
40 |
+
num_mel_bins: int = 100
|
41 |
+
use_energy: bool = False
|
42 |
+
|
43 |
+
def to_dict(self) -> Dict[str, Any]:
|
44 |
+
return asdict(self)
|
45 |
+
|
46 |
+
@staticmethod
|
47 |
+
def from_dict(data: Dict[str, Any]) -> "BigVGANFbankConfig":
|
48 |
+
return BigVGANFbankConfig(**data)
|
49 |
+
|
50 |
+
|
51 |
+
def dynamic_range_compression_torch(x, C=1, clip_val=1e-5):
|
52 |
+
return torch.log(torch.clamp(x, min=clip_val) * C)
|
53 |
+
|
54 |
+
|
55 |
+
def spectral_normalize_torch(magnitudes):
|
56 |
+
output = dynamic_range_compression_torch(magnitudes)
|
57 |
+
return output
|
58 |
+
|
59 |
+
|
60 |
+
# https://github.com/NVIDIA/BigVGAN
|
61 |
+
# bigvgan_24khz_100band https://drive.google.com/drive/folders/1EpxX6AsxjCbbk0mmAhE0td6eYiABr8Oz
|
62 |
+
class BigVGANFbank(FeatureExtractor):
|
63 |
+
name = "fbank"
|
64 |
+
config_type = BigVGANFbankConfig
|
65 |
+
|
66 |
+
def __init__(self, config: Optional[Any] = None):
|
67 |
+
super(BigVGANFbank, self).__init__(config)
|
68 |
+
sampling_rate = 24000
|
69 |
+
self.mel_basis = torch.from_numpy(
|
70 |
+
librosa_mel_fn(
|
71 |
+
sampling_rate,
|
72 |
+
1024,
|
73 |
+
self.config.num_mel_bins,
|
74 |
+
self.config.low_freq,
|
75 |
+
self.config.high_freq,
|
76 |
+
).astype(np.float32)
|
77 |
+
)
|
78 |
+
self.hann_window = torch.hann_window(1024)
|
79 |
+
|
80 |
+
def _feature_fn(self, samples, **kwargs):
|
81 |
+
win_length, n_fft = 1024, 1024
|
82 |
+
hop_size = 256
|
83 |
+
if True:
|
84 |
+
sampling_rate = 24000
|
85 |
+
duration = round(samples.shape[-1] / sampling_rate, ndigits=12)
|
86 |
+
expected_num_frames = compute_num_frames(
|
87 |
+
duration=duration,
|
88 |
+
frame_shift=self.frame_shift,
|
89 |
+
sampling_rate=sampling_rate,
|
90 |
+
)
|
91 |
+
pad_size = (
|
92 |
+
(expected_num_frames - 1) * hop_size
|
93 |
+
+ win_length
|
94 |
+
- samples.shape[-1]
|
95 |
+
)
|
96 |
+
assert pad_size >= 0
|
97 |
+
|
98 |
+
y = torch.nn.functional.pad(
|
99 |
+
samples,
|
100 |
+
(0, pad_size),
|
101 |
+
mode="constant",
|
102 |
+
)
|
103 |
+
else:
|
104 |
+
y = torch.nn.functional.pad(
|
105 |
+
samples,
|
106 |
+
(int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)),
|
107 |
+
mode="reflect",
|
108 |
+
)
|
109 |
+
|
110 |
+
y = y.squeeze(1)
|
111 |
+
|
112 |
+
# complex tensor as default, then use view_as_real for future pytorch compatibility
|
113 |
+
spec = torch.stft(
|
114 |
+
y,
|
115 |
+
n_fft,
|
116 |
+
hop_length=hop_size,
|
117 |
+
win_length=win_length,
|
118 |
+
window=self.hann_window,
|
119 |
+
center=False,
|
120 |
+
pad_mode="reflect",
|
121 |
+
normalized=False,
|
122 |
+
onesided=True,
|
123 |
+
return_complex=True,
|
124 |
+
)
|
125 |
+
spec = torch.view_as_real(spec)
|
126 |
+
spec = torch.sqrt(spec.pow(2).sum(-1) + (1e-9))
|
127 |
+
|
128 |
+
spec = torch.matmul(self.mel_basis, spec)
|
129 |
+
spec = spectral_normalize_torch(spec)
|
130 |
+
|
131 |
+
return spec.transpose(2, 1).squeeze(0)
|
132 |
+
|
133 |
+
def extract(
|
134 |
+
self, samples: Union[np.ndarray, torch.Tensor], sampling_rate: int
|
135 |
+
) -> np.ndarray:
|
136 |
+
assert sampling_rate == 24000
|
137 |
+
params = asdict(self.config)
|
138 |
+
params.update({"sample_frequency": sampling_rate, "snip_edges": False})
|
139 |
+
params["frame_shift"] *= 1000.0
|
140 |
+
params["frame_length"] *= 1000.0
|
141 |
+
if not isinstance(samples, torch.Tensor):
|
142 |
+
samples = torch.from_numpy(samples)
|
143 |
+
# Torchaudio Kaldi feature extractors expect the channel dimension to be first.
|
144 |
+
if len(samples.shape) == 1:
|
145 |
+
samples = samples.unsqueeze(0)
|
146 |
+
features = self._feature_fn(samples, **params).to(torch.float32)
|
147 |
+
return features.numpy()
|
148 |
+
|
149 |
+
@property
|
150 |
+
def frame_shift(self) -> Seconds:
|
151 |
+
return self.config.frame_shift
|
152 |
+
|
153 |
+
def feature_dim(self, sampling_rate: int) -> int:
|
154 |
+
return self.config.num_mel_bins
|
155 |
+
|
156 |
+
@staticmethod
|
157 |
+
def mix(
|
158 |
+
features_a: np.ndarray,
|
159 |
+
features_b: np.ndarray,
|
160 |
+
energy_scaling_factor_b: float,
|
161 |
+
) -> np.ndarray:
|
162 |
+
return np.log(
|
163 |
+
np.maximum(
|
164 |
+
# protection against log(0); max with EPSILON is adequate since these are energies (always >= 0)
|
165 |
+
EPSILON,
|
166 |
+
np.exp(features_a)
|
167 |
+
+ energy_scaling_factor_b * np.exp(features_b),
|
168 |
+
)
|
169 |
+
)
|
170 |
+
|
171 |
+
@staticmethod
|
172 |
+
def compute_energy(features: np.ndarray) -> float:
|
173 |
+
return float(np.sum(np.exp(features)))
|
174 |
+
|
175 |
+
|
176 |
+
def get_fbank_extractor() -> BigVGANFbank:
|
177 |
+
return BigVGANFbank(BigVGANFbankConfig())
|
178 |
+
|
179 |
+
|
180 |
+
if __name__ == "__main__":
|
181 |
+
extractor = BigVGANFbank(BigVGANFbankConfig())
|
182 |
+
|
183 |
+
samples = torch.from_numpy(np.random.random([1000]).astype(np.float32))
|
184 |
+
samples = torch.clip(samples, -1.0, 1.0)
|
185 |
+
fbank = extractor.extract(samples, 24000.0)
|
186 |
+
print(f"fbank {fbank.shape}")
|
187 |
+
|
188 |
+
from scipy.io.wavfile import read
|
189 |
+
|
190 |
+
MAX_WAV_VALUE = 32768.0
|
191 |
+
|
192 |
+
sampling_rate, samples = read(
|
193 |
+
"egs/libritts/prompts/5639_40744_000000_000002.wav"
|
194 |
+
)
|
195 |
+
print(f"samples: [{samples.min()}, {samples.max()}]")
|
196 |
+
fbank = extractor.extract(samples.astype(np.float32) / MAX_WAV_VALUE, 24000)
|
197 |
+
print(f"fbank {fbank.shape}")
|
198 |
+
|
199 |
+
import matplotlib.pyplot as plt
|
200 |
+
|
201 |
+
_ = plt.figure(figsize=(18, 10))
|
202 |
+
plt.imshow(
|
203 |
+
X=fbank.transpose(1, 0),
|
204 |
+
cmap=plt.get_cmap("jet"),
|
205 |
+
aspect="auto",
|
206 |
+
interpolation="nearest",
|
207 |
+
)
|
208 |
+
plt.gca().invert_yaxis()
|
209 |
+
plt.savefig("egs/libritts/prompts/5639_40744_000000_000002.png")
|
210 |
+
plt.close()
|
211 |
+
|
212 |
+
print("fbank test PASS!")
|
data/input_strategies.py
ADDED
@@ -0,0 +1,159 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import random
|
2 |
+
from collections import defaultdict
|
3 |
+
from concurrent.futures import ThreadPoolExecutor
|
4 |
+
from typing import Tuple, Type
|
5 |
+
|
6 |
+
# from lhotse import CutSet
|
7 |
+
# from lhotse.dataset.collation import collate_features
|
8 |
+
# from lhotse.dataset.input_strategies import (
|
9 |
+
# ExecutorType,
|
10 |
+
# PrecomputedFeatures,
|
11 |
+
# _get_executor,
|
12 |
+
# )
|
13 |
+
# from lhotse.utils import fastcopy
|
14 |
+
|
15 |
+
|
16 |
+
class PromptedFeatures:
|
17 |
+
def __init__(self, prompts, features):
|
18 |
+
self.prompts = prompts
|
19 |
+
self.features = features
|
20 |
+
|
21 |
+
def to(self, device):
|
22 |
+
return PromptedFeatures(
|
23 |
+
self.prompts.to(device), self.features.to(device)
|
24 |
+
)
|
25 |
+
|
26 |
+
def sum(self):
|
27 |
+
return self.features.sum()
|
28 |
+
|
29 |
+
@property
|
30 |
+
def ndim(self):
|
31 |
+
return self.features.ndim
|
32 |
+
|
33 |
+
@property
|
34 |
+
def data(self):
|
35 |
+
return (self.prompts, self.features)
|
36 |
+
|
37 |
+
|
38 |
+
# class PromptedPrecomputedFeatures(PrecomputedFeatures):
|
39 |
+
# """
|
40 |
+
# :class:`InputStrategy` that reads pre-computed features, whose manifests
|
41 |
+
# are attached to cuts, from disk.
|
42 |
+
#
|
43 |
+
# It automatically pads the feature matrices with pre or post feature.
|
44 |
+
#
|
45 |
+
# .. automethod:: __call__
|
46 |
+
# """
|
47 |
+
#
|
48 |
+
# def __init__(
|
49 |
+
# self,
|
50 |
+
# dataset: str,
|
51 |
+
# cuts: CutSet,
|
52 |
+
# num_workers: int = 0,
|
53 |
+
# executor_type: Type[ExecutorType] = ThreadPoolExecutor,
|
54 |
+
# ) -> None:
|
55 |
+
# super(PromptedPrecomputedFeatures, self).__init__(
|
56 |
+
# num_workers, executor_type
|
57 |
+
# )
|
58 |
+
#
|
59 |
+
# self.utt2neighbors = defaultdict(lambda: [])
|
60 |
+
#
|
61 |
+
# if dataset.lower() == "libritts":
|
62 |
+
# # 909_131041_000013_000002
|
63 |
+
# # 909_131041_000013_000003
|
64 |
+
# speaker2utts = defaultdict(lambda: [])
|
65 |
+
#
|
66 |
+
# utt2cut = {}
|
67 |
+
# for cut in cuts:
|
68 |
+
# speaker = cut.supervisions[0].speaker
|
69 |
+
# speaker2utts[speaker].append(cut.id)
|
70 |
+
# utt2cut[cut.id] = cut
|
71 |
+
#
|
72 |
+
# for spk in speaker2utts:
|
73 |
+
# uttids = sorted(speaker2utts[spk])
|
74 |
+
# # Using the property of sorted keys to find previous utterance
|
75 |
+
# # The keys has structure speaker_book_x_y e.g. 1089_134691_000004_000001
|
76 |
+
# if len(uttids) == 1:
|
77 |
+
# self.utt2neighbors[uttids[0]].append(utt2cut[uttids[0]])
|
78 |
+
# continue
|
79 |
+
#
|
80 |
+
# utt2prevutt = dict(zip(uttids, [uttids[1]] + uttids[:-1]))
|
81 |
+
# utt2postutt = dict(zip(uttids[:-1], uttids[1:]))
|
82 |
+
#
|
83 |
+
# for utt in utt2prevutt:
|
84 |
+
# self.utt2neighbors[utt].append(utt2cut[utt2prevutt[utt]])
|
85 |
+
#
|
86 |
+
# for utt in utt2postutt:
|
87 |
+
# self.utt2neighbors[utt].append(utt2cut[utt2postutt[utt]])
|
88 |
+
# elif dataset.lower() == "ljspeech":
|
89 |
+
# utt2cut = {}
|
90 |
+
# uttids = []
|
91 |
+
# for cut in cuts:
|
92 |
+
# uttids.append(cut.id)
|
93 |
+
# utt2cut[cut.id] = cut
|
94 |
+
#
|
95 |
+
# if len(uttids) == 1:
|
96 |
+
# self.utt2neighbors[uttids[0]].append(utt2cut[uttids[0]])
|
97 |
+
# else:
|
98 |
+
# # Using the property of sorted keys to find previous utterance
|
99 |
+
# # The keys has structure: LJ001-0010
|
100 |
+
# utt2prevutt = dict(zip(uttids, [uttids[1]] + uttids[:-1]))
|
101 |
+
# utt2postutt = dict(zip(uttids[:-1], uttids[1:]))
|
102 |
+
#
|
103 |
+
# for utt in utt2postutt:
|
104 |
+
# postutt = utt2postutt[utt]
|
105 |
+
# if utt[:5] == postutt[:5]:
|
106 |
+
# self.utt2neighbors[utt].append(utt2cut[postutt])
|
107 |
+
#
|
108 |
+
# for utt in utt2prevutt:
|
109 |
+
# prevutt = utt2prevutt[utt]
|
110 |
+
# if utt[:5] == prevutt[:5] or not self.utt2neighbors[utt]:
|
111 |
+
# self.utt2neighbors[utt].append(utt2cut[prevutt])
|
112 |
+
# else:
|
113 |
+
# raise ValueError
|
114 |
+
#
|
115 |
+
# def __call__(
|
116 |
+
# self, cuts: CutSet
|
117 |
+
# ) -> Tuple[PromptedFeatures, PromptedFeatures]:
|
118 |
+
# """
|
119 |
+
# Reads the pre-computed features from disk/other storage.
|
120 |
+
# The returned shape is``(B, T, F) => (batch_size, num_frames, num_features)``.
|
121 |
+
#
|
122 |
+
# :return: a tensor with collated features, and a tensor of ``num_frames`` of each cut before padding.
|
123 |
+
# """
|
124 |
+
# features, features_lens = collate_features(
|
125 |
+
# cuts,
|
126 |
+
# executor=_get_executor(
|
127 |
+
# self.num_workers, executor_type=self._executor_type
|
128 |
+
# ),
|
129 |
+
# )
|
130 |
+
#
|
131 |
+
# prompts_cuts = []
|
132 |
+
# for k, cut in enumerate(cuts):
|
133 |
+
# prompts_cut = random.choice(self.utt2neighbors[cut.id])
|
134 |
+
# prompts_cuts.append(fastcopy(prompts_cut, id=f"{cut.id}-{str(k)}"))
|
135 |
+
#
|
136 |
+
# mini_duration = min([cut.duration for cut in prompts_cuts] + [3.0])
|
137 |
+
# # prompts_cuts = CutSet.from_cuts(prompts_cuts).truncate(
|
138 |
+
# # max_duration=mini_duration,
|
139 |
+
# # offset_type="random",
|
140 |
+
# # preserve_id=True,
|
141 |
+
# # )
|
142 |
+
# prompts_cuts = CutSet(
|
143 |
+
# cuts={k: cut for k, cut in enumerate(prompts_cuts)}
|
144 |
+
# ).truncate(
|
145 |
+
# max_duration=mini_duration,
|
146 |
+
# offset_type="random",
|
147 |
+
# preserve_id=False,
|
148 |
+
# )
|
149 |
+
#
|
150 |
+
# prompts, prompts_lens = collate_features(
|
151 |
+
# prompts_cuts,
|
152 |
+
# executor=_get_executor(
|
153 |
+
# self.num_workers, executor_type=self._executor_type
|
154 |
+
# ),
|
155 |
+
# )
|
156 |
+
#
|
157 |
+
# return PromptedFeatures(prompts, features), PromptedFeatures(
|
158 |
+
# prompts_lens, features_lens
|
159 |
+
# )
|
data/tokenizer.py
ADDED
@@ -0,0 +1,376 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
# Copyright 2023 (authors: Feiteng Li)
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
import re
|
17 |
+
from dataclasses import asdict, dataclass
|
18 |
+
from typing import Any, Dict, List, Optional, Pattern, Union
|
19 |
+
|
20 |
+
import numpy as np
|
21 |
+
import torch
|
22 |
+
import torchaudio
|
23 |
+
from encodec import EncodecModel
|
24 |
+
from encodec.utils import convert_audio
|
25 |
+
from phonemizer.backend import EspeakBackend
|
26 |
+
from phonemizer.backend.espeak.language_switch import LanguageSwitch
|
27 |
+
from phonemizer.backend.espeak.words_mismatch import WordMismatch
|
28 |
+
from phonemizer.punctuation import Punctuation
|
29 |
+
from phonemizer.separator import Separator
|
30 |
+
|
31 |
+
try:
|
32 |
+
from pypinyin import Style, pinyin
|
33 |
+
from pypinyin.style._utils import get_finals, get_initials
|
34 |
+
except Exception:
|
35 |
+
pass
|
36 |
+
|
37 |
+
|
38 |
+
class PypinyinBackend:
|
39 |
+
"""PypinyinBackend for Chinese. Most codes is referenced from espnet.
|
40 |
+
There are two types pinyin or initials_finals, one is
|
41 |
+
just like "ni1 hao3", the other is like "n i1 h ao3".
|
42 |
+
"""
|
43 |
+
|
44 |
+
def __init__(
|
45 |
+
self,
|
46 |
+
backend="initials_finals",
|
47 |
+
punctuation_marks: Union[str, Pattern] = Punctuation.default_marks(),
|
48 |
+
) -> None:
|
49 |
+
self.backend = backend
|
50 |
+
self.punctuation_marks = punctuation_marks
|
51 |
+
|
52 |
+
def phonemize(
|
53 |
+
self, text: List[str], separator: Separator, strip=True, njobs=1
|
54 |
+
) -> List[str]:
|
55 |
+
assert isinstance(text, List)
|
56 |
+
phonemized = []
|
57 |
+
for _text in text:
|
58 |
+
_text = re.sub(" +", " ", _text.strip())
|
59 |
+
_text = _text.replace(" ", separator.word)
|
60 |
+
phones = []
|
61 |
+
if self.backend == "pypinyin":
|
62 |
+
for n, py in enumerate(
|
63 |
+
pinyin(
|
64 |
+
_text, style=Style.TONE3, neutral_tone_with_five=True
|
65 |
+
)
|
66 |
+
):
|
67 |
+
if all([c in self.punctuation_marks for c in py[0]]):
|
68 |
+
if len(phones):
|
69 |
+
assert phones[-1] == separator.syllable
|
70 |
+
phones.pop(-1)
|
71 |
+
|
72 |
+
phones.extend(list(py[0]))
|
73 |
+
else:
|
74 |
+
phones.extend([py[0], separator.syllable])
|
75 |
+
elif self.backend == "pypinyin_initials_finals":
|
76 |
+
for n, py in enumerate(
|
77 |
+
pinyin(
|
78 |
+
_text, style=Style.TONE3, neutral_tone_with_five=True
|
79 |
+
)
|
80 |
+
):
|
81 |
+
if all([c in self.punctuation_marks for c in py[0]]):
|
82 |
+
if len(phones):
|
83 |
+
assert phones[-1] == separator.syllable
|
84 |
+
phones.pop(-1)
|
85 |
+
phones.extend(list(py[0]))
|
86 |
+
else:
|
87 |
+
if py[0][-1].isalnum():
|
88 |
+
initial = get_initials(py[0], strict=False)
|
89 |
+
if py[0][-1].isdigit():
|
90 |
+
final = (
|
91 |
+
get_finals(py[0][:-1], strict=False)
|
92 |
+
+ py[0][-1]
|
93 |
+
)
|
94 |
+
else:
|
95 |
+
final = get_finals(py[0], strict=False)
|
96 |
+
phones.extend(
|
97 |
+
[
|
98 |
+
initial,
|
99 |
+
separator.phone,
|
100 |
+
final,
|
101 |
+
separator.syllable,
|
102 |
+
]
|
103 |
+
)
|
104 |
+
else:
|
105 |
+
assert ValueError
|
106 |
+
else:
|
107 |
+
raise NotImplementedError
|
108 |
+
phonemized.append(
|
109 |
+
"".join(phones).rstrip(f"{separator.word}{separator.syllable}")
|
110 |
+
)
|
111 |
+
return phonemized
|
112 |
+
|
113 |
+
|
114 |
+
class TextTokenizer:
|
115 |
+
"""Phonemize Text."""
|
116 |
+
|
117 |
+
def __init__(
|
118 |
+
self,
|
119 |
+
language="en-us",
|
120 |
+
backend="espeak",
|
121 |
+
separator=Separator(word="_", syllable="-", phone="|"),
|
122 |
+
preserve_punctuation=True,
|
123 |
+
punctuation_marks: Union[str, Pattern] = Punctuation.default_marks(),
|
124 |
+
with_stress: bool = False,
|
125 |
+
tie: Union[bool, str] = False,
|
126 |
+
language_switch: LanguageSwitch = "keep-flags",
|
127 |
+
words_mismatch: WordMismatch = "ignore",
|
128 |
+
) -> None:
|
129 |
+
if backend == "espeak":
|
130 |
+
phonemizer = EspeakBackend(
|
131 |
+
language,
|
132 |
+
punctuation_marks=punctuation_marks,
|
133 |
+
preserve_punctuation=preserve_punctuation,
|
134 |
+
with_stress=with_stress,
|
135 |
+
tie=tie,
|
136 |
+
language_switch=language_switch,
|
137 |
+
words_mismatch=words_mismatch,
|
138 |
+
)
|
139 |
+
elif backend in ["pypinyin", "pypinyin_initials_finals"]:
|
140 |
+
phonemizer = PypinyinBackend(
|
141 |
+
backend=backend,
|
142 |
+
punctuation_marks=punctuation_marks + separator.word,
|
143 |
+
)
|
144 |
+
else:
|
145 |
+
raise NotImplementedError(f"{backend}")
|
146 |
+
|
147 |
+
self.backend = phonemizer
|
148 |
+
self.separator = separator
|
149 |
+
|
150 |
+
def to_list(self, phonemized: str) -> List[str]:
|
151 |
+
fields = []
|
152 |
+
for word in phonemized.split(self.separator.word):
|
153 |
+
# "ɐ m|iː|n?" ɹ|ɪ|z|ɜː|v; h|ɪ|z.
|
154 |
+
pp = re.findall(r"\w+|[^\w\s]", word, re.UNICODE)
|
155 |
+
fields.extend(
|
156 |
+
[p for p in pp if p != self.separator.phone]
|
157 |
+
+ [self.separator.word]
|
158 |
+
)
|
159 |
+
assert len("".join(fields[:-1])) == len(phonemized) - phonemized.count(
|
160 |
+
self.separator.phone
|
161 |
+
)
|
162 |
+
return fields[:-1]
|
163 |
+
|
164 |
+
def __call__(self, text, strip=True) -> List[List[str]]:
|
165 |
+
if isinstance(text, str):
|
166 |
+
text = [text]
|
167 |
+
|
168 |
+
phonemized = self.backend.phonemize(
|
169 |
+
text, separator=self.separator, strip=strip, njobs=1
|
170 |
+
)
|
171 |
+
return [self.to_list(p) for p in phonemized]
|
172 |
+
|
173 |
+
|
174 |
+
def tokenize_text(tokenizer: TextTokenizer, text: str) -> List[str]:
|
175 |
+
phonemes = tokenizer([text.strip()])
|
176 |
+
return phonemes[0] # k2symbols
|
177 |
+
|
178 |
+
|
179 |
+
def remove_encodec_weight_norm(model):
|
180 |
+
from encodec.modules import SConv1d
|
181 |
+
from encodec.modules.seanet import SConvTranspose1d, SEANetResnetBlock
|
182 |
+
from torch.nn.utils import remove_weight_norm
|
183 |
+
|
184 |
+
encoder = model.encoder.model
|
185 |
+
for key in encoder._modules:
|
186 |
+
if isinstance(encoder._modules[key], SEANetResnetBlock):
|
187 |
+
remove_weight_norm(encoder._modules[key].shortcut.conv.conv)
|
188 |
+
block_modules = encoder._modules[key].block._modules
|
189 |
+
for skey in block_modules:
|
190 |
+
if isinstance(block_modules[skey], SConv1d):
|
191 |
+
remove_weight_norm(block_modules[skey].conv.conv)
|
192 |
+
elif isinstance(encoder._modules[key], SConv1d):
|
193 |
+
remove_weight_norm(encoder._modules[key].conv.conv)
|
194 |
+
|
195 |
+
decoder = model.decoder.model
|
196 |
+
for key in decoder._modules:
|
197 |
+
if isinstance(decoder._modules[key], SEANetResnetBlock):
|
198 |
+
remove_weight_norm(decoder._modules[key].shortcut.conv.conv)
|
199 |
+
block_modules = decoder._modules[key].block._modules
|
200 |
+
for skey in block_modules:
|
201 |
+
if isinstance(block_modules[skey], SConv1d):
|
202 |
+
remove_weight_norm(block_modules[skey].conv.conv)
|
203 |
+
elif isinstance(decoder._modules[key], SConvTranspose1d):
|
204 |
+
remove_weight_norm(decoder._modules[key].convtr.convtr)
|
205 |
+
elif isinstance(decoder._modules[key], SConv1d):
|
206 |
+
remove_weight_norm(decoder._modules[key].conv.conv)
|
207 |
+
|
208 |
+
|
209 |
+
class AudioTokenizer:
|
210 |
+
"""EnCodec audio."""
|
211 |
+
|
212 |
+
def __init__(
|
213 |
+
self,
|
214 |
+
device: Any = None,
|
215 |
+
) -> None:
|
216 |
+
# Instantiate a pretrained EnCodec model
|
217 |
+
model = EncodecModel.encodec_model_24khz()
|
218 |
+
model.set_target_bandwidth(6.0)
|
219 |
+
remove_encodec_weight_norm(model)
|
220 |
+
|
221 |
+
if not device:
|
222 |
+
device = torch.device("cpu")
|
223 |
+
if torch.cuda.is_available():
|
224 |
+
device = torch.device("cuda:0")
|
225 |
+
|
226 |
+
self._device = device
|
227 |
+
|
228 |
+
self.codec = model.to(device)
|
229 |
+
self.sample_rate = model.sample_rate
|
230 |
+
self.channels = model.channels
|
231 |
+
|
232 |
+
@property
|
233 |
+
def device(self):
|
234 |
+
return self._device
|
235 |
+
|
236 |
+
def encode(self, wav: torch.Tensor) -> torch.Tensor:
|
237 |
+
return self.codec.encode(wav.to(self.device))
|
238 |
+
|
239 |
+
def decode(self, frames: torch.Tensor) -> torch.Tensor:
|
240 |
+
return self.codec.decode(frames)
|
241 |
+
|
242 |
+
|
243 |
+
def tokenize_audio(tokenizer: AudioTokenizer, audio):
|
244 |
+
# Load and pre-process the audio waveform
|
245 |
+
if isinstance(audio, str):
|
246 |
+
wav, sr = torchaudio.load(audio)
|
247 |
+
else:
|
248 |
+
wav, sr = audio
|
249 |
+
wav = convert_audio(wav, sr, tokenizer.sample_rate, tokenizer.channels)
|
250 |
+
wav = wav.unsqueeze(0)
|
251 |
+
|
252 |
+
# Extract discrete codes from EnCodec
|
253 |
+
with torch.no_grad():
|
254 |
+
encoded_frames = tokenizer.encode(wav)
|
255 |
+
return encoded_frames
|
256 |
+
|
257 |
+
|
258 |
+
# @dataclass
|
259 |
+
# class AudioTokenConfig:
|
260 |
+
# frame_shift: Seconds = 320.0 / 24000
|
261 |
+
# num_quantizers: int = 8
|
262 |
+
#
|
263 |
+
# def to_dict(self) -> Dict[str, Any]:
|
264 |
+
# return asdict(self)
|
265 |
+
#
|
266 |
+
# @staticmethod
|
267 |
+
# def from_dict(data: Dict[str, Any]) -> "AudioTokenConfig":
|
268 |
+
# return AudioTokenConfig(**data)
|
269 |
+
#
|
270 |
+
#
|
271 |
+
# class AudioTokenExtractor(FeatureExtractor):
|
272 |
+
# name = "encodec"
|
273 |
+
# config_type = AudioTokenConfig
|
274 |
+
#
|
275 |
+
# def __init__(self, config: Optional[Any] = None):
|
276 |
+
# super(AudioTokenExtractor, self).__init__(config)
|
277 |
+
# self.tokenizer = AudioTokenizer()
|
278 |
+
#
|
279 |
+
# def extract(
|
280 |
+
# self, samples: Union[np.ndarray, torch.Tensor], sampling_rate: int
|
281 |
+
# ) -> np.ndarray:
|
282 |
+
# if not isinstance(samples, torch.Tensor):
|
283 |
+
# samples = torch.from_numpy(samples)
|
284 |
+
# if sampling_rate != self.tokenizer.sample_rate:
|
285 |
+
# samples = convert_audio(
|
286 |
+
# samples,
|
287 |
+
# sampling_rate,
|
288 |
+
# self.tokenizer.sample_rate,
|
289 |
+
# self.tokenizer.channels,
|
290 |
+
# )
|
291 |
+
# if len(samples.shape) == 2:
|
292 |
+
# samples = samples.unsqueeze(0)
|
293 |
+
# else:
|
294 |
+
# raise ValueError()
|
295 |
+
#
|
296 |
+
# device = self.tokenizer.device
|
297 |
+
# encoded_frames = self.tokenizer.encode(samples.detach().to(device))
|
298 |
+
# codes = encoded_frames[0][0] # [B, n_q, T]
|
299 |
+
# if True:
|
300 |
+
# duration = round(samples.shape[-1] / sampling_rate, ndigits=12)
|
301 |
+
# expected_num_frames = compute_num_frames(
|
302 |
+
# duration=duration,
|
303 |
+
# frame_shift=self.frame_shift,
|
304 |
+
# sampling_rate=sampling_rate,
|
305 |
+
# )
|
306 |
+
# assert abs(codes.shape[-1] - expected_num_frames) <= 1
|
307 |
+
# codes = codes[..., :expected_num_frames]
|
308 |
+
# return codes.cpu().squeeze(0).permute(1, 0).numpy()
|
309 |
+
#
|
310 |
+
# @property
|
311 |
+
# def frame_shift(self) -> Seconds:
|
312 |
+
# return self.config.frame_shift
|
313 |
+
#
|
314 |
+
# def feature_dim(self, sampling_rate: int) -> int:
|
315 |
+
# return self.config.num_quantizers
|
316 |
+
#
|
317 |
+
# def pad_tensor_list(self, tensor_list, device, padding_value=0):
|
318 |
+
# # 计算每个张量的长度
|
319 |
+
# lengths = [tensor.shape[0] for tensor in tensor_list]
|
320 |
+
# # 使用pad_sequence函数进行填充
|
321 |
+
# tensor_list = [torch.Tensor(t).to(device) for t in tensor_list]
|
322 |
+
# padded_tensor = torch.nn.utils.rnn.pad_sequence(
|
323 |
+
# tensor_list, batch_first=True, padding_value=padding_value
|
324 |
+
# )
|
325 |
+
# return padded_tensor, lengths
|
326 |
+
#
|
327 |
+
# def extract_batch(self, samples, sampling_rate, lengths) -> np.ndarray:
|
328 |
+
# samples = [wav.squeeze() for wav in samples]
|
329 |
+
# device = self.tokenizer.device
|
330 |
+
# samples, lengths = self.pad_tensor_list(samples, device)
|
331 |
+
# samples = samples.unsqueeze(1)
|
332 |
+
#
|
333 |
+
# if not isinstance(samples, torch.Tensor):
|
334 |
+
# samples = torch.from_numpy(samples)
|
335 |
+
# if len(samples.shape) != 3:
|
336 |
+
# raise ValueError()
|
337 |
+
# if sampling_rate != self.tokenizer.sample_rate:
|
338 |
+
# samples = [
|
339 |
+
# convert_audio(
|
340 |
+
# wav,
|
341 |
+
# sampling_rate,
|
342 |
+
# self.tokenizer.sample_rate,
|
343 |
+
# self.tokenizer.channels,
|
344 |
+
# )
|
345 |
+
# for wav in samples
|
346 |
+
# ]
|
347 |
+
# # Extract discrete codes from EnCodec
|
348 |
+
# with torch.no_grad():
|
349 |
+
# encoded_frames = self.tokenizer.encode(samples.detach().to(device))
|
350 |
+
# encoded_frames = encoded_frames[0][0] # [B, n_q, T]
|
351 |
+
# batch_codes = []
|
352 |
+
# for b, length in enumerate(lengths):
|
353 |
+
# codes = encoded_frames[b]
|
354 |
+
# duration = round(length / sampling_rate, ndigits=12)
|
355 |
+
# expected_num_frames = compute_num_frames(
|
356 |
+
# duration=duration,
|
357 |
+
# frame_shift=self.frame_shift,
|
358 |
+
# sampling_rate=sampling_rate,
|
359 |
+
# )
|
360 |
+
# batch_codes.append(codes[..., :expected_num_frames])
|
361 |
+
# return [codes.cpu().permute(1, 0).numpy() for codes in batch_codes]
|
362 |
+
|
363 |
+
|
364 |
+
if __name__ == "__main__":
|
365 |
+
model = EncodecModel.encodec_model_24khz()
|
366 |
+
model.set_target_bandwidth(6.0)
|
367 |
+
|
368 |
+
samples = torch.from_numpy(np.random.random([4, 1, 1600])).type(
|
369 |
+
torch.float32
|
370 |
+
)
|
371 |
+
codes_raw = model.encode(samples)
|
372 |
+
|
373 |
+
remove_encodec_weight_norm(model)
|
374 |
+
codes_norm = model.encode(samples)
|
375 |
+
|
376 |
+
assert torch.allclose(codes_raw[0][0], codes_norm[0][0])
|
descriptions.py
ADDED
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
top_md = """
|
2 |
+
# VALL-E X
|
3 |
+
VALL-E X can synthesize high-quality personalized speech with only a 3-second enrolled recording of
|
4 |
+
an unseen speaker as an acoustic prompt, even in another language for a monolingual speaker.<br>
|
5 |
+
This implementation supports zero-shot, mono-lingual/cross-lingual text-to-speech functionality of three languages (English, Chinese, Japanese)<br>
|
6 |
+
See this [demo](https://plachtaa.github.io/) page for more details.
|
7 |
+
"""
|
8 |
+
|
9 |
+
infer_from_audio_md = """
|
10 |
+
Upload a speech of 3~10 seconds as the audio prompt and type in the text you'd like to synthesize.<br>
|
11 |
+
The model will synthesize speech of given text with the same voice of your audio prompt.<br>
|
12 |
+
The model also tends to preserve the emotion & acoustic environment of your given speech.<br>
|
13 |
+
For faster inference, please use **"Make prompt"** to get a `.npz` file as the encoded audio prompt, and use it by **"Infer from prompt"**
|
14 |
+
"""
|
15 |
+
|
16 |
+
make_prompt_md = """
|
17 |
+
Upload a speech of 3~10 seconds as the audio prompt.<br>
|
18 |
+
Get a `.npz` file as the encoded audio prompt. Use it by **"Infer with prompt"**
|
19 |
+
"""
|
20 |
+
|
21 |
+
infer_from_prompt_md = """
|
22 |
+
Faster than **"Infer from audio"**.<br>
|
23 |
+
You need to **"Make prompt"** first, and upload the encoded prompt (a `.npz` file)
|
24 |
+
"""
|
25 |
+
|
26 |
+
long_text_md = """
|
27 |
+
Very long text is chunked into several sentences, and each sentence is synthesized separately.<br>
|
28 |
+
Please make a prompt or use a preset prompt to infer long text.
|
29 |
+
"""
|
30 |
+
|
31 |
+
long_text_example = "Speech processing is a field in computer science and artificial intelligence that involves the analysis, processing, and understanding of human spoken language. The main goal of speech processing is to enable computers to recognize, analyze, and respond to human speech in a natural and efficient manner."
|
examples.py
ADDED
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
infer_from_audio_examples = [
|
2 |
+
["This is how this machine has taken my voice.", 'English', 'no-accent', "prompts/en-2.wav", None, "Wow, look at that! That's no ordinary Teddy bear!"],
|
3 |
+
["我喜欢抽电子烟,尤其是锐刻五代。", '中文', 'no-accent', "prompts/zh-1.wav", None, "今天我很荣幸,"],
|
4 |
+
["私の声を真似するのはそんなに面白いですか?", '日本語', 'no-accent', "prompts/ja-2.ogg", None, "初めまして、朝武よしのです。"],
|
5 |
+
["你可以听得出来我有多困。", '中文', 'no-accent', "prompts/en-1.wav", None, ""],
|
6 |
+
["この文は、クロスリンガル合成の例です。", '日本語', 'no-accent', "prompts/zh-2.wav", None, ""],
|
7 |
+
["Actually, I can't speak English, but this machine helped me do it.", 'English', 'no-accent', "prompts/ja-1.wav", None, ""],
|
8 |
+
]
|
9 |
+
|
10 |
+
make_npz_prompt_examples = [
|
11 |
+
["Gem-trader", "prompts/en-2.wav", None, "Wow, look at that! That's no ordinary Teddy bear!"],
|
12 |
+
["Ding Zhen", "prompts/zh-1.wav", None, "今天我很荣幸,"],
|
13 |
+
["Yoshino", "prompts/ja-2.ogg", None, "初めまして、朝武よしのです。"],
|
14 |
+
["Sleepy-woman", "prompts/en-1.wav", None, ""],
|
15 |
+
["Yae", "prompts/zh-2.wav", None, ""],
|
16 |
+
["Cafe", "prompts/ja-1.wav", None, ""],
|
17 |
+
]
|
18 |
+
|
19 |
+
infer_from_prompt_examples = [
|
20 |
+
["A prompt contains voice, prosody and emotion information of a certain speaker.", "English", "no-accent", "vctk_1", None],
|
21 |
+
["This prompt is made with an audio of three seconds.", "English", "no-accent", "librispeech_1", None],
|
22 |
+
["This prompt is made with Chinese speech", "English", "no-accent", "seel", None],
|
23 |
+
]
|
24 |
+
|
exp/valle_dev/log/log-train-2023-11-01-00-19-48
ADDED
@@ -0,0 +1,117 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
2023-11-01 00:19:48,610 INFO [train.py:851] Training started
|
2 |
+
2023-11-01 00:19:48,611 INFO [train.py:870] Device: cuda:0
|
3 |
+
2023-11-01 00:19:48,611 INFO [train.py:871] {'best_train_loss': inf, 'best_valid_loss': inf, 'best_train_epoch': -1, 'best_valid_epoch': -1, 'batch_idx_train': 0, 'log_interval': 100, 'reset_interval': 200, 'valid_interval': 10000, 'world_size': 1, 'master_port': 12354, 'tensorboard': True, 'num_epochs': 20, 'start_epoch': 1, 'start_batch': 0, 'exp_dir': PosixPath('exp/valle_dev'), 'optimizer_name': 'ScaledAdam', 'scheduler_name': 'Eden', 'base_lr': 0.005, 'warmup_steps': 200, 'seed': 42, 'inf_check': False, 'save_every_n': 10000, 'keep_last_k': 20, 'average_period': 0, 'accumulate_grad_steps': 1, 'dtype': 'float16', 'filter_min_duration': 0.0, 'filter_max_duration': 20.0, 'train_stage': 0, 'visualize': False, 'oom_check': True, 'train_dir': '/home/ubuntu/VALL-E-X/JS_Dataset/JS_Dataset/train_tune', 'valid_dir': '/home/ubuntu/VALL-E-X/JS_Dataset/JS_Dataset/valid_tune', 'model_name': 'VALL-E', 'decoder_dim': 1024, 'nhead': 16, 'num_decoder_layers': 12, 'scale_factor': 1.0, 'norm_first': True, 'add_prenet': False, 'prefix_mode': 0, 'share_embedding': True, 'prepend_bos': False, 'num_quantizers': 8, 'scaling_xformers': False}
|
4 |
+
2023-11-01 00:19:48,611 INFO [train.py:873] About to create model
|
5 |
+
2023-11-01 00:19:52,689 INFO [train.py:877] Number of model parameters: 370539524
|
6 |
+
2023-11-01 00:19:52,909 DEBUG [__init__.py:113] Building prefix dict from the default dictionary ...
|
7 |
+
2023-11-01 00:19:52,910 DEBUG [__init__.py:132] Loading model from cache /tmp/jieba.cache
|
8 |
+
2023-11-01 00:19:53,423 DEBUG [__init__.py:164] Loading model cost 0.513 seconds.
|
9 |
+
2023-11-01 00:19:53,423 DEBUG [__init__.py:166] Prefix dict has been built successfully.
|
10 |
+
2023-11-01 00:20:15,635 INFO [train.py:764] Epoch 1, batch 100, train_loss[loss=3.202, ArTop10Accuracy=0.7412, NarTop10Accuracy=0.6004, over 1306.00 frames. ], tot_loss[loss=3.398, ArTop10Accuracy=0.7055, NarTop10Accuracy=0.553, over 476.97 frames. ], batch size: 3, lr: 3.75e-03, grad_scale: 1.0
|
11 |
+
2023-11-01 00:20:37,125 INFO [train.py:764] Epoch 1, batch 200, train_loss[loss=3.531, ArTop10Accuracy=0.6921, NarTop10Accuracy=0.5406, over 1234.00 frames. ], tot_loss[loss=3.408, ArTop10Accuracy=0.7094, NarTop10Accuracy=0.5521, over 749.45 frames. ], batch size: 3, lr: 5.00e-03, grad_scale: 1.0
|
12 |
+
2023-11-01 00:20:59,073 INFO [train.py:764] Epoch 1, batch 300, train_loss[loss=3.609, ArTop10Accuracy=0.7015, NarTop10Accuracy=0.4503, over 995.00 frames. ], tot_loss[loss=3.443, ArTop10Accuracy=0.7109, NarTop10Accuracy=0.5387, over 935.67 frames. ], batch size: 2, lr: 5.00e-03, grad_scale: 1.0
|
13 |
+
2023-11-01 00:21:20,852 INFO [train.py:764] Epoch 1, batch 400, train_loss[loss=3.419, ArTop10Accuracy=0.6896, NarTop10Accuracy=0.5319, over 1234.00 frames. ], tot_loss[loss=3.463, ArTop10Accuracy=0.7133, NarTop10Accuracy=0.5285, over 1040.70 frames. ], batch size: 3, lr: 4.99e-03, grad_scale: 2.0
|
14 |
+
2023-11-01 00:21:42,406 INFO [train.py:764] Epoch 1, batch 500, train_loss[loss=3.293, ArTop10Accuracy=0.7238, NarTop10Accuracy=0.5892, over 1271.00 frames. ], tot_loss[loss=3.483, ArTop10Accuracy=0.7149, NarTop10Accuracy=0.5217, over 1094.80 frames. ], batch size: 3, lr: 4.99e-03, grad_scale: 2.0
|
15 |
+
2023-11-01 00:22:04,280 INFO [train.py:764] Epoch 1, batch 600, train_loss[loss=3.675, ArTop10Accuracy=0.7139, NarTop10Accuracy=0.4414, over 1496.00 frames. ], tot_loss[loss=3.461, ArTop10Accuracy=0.7191, NarTop10Accuracy=0.5267, over 1141.86 frames. ], batch size: 3, lr: 4.98e-03, grad_scale: 2.0
|
16 |
+
2023-11-01 00:22:25,961 INFO [train.py:764] Epoch 1, batch 700, train_loss[loss=3.25, ArTop10Accuracy=0.6822, NarTop10Accuracy=0.6536, over 1010.00 frames. ], tot_loss[loss=3.468, ArTop10Accuracy=0.7182, NarTop10Accuracy=0.5271, over 1167.56 frames. ], batch size: 2, lr: 4.98e-03, grad_scale: 2.0
|
17 |
+
2023-11-01 00:22:47,713 INFO [train.py:764] Epoch 1, batch 800, train_loss[loss=3.896, ArTop10Accuracy=0.6663, NarTop10Accuracy=0.4321, over 947.00 frames. ], tot_loss[loss=3.468, ArTop10Accuracy=0.7199, NarTop10Accuracy=0.5278, over 1184.08 frames. ], batch size: 2, lr: 4.97e-03, grad_scale: 4.0
|
18 |
+
2023-11-01 00:23:09,399 INFO [train.py:764] Epoch 1, batch 900, train_loss[loss=3.434, ArTop10Accuracy=0.7556, NarTop10Accuracy=0.5224, over 1195.00 frames. ], tot_loss[loss=3.461, ArTop10Accuracy=0.7214, NarTop10Accuracy=0.5295, over 1188.23 frames. ], batch size: 3, lr: 4.96e-03, grad_scale: 4.0
|
19 |
+
2023-11-01 00:23:31,117 INFO [train.py:764] Epoch 1, batch 1000, train_loss[loss=3.611, ArTop10Accuracy=0.7125, NarTop10Accuracy=0.5023, over 1339.00 frames. ], tot_loss[loss=3.458, ArTop10Accuracy=0.7229, NarTop10Accuracy=0.5286, over 1190.57 frames. ], batch size: 3, lr: 4.95e-03, grad_scale: 4.0
|
20 |
+
2023-11-01 00:23:31,281 INFO [utils.py:877] Clipping_scale=2.0, grad-norm quartiles 3.204e+01 4.529e+01 5.014e+01 5.593e+01 9.312e+01, threshold=1.003e+02, percent-clipped=0.0
|
21 |
+
2023-11-01 00:23:53,034 INFO [train.py:764] Epoch 1, batch 1100, train_loss[loss=3.754, ArTop10Accuracy=0.7051, NarTop10Accuracy=0.4622, over 1007.00 frames. ], tot_loss[loss=3.465, ArTop10Accuracy=0.7233, NarTop10Accuracy=0.5252, over 1194.69 frames. ], batch size: 2, lr: 4.94e-03, grad_scale: 4.0
|
22 |
+
2023-11-01 00:24:15,002 INFO [train.py:764] Epoch 1, batch 1200, train_loss[loss=3.233, ArTop10Accuracy=0.7508, NarTop10Accuracy=0.5991, over 1228.00 frames. ], tot_loss[loss=3.456, ArTop10Accuracy=0.7254, NarTop10Accuracy=0.5279, over 1198.98 frames. ], batch size: 3, lr: 4.93e-03, grad_scale: 8.0
|
23 |
+
2023-11-01 00:24:36,955 INFO [train.py:764] Epoch 1, batch 1300, train_loss[loss=3.365, ArTop10Accuracy=0.7287, NarTop10Accuracy=0.5916, over 1054.00 frames. ], tot_loss[loss=3.464, ArTop10Accuracy=0.7257, NarTop10Accuracy=0.5265, over 1202.37 frames. ], batch size: 2, lr: 4.92e-03, grad_scale: 8.0
|
24 |
+
2023-11-01 00:24:58,700 INFO [train.py:764] Epoch 1, batch 1400, train_loss[loss=3.458, ArTop10Accuracy=0.731, NarTop10Accuracy=0.5298, over 1301.00 frames. ], tot_loss[loss=3.474, ArTop10Accuracy=0.7271, NarTop10Accuracy=0.5209, over 1198.94 frames. ], batch size: 3, lr: 4.91e-03, grad_scale: 8.0
|
25 |
+
2023-11-01 00:25:20,642 INFO [train.py:764] Epoch 1, batch 1500, train_loss[loss=3.331, ArTop10Accuracy=0.7508, NarTop10Accuracy=0.5712, over 1236.00 frames. ], tot_loss[loss=3.46, ArTop10Accuracy=0.7278, NarTop10Accuracy=0.5248, over 1198.26 frames. ], batch size: 3, lr: 4.89e-03, grad_scale: 8.0
|
26 |
+
2023-11-01 00:25:42,599 INFO [train.py:764] Epoch 1, batch 1600, train_loss[loss=3.627, ArTop10Accuracy=0.7136, NarTop10Accuracy=0.4668, over 1201.00 frames. ], tot_loss[loss=3.463, ArTop10Accuracy=0.7265, NarTop10Accuracy=0.5249, over 1204.65 frames. ], batch size: 3, lr: 4.88e-03, grad_scale: 8.0
|
27 |
+
2023-11-01 00:26:04,490 INFO [train.py:764] Epoch 1, batch 1700, train_loss[loss=3.399, ArTop10Accuracy=0.7477, NarTop10Accuracy=0.5198, over 650.00 frames. ], tot_loss[loss=3.46, ArTop10Accuracy=0.7275, NarTop10Accuracy=0.5243, over 1202.22 frames. ], batch size: 1, lr: 4.87e-03, grad_scale: 8.0
|
28 |
+
2023-11-01 00:26:26,353 INFO [train.py:764] Epoch 1, batch 1800, train_loss[loss=3.566, ArTop10Accuracy=0.7215, NarTop10Accuracy=0.4575, over 1253.00 frames. ], tot_loss[loss=3.46, ArTop10Accuracy=0.7281, NarTop10Accuracy=0.5244, over 1198.05 frames. ], batch size: 3, lr: 4.85e-03, grad_scale: 8.0
|
29 |
+
2023-11-01 00:26:48,377 INFO [train.py:764] Epoch 1, batch 1900, train_loss[loss=3.387, ArTop10Accuracy=0.8449, NarTop10Accuracy=0.4865, over 819.00 frames. ], tot_loss[loss=3.453, ArTop10Accuracy=0.7281, NarTop10Accuracy=0.527, over 1198.59 frames. ], batch size: 1, lr: 4.83e-03, grad_scale: 8.0
|
30 |
+
2023-11-01 00:27:10,331 INFO [train.py:764] Epoch 1, batch 2000, train_loss[loss=3.325, ArTop10Accuracy=0.7224, NarTop10Accuracy=0.5816, over 1286.00 frames. ], tot_loss[loss=3.442, ArTop10Accuracy=0.7305, NarTop10Accuracy=0.5298, over 1193.18 frames. ], batch size: 3, lr: 4.82e-03, grad_scale: 16.0
|
31 |
+
2023-11-01 00:27:10,506 INFO [utils.py:877] Clipping_scale=2.0, grad-norm quartiles 2.916e+01 4.003e+01 4.277e+01 4.588e+01 1.290e+02, threshold=8.555e+01, percent-clipped=0.1
|
32 |
+
2023-11-01 00:27:32,462 INFO [train.py:764] Epoch 1, batch 2100, train_loss[loss=3.532, ArTop10Accuracy=0.7298, NarTop10Accuracy=0.5061, over 1114.00 frames. ], tot_loss[loss=3.458, ArTop10Accuracy=0.7291, NarTop10Accuracy=0.5252, over 1201.39 frames. ], batch size: 2, lr: 4.80e-03, grad_scale: 16.0
|
33 |
+
2023-11-01 00:27:54,843 INFO [train.py:764] Epoch 1, batch 2200, train_loss[loss=3.206, ArTop10Accuracy=0.8044, NarTop10Accuracy=0.5327, over 1084.00 frames. ], tot_loss[loss=3.452, ArTop10Accuracy=0.7295, NarTop10Accuracy=0.5281, over 1212.86 frames. ], batch size: 2, lr: 4.78e-03, grad_scale: 16.0
|
34 |
+
2023-11-01 00:28:17,007 INFO [train.py:764] Epoch 1, batch 2300, train_loss[loss=3.656, ArTop10Accuracy=0.7355, NarTop10Accuracy=0.4183, over 1478.00 frames. ], tot_loss[loss=3.457, ArTop10Accuracy=0.729, NarTop10Accuracy=0.525, over 1212.39 frames. ], batch size: 3, lr: 4.77e-03, grad_scale: 16.0
|
35 |
+
2023-11-01 00:28:39,010 INFO [train.py:764] Epoch 1, batch 2400, train_loss[loss=3.578, ArTop10Accuracy=0.7343, NarTop10Accuracy=0.4818, over 1325.00 frames. ], tot_loss[loss=3.43, ArTop10Accuracy=0.7311, NarTop10Accuracy=0.5341, over 1209.83 frames. ], batch size: 3, lr: 4.75e-03, grad_scale: 16.0
|
36 |
+
2023-11-01 00:29:00,770 INFO [train.py:764] Epoch 1, batch 2500, train_loss[loss=3.153, ArTop10Accuracy=0.7836, NarTop10Accuracy=0.6147, over 1280.00 frames. ], tot_loss[loss=3.419, ArTop10Accuracy=0.7334, NarTop10Accuracy=0.5348, over 1200.35 frames. ], batch size: 3, lr: 4.73e-03, grad_scale: 16.0
|
37 |
+
2023-11-01 00:29:22,695 INFO [train.py:764] Epoch 1, batch 2600, train_loss[loss=3.217, ArTop10Accuracy=0.7502, NarTop10Accuracy=0.5967, over 1321.00 frames. ], tot_loss[loss=3.417, ArTop10Accuracy=0.7336, NarTop10Accuracy=0.5363, over 1203.85 frames. ], batch size: 3, lr: 4.71e-03, grad_scale: 16.0
|
38 |
+
2023-11-01 00:29:44,570 INFO [train.py:764] Epoch 1, batch 2700, train_loss[loss=3.181, ArTop10Accuracy=0.7385, NarTop10Accuracy=0.624, over 1480.00 frames. ], tot_loss[loss=3.406, ArTop10Accuracy=0.7344, NarTop10Accuracy=0.5392, over 1201.18 frames. ], batch size: 3, lr: 4.69e-03, grad_scale: 16.0
|
39 |
+
2023-11-01 00:30:06,497 INFO [train.py:764] Epoch 1, batch 2800, train_loss[loss=3.198, ArTop10Accuracy=0.734, NarTop10Accuracy=0.6348, over 1297.00 frames. ], tot_loss[loss=3.411, ArTop10Accuracy=0.7353, NarTop10Accuracy=0.5371, over 1201.74 frames. ], batch size: 3, lr: 4.67e-03, grad_scale: 16.0
|
40 |
+
2023-11-01 00:30:28,351 INFO [train.py:764] Epoch 1, batch 2900, train_loss[loss=3.037, ArTop10Accuracy=0.7592, NarTop10Accuracy=0.6745, over 1387.00 frames. ], tot_loss[loss=3.422, ArTop10Accuracy=0.7355, NarTop10Accuracy=0.5335, over 1199.06 frames. ], batch size: 3, lr: 4.65e-03, grad_scale: 16.0
|
41 |
+
2023-11-01 00:30:50,226 INFO [train.py:764] Epoch 1, batch 3000, train_loss[loss=3.231, ArTop10Accuracy=0.7304, NarTop10Accuracy=0.6256, over 1261.00 frames. ], tot_loss[loss=3.429, ArTop10Accuracy=0.7346, NarTop10Accuracy=0.5312, over 1193.63 frames. ], batch size: 3, lr: 4.63e-03, grad_scale: 16.0
|
42 |
+
2023-11-01 00:30:50,404 INFO [utils.py:877] Clipping_scale=2.0, grad-norm quartiles 2.518e+01 3.788e+01 4.043e+01 4.342e+01 9.222e+01, threshold=8.086e+01, percent-clipped=0.2
|
43 |
+
2023-11-01 00:31:12,558 INFO [train.py:764] Epoch 1, batch 3100, train_loss[loss=3.228, ArTop10Accuracy=0.7206, NarTop10Accuracy=0.6731, over 1070.00 frames. ], tot_loss[loss=3.421, ArTop10Accuracy=0.7344, NarTop10Accuracy=0.5346, over 1195.52 frames. ], batch size: 2, lr: 4.61e-03, grad_scale: 16.0
|
44 |
+
2023-11-01 00:31:34,821 INFO [train.py:764] Epoch 1, batch 3200, train_loss[loss=3.373, ArTop10Accuracy=0.7137, NarTop10Accuracy=0.5737, over 1289.00 frames. ], tot_loss[loss=3.418, ArTop10Accuracy=0.7353, NarTop10Accuracy=0.5343, over 1210.89 frames. ], batch size: 3, lr: 4.59e-03, grad_scale: 16.0
|
45 |
+
2023-11-01 00:31:56,704 INFO [train.py:764] Epoch 1, batch 3300, train_loss[loss=3.731, ArTop10Accuracy=0.7498, NarTop10Accuracy=0.4016, over 1331.00 frames. ], tot_loss[loss=3.428, ArTop10Accuracy=0.736, NarTop10Accuracy=0.5316, over 1199.90 frames. ], batch size: 3, lr: 4.57e-03, grad_scale: 16.0
|
46 |
+
2023-11-01 00:32:18,676 INFO [train.py:764] Epoch 1, batch 3400, train_loss[loss=3.212, ArTop10Accuracy=0.7604, NarTop10Accuracy=0.6446, over 1223.00 frames. ], tot_loss[loss=3.424, ArTop10Accuracy=0.737, NarTop10Accuracy=0.531, over 1204.36 frames. ], batch size: 3, lr: 4.55e-03, grad_scale: 16.0
|
47 |
+
2023-11-01 00:32:40,608 INFO [train.py:764] Epoch 1, batch 3500, train_loss[loss=3.223, ArTop10Accuracy=0.7422, NarTop10Accuracy=0.6264, over 1315.00 frames. ], tot_loss[loss=3.418, ArTop10Accuracy=0.7374, NarTop10Accuracy=0.5322, over 1203.43 frames. ], batch size: 3, lr: 4.53e-03, grad_scale: 16.0
|
48 |
+
2023-11-01 00:33:02,672 INFO [train.py:764] Epoch 1, batch 3600, train_loss[loss=3.237, ArTop10Accuracy=0.7415, NarTop10Accuracy=0.5996, over 1002.00 frames. ], tot_loss[loss=3.411, ArTop10Accuracy=0.7378, NarTop10Accuracy=0.5369, over 1200.89 frames. ], batch size: 2, lr: 4.50e-03, grad_scale: 16.0
|
49 |
+
2023-11-01 00:33:24,593 INFO [train.py:764] Epoch 1, batch 3700, train_loss[loss=3.049, ArTop10Accuracy=0.7433, NarTop10Accuracy=0.6294, over 1270.00 frames. ], tot_loss[loss=3.404, ArTop10Accuracy=0.7378, NarTop10Accuracy=0.5378, over 1199.71 frames. ], batch size: 3, lr: 4.48e-03, grad_scale: 16.0
|
50 |
+
2023-11-01 00:33:46,541 INFO [train.py:764] Epoch 1, batch 3800, train_loss[loss=3.443, ArTop10Accuracy=0.7345, NarTop10Accuracy=0.5163, over 953.00 frames. ], tot_loss[loss=3.423, ArTop10Accuracy=0.737, NarTop10Accuracy=0.5303, over 1206.54 frames. ], batch size: 2, lr: 4.46e-03, grad_scale: 16.0
|
51 |
+
2023-11-01 00:34:08,561 INFO [train.py:764] Epoch 1, batch 3900, train_loss[loss=3.271, ArTop10Accuracy=0.7608, NarTop10Accuracy=0.5504, over 1346.00 frames. ], tot_loss[loss=3.408, ArTop10Accuracy=0.7383, NarTop10Accuracy=0.5361, over 1214.09 frames. ], batch size: 3, lr: 4.44e-03, grad_scale: 16.0
|
52 |
+
2023-11-01 00:34:30,421 INFO [train.py:764] Epoch 1, batch 4000, train_loss[loss=3.177, ArTop10Accuracy=0.7882, NarTop10Accuracy=0.6005, over 1336.00 frames. ], tot_loss[loss=3.396, ArTop10Accuracy=0.739, NarTop10Accuracy=0.5396, over 1202.95 frames. ], batch size: 3, lr: 4.42e-03, grad_scale: 32.0
|
53 |
+
2023-11-01 00:34:30,586 INFO [utils.py:877] Clipping_scale=2.0, grad-norm quartiles 2.781e+01 3.743e+01 4.001e+01 4.276e+01 1.199e+02, threshold=8.003e+01, percent-clipped=0.2
|
54 |
+
2023-11-01 00:34:52,163 INFO [train.py:764] Epoch 1, batch 4100, train_loss[loss=3.238, ArTop10Accuracy=0.74, NarTop10Accuracy=0.5985, over 1277.00 frames. ], tot_loss[loss=3.4, ArTop10Accuracy=0.7395, NarTop10Accuracy=0.5383, over 1197.35 frames. ], batch size: 3, lr: 4.40e-03, grad_scale: 32.0
|
55 |
+
2023-11-01 00:35:14,150 INFO [train.py:764] Epoch 1, batch 4200, train_loss[loss=3.329, ArTop10Accuracy=0.737, NarTop10Accuracy=0.5832, over 1228.00 frames. ], tot_loss[loss=3.395, ArTop10Accuracy=0.7388, NarTop10Accuracy=0.5409, over 1204.03 frames. ], batch size: 3, lr: 4.38e-03, grad_scale: 32.0
|
56 |
+
2023-11-01 00:35:36,263 INFO [train.py:764] Epoch 1, batch 4300, train_loss[loss=3.293, ArTop10Accuracy=0.7733, NarTop10Accuracy=0.5665, over 1125.00 frames. ], tot_loss[loss=3.393, ArTop10Accuracy=0.7399, NarTop10Accuracy=0.5412, over 1201.24 frames. ], batch size: 1, lr: 4.35e-03, grad_scale: 32.0
|
57 |
+
2023-11-01 00:35:58,283 INFO [train.py:764] Epoch 1, batch 4400, train_loss[loss=2.99, ArTop10Accuracy=0.7618, NarTop10Accuracy=0.7263, over 1297.00 frames. ], tot_loss[loss=3.387, ArTop10Accuracy=0.7425, NarTop10Accuracy=0.5403, over 1202.84 frames. ], batch size: 3, lr: 4.33e-03, grad_scale: 32.0
|
58 |
+
2023-11-01 00:36:20,271 INFO [train.py:764] Epoch 1, batch 4500, train_loss[loss=3.166, ArTop10Accuracy=0.7474, NarTop10Accuracy=0.6354, over 1271.00 frames. ], tot_loss[loss=3.375, ArTop10Accuracy=0.7434, NarTop10Accuracy=0.5462, over 1206.43 frames. ], batch size: 3, lr: 4.31e-03, grad_scale: 32.0
|
59 |
+
2023-11-01 00:36:42,281 INFO [train.py:764] Epoch 1, batch 4600, train_loss[loss=3.479, ArTop10Accuracy=0.7316, NarTop10Accuracy=0.5057, over 980.00 frames. ], tot_loss[loss=3.377, ArTop10Accuracy=0.7436, NarTop10Accuracy=0.5457, over 1199.69 frames. ], batch size: 2, lr: 4.29e-03, grad_scale: 32.0
|
60 |
+
2023-11-01 00:37:04,337 INFO [train.py:764] Epoch 1, batch 4700, train_loss[loss=3.177, ArTop10Accuracy=0.7289, NarTop10Accuracy=0.6232, over 1280.00 frames. ], tot_loss[loss=3.378, ArTop10Accuracy=0.7425, NarTop10Accuracy=0.5457, over 1200.61 frames. ], batch size: 3, lr: 4.27e-03, grad_scale: 32.0
|
61 |
+
2023-11-01 00:37:26,359 INFO [train.py:764] Epoch 1, batch 4800, train_loss[loss=3.718, ArTop10Accuracy=0.7287, NarTop10Accuracy=0.4516, over 1176.00 frames. ], tot_loss[loss=3.384, ArTop10Accuracy=0.7417, NarTop10Accuracy=0.5444, over 1204.83 frames. ], batch size: 2, lr: 4.25e-03, grad_scale: 32.0
|
62 |
+
2023-11-01 00:37:48,245 INFO [train.py:764] Epoch 1, batch 4900, train_loss[loss=3.189, ArTop10Accuracy=0.7412, NarTop10Accuracy=0.6122, over 962.00 frames. ], tot_loss[loss=3.395, ArTop10Accuracy=0.7405, NarTop10Accuracy=0.5393, over 1196.82 frames. ], batch size: 2, lr: 4.23e-03, grad_scale: 32.0
|
63 |
+
2023-11-01 00:38:10,398 INFO [train.py:764] Epoch 1, batch 5000, train_loss[loss=3.969, ArTop10Accuracy=0.6697, NarTop10Accuracy=0.3746, over 1002.00 frames. ], tot_loss[loss=3.4, ArTop10Accuracy=0.7401, NarTop10Accuracy=0.5376, over 1198.57 frames. ], batch size: 2, lr: 4.20e-03, grad_scale: 32.0
|
64 |
+
2023-11-01 00:38:10,556 INFO [utils.py:877] Clipping_scale=2.0, grad-norm quartiles 2.722e+01 3.674e+01 3.909e+01 4.182e+01 1.163e+02, threshold=7.817e+01, percent-clipped=0.1
|
65 |
+
2023-11-01 00:38:32,372 INFO [train.py:764] Epoch 1, batch 5100, train_loss[loss=3.425, ArTop10Accuracy=0.7333, NarTop10Accuracy=0.535, over 1050.00 frames. ], tot_loss[loss=3.389, ArTop10Accuracy=0.7421, NarTop10Accuracy=0.5403, over 1202.12 frames. ], batch size: 2, lr: 4.18e-03, grad_scale: 32.0
|
66 |
+
2023-11-01 00:38:54,235 INFO [train.py:764] Epoch 1, batch 5200, train_loss[loss=3.11, ArTop10Accuracy=0.7507, NarTop10Accuracy=0.6442, over 1352.00 frames. ], tot_loss[loss=3.382, ArTop10Accuracy=0.7421, NarTop10Accuracy=0.5436, over 1199.90 frames. ], batch size: 3, lr: 4.16e-03, grad_scale: 32.0
|
67 |
+
2023-11-01 00:39:15,969 INFO [train.py:764] Epoch 1, batch 5300, train_loss[loss=3.178, ArTop10Accuracy=0.756, NarTop10Accuracy=0.5823, over 1217.00 frames. ], tot_loss[loss=3.4, ArTop10Accuracy=0.7417, NarTop10Accuracy=0.5379, over 1195.55 frames. ], batch size: 3, lr: 4.14e-03, grad_scale: 32.0
|
68 |
+
2023-11-01 00:39:37,856 INFO [train.py:764] Epoch 1, batch 5400, train_loss[loss=3.272, ArTop10Accuracy=0.7515, NarTop10Accuracy=0.5808, over 1300.00 frames. ], tot_loss[loss=3.385, ArTop10Accuracy=0.7428, NarTop10Accuracy=0.5426, over 1202.33 frames. ], batch size: 3, lr: 4.12e-03, grad_scale: 32.0
|
69 |
+
2023-11-01 00:39:59,844 INFO [train.py:764] Epoch 1, batch 5500, train_loss[loss=3.962, ArTop10Accuracy=0.7026, NarTop10Accuracy=0.362, over 1318.00 frames. ], tot_loss[loss=3.393, ArTop10Accuracy=0.7416, NarTop10Accuracy=0.5403, over 1204.76 frames. ], batch size: 3, lr: 4.10e-03, grad_scale: 32.0
|
70 |
+
2023-11-01 00:40:21,881 INFO [train.py:764] Epoch 1, batch 5600, train_loss[loss=3.23, ArTop10Accuracy=0.7509, NarTop10Accuracy=0.6012, over 1064.00 frames. ], tot_loss[loss=3.389, ArTop10Accuracy=0.742, NarTop10Accuracy=0.5422, over 1209.03 frames. ], batch size: 2, lr: 4.08e-03, grad_scale: 32.0
|
71 |
+
2023-11-01 00:40:44,136 INFO [train.py:764] Epoch 1, batch 5700, train_loss[loss=3.756, ArTop10Accuracy=0.7013, NarTop10Accuracy=0.43, over 1309.00 frames. ], tot_loss[loss=3.404, ArTop10Accuracy=0.7411, NarTop10Accuracy=0.5373, over 1206.83 frames. ], batch size: 3, lr: 4.06e-03, grad_scale: 32.0
|
72 |
+
2023-11-01 00:41:06,015 INFO [train.py:764] Epoch 1, batch 5800, train_loss[loss=3.22, ArTop10Accuracy=0.7976, NarTop10Accuracy=0.5996, over 1077.00 frames. ], tot_loss[loss=3.393, ArTop10Accuracy=0.7424, NarTop10Accuracy=0.5391, over 1198.69 frames. ], batch size: 2, lr: 4.04e-03, grad_scale: 32.0
|
73 |
+
2023-11-01 00:41:27,922 INFO [train.py:764] Epoch 1, batch 5900, train_loss[loss=3.255, ArTop10Accuracy=0.7603, NarTop10Accuracy=0.5729, over 1410.00 frames. ], tot_loss[loss=3.379, ArTop10Accuracy=0.7444, NarTop10Accuracy=0.5428, over 1195.35 frames. ], batch size: 2, lr: 4.02e-03, grad_scale: 32.0
|
74 |
+
2023-11-01 00:41:49,781 INFO [train.py:764] Epoch 1, batch 6000, train_loss[loss=3.04, ArTop10Accuracy=0.754, NarTop10Accuracy=0.6513, over 1264.00 frames. ], tot_loss[loss=3.367, ArTop10Accuracy=0.7448, NarTop10Accuracy=0.549, over 1191.87 frames. ], batch size: 3, lr: 4.00e-03, grad_scale: 64.0
|
75 |
+
2023-11-01 00:41:49,950 INFO [utils.py:877] Clipping_scale=2.0, grad-norm quartiles 3.018e+01 3.667e+01 3.892e+01 4.174e+01 9.103e+01, threshold=7.785e+01, percent-clipped=0.2
|
76 |
+
2023-11-01 00:42:11,871 INFO [train.py:764] Epoch 1, batch 6100, train_loss[loss=3.224, ArTop10Accuracy=0.7514, NarTop10Accuracy=0.5748, over 1275.00 frames. ], tot_loss[loss=3.365, ArTop10Accuracy=0.7445, NarTop10Accuracy=0.5489, over 1200.31 frames. ], batch size: 3, lr: 3.98e-03, grad_scale: 64.0
|
77 |
+
2023-11-01 00:42:33,790 INFO [train.py:764] Epoch 1, batch 6200, train_loss[loss=3.631, ArTop10Accuracy=0.7441, NarTop10Accuracy=0.4927, over 1067.00 frames. ], tot_loss[loss=3.368, ArTop10Accuracy=0.7456, NarTop10Accuracy=0.5473, over 1197.29 frames. ], batch size: 2, lr: 3.96e-03, grad_scale: 64.0
|
78 |
+
2023-11-01 00:42:55,647 INFO [train.py:764] Epoch 1, batch 6300, train_loss[loss=2.863, ArTop10Accuracy=0.7722, NarTop10Accuracy=0.6714, over 1229.00 frames. ], tot_loss[loss=3.358, ArTop10Accuracy=0.7464, NarTop10Accuracy=0.549, over 1195.66 frames. ], batch size: 3, lr: 3.94e-03, grad_scale: 64.0
|
79 |
+
2023-11-01 00:43:17,732 INFO [train.py:764] Epoch 1, batch 6400, train_loss[loss=3.313, ArTop10Accuracy=0.7168, NarTop10Accuracy=0.6248, over 1342.00 frames. ], tot_loss[loss=3.364, ArTop10Accuracy=0.7454, NarTop10Accuracy=0.5488, over 1200.90 frames. ], batch size: 3, lr: 3.92e-03, grad_scale: 16.0
|
80 |
+
2023-11-01 00:43:39,683 INFO [train.py:764] Epoch 1, batch 6500, train_loss[loss=3.58, ArTop10Accuracy=0.7492, NarTop10Accuracy=0.4578, over 1324.00 frames. ], tot_loss[loss=3.365, ArTop10Accuracy=0.7452, NarTop10Accuracy=0.5472, over 1203.64 frames. ], batch size: 3, lr: 3.90e-03, grad_scale: 16.0
|
81 |
+
2023-11-01 00:44:01,658 INFO [train.py:764] Epoch 1, batch 6600, train_loss[loss=3.618, ArTop10Accuracy=0.7234, NarTop10Accuracy=0.4793, over 1157.00 frames. ], tot_loss[loss=3.375, ArTop10Accuracy=0.7447, NarTop10Accuracy=0.5439, over 1199.80 frames. ], batch size: 2, lr: 3.89e-03, grad_scale: 16.0
|
82 |
+
2023-11-01 00:44:23,811 INFO [train.py:764] Epoch 1, batch 6700, train_loss[loss=3.909, ArTop10Accuracy=0.7282, NarTop10Accuracy=0.3605, over 1512.00 frames. ], tot_loss[loss=3.374, ArTop10Accuracy=0.747, NarTop10Accuracy=0.5432, over 1206.18 frames. ], batch size: 2, lr: 3.87e-03, grad_scale: 16.0
|
83 |
+
2023-11-01 00:44:45,778 INFO [train.py:764] Epoch 1, batch 6800, train_loss[loss=3.756, ArTop10Accuracy=0.6656, NarTop10Accuracy=0.4652, over 1250.00 frames. ], tot_loss[loss=3.378, ArTop10Accuracy=0.7473, NarTop10Accuracy=0.5408, over 1211.75 frames. ], batch size: 3, lr: 3.85e-03, grad_scale: 16.0
|
84 |
+
2023-11-01 00:45:07,536 INFO [train.py:764] Epoch 1, batch 6900, train_loss[loss=3.302, ArTop10Accuracy=0.7848, NarTop10Accuracy=0.5048, over 1199.00 frames. ], tot_loss[loss=3.369, ArTop10Accuracy=0.7478, NarTop10Accuracy=0.5444, over 1203.83 frames. ], batch size: 3, lr: 3.83e-03, grad_scale: 16.0
|
85 |
+
2023-11-01 00:45:29,567 INFO [train.py:764] Epoch 1, batch 7000, train_loss[loss=3.614, ArTop10Accuracy=0.667, NarTop10Accuracy=0.5579, over 1054.00 frames. ], tot_loss[loss=3.379, ArTop10Accuracy=0.7464, NarTop10Accuracy=0.5433, over 1205.70 frames. ], batch size: 2, lr: 3.81e-03, grad_scale: 16.0
|
86 |
+
2023-11-01 00:45:30,189 INFO [utils.py:877] Clipping_scale=2.0, grad-norm quartiles 2.893e+01 3.646e+01 3.891e+01 4.178e+01 1.168e+02, threshold=7.782e+01, percent-clipped=0.2
|
87 |
+
2023-11-01 00:45:51,356 INFO [train.py:764] Epoch 1, batch 7100, train_loss[loss=3.365, ArTop10Accuracy=0.7592, NarTop10Accuracy=0.539, over 1221.00 frames. ], tot_loss[loss=3.363, ArTop10Accuracy=0.7483, NarTop10Accuracy=0.5466, over 1200.17 frames. ], batch size: 3, lr: 3.79e-03, grad_scale: 16.0
|
88 |
+
2023-11-01 00:46:13,329 INFO [train.py:764] Epoch 1, batch 7200, train_loss[loss=3.422, ArTop10Accuracy=0.7782, NarTop10Accuracy=0.5144, over 1109.00 frames. ], tot_loss[loss=3.349, ArTop10Accuracy=0.7504, NarTop10Accuracy=0.5512, over 1200.81 frames. ], batch size: 2, lr: 3.78e-03, grad_scale: 16.0
|
89 |
+
2023-11-01 00:46:35,352 INFO [train.py:764] Epoch 1, batch 7300, train_loss[loss=3.123, ArTop10Accuracy=0.7526, NarTop10Accuracy=0.6595, over 1047.00 frames. ], tot_loss[loss=3.357, ArTop10Accuracy=0.7506, NarTop10Accuracy=0.549, over 1199.39 frames. ], batch size: 2, lr: 3.76e-03, grad_scale: 16.0
|
90 |
+
2023-11-01 00:46:57,454 INFO [train.py:764] Epoch 1, batch 7400, train_loss[loss=3.457, ArTop10Accuracy=0.7662, NarTop10Accuracy=0.4982, over 1142.00 frames. ], tot_loss[loss=3.35, ArTop10Accuracy=0.7506, NarTop10Accuracy=0.552, over 1204.30 frames. ], batch size: 2, lr: 3.74e-03, grad_scale: 16.0
|
91 |
+
2023-11-01 00:47:19,658 INFO [train.py:764] Epoch 1, batch 7500, train_loss[loss=3.83, ArTop10Accuracy=0.6948, NarTop10Accuracy=0.4401, over 1019.00 frames. ], tot_loss[loss=3.359, ArTop10Accuracy=0.7496, NarTop10Accuracy=0.5494, over 1211.85 frames. ], batch size: 2, lr: 3.72e-03, grad_scale: 16.0
|
92 |
+
2023-11-01 00:47:41,652 INFO [train.py:764] Epoch 1, batch 7600, train_loss[loss=3.049, ArTop10Accuracy=0.7516, NarTop10Accuracy=0.6564, over 1538.00 frames. ], tot_loss[loss=3.346, ArTop10Accuracy=0.7518, NarTop10Accuracy=0.5516, over 1203.46 frames. ], batch size: 3, lr: 3.71e-03, grad_scale: 16.0
|
93 |
+
2023-11-01 00:48:03,855 INFO [train.py:764] Epoch 1, batch 7700, train_loss[loss=3.38, ArTop10Accuracy=0.7527, NarTop10Accuracy=0.5145, over 1500.00 frames. ], tot_loss[loss=3.343, ArTop10Accuracy=0.7512, NarTop10Accuracy=0.553, over 1210.10 frames. ], batch size: 3, lr: 3.69e-03, grad_scale: 16.0
|
94 |
+
2023-11-01 00:48:25,842 INFO [train.py:764] Epoch 1, batch 7800, train_loss[loss=3.577, ArTop10Accuracy=0.7123, NarTop10Accuracy=0.5032, over 1300.00 frames. ], tot_loss[loss=3.344, ArTop10Accuracy=0.7498, NarTop10Accuracy=0.5536, over 1207.39 frames. ], batch size: 3, lr: 3.67e-03, grad_scale: 16.0
|
95 |
+
2023-11-01 00:48:47,914 INFO [train.py:764] Epoch 1, batch 7900, train_loss[loss=3.129, ArTop10Accuracy=0.7551, NarTop10Accuracy=0.6453, over 1323.00 frames. ], tot_loss[loss=3.348, ArTop10Accuracy=0.7489, NarTop10Accuracy=0.5529, over 1205.69 frames. ], batch size: 3, lr: 3.66e-03, grad_scale: 16.0
|
96 |
+
2023-11-01 00:49:10,108 INFO [train.py:764] Epoch 1, batch 8000, train_loss[loss=3.435, ArTop10Accuracy=0.7504, NarTop10Accuracy=0.5231, over 1314.00 frames. ], tot_loss[loss=3.344, ArTop10Accuracy=0.7484, NarTop10Accuracy=0.5538, over 1204.55 frames. ], batch size: 2, lr: 3.64e-03, grad_scale: 16.0
|
97 |
+
2023-11-01 00:49:10,719 INFO [utils.py:877] Clipping_scale=2.0, grad-norm quartiles 2.574e+01 3.623e+01 3.864e+01 4.175e+01 9.270e+01, threshold=7.727e+01, percent-clipped=0.3
|
98 |
+
2023-11-01 00:49:32,315 INFO [train.py:764] Epoch 1, batch 8100, train_loss[loss=3.578, ArTop10Accuracy=0.7281, NarTop10Accuracy=0.5042, over 1453.00 frames. ], tot_loss[loss=3.369, ArTop10Accuracy=0.7464, NarTop10Accuracy=0.5469, over 1201.98 frames. ], batch size: 3, lr: 3.62e-03, grad_scale: 16.0
|
99 |
+
2023-11-01 00:49:54,412 INFO [train.py:764] Epoch 1, batch 8200, train_loss[loss=3.126, ArTop10Accuracy=0.7671, NarTop10Accuracy=0.6341, over 1198.00 frames. ], tot_loss[loss=3.377, ArTop10Accuracy=0.7451, NarTop10Accuracy=0.5455, over 1201.66 frames. ], batch size: 3, lr: 3.61e-03, grad_scale: 16.0
|
100 |
+
2023-11-01 00:50:16,462 INFO [train.py:764] Epoch 1, batch 8300, train_loss[loss=3.51, ArTop10Accuracy=0.7526, NarTop10Accuracy=0.4518, over 1354.00 frames. ], tot_loss[loss=3.363, ArTop10Accuracy=0.7486, NarTop10Accuracy=0.546, over 1200.11 frames. ], batch size: 3, lr: 3.59e-03, grad_scale: 16.0
|
101 |
+
2023-11-01 00:50:38,610 INFO [train.py:764] Epoch 1, batch 8400, train_loss[loss=3.781, ArTop10Accuracy=0.7058, NarTop10Accuracy=0.4102, over 1064.00 frames. ], tot_loss[loss=3.363, ArTop10Accuracy=0.7496, NarTop10Accuracy=0.545, over 1201.74 frames. ], batch size: 2, lr: 3.58e-03, grad_scale: 32.0
|
102 |
+
2023-11-01 00:51:00,590 INFO [train.py:764] Epoch 1, batch 8500, train_loss[loss=3.679, ArTop10Accuracy=0.72, NarTop10Accuracy=0.4794, over 1082.00 frames. ], tot_loss[loss=3.357, ArTop10Accuracy=0.7495, NarTop10Accuracy=0.5478, over 1189.28 frames. ], batch size: 2, lr: 3.56e-03, grad_scale: 32.0
|
103 |
+
2023-11-01 00:51:22,638 INFO [train.py:764] Epoch 1, batch 8600, train_loss[loss=3.35, ArTop10Accuracy=0.7558, NarTop10Accuracy=0.5426, over 1208.00 frames. ], tot_loss[loss=3.361, ArTop10Accuracy=0.7485, NarTop10Accuracy=0.546, over 1195.17 frames. ], batch size: 3, lr: 3.54e-03, grad_scale: 32.0
|
104 |
+
2023-11-01 00:51:44,857 INFO [train.py:764] Epoch 1, batch 8700, train_loss[loss=3.502, ArTop10Accuracy=0.7509, NarTop10Accuracy=0.4911, over 1076.00 frames. ], tot_loss[loss=3.341, ArTop10Accuracy=0.7513, NarTop10Accuracy=0.5539, over 1199.46 frames. ], batch size: 2, lr: 3.53e-03, grad_scale: 32.0
|
105 |
+
2023-11-01 00:52:06,705 INFO [train.py:764] Epoch 1, batch 8800, train_loss[loss=3.171, ArTop10Accuracy=0.7301, NarTop10Accuracy=0.6783, over 1178.00 frames. ], tot_loss[loss=3.339, ArTop10Accuracy=0.751, NarTop10Accuracy=0.5542, over 1196.85 frames. ], batch size: 3, lr: 3.51e-03, grad_scale: 16.0
|
106 |
+
2023-11-01 00:52:28,687 INFO [train.py:764] Epoch 1, batch 8900, train_loss[loss=3.443, ArTop10Accuracy=0.7263, NarTop10Accuracy=0.5612, over 1330.00 frames. ], tot_loss[loss=3.333, ArTop10Accuracy=0.7516, NarTop10Accuracy=0.5562, over 1197.63 frames. ], batch size: 3, lr: 3.50e-03, grad_scale: 16.0
|
107 |
+
2023-11-01 00:52:50,763 INFO [train.py:764] Epoch 1, batch 9000, train_loss[loss=3.046, ArTop10Accuracy=0.7712, NarTop10Accuracy=0.6817, over 1355.00 frames. ], tot_loss[loss=3.346, ArTop10Accuracy=0.7506, NarTop10Accuracy=0.5509, over 1206.31 frames. ], batch size: 3, lr: 3.48e-03, grad_scale: 16.0
|
108 |
+
2023-11-01 00:52:51,581 INFO [utils.py:877] Clipping_scale=2.0, grad-norm quartiles 2.596e+01 3.632e+01 3.884e+01 4.203e+01 1.192e+02, threshold=7.767e+01, percent-clipped=0.1
|
109 |
+
2023-11-01 00:53:12,725 INFO [train.py:764] Epoch 1, batch 9100, train_loss[loss=3.108, ArTop10Accuracy=0.745, NarTop10Accuracy=0.6108, over 1255.00 frames. ], tot_loss[loss=3.325, ArTop10Accuracy=0.7506, NarTop10Accuracy=0.5588, over 1204.55 frames. ], batch size: 3, lr: 3.47e-03, grad_scale: 16.0
|
110 |
+
2023-11-01 00:53:34,455 INFO [train.py:764] Epoch 1, batch 9200, train_loss[loss=3.607, ArTop10Accuracy=0.7313, NarTop10Accuracy=0.4606, over 1243.00 frames. ], tot_loss[loss=3.323, ArTop10Accuracy=0.7512, NarTop10Accuracy=0.559, over 1199.74 frames. ], batch size: 3, lr: 3.46e-03, grad_scale: 16.0
|
111 |
+
2023-11-01 00:53:56,283 INFO [train.py:764] Epoch 1, batch 9300, train_loss[loss=3.341, ArTop10Accuracy=0.7796, NarTop10Accuracy=0.5181, over 1243.00 frames. ], tot_loss[loss=3.327, ArTop10Accuracy=0.7531, NarTop10Accuracy=0.5566, over 1196.75 frames. ], batch size: 3, lr: 3.44e-03, grad_scale: 16.0
|
112 |
+
2023-11-01 00:54:18,225 INFO [train.py:764] Epoch 1, batch 9400, train_loss[loss=3.287, ArTop10Accuracy=0.7442, NarTop10Accuracy=0.5898, over 1501.00 frames. ], tot_loss[loss=3.318, ArTop10Accuracy=0.7531, NarTop10Accuracy=0.5604, over 1201.10 frames. ], batch size: 3, lr: 3.43e-03, grad_scale: 16.0
|
113 |
+
2023-11-01 00:54:40,274 INFO [train.py:764] Epoch 1, batch 9500, train_loss[loss=3.335, ArTop10Accuracy=0.7351, NarTop10Accuracy=0.6181, over 1023.00 frames. ], tot_loss[loss=3.318, ArTop10Accuracy=0.7547, NarTop10Accuracy=0.5604, over 1206.75 frames. ], batch size: 2, lr: 3.41e-03, grad_scale: 16.0
|
114 |
+
2023-11-01 00:55:02,388 INFO [train.py:764] Epoch 1, batch 9600, train_loss[loss=3.33, ArTop10Accuracy=0.7322, NarTop10Accuracy=0.6035, over 1225.00 frames. ], tot_loss[loss=3.326, ArTop10Accuracy=0.7529, NarTop10Accuracy=0.5581, over 1210.68 frames. ], batch size: 3, lr: 3.40e-03, grad_scale: 16.0
|
115 |
+
2023-11-01 00:55:24,417 INFO [train.py:764] Epoch 1, batch 9700, train_loss[loss=3.696, ArTop10Accuracy=0.7275, NarTop10Accuracy=0.4242, over 1358.00 frames. ], tot_loss[loss=3.33, ArTop10Accuracy=0.752, NarTop10Accuracy=0.5553, over 1204.10 frames. ], batch size: 3, lr: 3.38e-03, grad_scale: 16.0
|
116 |
+
2023-11-01 00:55:46,226 INFO [train.py:764] Epoch 1, batch 9800, train_loss[loss=3.352, ArTop10Accuracy=0.7327, NarTop10Accuracy=0.607, over 1025.00 frames. ], tot_loss[loss=3.327, ArTop10Accuracy=0.7526, NarTop10Accuracy=0.5557, over 1201.65 frames. ], batch size: 2, lr: 3.37e-03, grad_scale: 16.0
|
117 |
+
2023-11-01 00:56:08,325 INFO [train.py:764] Epoch 1, batch 9900, train_loss[loss=3.73, ArTop10Accuracy=0.715, NarTop10Accuracy=0.4577, over 1130.00 frames. ], tot_loss[loss=3.33, ArTop10Accuracy=0.7521, NarTop10Accuracy=0.5557, over 1206.02 frames. ], batch size: 2, lr: 3.36e-03, grad_scale: 16.0
|
exp/valle_dev/log/log-train-2023-11-01-01-01-00
ADDED
@@ -0,0 +1,172 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
2023-11-01 01:01:00,501 INFO [train.py:851] Training started
|
2 |
+
2023-11-01 01:01:00,502 INFO [train.py:870] Device: cuda:0
|
3 |
+
2023-11-01 01:01:00,503 INFO [train.py:871] {'best_train_loss': inf, 'best_valid_loss': inf, 'best_train_epoch': -1, 'best_valid_epoch': -1, 'batch_idx_train': 0, 'log_interval': 100, 'reset_interval': 200, 'valid_interval': 10000, 'world_size': 1, 'master_port': 12354, 'tensorboard': True, 'num_epochs': 20, 'start_epoch': 1, 'start_batch': 0, 'exp_dir': PosixPath('exp/valle_dev'), 'optimizer_name': 'ScaledAdam', 'scheduler_name': 'Eden', 'base_lr': 0.005, 'warmup_steps': 200, 'seed': 42, 'inf_check': False, 'save_every_n': 10000, 'keep_last_k': 20, 'average_period': 0, 'accumulate_grad_steps': 1, 'dtype': 'float16', 'filter_min_duration': 0.0, 'filter_max_duration': 20.0, 'train_stage': 0, 'visualize': False, 'oom_check': True, 'train_dir': '/home/ubuntu/VALL-E-X/JS_Dataset/JS_Dataset/train_tune', 'valid_dir': '/home/ubuntu/VALL-E-X/JS_Dataset/JS_Dataset/valid_tune', 'model_name': 'VALL-E', 'decoder_dim': 1024, 'nhead': 16, 'num_decoder_layers': 12, 'scale_factor': 1.0, 'norm_first': True, 'add_prenet': False, 'prefix_mode': 0, 'share_embedding': True, 'prepend_bos': False, 'num_quantizers': 8, 'scaling_xformers': False}
|
4 |
+
2023-11-01 01:01:00,503 INFO [train.py:873] About to create model
|
5 |
+
2023-11-01 01:01:05,108 INFO [train.py:877] Number of model parameters: 370539524
|
6 |
+
2023-11-01 01:01:05,334 DEBUG [__init__.py:113] Building prefix dict from the default dictionary ...
|
7 |
+
2023-11-01 01:01:05,334 DEBUG [__init__.py:132] Loading model from cache /tmp/jieba.cache
|
8 |
+
2023-11-01 01:01:05,847 DEBUG [__init__.py:164] Loading model cost 0.513 seconds.
|
9 |
+
2023-11-01 01:01:05,847 DEBUG [__init__.py:166] Prefix dict has been built successfully.
|
10 |
+
2023-11-01 01:01:28,103 INFO [train.py:764] Epoch 1, batch 100, train_loss[loss=3.201, ArTop10Accuracy=0.7404, NarTop10Accuracy=0.6034, over 1306.00 frames. ], tot_loss[loss=3.398, ArTop10Accuracy=0.7056, NarTop10Accuracy=0.5532, over 476.97 frames. ], batch size: 3, lr: 3.75e-03, grad_scale: 1.0
|
11 |
+
2023-11-01 01:01:49,616 INFO [train.py:764] Epoch 1, batch 200, train_loss[loss=3.527, ArTop10Accuracy=0.6953, NarTop10Accuracy=0.5522, over 1234.00 frames. ], tot_loss[loss=3.408, ArTop10Accuracy=0.709, NarTop10Accuracy=0.552, over 749.45 frames. ], batch size: 3, lr: 5.00e-03, grad_scale: 1.0
|
12 |
+
2023-11-01 01:02:11,657 INFO [train.py:764] Epoch 1, batch 300, train_loss[loss=3.605, ArTop10Accuracy=0.7095, NarTop10Accuracy=0.4503, over 995.00 frames. ], tot_loss[loss=3.443, ArTop10Accuracy=0.7106, NarTop10Accuracy=0.5387, over 935.67 frames. ], batch size: 2, lr: 5.00e-03, grad_scale: 1.0
|
13 |
+
2023-11-01 01:02:33,731 INFO [train.py:764] Epoch 1, batch 400, train_loss[loss=3.412, ArTop10Accuracy=0.6864, NarTop10Accuracy=0.5319, over 1234.00 frames. ], tot_loss[loss=3.462, ArTop10Accuracy=0.7132, NarTop10Accuracy=0.5284, over 1040.70 frames. ], batch size: 3, lr: 4.99e-03, grad_scale: 2.0
|
14 |
+
2023-11-01 01:02:55,462 INFO [train.py:764] Epoch 1, batch 500, train_loss[loss=3.292, ArTop10Accuracy=0.727, NarTop10Accuracy=0.5829, over 1271.00 frames. ], tot_loss[loss=3.483, ArTop10Accuracy=0.7154, NarTop10Accuracy=0.5211, over 1094.80 frames. ], batch size: 3, lr: 4.99e-03, grad_scale: 2.0
|
15 |
+
2023-11-01 01:03:17,499 INFO [train.py:764] Epoch 1, batch 600, train_loss[loss=3.666, ArTop10Accuracy=0.7166, NarTop10Accuracy=0.4386, over 1496.00 frames. ], tot_loss[loss=3.46, ArTop10Accuracy=0.719, NarTop10Accuracy=0.5268, over 1141.86 frames. ], batch size: 3, lr: 4.98e-03, grad_scale: 2.0
|
16 |
+
2023-11-01 01:03:39,337 INFO [train.py:764] Epoch 1, batch 700, train_loss[loss=3.255, ArTop10Accuracy=0.6812, NarTop10Accuracy=0.6464, over 1010.00 frames. ], tot_loss[loss=3.468, ArTop10Accuracy=0.7186, NarTop10Accuracy=0.527, over 1167.56 frames. ], batch size: 2, lr: 4.98e-03, grad_scale: 2.0
|
17 |
+
2023-11-01 01:04:01,250 INFO [train.py:764] Epoch 1, batch 800, train_loss[loss=3.898, ArTop10Accuracy=0.6695, NarTop10Accuracy=0.4247, over 947.00 frames. ], tot_loss[loss=3.467, ArTop10Accuracy=0.7206, NarTop10Accuracy=0.527, over 1184.08 frames. ], batch size: 2, lr: 4.97e-03, grad_scale: 4.0
|
18 |
+
2023-11-01 01:04:23,115 INFO [train.py:764] Epoch 1, batch 900, train_loss[loss=3.424, ArTop10Accuracy=0.7506, NarTop10Accuracy=0.5276, over 1195.00 frames. ], tot_loss[loss=3.461, ArTop10Accuracy=0.7221, NarTop10Accuracy=0.5286, over 1188.23 frames. ], batch size: 3, lr: 4.96e-03, grad_scale: 4.0
|
19 |
+
2023-11-01 01:04:45,007 INFO [train.py:764] Epoch 1, batch 1000, train_loss[loss=3.606, ArTop10Accuracy=0.7147, NarTop10Accuracy=0.4905, over 1339.00 frames. ], tot_loss[loss=3.458, ArTop10Accuracy=0.7239, NarTop10Accuracy=0.5285, over 1190.57 frames. ], batch size: 3, lr: 4.95e-03, grad_scale: 4.0
|
20 |
+
2023-11-01 01:04:45,173 INFO [utils.py:877] Clipping_scale=2.0, grad-norm quartiles 3.165e+01 4.527e+01 4.998e+01 5.591e+01 9.306e+01, threshold=9.997e+01, percent-clipped=0.0
|
21 |
+
2023-11-01 01:05:07,092 INFO [train.py:764] Epoch 1, batch 1100, train_loss[loss=3.756, ArTop10Accuracy=0.7031, NarTop10Accuracy=0.4571, over 1007.00 frames. ], tot_loss[loss=3.465, ArTop10Accuracy=0.7242, NarTop10Accuracy=0.5252, over 1194.69 frames. ], batch size: 2, lr: 4.94e-03, grad_scale: 4.0
|
22 |
+
2023-11-01 01:05:29,224 INFO [train.py:764] Epoch 1, batch 1200, train_loss[loss=3.244, ArTop10Accuracy=0.7549, NarTop10Accuracy=0.5967, over 1228.00 frames. ], tot_loss[loss=3.456, ArTop10Accuracy=0.7256, NarTop10Accuracy=0.5275, over 1198.98 frames. ], batch size: 3, lr: 4.93e-03, grad_scale: 8.0
|
23 |
+
2023-11-01 01:05:51,281 INFO [train.py:764] Epoch 1, batch 1300, train_loss[loss=3.368, ArTop10Accuracy=0.7277, NarTop10Accuracy=0.5906, over 1054.00 frames. ], tot_loss[loss=3.464, ArTop10Accuracy=0.7263, NarTop10Accuracy=0.526, over 1202.37 frames. ], batch size: 2, lr: 4.92e-03, grad_scale: 8.0
|
24 |
+
2023-11-01 01:06:13,118 INFO [train.py:764] Epoch 1, batch 1400, train_loss[loss=3.445, ArTop10Accuracy=0.7333, NarTop10Accuracy=0.5311, over 1301.00 frames. ], tot_loss[loss=3.474, ArTop10Accuracy=0.7277, NarTop10Accuracy=0.5206, over 1198.94 frames. ], batch size: 3, lr: 4.91e-03, grad_scale: 8.0
|
25 |
+
2023-11-01 01:06:35,108 INFO [train.py:764] Epoch 1, batch 1500, train_loss[loss=3.337, ArTop10Accuracy=0.754, NarTop10Accuracy=0.566, over 1236.00 frames. ], tot_loss[loss=3.46, ArTop10Accuracy=0.7282, NarTop10Accuracy=0.5245, over 1198.26 frames. ], batch size: 3, lr: 4.89e-03, grad_scale: 8.0
|
26 |
+
2023-11-01 01:06:57,122 INFO [train.py:764] Epoch 1, batch 1600, train_loss[loss=3.626, ArTop10Accuracy=0.7161, NarTop10Accuracy=0.4777, over 1201.00 frames. ], tot_loss[loss=3.464, ArTop10Accuracy=0.7262, NarTop10Accuracy=0.5244, over 1204.65 frames. ], batch size: 3, lr: 4.88e-03, grad_scale: 8.0
|
27 |
+
2023-11-01 01:07:19,077 INFO [train.py:764] Epoch 1, batch 1700, train_loss[loss=3.4, ArTop10Accuracy=0.7569, NarTop10Accuracy=0.499, over 650.00 frames. ], tot_loss[loss=3.46, ArTop10Accuracy=0.7276, NarTop10Accuracy=0.5238, over 1202.22 frames. ], batch size: 1, lr: 4.87e-03, grad_scale: 8.0
|
28 |
+
2023-11-01 01:07:40,966 INFO [train.py:764] Epoch 1, batch 1800, train_loss[loss=3.575, ArTop10Accuracy=0.7287, NarTop10Accuracy=0.4463, over 1253.00 frames. ], tot_loss[loss=3.46, ArTop10Accuracy=0.7279, NarTop10Accuracy=0.5244, over 1198.05 frames. ], batch size: 3, lr: 4.85e-03, grad_scale: 8.0
|
29 |
+
2023-11-01 01:08:02,894 INFO [train.py:764] Epoch 1, batch 1900, train_loss[loss=3.345, ArTop10Accuracy=0.8437, NarTop10Accuracy=0.5034, over 819.00 frames. ], tot_loss[loss=3.453, ArTop10Accuracy=0.7277, NarTop10Accuracy=0.5275, over 1198.59 frames. ], batch size: 1, lr: 4.83e-03, grad_scale: 8.0
|
30 |
+
2023-11-01 01:08:24,667 INFO [train.py:764] Epoch 1, batch 2000, train_loss[loss=3.321, ArTop10Accuracy=0.7224, NarTop10Accuracy=0.5903, over 1286.00 frames. ], tot_loss[loss=3.442, ArTop10Accuracy=0.7301, NarTop10Accuracy=0.53, over 1193.18 frames. ], batch size: 3, lr: 4.82e-03, grad_scale: 16.0
|
31 |
+
2023-11-01 01:08:24,840 INFO [utils.py:877] Clipping_scale=2.0, grad-norm quartiles 2.935e+01 4.004e+01 4.278e+01 4.591e+01 1.283e+02, threshold=8.555e+01, percent-clipped=0.1
|
32 |
+
2023-11-01 01:08:46,652 INFO [train.py:764] Epoch 1, batch 2100, train_loss[loss=3.536, ArTop10Accuracy=0.7298, NarTop10Accuracy=0.5135, over 1114.00 frames. ], tot_loss[loss=3.459, ArTop10Accuracy=0.7287, NarTop10Accuracy=0.5254, over 1201.39 frames. ], batch size: 2, lr: 4.80e-03, grad_scale: 16.0
|
33 |
+
2023-11-01 01:09:08,857 INFO [train.py:764] Epoch 1, batch 2200, train_loss[loss=3.196, ArTop10Accuracy=0.8063, NarTop10Accuracy=0.545, over 1084.00 frames. ], tot_loss[loss=3.452, ArTop10Accuracy=0.73, NarTop10Accuracy=0.5287, over 1212.86 frames. ], batch size: 2, lr: 4.78e-03, grad_scale: 16.0
|
34 |
+
2023-11-01 01:09:30,803 INFO [train.py:764] Epoch 1, batch 2300, train_loss[loss=3.648, ArTop10Accuracy=0.7415, NarTop10Accuracy=0.4108, over 1478.00 frames. ], tot_loss[loss=3.457, ArTop10Accuracy=0.7292, NarTop10Accuracy=0.5254, over 1212.39 frames. ], batch size: 3, lr: 4.77e-03, grad_scale: 16.0
|
35 |
+
2023-11-01 01:09:52,725 INFO [train.py:764] Epoch 1, batch 2400, train_loss[loss=3.594, ArTop10Accuracy=0.7306, NarTop10Accuracy=0.4828, over 1325.00 frames. ], tot_loss[loss=3.43, ArTop10Accuracy=0.7311, NarTop10Accuracy=0.5345, over 1209.83 frames. ], batch size: 3, lr: 4.75e-03, grad_scale: 16.0
|
36 |
+
2023-11-01 01:10:14,502 INFO [train.py:764] Epoch 1, batch 2500, train_loss[loss=3.158, ArTop10Accuracy=0.7867, NarTop10Accuracy=0.6099, over 1280.00 frames. ], tot_loss[loss=3.419, ArTop10Accuracy=0.7338, NarTop10Accuracy=0.5354, over 1200.35 frames. ], batch size: 3, lr: 4.73e-03, grad_scale: 16.0
|
37 |
+
2023-11-01 01:10:36,448 INFO [train.py:764] Epoch 1, batch 2600, train_loss[loss=3.217, ArTop10Accuracy=0.7426, NarTop10Accuracy=0.6008, over 1321.00 frames. ], tot_loss[loss=3.417, ArTop10Accuracy=0.7343, NarTop10Accuracy=0.5362, over 1203.85 frames. ], batch size: 3, lr: 4.71e-03, grad_scale: 16.0
|
38 |
+
2023-11-01 01:10:58,352 INFO [train.py:764] Epoch 1, batch 2700, train_loss[loss=3.172, ArTop10Accuracy=0.748, NarTop10Accuracy=0.624, over 1480.00 frames. ], tot_loss[loss=3.406, ArTop10Accuracy=0.7351, NarTop10Accuracy=0.5391, over 1201.18 frames. ], batch size: 3, lr: 4.69e-03, grad_scale: 16.0
|
39 |
+
2023-11-01 01:11:20,304 INFO [train.py:764] Epoch 1, batch 2800, train_loss[loss=3.185, ArTop10Accuracy=0.734, NarTop10Accuracy=0.6313, over 1297.00 frames. ], tot_loss[loss=3.409, ArTop10Accuracy=0.7356, NarTop10Accuracy=0.5372, over 1201.74 frames. ], batch size: 3, lr: 4.67e-03, grad_scale: 16.0
|
40 |
+
2023-11-01 01:11:42,304 INFO [train.py:764] Epoch 1, batch 2900, train_loss[loss=3.039, ArTop10Accuracy=0.7433, NarTop10Accuracy=0.6679, over 1387.00 frames. ], tot_loss[loss=3.421, ArTop10Accuracy=0.7359, NarTop10Accuracy=0.5336, over 1199.06 frames. ], batch size: 3, lr: 4.65e-03, grad_scale: 16.0
|
41 |
+
2023-11-01 01:12:04,208 INFO [train.py:764] Epoch 1, batch 3000, train_loss[loss=3.229, ArTop10Accuracy=0.7407, NarTop10Accuracy=0.6181, over 1261.00 frames. ], tot_loss[loss=3.428, ArTop10Accuracy=0.7354, NarTop10Accuracy=0.5309, over 1193.63 frames. ], batch size: 3, lr: 4.63e-03, grad_scale: 16.0
|
42 |
+
2023-11-01 01:12:04,384 INFO [utils.py:877] Clipping_scale=2.0, grad-norm quartiles 2.561e+01 3.787e+01 4.032e+01 4.340e+01 1.223e+02, threshold=8.065e+01, percent-clipped=0.1
|
43 |
+
2023-11-01 01:12:26,564 INFO [train.py:764] Epoch 1, batch 3100, train_loss[loss=3.249, ArTop10Accuracy=0.7327, NarTop10Accuracy=0.6701, over 1070.00 frames. ], tot_loss[loss=3.42, ArTop10Accuracy=0.7352, NarTop10Accuracy=0.5346, over 1195.52 frames. ], batch size: 2, lr: 4.61e-03, grad_scale: 16.0
|
44 |
+
2023-11-01 01:12:48,991 INFO [train.py:764] Epoch 1, batch 3200, train_loss[loss=3.381, ArTop10Accuracy=0.7176, NarTop10Accuracy=0.5645, over 1289.00 frames. ], tot_loss[loss=3.418, ArTop10Accuracy=0.7361, NarTop10Accuracy=0.5337, over 1210.89 frames. ], batch size: 3, lr: 4.59e-03, grad_scale: 16.0
|
45 |
+
2023-11-01 01:13:10,935 INFO [train.py:764] Epoch 1, batch 3300, train_loss[loss=3.741, ArTop10Accuracy=0.7543, NarTop10Accuracy=0.3948, over 1331.00 frames. ], tot_loss[loss=3.427, ArTop10Accuracy=0.7365, NarTop10Accuracy=0.5314, over 1199.90 frames. ], batch size: 3, lr: 4.57e-03, grad_scale: 16.0
|
46 |
+
2023-11-01 01:13:32,934 INFO [train.py:764] Epoch 1, batch 3400, train_loss[loss=3.206, ArTop10Accuracy=0.7596, NarTop10Accuracy=0.6492, over 1223.00 frames. ], tot_loss[loss=3.425, ArTop10Accuracy=0.7375, NarTop10Accuracy=0.5304, over 1204.36 frames. ], batch size: 3, lr: 4.55e-03, grad_scale: 16.0
|
47 |
+
2023-11-01 01:13:54,893 INFO [train.py:764] Epoch 1, batch 3500, train_loss[loss=3.203, ArTop10Accuracy=0.7498, NarTop10Accuracy=0.6391, over 1315.00 frames. ], tot_loss[loss=3.417, ArTop10Accuracy=0.7378, NarTop10Accuracy=0.5328, over 1203.43 frames. ], batch size: 3, lr: 4.53e-03, grad_scale: 16.0
|
48 |
+
2023-11-01 01:14:16,952 INFO [train.py:764] Epoch 1, batch 3600, train_loss[loss=3.249, ArTop10Accuracy=0.7305, NarTop10Accuracy=0.596, over 1002.00 frames. ], tot_loss[loss=3.41, ArTop10Accuracy=0.738, NarTop10Accuracy=0.5372, over 1200.89 frames. ], batch size: 2, lr: 4.50e-03, grad_scale: 16.0
|
49 |
+
2023-11-01 01:14:38,878 INFO [train.py:764] Epoch 1, batch 3700, train_loss[loss=3.047, ArTop10Accuracy=0.7559, NarTop10Accuracy=0.6173, over 1270.00 frames. ], tot_loss[loss=3.404, ArTop10Accuracy=0.7377, NarTop10Accuracy=0.5375, over 1199.71 frames. ], batch size: 3, lr: 4.48e-03, grad_scale: 16.0
|
50 |
+
2023-11-01 01:15:00,764 INFO [train.py:764] Epoch 1, batch 3800, train_loss[loss=3.431, ArTop10Accuracy=0.7261, NarTop10Accuracy=0.5124, over 953.00 frames. ], tot_loss[loss=3.422, ArTop10Accuracy=0.737, NarTop10Accuracy=0.5299, over 1206.54 frames. ], batch size: 2, lr: 4.46e-03, grad_scale: 16.0
|
51 |
+
2023-11-01 01:15:22,762 INFO [train.py:764] Epoch 1, batch 3900, train_loss[loss=3.266, ArTop10Accuracy=0.766, NarTop10Accuracy=0.5619, over 1346.00 frames. ], tot_loss[loss=3.409, ArTop10Accuracy=0.738, NarTop10Accuracy=0.5353, over 1214.09 frames. ], batch size: 3, lr: 4.44e-03, grad_scale: 16.0
|
52 |
+
2023-11-01 01:15:44,640 INFO [train.py:764] Epoch 1, batch 4000, train_loss[loss=3.197, ArTop10Accuracy=0.7829, NarTop10Accuracy=0.5988, over 1336.00 frames. ], tot_loss[loss=3.396, ArTop10Accuracy=0.7388, NarTop10Accuracy=0.539, over 1202.95 frames. ], batch size: 3, lr: 4.42e-03, grad_scale: 32.0
|
53 |
+
2023-11-01 01:15:44,805 INFO [utils.py:877] Clipping_scale=2.0, grad-norm quartiles 2.733e+01 3.740e+01 3.969e+01 4.257e+01 1.029e+02, threshold=7.938e+01, percent-clipped=0.2
|
54 |
+
2023-11-01 01:16:06,416 INFO [train.py:764] Epoch 1, batch 4100, train_loss[loss=3.244, ArTop10Accuracy=0.7525, NarTop10Accuracy=0.5985, over 1277.00 frames. ], tot_loss[loss=3.4, ArTop10Accuracy=0.7391, NarTop10Accuracy=0.5378, over 1197.35 frames. ], batch size: 3, lr: 4.40e-03, grad_scale: 32.0
|
55 |
+
2023-11-01 01:16:28,445 INFO [train.py:764] Epoch 1, batch 4200, train_loss[loss=3.324, ArTop10Accuracy=0.7329, NarTop10Accuracy=0.5902, over 1228.00 frames. ], tot_loss[loss=3.395, ArTop10Accuracy=0.7386, NarTop10Accuracy=0.5409, over 1204.03 frames. ], batch size: 3, lr: 4.38e-03, grad_scale: 32.0
|
56 |
+
2023-11-01 01:16:50,432 INFO [train.py:764] Epoch 1, batch 4300, train_loss[loss=3.314, ArTop10Accuracy=0.7644, NarTop10Accuracy=0.5612, over 1125.00 frames. ], tot_loss[loss=3.401, ArTop10Accuracy=0.7396, NarTop10Accuracy=0.5402, over 1201.24 frames. ], batch size: 1, lr: 4.35e-03, grad_scale: 8.0
|
57 |
+
2023-11-01 01:17:12,468 INFO [train.py:764] Epoch 1, batch 4400, train_loss[loss=2.998, ArTop10Accuracy=0.754, NarTop10Accuracy=0.7196, over 1297.00 frames. ], tot_loss[loss=3.392, ArTop10Accuracy=0.7426, NarTop10Accuracy=0.54, over 1202.84 frames. ], batch size: 3, lr: 4.33e-03, grad_scale: 8.0
|
58 |
+
2023-11-01 01:17:34,486 INFO [train.py:764] Epoch 1, batch 4500, train_loss[loss=3.158, ArTop10Accuracy=0.7451, NarTop10Accuracy=0.6344, over 1271.00 frames. ], tot_loss[loss=3.377, ArTop10Accuracy=0.7434, NarTop10Accuracy=0.5462, over 1206.43 frames. ], batch size: 3, lr: 4.31e-03, grad_scale: 8.0
|
59 |
+
2023-11-01 01:17:56,444 INFO [train.py:764] Epoch 1, batch 4600, train_loss[loss=3.468, ArTop10Accuracy=0.7347, NarTop10Accuracy=0.4962, over 980.00 frames. ], tot_loss[loss=3.379, ArTop10Accuracy=0.7433, NarTop10Accuracy=0.5449, over 1199.69 frames. ], batch size: 2, lr: 4.29e-03, grad_scale: 8.0
|
60 |
+
2023-11-01 01:18:18,361 INFO [train.py:764] Epoch 1, batch 4700, train_loss[loss=3.19, ArTop10Accuracy=0.7266, NarTop10Accuracy=0.6189, over 1280.00 frames. ], tot_loss[loss=3.379, ArTop10Accuracy=0.7423, NarTop10Accuracy=0.545, over 1200.61 frames. ], batch size: 3, lr: 4.27e-03, grad_scale: 8.0
|
61 |
+
2023-11-01 01:18:40,342 INFO [train.py:764] Epoch 1, batch 4800, train_loss[loss=3.732, ArTop10Accuracy=0.7355, NarTop10Accuracy=0.4267, over 1176.00 frames. ], tot_loss[loss=3.384, ArTop10Accuracy=0.7417, NarTop10Accuracy=0.5435, over 1204.83 frames. ], batch size: 2, lr: 4.25e-03, grad_scale: 8.0
|
62 |
+
2023-11-01 01:19:02,286 INFO [train.py:764] Epoch 1, batch 4900, train_loss[loss=3.203, ArTop10Accuracy=0.7443, NarTop10Accuracy=0.6233, over 962.00 frames. ], tot_loss[loss=3.396, ArTop10Accuracy=0.7404, NarTop10Accuracy=0.5386, over 1196.82 frames. ], batch size: 2, lr: 4.23e-03, grad_scale: 8.0
|
63 |
+
2023-11-01 01:19:24,468 INFO [train.py:764] Epoch 1, batch 5000, train_loss[loss=3.96, ArTop10Accuracy=0.6766, NarTop10Accuracy=0.3849, over 1002.00 frames. ], tot_loss[loss=3.4, ArTop10Accuracy=0.74, NarTop10Accuracy=0.5378, over 1198.57 frames. ], batch size: 2, lr: 4.20e-03, grad_scale: 8.0
|
64 |
+
2023-11-01 01:19:25,064 INFO [utils.py:877] Clipping_scale=2.0, grad-norm quartiles 2.820e+01 3.661e+01 3.895e+01 4.176e+01 5.538e+02, threshold=7.791e+01, percent-clipped=0.3
|
65 |
+
2023-11-01 01:19:46,481 INFO [train.py:764] Epoch 1, batch 5100, train_loss[loss=3.419, ArTop10Accuracy=0.7219, NarTop10Accuracy=0.5399, over 1050.00 frames. ], tot_loss[loss=3.389, ArTop10Accuracy=0.7419, NarTop10Accuracy=0.5407, over 1202.12 frames. ], batch size: 2, lr: 4.18e-03, grad_scale: 8.0
|
66 |
+
2023-11-01 01:20:08,364 INFO [train.py:764] Epoch 1, batch 5200, train_loss[loss=3.099, ArTop10Accuracy=0.7485, NarTop10Accuracy=0.6468, over 1352.00 frames. ], tot_loss[loss=3.381, ArTop10Accuracy=0.7422, NarTop10Accuracy=0.5441, over 1199.90 frames. ], batch size: 3, lr: 4.16e-03, grad_scale: 8.0
|
67 |
+
2023-11-01 01:20:30,120 INFO [train.py:764] Epoch 1, batch 5300, train_loss[loss=3.2, ArTop10Accuracy=0.7527, NarTop10Accuracy=0.573, over 1217.00 frames. ], tot_loss[loss=3.398, ArTop10Accuracy=0.7423, NarTop10Accuracy=0.5382, over 1195.55 frames. ], batch size: 3, lr: 4.14e-03, grad_scale: 8.0
|
68 |
+
2023-11-01 01:20:52,015 INFO [train.py:764] Epoch 1, batch 5400, train_loss[loss=3.275, ArTop10Accuracy=0.7423, NarTop10Accuracy=0.5832, over 1300.00 frames. ], tot_loss[loss=3.384, ArTop10Accuracy=0.7429, NarTop10Accuracy=0.5434, over 1202.33 frames. ], batch size: 3, lr: 4.12e-03, grad_scale: 8.0
|
69 |
+
2023-11-01 01:21:13,989 INFO [train.py:764] Epoch 1, batch 5500, train_loss[loss=3.955, ArTop10Accuracy=0.6973, NarTop10Accuracy=0.361, over 1318.00 frames. ], tot_loss[loss=3.391, ArTop10Accuracy=0.7416, NarTop10Accuracy=0.5408, over 1204.76 frames. ], batch size: 3, lr: 4.10e-03, grad_scale: 8.0
|
70 |
+
2023-11-01 01:21:36,015 INFO [train.py:764] Epoch 1, batch 5600, train_loss[loss=3.215, ArTop10Accuracy=0.7528, NarTop10Accuracy=0.6052, over 1064.00 frames. ], tot_loss[loss=3.389, ArTop10Accuracy=0.7419, NarTop10Accuracy=0.5422, over 1209.03 frames. ], batch size: 2, lr: 4.08e-03, grad_scale: 8.0
|
71 |
+
2023-11-01 01:21:58,246 INFO [train.py:764] Epoch 1, batch 5700, train_loss[loss=3.757, ArTop10Accuracy=0.7044, NarTop10Accuracy=0.4133, over 1309.00 frames. ], tot_loss[loss=3.403, ArTop10Accuracy=0.7408, NarTop10Accuracy=0.5376, over 1206.83 frames. ], batch size: 3, lr: 4.06e-03, grad_scale: 8.0
|
72 |
+
2023-11-01 01:22:20,084 INFO [train.py:764] Epoch 1, batch 5800, train_loss[loss=3.222, ArTop10Accuracy=0.7902, NarTop10Accuracy=0.5901, over 1077.00 frames. ], tot_loss[loss=3.393, ArTop10Accuracy=0.7415, NarTop10Accuracy=0.5391, over 1198.69 frames. ], batch size: 2, lr: 4.04e-03, grad_scale: 8.0
|
73 |
+
2023-11-01 01:22:41,986 INFO [train.py:764] Epoch 1, batch 5900, train_loss[loss=3.256, ArTop10Accuracy=0.7596, NarTop10Accuracy=0.5688, over 1410.00 frames. ], tot_loss[loss=3.379, ArTop10Accuracy=0.7435, NarTop10Accuracy=0.5425, over 1195.35 frames. ], batch size: 2, lr: 4.02e-03, grad_scale: 8.0
|
74 |
+
2023-11-01 01:23:03,855 INFO [train.py:764] Epoch 1, batch 6000, train_loss[loss=3.04, ArTop10Accuracy=0.7555, NarTop10Accuracy=0.6494, over 1264.00 frames. ], tot_loss[loss=3.367, ArTop10Accuracy=0.7443, NarTop10Accuracy=0.5488, over 1191.87 frames. ], batch size: 3, lr: 4.00e-03, grad_scale: 8.0
|
75 |
+
2023-11-01 01:23:04,464 INFO [utils.py:877] Clipping_scale=2.0, grad-norm quartiles 2.955e+01 3.660e+01 3.898e+01 4.175e+01 1.106e+02, threshold=7.796e+01, percent-clipped=0.2
|
76 |
+
2023-11-01 01:23:25,947 INFO [train.py:764] Epoch 1, batch 6100, train_loss[loss=3.244, ArTop10Accuracy=0.7569, NarTop10Accuracy=0.5706, over 1275.00 frames. ], tot_loss[loss=3.366, ArTop10Accuracy=0.7445, NarTop10Accuracy=0.5487, over 1200.31 frames. ], batch size: 3, lr: 3.98e-03, grad_scale: 8.0
|
77 |
+
2023-11-01 01:23:47,870 INFO [train.py:764] Epoch 1, batch 6200, train_loss[loss=3.64, ArTop10Accuracy=0.7432, NarTop10Accuracy=0.4895, over 1067.00 frames. ], tot_loss[loss=3.368, ArTop10Accuracy=0.7455, NarTop10Accuracy=0.5473, over 1197.29 frames. ], batch size: 2, lr: 3.96e-03, grad_scale: 8.0
|
78 |
+
2023-11-01 01:24:09,740 INFO [train.py:764] Epoch 1, batch 6300, train_loss[loss=2.884, ArTop10Accuracy=0.7787, NarTop10Accuracy=0.656, over 1229.00 frames. ], tot_loss[loss=3.358, ArTop10Accuracy=0.7462, NarTop10Accuracy=0.5488, over 1195.66 frames. ], batch size: 3, lr: 3.94e-03, grad_scale: 16.0
|
79 |
+
2023-11-01 01:24:31,832 INFO [train.py:764] Epoch 1, batch 6400, train_loss[loss=3.337, ArTop10Accuracy=0.7198, NarTop10Accuracy=0.6022, over 1342.00 frames. ], tot_loss[loss=3.363, ArTop10Accuracy=0.7452, NarTop10Accuracy=0.5489, over 1200.90 frames. ], batch size: 3, lr: 3.92e-03, grad_scale: 16.0
|
80 |
+
2023-11-01 01:24:53,712 INFO [train.py:764] Epoch 1, batch 6500, train_loss[loss=3.571, ArTop10Accuracy=0.7508, NarTop10Accuracy=0.4587, over 1324.00 frames. ], tot_loss[loss=3.364, ArTop10Accuracy=0.7447, NarTop10Accuracy=0.5472, over 1203.64 frames. ], batch size: 3, lr: 3.90e-03, grad_scale: 16.0
|
81 |
+
2023-11-01 01:25:15,706 INFO [train.py:764] Epoch 1, batch 6600, train_loss[loss=3.618, ArTop10Accuracy=0.7208, NarTop10Accuracy=0.4864, over 1157.00 frames. ], tot_loss[loss=3.375, ArTop10Accuracy=0.7441, NarTop10Accuracy=0.544, over 1199.80 frames. ], batch size: 2, lr: 3.89e-03, grad_scale: 16.0
|
82 |
+
2023-11-01 01:25:37,824 INFO [train.py:764] Epoch 1, batch 6700, train_loss[loss=3.916, ArTop10Accuracy=0.7235, NarTop10Accuracy=0.3773, over 1512.00 frames. ], tot_loss[loss=3.373, ArTop10Accuracy=0.7464, NarTop10Accuracy=0.5436, over 1206.18 frames. ], batch size: 2, lr: 3.87e-03, grad_scale: 16.0
|
83 |
+
2023-11-01 01:25:59,844 INFO [train.py:764] Epoch 1, batch 6800, train_loss[loss=3.743, ArTop10Accuracy=0.6712, NarTop10Accuracy=0.473, over 1250.00 frames. ], tot_loss[loss=3.378, ArTop10Accuracy=0.747, NarTop10Accuracy=0.5415, over 1211.75 frames. ], batch size: 3, lr: 3.85e-03, grad_scale: 16.0
|
84 |
+
2023-11-01 01:26:21,673 INFO [train.py:764] Epoch 1, batch 6900, train_loss[loss=3.288, ArTop10Accuracy=0.794, NarTop10Accuracy=0.5091, over 1199.00 frames. ], tot_loss[loss=3.368, ArTop10Accuracy=0.748, NarTop10Accuracy=0.5445, over 1203.83 frames. ], batch size: 3, lr: 3.83e-03, grad_scale: 16.0
|
85 |
+
2023-11-01 01:26:43,753 INFO [train.py:764] Epoch 1, batch 7000, train_loss[loss=3.611, ArTop10Accuracy=0.6679, NarTop10Accuracy=0.559, over 1054.00 frames. ], tot_loss[loss=3.379, ArTop10Accuracy=0.7463, NarTop10Accuracy=0.5433, over 1205.70 frames. ], batch size: 2, lr: 3.81e-03, grad_scale: 16.0
|
86 |
+
2023-11-01 01:26:44,377 INFO [utils.py:877] Clipping_scale=2.0, grad-norm quartiles 2.938e+01 3.640e+01 3.899e+01 4.197e+01 1.539e+02, threshold=7.798e+01, percent-clipped=0.2
|
87 |
+
2023-11-01 01:27:05,623 INFO [train.py:764] Epoch 1, batch 7100, train_loss[loss=3.34, ArTop10Accuracy=0.7641, NarTop10Accuracy=0.5495, over 1221.00 frames. ], tot_loss[loss=3.364, ArTop10Accuracy=0.7483, NarTop10Accuracy=0.5462, over 1200.17 frames. ], batch size: 3, lr: 3.79e-03, grad_scale: 16.0
|
88 |
+
2023-11-01 01:27:27,790 INFO [train.py:764] Epoch 1, batch 7200, train_loss[loss=3.402, ArTop10Accuracy=0.7728, NarTop10Accuracy=0.522, over 1109.00 frames. ], tot_loss[loss=3.349, ArTop10Accuracy=0.75, NarTop10Accuracy=0.551, over 1200.81 frames. ], batch size: 2, lr: 3.78e-03, grad_scale: 16.0
|
89 |
+
2023-11-01 01:27:49,946 INFO [train.py:764] Epoch 1, batch 7300, train_loss[loss=3.102, ArTop10Accuracy=0.7641, NarTop10Accuracy=0.645, over 1047.00 frames. ], tot_loss[loss=3.358, ArTop10Accuracy=0.7507, NarTop10Accuracy=0.5483, over 1199.39 frames. ], batch size: 2, lr: 3.76e-03, grad_scale: 16.0
|
90 |
+
2023-11-01 01:28:12,023 INFO [train.py:764] Epoch 1, batch 7400, train_loss[loss=3.492, ArTop10Accuracy=0.7566, NarTop10Accuracy=0.49, over 1142.00 frames. ], tot_loss[loss=3.35, ArTop10Accuracy=0.7503, NarTop10Accuracy=0.552, over 1204.30 frames. ], batch size: 2, lr: 3.74e-03, grad_scale: 16.0
|
91 |
+
2023-11-01 01:28:34,245 INFO [train.py:764] Epoch 1, batch 7500, train_loss[loss=3.834, ArTop10Accuracy=0.6968, NarTop10Accuracy=0.4323, over 1019.00 frames. ], tot_loss[loss=3.359, ArTop10Accuracy=0.7492, NarTop10Accuracy=0.5488, over 1211.85 frames. ], batch size: 2, lr: 3.72e-03, grad_scale: 16.0
|
92 |
+
2023-11-01 01:28:56,204 INFO [train.py:764] Epoch 1, batch 7600, train_loss[loss=3.051, ArTop10Accuracy=0.7471, NarTop10Accuracy=0.6682, over 1538.00 frames. ], tot_loss[loss=3.346, ArTop10Accuracy=0.7511, NarTop10Accuracy=0.5516, over 1203.46 frames. ], batch size: 3, lr: 3.71e-03, grad_scale: 16.0
|
93 |
+
2023-11-01 01:29:18,414 INFO [train.py:764] Epoch 1, batch 7700, train_loss[loss=3.394, ArTop10Accuracy=0.7467, NarTop10Accuracy=0.508, over 1500.00 frames. ], tot_loss[loss=3.342, ArTop10Accuracy=0.7508, NarTop10Accuracy=0.5528, over 1210.10 frames. ], batch size: 3, lr: 3.69e-03, grad_scale: 16.0
|
94 |
+
2023-11-01 01:29:40,478 INFO [train.py:764] Epoch 1, batch 7800, train_loss[loss=3.565, ArTop10Accuracy=0.7138, NarTop10Accuracy=0.5097, over 1300.00 frames. ], tot_loss[loss=3.345, ArTop10Accuracy=0.7497, NarTop10Accuracy=0.553, over 1207.39 frames. ], batch size: 3, lr: 3.67e-03, grad_scale: 16.0
|
95 |
+
2023-11-01 01:30:02,574 INFO [train.py:764] Epoch 1, batch 7900, train_loss[loss=3.114, ArTop10Accuracy=0.7528, NarTop10Accuracy=0.6468, over 1323.00 frames. ], tot_loss[loss=3.348, ArTop10Accuracy=0.7485, NarTop10Accuracy=0.5526, over 1205.69 frames. ], batch size: 3, lr: 3.66e-03, grad_scale: 16.0
|
96 |
+
2023-11-01 01:30:24,766 INFO [train.py:764] Epoch 1, batch 8000, train_loss[loss=3.416, ArTop10Accuracy=0.7549, NarTop10Accuracy=0.5289, over 1314.00 frames. ], tot_loss[loss=3.344, ArTop10Accuracy=0.7481, NarTop10Accuracy=0.553, over 1204.55 frames. ], batch size: 2, lr: 3.64e-03, grad_scale: 16.0
|
97 |
+
2023-11-01 01:30:25,379 INFO [utils.py:877] Clipping_scale=2.0, grad-norm quartiles 2.581e+01 3.622e+01 3.870e+01 4.187e+01 1.276e+02, threshold=7.739e+01, percent-clipped=0.3
|
98 |
+
2023-11-01 01:30:47,002 INFO [train.py:764] Epoch 1, batch 8100, train_loss[loss=3.579, ArTop10Accuracy=0.7254, NarTop10Accuracy=0.493, over 1453.00 frames. ], tot_loss[loss=3.368, ArTop10Accuracy=0.7466, NarTop10Accuracy=0.5466, over 1201.98 frames. ], batch size: 3, lr: 3.62e-03, grad_scale: 16.0
|
99 |
+
2023-11-01 01:31:09,160 INFO [train.py:764] Epoch 1, batch 8200, train_loss[loss=3.123, ArTop10Accuracy=0.7755, NarTop10Accuracy=0.6341, over 1198.00 frames. ], tot_loss[loss=3.377, ArTop10Accuracy=0.7455, NarTop10Accuracy=0.5452, over 1201.66 frames. ], batch size: 3, lr: 3.61e-03, grad_scale: 16.0
|
100 |
+
2023-11-01 01:31:31,231 INFO [train.py:764] Epoch 1, batch 8300, train_loss[loss=3.518, ArTop10Accuracy=0.7555, NarTop10Accuracy=0.4488, over 1354.00 frames. ], tot_loss[loss=3.362, ArTop10Accuracy=0.7486, NarTop10Accuracy=0.5461, over 1200.11 frames. ], batch size: 3, lr: 3.59e-03, grad_scale: 32.0
|
101 |
+
2023-11-01 01:31:53,404 INFO [train.py:764] Epoch 1, batch 8400, train_loss[loss=3.782, ArTop10Accuracy=0.7124, NarTop10Accuracy=0.4265, over 1064.00 frames. ], tot_loss[loss=3.363, ArTop10Accuracy=0.7499, NarTop10Accuracy=0.5457, over 1201.74 frames. ], batch size: 2, lr: 3.58e-03, grad_scale: 32.0
|
102 |
+
2023-11-01 01:32:15,398 INFO [train.py:764] Epoch 1, batch 8500, train_loss[loss=3.664, ArTop10Accuracy=0.7246, NarTop10Accuracy=0.4731, over 1082.00 frames. ], tot_loss[loss=3.357, ArTop10Accuracy=0.7498, NarTop10Accuracy=0.5488, over 1189.28 frames. ], batch size: 2, lr: 3.56e-03, grad_scale: 32.0
|
103 |
+
2023-11-01 01:32:37,449 INFO [train.py:764] Epoch 1, batch 8600, train_loss[loss=3.343, ArTop10Accuracy=0.755, NarTop10Accuracy=0.5416, over 1208.00 frames. ], tot_loss[loss=3.362, ArTop10Accuracy=0.7486, NarTop10Accuracy=0.5467, over 1195.17 frames. ], batch size: 3, lr: 3.54e-03, grad_scale: 32.0
|
104 |
+
2023-11-01 01:32:59,631 INFO [train.py:764] Epoch 1, batch 8700, train_loss[loss=3.506, ArTop10Accuracy=0.7528, NarTop10Accuracy=0.4847, over 1076.00 frames. ], tot_loss[loss=3.341, ArTop10Accuracy=0.7514, NarTop10Accuracy=0.5544, over 1199.46 frames. ], batch size: 2, lr: 3.53e-03, grad_scale: 32.0
|
105 |
+
2023-11-01 01:33:21,552 INFO [train.py:764] Epoch 1, batch 8800, train_loss[loss=3.168, ArTop10Accuracy=0.7199, NarTop10Accuracy=0.6646, over 1178.00 frames. ], tot_loss[loss=3.338, ArTop10Accuracy=0.751, NarTop10Accuracy=0.5547, over 1196.85 frames. ], batch size: 3, lr: 3.51e-03, grad_scale: 32.0
|
106 |
+
2023-11-01 01:33:43,552 INFO [train.py:764] Epoch 1, batch 8900, train_loss[loss=3.424, ArTop10Accuracy=0.7301, NarTop10Accuracy=0.5693, over 1330.00 frames. ], tot_loss[loss=3.332, ArTop10Accuracy=0.7518, NarTop10Accuracy=0.5568, over 1197.63 frames. ], batch size: 3, lr: 3.50e-03, grad_scale: 32.0
|
107 |
+
2023-11-01 01:34:05,651 INFO [train.py:764] Epoch 1, batch 9000, train_loss[loss=3.032, ArTop10Accuracy=0.7749, NarTop10Accuracy=0.6663, over 1355.00 frames. ], tot_loss[loss=3.344, ArTop10Accuracy=0.7508, NarTop10Accuracy=0.5512, over 1206.31 frames. ], batch size: 3, lr: 3.48e-03, grad_scale: 32.0
|
108 |
+
2023-11-01 01:34:06,281 INFO [utils.py:877] Clipping_scale=2.0, grad-norm quartiles 2.623e+01 3.625e+01 3.891e+01 4.198e+01 1.204e+02, threshold=7.783e+01, percent-clipped=0.2
|
109 |
+
2023-11-01 01:34:27,643 INFO [train.py:764] Epoch 1, batch 9100, train_loss[loss=3.115, ArTop10Accuracy=0.7458, NarTop10Accuracy=0.6156, over 1255.00 frames. ], tot_loss[loss=3.325, ArTop10Accuracy=0.7508, NarTop10Accuracy=0.5588, over 1204.55 frames. ], batch size: 3, lr: 3.47e-03, grad_scale: 32.0
|
110 |
+
2023-11-01 01:34:49,396 INFO [train.py:764] Epoch 1, batch 9200, train_loss[loss=3.593, ArTop10Accuracy=0.7426, NarTop10Accuracy=0.4606, over 1243.00 frames. ], tot_loss[loss=3.323, ArTop10Accuracy=0.7518, NarTop10Accuracy=0.5592, over 1199.74 frames. ], batch size: 3, lr: 3.46e-03, grad_scale: 32.0
|
111 |
+
2023-11-01 01:35:11,250 INFO [train.py:764] Epoch 1, batch 9300, train_loss[loss=3.347, ArTop10Accuracy=0.786, NarTop10Accuracy=0.5172, over 1243.00 frames. ], tot_loss[loss=3.328, ArTop10Accuracy=0.7537, NarTop10Accuracy=0.5565, over 1196.75 frames. ], batch size: 3, lr: 3.44e-03, grad_scale: 32.0
|
112 |
+
2023-11-01 01:35:33,210 INFO [train.py:764] Epoch 1, batch 9400, train_loss[loss=3.267, ArTop10Accuracy=0.7515, NarTop10Accuracy=0.5882, over 1501.00 frames. ], tot_loss[loss=3.318, ArTop10Accuracy=0.7535, NarTop10Accuracy=0.5602, over 1201.10 frames. ], batch size: 3, lr: 3.43e-03, grad_scale: 32.0
|
113 |
+
2023-11-01 01:35:55,287 INFO [train.py:764] Epoch 1, batch 9500, train_loss[loss=3.33, ArTop10Accuracy=0.741, NarTop10Accuracy=0.6289, over 1023.00 frames. ], tot_loss[loss=3.318, ArTop10Accuracy=0.7546, NarTop10Accuracy=0.5603, over 1206.75 frames. ], batch size: 2, lr: 3.41e-03, grad_scale: 32.0
|
114 |
+
2023-11-01 01:36:17,423 INFO [train.py:764] Epoch 1, batch 9600, train_loss[loss=3.327, ArTop10Accuracy=0.7469, NarTop10Accuracy=0.6008, over 1225.00 frames. ], tot_loss[loss=3.326, ArTop10Accuracy=0.7527, NarTop10Accuracy=0.5576, over 1210.68 frames. ], batch size: 3, lr: 3.40e-03, grad_scale: 32.0
|
115 |
+
2023-11-01 01:36:39,461 INFO [train.py:764] Epoch 1, batch 9700, train_loss[loss=3.695, ArTop10Accuracy=0.7334, NarTop10Accuracy=0.4212, over 1358.00 frames. ], tot_loss[loss=3.33, ArTop10Accuracy=0.7523, NarTop10Accuracy=0.5554, over 1204.10 frames. ], batch size: 3, lr: 3.38e-03, grad_scale: 32.0
|
116 |
+
2023-11-01 01:37:01,299 INFO [train.py:764] Epoch 1, batch 9800, train_loss[loss=3.36, ArTop10Accuracy=0.7327, NarTop10Accuracy=0.6126, over 1025.00 frames. ], tot_loss[loss=3.328, ArTop10Accuracy=0.7528, NarTop10Accuracy=0.5558, over 1201.65 frames. ], batch size: 2, lr: 3.37e-03, grad_scale: 32.0
|
117 |
+
2023-11-01 01:37:23,421 INFO [train.py:764] Epoch 1, batch 9900, train_loss[loss=3.719, ArTop10Accuracy=0.7257, NarTop10Accuracy=0.4459, over 1130.00 frames. ], tot_loss[loss=3.331, ArTop10Accuracy=0.7523, NarTop10Accuracy=0.5558, over 1206.02 frames. ], batch size: 2, lr: 3.36e-03, grad_scale: 32.0
|
118 |
+
2023-11-01 01:37:45,281 INFO [utils.py:237] Saving checkpoint to exp/valle_dev/checkpoint-10000.pt
|
119 |
+
2023-11-01 01:37:54,135 INFO [train.py:764] Epoch 1, batch 10000, train_loss[loss=3.626, ArTop10Accuracy=0.7504, NarTop10Accuracy=0.4714, over 1278.00 frames. ], tot_loss[loss=3.326, ArTop10Accuracy=0.7525, NarTop10Accuracy=0.5572, over 1196.17 frames. ], batch size: 3, lr: 3.34e-03, grad_scale: 32.0
|
120 |
+
2023-11-01 01:37:54,138 INFO [train.py:802] Computing validation loss
|
121 |
+
2023-11-01 01:41:43,550 INFO [train.py:810] Epoch 1, validation: loss=3.193, ArTop10Accuracy=0.7614, NarTop10Accuracy=0.5796, over 1739106.00 frames.
|
122 |
+
2023-11-01 01:41:43,550 INFO [train.py:813] Maximum memory allocated so far is 17387MB
|
123 |
+
2023-11-01 01:41:44,169 INFO [utils.py:877] Clipping_scale=2.0, grad-norm quartiles 2.795e+01 3.605e+01 3.882e+01 4.179e+01 1.156e+02, threshold=7.765e+01, percent-clipped=0.2
|
124 |
+
2023-11-01 01:42:05,744 INFO [train.py:764] Epoch 1, batch 10100, train_loss[loss=3.277, ArTop10Accuracy=0.7554, NarTop10Accuracy=0.5287, over 1165.00 frames. ], tot_loss[loss=3.317, ArTop10Accuracy=0.754, NarTop10Accuracy=0.5596, over 1204.24 frames. ], batch size: 3, lr: 3.33e-03, grad_scale: 32.0
|
125 |
+
2023-11-01 01:42:27,939 INFO [train.py:764] Epoch 1, batch 10200, train_loss[loss=3.551, ArTop10Accuracy=0.7434, NarTop10Accuracy=0.4833, over 1243.00 frames. ], tot_loss[loss=3.316, ArTop10Accuracy=0.7536, NarTop10Accuracy=0.5606, over 1202.58 frames. ], batch size: 3, lr: 3.32e-03, grad_scale: 32.0
|
126 |
+
2023-11-01 01:42:50,826 INFO [train.py:764] Epoch 1, batch 10300, train_loss[loss=3.3, ArTop10Accuracy=0.7664, NarTop10Accuracy=0.508, over 1049.00 frames. ], tot_loss[loss=3.326, ArTop10Accuracy=0.7522, NarTop10Accuracy=0.5587, over 1209.79 frames. ], batch size: 2, lr: 3.30e-03, grad_scale: 64.0
|
127 |
+
2023-11-01 01:43:12,842 INFO [train.py:764] Epoch 1, batch 10400, train_loss[loss=3.616, ArTop10Accuracy=0.7065, NarTop10Accuracy=0.5428, over 1237.00 frames. ], tot_loss[loss=3.314, ArTop10Accuracy=0.7537, NarTop10Accuracy=0.5616, over 1195.83 frames. ], batch size: 3, lr: 3.29e-03, grad_scale: 64.0
|
128 |
+
2023-11-01 01:43:34,961 INFO [train.py:764] Epoch 1, batch 10500, train_loss[loss=3.53, ArTop10Accuracy=0.7334, NarTop10Accuracy=0.5249, over 1343.00 frames. ], tot_loss[loss=3.328, ArTop10Accuracy=0.7528, NarTop10Accuracy=0.5573, over 1201.99 frames. ], batch size: 3, lr: 3.28e-03, grad_scale: 64.0
|
129 |
+
2023-11-01 01:43:57,011 INFO [train.py:764] Epoch 1, batch 10600, train_loss[loss=3.209, ArTop10Accuracy=0.7266, NarTop10Accuracy=0.6411, over 1262.00 frames. ], tot_loss[loss=3.323, ArTop10Accuracy=0.7559, NarTop10Accuracy=0.5564, over 1204.76 frames. ], batch size: 3, lr: 3.27e-03, grad_scale: 16.0
|
130 |
+
2023-11-01 01:44:19,318 INFO [train.py:764] Epoch 1, batch 10700, train_loss[loss=3.45, ArTop10Accuracy=0.716, NarTop10Accuracy=0.5662, over 1067.00 frames. ], tot_loss[loss=3.324, ArTop10Accuracy=0.7551, NarTop10Accuracy=0.5563, over 1203.94 frames. ], batch size: 2, lr: 3.25e-03, grad_scale: 16.0
|
131 |
+
2023-11-01 01:44:42,212 INFO [train.py:764] Epoch 1, batch 10800, train_loss[loss=3.276, ArTop10Accuracy=0.7893, NarTop10Accuracy=0.5161, over 992.00 frames. ], tot_loss[loss=3.311, ArTop10Accuracy=0.7567, NarTop10Accuracy=0.5599, over 1206.94 frames. ], batch size: 2, lr: 3.24e-03, grad_scale: 16.0
|
132 |
+
2023-11-01 01:45:04,261 INFO [train.py:764] Epoch 1, batch 10900, train_loss[loss=3.293, ArTop10Accuracy=0.7395, NarTop10Accuracy=0.5362, over 833.00 frames. ], tot_loss[loss=3.312, ArTop10Accuracy=0.7552, NarTop10Accuracy=0.5621, over 1197.04 frames. ], batch size: 1, lr: 3.23e-03, grad_scale: 16.0
|
133 |
+
2023-11-01 01:45:26,606 INFO [train.py:764] Epoch 1, batch 11000, train_loss[loss=3.684, ArTop10Accuracy=0.7243, NarTop10Accuracy=0.4497, over 1023.00 frames. ], tot_loss[loss=3.319, ArTop10Accuracy=0.7556, NarTop10Accuracy=0.5566, over 1192.92 frames. ], batch size: 2, lr: 3.22e-03, grad_scale: 16.0
|
134 |
+
2023-11-01 01:45:27,631 INFO [utils.py:877] Clipping_scale=2.0, grad-norm quartiles 2.656e+01 3.581e+01 3.876e+01 4.218e+01 2.198e+02, threshold=7.753e+01, percent-clipped=0.5
|
135 |
+
2023-11-01 01:45:48,570 INFO [train.py:764] Epoch 1, batch 11100, train_loss[loss=3.345, ArTop10Accuracy=0.7677, NarTop10Accuracy=0.5429, over 1218.00 frames. ], tot_loss[loss=3.328, ArTop10Accuracy=0.7565, NarTop10Accuracy=0.5523, over 1188.64 frames. ], batch size: 3, lr: 3.20e-03, grad_scale: 16.0
|
136 |
+
2023-11-01 01:46:11,057 INFO [train.py:764] Epoch 1, batch 11200, train_loss[loss=3.674, ArTop10Accuracy=0.7002, NarTop10Accuracy=0.4926, over 1451.00 frames. ], tot_loss[loss=3.328, ArTop10Accuracy=0.7558, NarTop10Accuracy=0.5533, over 1199.65 frames. ], batch size: 3, lr: 3.19e-03, grad_scale: 16.0
|
137 |
+
2023-11-01 01:46:33,210 INFO [train.py:764] Epoch 1, batch 11300, train_loss[loss=3.116, ArTop10Accuracy=0.7731, NarTop10Accuracy=0.6661, over 1516.00 frames. ], tot_loss[loss=3.333, ArTop10Accuracy=0.7557, NarTop10Accuracy=0.5515, over 1198.36 frames. ], batch size: 3, lr: 3.18e-03, grad_scale: 16.0
|
138 |
+
2023-11-01 01:46:55,584 INFO [train.py:764] Epoch 1, batch 11400, train_loss[loss=3.208, ArTop10Accuracy=0.7723, NarTop10Accuracy=0.6035, over 1010.00 frames. ], tot_loss[loss=3.317, ArTop10Accuracy=0.7566, NarTop10Accuracy=0.5581, over 1207.38 frames. ], batch size: 2, lr: 3.17e-03, grad_scale: 16.0
|
139 |
+
2023-11-01 01:47:18,045 INFO [train.py:764] Epoch 1, batch 11500, train_loss[loss=2.939, ArTop10Accuracy=0.7846, NarTop10Accuracy=0.6675, over 1114.00 frames. ], tot_loss[loss=3.314, ArTop10Accuracy=0.7566, NarTop10Accuracy=0.5598, over 1217.46 frames. ], batch size: 2, lr: 3.16e-03, grad_scale: 16.0
|
140 |
+
2023-11-01 01:47:39,987 INFO [train.py:764] Epoch 1, batch 11600, train_loss[loss=3.135, ArTop10Accuracy=0.7496, NarTop10Accuracy=0.6593, over 1178.00 frames. ], tot_loss[loss=3.321, ArTop10Accuracy=0.7559, NarTop10Accuracy=0.558, over 1200.17 frames. ], batch size: 3, lr: 3.15e-03, grad_scale: 16.0
|
141 |
+
2023-11-01 01:47:56,239 INFO [train.py:648] Reaches end of dataloader.
|
142 |
+
2023-11-01 01:47:56,242 INFO [utils.py:237] Saving checkpoint to exp/valle_dev/epoch-1.pt
|
143 |
+
2023-11-01 01:48:36,133 INFO [train.py:764] Epoch 2, batch 100, train_loss[loss=3.517, ArTop10Accuracy=0.7405, NarTop10Accuracy=0.5475, over 1291.00 frames. ], tot_loss[loss=3.278, ArTop10Accuracy=0.7682, NarTop10Accuracy=0.5618, over 475.14 frames. ], batch size: 3, lr: 3.08e-03, grad_scale: 16.0
|
144 |
+
2023-11-01 01:48:58,354 INFO [train.py:764] Epoch 2, batch 200, train_loss[loss=3.173, ArTop10Accuracy=0.7737, NarTop10Accuracy=0.5945, over 1264.00 frames. ], tot_loss[loss=3.284, ArTop10Accuracy=0.7679, NarTop10Accuracy=0.5589, over 764.18 frames. ], batch size: 3, lr: 3.07e-03, grad_scale: 16.0
|
145 |
+
2023-11-01 01:49:20,372 INFO [train.py:764] Epoch 2, batch 300, train_loss[loss=3.15, ArTop10Accuracy=0.7964, NarTop10Accuracy=0.5779, over 1238.00 frames. ], tot_loss[loss=3.279, ArTop10Accuracy=0.7677, NarTop10Accuracy=0.5628, over 933.46 frames. ], batch size: 3, lr: 3.06e-03, grad_scale: 16.0
|
146 |
+
2023-11-01 01:49:27,316 INFO [utils.py:877] Clipping_scale=2.0, grad-norm quartiles 2.615e+01 3.644e+01 3.906e+01 4.267e+01 1.348e+02, threshold=7.812e+01, percent-clipped=0.3
|
147 |
+
2023-11-01 01:49:42,335 INFO [train.py:764] Epoch 2, batch 400, train_loss[loss=3.183, ArTop10Accuracy=0.7688, NarTop10Accuracy=0.6139, over 1289.00 frames. ], tot_loss[loss=3.263, ArTop10Accuracy=0.7705, NarTop10Accuracy=0.5659, over 1033.80 frames. ], batch size: 3, lr: 3.05e-03, grad_scale: 16.0
|
148 |
+
2023-11-01 01:50:04,654 INFO [train.py:764] Epoch 2, batch 500, train_loss[loss=3.224, ArTop10Accuracy=0.7834, NarTop10Accuracy=0.5896, over 974.00 frames. ], tot_loss[loss=3.266, ArTop10Accuracy=0.7696, NarTop10Accuracy=0.5658, over 1103.83 frames. ], batch size: 2, lr: 3.04e-03, grad_scale: 16.0
|
149 |
+
2023-11-01 01:50:26,894 INFO [train.py:764] Epoch 2, batch 600, train_loss[loss=3.192, ArTop10Accuracy=0.7686, NarTop10Accuracy=0.5814, over 1059.00 frames. ], tot_loss[loss=3.252, ArTop10Accuracy=0.7713, NarTop10Accuracy=0.5709, over 1142.77 frames. ], batch size: 2, lr: 3.02e-03, grad_scale: 16.0
|
150 |
+
2023-11-01 01:50:49,069 INFO [train.py:764] Epoch 2, batch 700, train_loss[loss=3.272, ArTop10Accuracy=0.7538, NarTop10Accuracy=0.6008, over 1174.00 frames. ], tot_loss[loss=3.253, ArTop10Accuracy=0.7714, NarTop10Accuracy=0.5695, over 1168.33 frames. ], batch size: 3, lr: 3.01e-03, grad_scale: 16.0
|
151 |
+
2023-11-01 01:51:11,287 INFO [train.py:764] Epoch 2, batch 800, train_loss[loss=3.539, ArTop10Accuracy=0.7563, NarTop10Accuracy=0.4397, over 1227.00 frames. ], tot_loss[loss=3.273, ArTop10Accuracy=0.7692, NarTop10Accuracy=0.5629, over 1178.72 frames. ], batch size: 3, lr: 3.00e-03, grad_scale: 16.0
|
152 |
+
2023-11-01 01:51:33,655 INFO [train.py:764] Epoch 2, batch 900, train_loss[loss=3.88, ArTop10Accuracy=0.6952, NarTop10Accuracy=0.4376, over 1004.00 frames. ], tot_loss[loss=3.277, ArTop10Accuracy=0.7692, NarTop10Accuracy=0.56, over 1183.23 frames. ], batch size: 2, lr: 2.99e-03, grad_scale: 32.0
|
153 |
+
2023-11-01 01:51:55,740 INFO [train.py:764] Epoch 2, batch 1000, train_loss[loss=3.467, ArTop10Accuracy=0.7846, NarTop10Accuracy=0.4436, over 1235.00 frames. ], tot_loss[loss=3.28, ArTop10Accuracy=0.768, NarTop10Accuracy=0.5604, over 1186.34 frames. ], batch size: 3, lr: 2.98e-03, grad_scale: 32.0
|
154 |
+
2023-11-01 01:52:17,749 INFO [train.py:764] Epoch 2, batch 1100, train_loss[loss=3.059, ArTop10Accuracy=0.8348, NarTop10Accuracy=0.57, over 1277.00 frames. ], tot_loss[loss=3.262, ArTop10Accuracy=0.7692, NarTop10Accuracy=0.5665, over 1186.39 frames. ], batch size: 3, lr: 2.97e-03, grad_scale: 32.0
|
155 |
+
2023-11-01 01:52:40,095 INFO [train.py:764] Epoch 2, batch 1200, train_loss[loss=3.164, ArTop10Accuracy=0.7892, NarTop10Accuracy=0.5847, over 1115.00 frames. ], tot_loss[loss=3.271, ArTop10Accuracy=0.7688, NarTop10Accuracy=0.5645, over 1196.82 frames. ], batch size: 2, lr: 2.96e-03, grad_scale: 32.0
|
156 |
+
2023-11-01 01:53:02,469 INFO [train.py:764] Epoch 2, batch 1300, train_loss[loss=3.172, ArTop10Accuracy=0.8073, NarTop10Accuracy=0.5677, over 1204.00 frames. ], tot_loss[loss=3.275, ArTop10Accuracy=0.768, NarTop10Accuracy=0.5649, over 1197.53 frames. ], batch size: 3, lr: 2.95e-03, grad_scale: 32.0
|
157 |
+
2023-11-01 01:53:09,537 INFO [utils.py:877] Clipping_scale=2.0, grad-norm quartiles 2.571e+01 3.646e+01 3.940e+01 4.297e+01 1.030e+02, threshold=7.880e+01, percent-clipped=0.1
|
158 |
+
2023-11-01 01:53:24,675 INFO [train.py:764] Epoch 2, batch 1400, train_loss[loss=3.376, ArTop10Accuracy=0.7127, NarTop10Accuracy=0.5969, over 1065.00 frames. ], tot_loss[loss=3.259, ArTop10Accuracy=0.7695, NarTop10Accuracy=0.5666, over 1207.86 frames. ], batch size: 2, lr: 2.94e-03, grad_scale: 32.0
|
159 |
+
2023-11-01 01:53:46,768 INFO [train.py:764] Epoch 2, batch 1500, train_loss[loss=3.237, ArTop10Accuracy=0.7795, NarTop10Accuracy=0.5848, over 1256.00 frames. ], tot_loss[loss=3.255, ArTop10Accuracy=0.7699, NarTop10Accuracy=0.5686, over 1208.58 frames. ], batch size: 3, lr: 2.93e-03, grad_scale: 32.0
|
160 |
+
2023-11-01 01:54:08,665 INFO [train.py:764] Epoch 2, batch 1600, train_loss[loss=3.242, ArTop10Accuracy=0.726, NarTop10Accuracy=0.6307, over 989.00 frames. ], tot_loss[loss=3.244, ArTop10Accuracy=0.7713, NarTop10Accuracy=0.5715, over 1199.33 frames. ], batch size: 2, lr: 2.92e-03, grad_scale: 32.0
|
161 |
+
2023-11-01 01:54:30,744 INFO [train.py:764] Epoch 2, batch 1700, train_loss[loss=2.888, ArTop10Accuracy=0.7923, NarTop10Accuracy=0.6888, over 1305.00 frames. ], tot_loss[loss=3.25, ArTop10Accuracy=0.7712, NarTop10Accuracy=0.5707, over 1203.12 frames. ], batch size: 2, lr: 2.91e-03, grad_scale: 32.0
|
162 |
+
2023-11-01 01:54:53,004 INFO [train.py:764] Epoch 2, batch 1800, train_loss[loss=2.743, ArTop10Accuracy=0.7932, NarTop10Accuracy=0.6948, over 967.00 frames. ], tot_loss[loss=3.27, ArTop10Accuracy=0.7691, NarTop10Accuracy=0.5652, over 1210.34 frames. ], batch size: 2, lr: 2.90e-03, grad_scale: 32.0
|
163 |
+
2023-11-01 01:55:14,944 INFO [train.py:764] Epoch 2, batch 1900, train_loss[loss=3.528, ArTop10Accuracy=0.7713, NarTop10Accuracy=0.4586, over 1277.00 frames. ], tot_loss[loss=3.274, ArTop10Accuracy=0.7706, NarTop10Accuracy=0.5617, over 1207.27 frames. ], batch size: 3, lr: 2.90e-03, grad_scale: 32.0
|
164 |
+
2023-11-01 01:55:37,022 INFO [train.py:764] Epoch 2, batch 2000, train_loss[loss=3.354, ArTop10Accuracy=0.7612, NarTop10Accuracy=0.509, over 1030.00 frames. ], tot_loss[loss=3.267, ArTop10Accuracy=0.7721, NarTop10Accuracy=0.5622, over 1205.23 frames. ], batch size: 2, lr: 2.89e-03, grad_scale: 32.0
|
165 |
+
2023-11-01 01:55:59,142 INFO [train.py:764] Epoch 2, batch 2100, train_loss[loss=3.008, ArTop10Accuracy=0.7989, NarTop10Accuracy=0.6327, over 1253.00 frames. ], tot_loss[loss=3.278, ArTop10Accuracy=0.7713, NarTop10Accuracy=0.5594, over 1207.06 frames. ], batch size: 3, lr: 2.88e-03, grad_scale: 32.0
|
166 |
+
2023-11-01 01:56:21,029 INFO [train.py:764] Epoch 2, batch 2200, train_loss[loss=3.097, ArTop10Accuracy=0.7775, NarTop10Accuracy=0.6281, over 1272.00 frames. ], tot_loss[loss=3.266, ArTop10Accuracy=0.7719, NarTop10Accuracy=0.563, over 1199.67 frames. ], batch size: 3, lr: 2.87e-03, grad_scale: 32.0
|
167 |
+
2023-11-01 01:56:43,276 INFO [train.py:764] Epoch 2, batch 2300, train_loss[loss=3.314, ArTop10Accuracy=0.7717, NarTop10Accuracy=0.5371, over 1270.00 frames. ], tot_loss[loss=3.244, ArTop10Accuracy=0.7723, NarTop10Accuracy=0.5711, over 1213.70 frames. ], batch size: 3, lr: 2.86e-03, grad_scale: 32.0
|
168 |
+
2023-11-01 01:56:50,205 INFO [utils.py:877] Clipping_scale=2.0, grad-norm quartiles 2.446e+01 3.674e+01 3.960e+01 4.425e+01 1.163e+02, threshold=7.920e+01, percent-clipped=0.2
|
169 |
+
2023-11-01 01:57:05,137 INFO [train.py:764] Epoch 2, batch 2400, train_loss[loss=2.835, ArTop10Accuracy=0.8084, NarTop10Accuracy=0.6297, over 950.00 frames. ], tot_loss[loss=3.257, ArTop10Accuracy=0.7717, NarTop10Accuracy=0.5661, over 1204.28 frames. ], batch size: 2, lr: 2.85e-03, grad_scale: 32.0
|
170 |
+
2023-11-01 01:57:27,371 INFO [train.py:764] Epoch 2, batch 2500, train_loss[loss=3.096, ArTop10Accuracy=0.7811, NarTop10Accuracy=0.6451, over 1421.00 frames. ], tot_loss[loss=3.266, ArTop10Accuracy=0.771, NarTop10Accuracy=0.5629, over 1206.71 frames. ], batch size: 2, lr: 2.84e-03, grad_scale: 16.0
|
171 |
+
2023-11-01 01:57:49,515 INFO [train.py:764] Epoch 2, batch 2600, train_loss[loss=3.526, ArTop10Accuracy=0.7525, NarTop10Accuracy=0.499, over 1204.00 frames. ], tot_loss[loss=3.269, ArTop10Accuracy=0.7704, NarTop10Accuracy=0.5633, over 1207.73 frames. ], batch size: 3, lr: 2.83e-03, grad_scale: 16.0
|
172 |
+
2023-11-01 01:58:11,257 INFO [train.py:764] Epoch 2, batch 2700, train_loss[loss=2.971, ArTop10Accuracy=0.7834, NarTop10Accuracy=0.6313, over 1279.00 frames. ], tot_loss[loss=3.267, ArTop10Accuracy=0.7712, NarTop10Accuracy=0.5641, over 1196.88 frames. ], batch size: 3, lr: 2.82e-03, grad_scale: 16.0
|
exp/valle_dev/tensorboard/events.out.tfevents.1698769188.vallex1-4110961-iaas.58414.0
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:d6675fdba042d5999c24e258a02434eee15b71eb9358e67df381d956010ee26e
|
3 |
+
size 66553
|
exp/valle_dev/tensorboard/events.out.tfevents.1698771660.vallex1-4110961-iaas.58697.0
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:0e32669f11c346ac94101e4b83226e3a8792257df84bf93ba2dbd428118de47a
|
3 |
+
size 96315
|
images/vallex_framework.jpg
ADDED
infer.ipynb
ADDED
The diff for this file is too large to render.
See raw diff
|
|
launch-ui.py
ADDED
@@ -0,0 +1,432 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding: utf-8
|
2 |
+
import logging
|
3 |
+
import os
|
4 |
+
import pathlib
|
5 |
+
import time
|
6 |
+
import tempfile
|
7 |
+
import platform
|
8 |
+
import webbrowser
|
9 |
+
import sys
|
10 |
+
|
11 |
+
print(f"default encoding is {sys.getdefaultencoding()},file system encoding is {sys.getfilesystemencoding()}")
|
12 |
+
print(f"You are using Python version {platform.python_version()}")
|
13 |
+
if (sys.version_info[0] < 3 or sys.version_info[1] < 7):
|
14 |
+
print("The Python version is too low and may cause problems")
|
15 |
+
|
16 |
+
if platform.system().lower() == 'windows':
|
17 |
+
temp = pathlib.PosixPath
|
18 |
+
pathlib.PosixPath = pathlib.WindowsPath
|
19 |
+
else:
|
20 |
+
temp = pathlib.WindowsPath
|
21 |
+
pathlib.WindowsPath = pathlib.PosixPath
|
22 |
+
os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION"] = "python"
|
23 |
+
|
24 |
+
import py3langid as langid
|
25 |
+
|
26 |
+
langid.set_languages(['en', 'zh', 'ja', 'vi'])
|
27 |
+
|
28 |
+
import nltk
|
29 |
+
|
30 |
+
nltk.data.path = nltk.data.path + [os.path.join(os.getcwd(), "nltk_data")]
|
31 |
+
|
32 |
+
import torch
|
33 |
+
import torchaudio
|
34 |
+
|
35 |
+
import numpy as np
|
36 |
+
|
37 |
+
from data.tokenizer import (
|
38 |
+
AudioTokenizer,
|
39 |
+
tokenize_audio,
|
40 |
+
)
|
41 |
+
from data.collation import get_text_token_collater
|
42 |
+
from models.vallex import VALLE
|
43 |
+
from utils.g2p import PhonemeBpeTokenizer
|
44 |
+
from descriptions import *
|
45 |
+
from macros import *
|
46 |
+
|
47 |
+
import gradio as gr
|
48 |
+
import whisper
|
49 |
+
from vocos import Vocos
|
50 |
+
import multiprocessing
|
51 |
+
|
52 |
+
thread_count = multiprocessing.cpu_count()
|
53 |
+
|
54 |
+
print("Use", thread_count, "cpu cores for computing")
|
55 |
+
|
56 |
+
torch.set_num_threads(thread_count)
|
57 |
+
torch.set_num_interop_threads(thread_count)
|
58 |
+
torch._C._jit_set_profiling_executor(False)
|
59 |
+
torch._C._jit_set_profiling_mode(False)
|
60 |
+
torch._C._set_graph_executor_optimize(False)
|
61 |
+
|
62 |
+
text_tokenizer = PhonemeBpeTokenizer(tokenizer_path="./utils/g2p/bpe_175.json")
|
63 |
+
text_collater = get_text_token_collater()
|
64 |
+
|
65 |
+
device = torch.device("cpu")
|
66 |
+
if torch.cuda.is_available():
|
67 |
+
device = torch.device("cuda", 0)
|
68 |
+
|
69 |
+
# VALL-E-X model
|
70 |
+
model = VALLE(
|
71 |
+
N_DIM,
|
72 |
+
NUM_HEAD,
|
73 |
+
NUM_LAYERS,
|
74 |
+
norm_first=True,
|
75 |
+
add_prenet=False,
|
76 |
+
prefix_mode=PREFIX_MODE,
|
77 |
+
share_embedding=True,
|
78 |
+
nar_scale_factor=1.0,
|
79 |
+
prepend_bos=True,
|
80 |
+
num_quantizers=NUM_QUANTIZERS,
|
81 |
+
)
|
82 |
+
checkpoint = torch.load("./checkpoints/vallex-checkpoint.pt", map_location='cpu')
|
83 |
+
missing_keys, unexpected_keys = model.load_state_dict(
|
84 |
+
checkpoint["model"], strict=True
|
85 |
+
)
|
86 |
+
assert not missing_keys
|
87 |
+
model.eval()
|
88 |
+
|
89 |
+
# Encodec model
|
90 |
+
audio_tokenizer = AudioTokenizer(device)
|
91 |
+
|
92 |
+
# Vocos decoder
|
93 |
+
vocos = Vocos.from_pretrained('charactr/vocos-encodec-24khz').to(device)
|
94 |
+
|
95 |
+
# ASR
|
96 |
+
if not os.path.exists("./whisper/"): os.mkdir("./whisper/")
|
97 |
+
try:
|
98 |
+
whisper_model = whisper.load_model("medium", download_root=os.path.join(os.getcwd(), "whisper")).cpu()
|
99 |
+
except Exception as e:
|
100 |
+
logging.info(e)
|
101 |
+
raise Exception(
|
102 |
+
"\n Whisper download failed or damaged, please go to "
|
103 |
+
"'https://openaipublic.azureedge.net/main/whisper/models/345ae4da62f9b3d59415adc60127b97c714f32e89e936602e85993674d08dcb1/medium.pt'"
|
104 |
+
"\n manually download model and put it to {} .".format(os.getcwd() + "/whisper"))
|
105 |
+
|
106 |
+
# Voice Presets
|
107 |
+
preset_list = os.walk("./presets/").__next__()[2]
|
108 |
+
preset_list = [preset[:-4] for preset in preset_list if preset.endswith(".npz")]
|
109 |
+
|
110 |
+
|
111 |
+
def inference_encoded_frames(text_tokens, text_tokens_lens, audio_prompts, enroll_x_lens, lang_pr, langs, accent, lang):
|
112 |
+
if lang_pr == vi_code:
|
113 |
+
lang_pr = zh_code
|
114 |
+
|
115 |
+
if lang == vi_code:
|
116 |
+
lang = ja_code
|
117 |
+
|
118 |
+
encoded_frames = model.inference(
|
119 |
+
text_tokens.to(device),
|
120 |
+
text_tokens_lens.to(device),
|
121 |
+
audio_prompts,
|
122 |
+
enroll_x_lens=enroll_x_lens,
|
123 |
+
top_k=-100,
|
124 |
+
temperature=1,
|
125 |
+
prompt_language=lang_pr,
|
126 |
+
text_language=langs if accent == "no-accent" else lang,
|
127 |
+
best_of=5,
|
128 |
+
)
|
129 |
+
|
130 |
+
return encoded_frames
|
131 |
+
|
132 |
+
|
133 |
+
def inference_samples(text_tokens, text_tokens_lens, audio_prompts, enroll_x_lens, lang_pr, langs, accent, lang):
|
134 |
+
encoded_frames = inference_encoded_frames(text_tokens, text_tokens_lens, audio_prompts, enroll_x_lens, lang_pr,
|
135 |
+
langs,
|
136 |
+
accent, lang)
|
137 |
+
# Decode with Vocos
|
138 |
+
frames = encoded_frames.permute(2, 0, 1)
|
139 |
+
features = vocos.codes_to_features(frames)
|
140 |
+
samples = vocos.decode(features, bandwidth_id=torch.tensor([2], device=device))
|
141 |
+
|
142 |
+
return samples
|
143 |
+
|
144 |
+
|
145 |
+
def clear_prompts():
|
146 |
+
try:
|
147 |
+
path = tempfile.gettempdir()
|
148 |
+
for eachfile in os.listdir(path):
|
149 |
+
filename = os.path.join(path, eachfile)
|
150 |
+
if os.path.isfile(filename) and filename.endswith(".npz"):
|
151 |
+
lastmodifytime = os.stat(filename).st_mtime
|
152 |
+
endfiletime = time.time() - 60
|
153 |
+
if endfiletime > lastmodifytime:
|
154 |
+
os.remove(filename)
|
155 |
+
except:
|
156 |
+
return
|
157 |
+
|
158 |
+
|
159 |
+
def transcribe_one(model, audio_path):
|
160 |
+
# load audio and pad/trim it to fit 30 seconds
|
161 |
+
audio = whisper.load_audio(audio_path)
|
162 |
+
audio = whisper.pad_or_trim(audio)
|
163 |
+
|
164 |
+
# make log-Mel spectrogram and move to the same device as the model
|
165 |
+
mel = whisper.log_mel_spectrogram(audio).to(model.device)
|
166 |
+
|
167 |
+
# detect the spoken language
|
168 |
+
_, probs = model.detect_language(mel)
|
169 |
+
print(f"Detected language: {max(probs, key=probs.get)}")
|
170 |
+
lang = max(probs, key=probs.get)
|
171 |
+
# decode the audio
|
172 |
+
options = whisper.DecodingOptions(temperature=1.0, best_of=5, fp16=False if device == torch.device("cpu") else True,
|
173 |
+
sample_len=150)
|
174 |
+
result = whisper.decode(model, mel, options)
|
175 |
+
|
176 |
+
# print the recognized text
|
177 |
+
print(result.text)
|
178 |
+
|
179 |
+
text_pr = result.text
|
180 |
+
if text_pr.strip(" ")[-1] not in "?!.,。,?!。、":
|
181 |
+
text_pr += "."
|
182 |
+
return lang, text_pr
|
183 |
+
|
184 |
+
|
185 |
+
def make_npz_prompt(name, uploaded_audio, recorded_audio, transcript_content):
|
186 |
+
global model, text_collater, text_tokenizer, audio_tokenizer
|
187 |
+
clear_prompts()
|
188 |
+
audio_prompt = uploaded_audio if uploaded_audio is not None else recorded_audio
|
189 |
+
sr, wav_pr = audio_prompt
|
190 |
+
if not isinstance(wav_pr, torch.FloatTensor):
|
191 |
+
wav_pr = torch.FloatTensor(wav_pr)
|
192 |
+
if wav_pr.abs().max() > 1:
|
193 |
+
wav_pr /= wav_pr.abs().max()
|
194 |
+
if wav_pr.size(-1) == 2:
|
195 |
+
wav_pr = wav_pr[:, 0]
|
196 |
+
if wav_pr.ndim == 1:
|
197 |
+
wav_pr = wav_pr.unsqueeze(0)
|
198 |
+
assert wav_pr.ndim and wav_pr.size(0) == 1
|
199 |
+
|
200 |
+
if transcript_content == "":
|
201 |
+
text_pr, lang_pr = make_prompt(name, wav_pr, sr, save=False)
|
202 |
+
else:
|
203 |
+
lang_pr = langid.classify(str(transcript_content))[0]
|
204 |
+
lang_token = lang2token[lang_pr]
|
205 |
+
text_pr = f"{lang_token}{str(transcript_content)}{lang_token}"
|
206 |
+
# tokenize audio
|
207 |
+
encoded_frames = tokenize_audio(audio_tokenizer, (wav_pr, sr))
|
208 |
+
audio_tokens = encoded_frames[0][0].transpose(2, 1).cpu().numpy()
|
209 |
+
|
210 |
+
# tokenize text
|
211 |
+
phonemes, _ = text_tokenizer.tokenize(text=f"{text_pr}".strip())
|
212 |
+
text_tokens, enroll_x_lens = text_collater(
|
213 |
+
[
|
214 |
+
phonemes
|
215 |
+
]
|
216 |
+
)
|
217 |
+
|
218 |
+
message = f"Detected language: {lang_pr}\n Detected text {text_pr}\n"
|
219 |
+
|
220 |
+
# save as npz file
|
221 |
+
np.savez(os.path.join(tempfile.gettempdir(), f"{name}.npz"),
|
222 |
+
audio_tokens=audio_tokens, text_tokens=text_tokens, lang_code=lang2code[lang_pr])
|
223 |
+
return message, os.path.join(tempfile.gettempdir(), f"{name}.npz")
|
224 |
+
|
225 |
+
|
226 |
+
def make_prompt(name, wav, sr, save=True):
|
227 |
+
global whisper_model
|
228 |
+
whisper_model.to(device)
|
229 |
+
if not isinstance(wav, torch.FloatTensor):
|
230 |
+
wav = torch.tensor(wav)
|
231 |
+
if wav.abs().max() > 1:
|
232 |
+
wav /= wav.abs().max()
|
233 |
+
if wav.size(-1) == 2:
|
234 |
+
wav = wav.mean(-1, keepdim=False)
|
235 |
+
if wav.ndim == 1:
|
236 |
+
wav = wav.unsqueeze(0)
|
237 |
+
assert wav.ndim and wav.size(0) == 1
|
238 |
+
torchaudio.save(f"./prompts/{name}.wav", wav, sr)
|
239 |
+
lang, text = transcribe_one(whisper_model, f"./prompts/{name}.wav")
|
240 |
+
lang_token = lang2token[lang]
|
241 |
+
text = lang_token + text + lang_token
|
242 |
+
with open(f"./prompts/{name}.txt", 'w', encoding='utf-8') as f:
|
243 |
+
f.write(text)
|
244 |
+
if not save:
|
245 |
+
os.remove(f"./prompts/{name}.wav")
|
246 |
+
os.remove(f"./prompts/{name}.txt")
|
247 |
+
|
248 |
+
whisper_model.cpu()
|
249 |
+
torch.cuda.empty_cache()
|
250 |
+
return text, lang
|
251 |
+
|
252 |
+
|
253 |
+
from utils.sentence_cutter import split_text_into_sentences
|
254 |
+
|
255 |
+
|
256 |
+
@torch.no_grad()
|
257 |
+
def infer_long_text(text, preset_prompt, prompt=None, language='auto', accent='no-accent'):
|
258 |
+
"""
|
259 |
+
For long audio generation, two modes are available.
|
260 |
+
fixed-prompt: This mode will keep using the same prompt the user has provided, and generate audio sentence by sentence.
|
261 |
+
sliding-window: This mode will use the last sentence as the prompt for the next sentence, but has some concern on speaker maintenance.
|
262 |
+
"""
|
263 |
+
mode = 'fixed-prompt'
|
264 |
+
global model, audio_tokenizer, text_tokenizer, text_collater
|
265 |
+
model.to(device)
|
266 |
+
if (prompt is None or prompt == "") and preset_prompt == "":
|
267 |
+
mode = 'sliding-window' # If no prompt is given, use sliding-window mode
|
268 |
+
sentences = split_text_into_sentences(text)
|
269 |
+
# detect language
|
270 |
+
if language == "auto-detect":
|
271 |
+
language = langid.classify(text)[0]
|
272 |
+
else:
|
273 |
+
language = token2lang[langdropdown2token[language]]
|
274 |
+
|
275 |
+
# if initial prompt is given, encode it
|
276 |
+
if prompt is not None and prompt != "":
|
277 |
+
# load prompt
|
278 |
+
prompt_data = np.load(prompt.name)
|
279 |
+
audio_prompts = prompt_data['audio_tokens']
|
280 |
+
text_prompts = prompt_data['text_tokens']
|
281 |
+
lang_pr = prompt_data['lang_code']
|
282 |
+
lang_pr = code2lang[int(lang_pr)]
|
283 |
+
|
284 |
+
# numpy to tensor
|
285 |
+
audio_prompts = torch.tensor(audio_prompts).type(torch.int32).to(device)
|
286 |
+
text_prompts = torch.tensor(text_prompts).type(torch.int32)
|
287 |
+
elif preset_prompt is not None and preset_prompt != "":
|
288 |
+
prompt_data = np.load(os.path.join("./presets/", f"{preset_prompt}.npz"))
|
289 |
+
audio_prompts = prompt_data['audio_tokens']
|
290 |
+
text_prompts = prompt_data['text_tokens']
|
291 |
+
lang_pr = prompt_data['lang_code']
|
292 |
+
lang_pr = code2lang[int(lang_pr)]
|
293 |
+
|
294 |
+
# numpy to tensor
|
295 |
+
audio_prompts = torch.tensor(audio_prompts).type(torch.int32).to(device)
|
296 |
+
text_prompts = torch.tensor(text_prompts).type(torch.int32)
|
297 |
+
else:
|
298 |
+
audio_prompts = torch.zeros([1, 0, NUM_QUANTIZERS]).type(torch.int32).to(device)
|
299 |
+
text_prompts = torch.zeros([1, 0]).type(torch.int32)
|
300 |
+
lang_pr = language if language != 'mix' else 'en'
|
301 |
+
if mode == 'fixed-prompt':
|
302 |
+
complete_tokens = torch.zeros([1, NUM_QUANTIZERS, 0]).type(torch.LongTensor).to(device)
|
303 |
+
for text in sentences:
|
304 |
+
text = text.replace("\n", "").strip(" ")
|
305 |
+
if text == "":
|
306 |
+
continue
|
307 |
+
lang_token = lang2token[language]
|
308 |
+
lang = token2lang[lang_token]
|
309 |
+
text = lang_token + text + lang_token
|
310 |
+
|
311 |
+
enroll_x_lens = text_prompts.shape[-1]
|
312 |
+
logging.info(f"synthesize text: {text}")
|
313 |
+
phone_tokens, langs = text_tokenizer.tokenize(text=f"_{text}".strip())
|
314 |
+
text_tokens, text_tokens_lens = text_collater(
|
315 |
+
[
|
316 |
+
phone_tokens
|
317 |
+
]
|
318 |
+
)
|
319 |
+
text_tokens = torch.cat([text_prompts, text_tokens], dim=-1)
|
320 |
+
text_tokens_lens += enroll_x_lens
|
321 |
+
# accent control
|
322 |
+
lang = lang if accent == "no-accent" else token2lang[langdropdown2token[accent]]
|
323 |
+
encoded_frames = inference_encoded_frames(text_tokens, text_tokens_lens, audio_prompts, enroll_x_lens,
|
324 |
+
lang_pr, langs, accent, lang)
|
325 |
+
complete_tokens = torch.cat([complete_tokens, encoded_frames.transpose(2, 1)], dim=-1)
|
326 |
+
# Decode with Vocos
|
327 |
+
frames = complete_tokens.permute(1, 0, 2)
|
328 |
+
features = vocos.codes_to_features(frames)
|
329 |
+
samples = vocos.decode(features, bandwidth_id=torch.tensor([2], device=device))
|
330 |
+
|
331 |
+
model.to('cpu')
|
332 |
+
message = f"Cut into {len(sentences)} sentences"
|
333 |
+
return message, (24000, samples.squeeze(0).cpu().numpy())
|
334 |
+
elif mode == "sliding-window":
|
335 |
+
complete_tokens = torch.zeros([1, NUM_QUANTIZERS, 0]).type(torch.LongTensor).to(device)
|
336 |
+
original_audio_prompts = audio_prompts
|
337 |
+
original_text_prompts = text_prompts
|
338 |
+
for text in sentences:
|
339 |
+
text = text.replace("\n", "").strip(" ")
|
340 |
+
if text == "":
|
341 |
+
continue
|
342 |
+
lang_token = lang2token[language]
|
343 |
+
lang = token2lang[lang_token]
|
344 |
+
text = lang_token + text + lang_token
|
345 |
+
|
346 |
+
enroll_x_lens = text_prompts.shape[-1]
|
347 |
+
logging.info(f"synthesize text: {text}")
|
348 |
+
phone_tokens, langs = text_tokenizer.tokenize(text=f"_{text}".strip())
|
349 |
+
text_tokens, text_tokens_lens = text_collater(
|
350 |
+
[
|
351 |
+
phone_tokens
|
352 |
+
]
|
353 |
+
)
|
354 |
+
text_tokens = torch.cat([text_prompts, text_tokens], dim=-1)
|
355 |
+
text_tokens_lens += enroll_x_lens
|
356 |
+
# accent control
|
357 |
+
lang = lang if accent == "no-accent" else token2lang[langdropdown2token[accent]]
|
358 |
+
encoded_frames = inference_encoded_frames(text_tokens, text_tokens_lens, audio_prompts, enroll_x_lens,
|
359 |
+
lang_pr, langs, accent, lang)
|
360 |
+
complete_tokens = torch.cat([complete_tokens, encoded_frames.transpose(2, 1)], dim=-1)
|
361 |
+
if torch.rand(1) < 1.0:
|
362 |
+
audio_prompts = encoded_frames[:, :, -NUM_QUANTIZERS:]
|
363 |
+
text_prompts = text_tokens[:, enroll_x_lens:]
|
364 |
+
else:
|
365 |
+
audio_prompts = original_audio_prompts
|
366 |
+
text_prompts = original_text_prompts
|
367 |
+
# Decode with Vocos
|
368 |
+
frames = complete_tokens.permute(1, 0, 2)
|
369 |
+
features = vocos.codes_to_features(frames)
|
370 |
+
samples = vocos.decode(features, bandwidth_id=torch.tensor([2], device=device))
|
371 |
+
|
372 |
+
model.to('cpu')
|
373 |
+
return 24000, samples.squeeze(0).cpu().numpy()
|
374 |
+
else:
|
375 |
+
raise ValueError(f"No such mode {mode}")
|
376 |
+
|
377 |
+
|
378 |
+
def main():
|
379 |
+
app = gr.Blocks(title="TTS and Voice Clone")
|
380 |
+
with app:
|
381 |
+
with gr.Tab("Text to Speech"):
|
382 |
+
with gr.Row():
|
383 |
+
with gr.Column():
|
384 |
+
textbox_4 = gr.TextArea(label="Text",
|
385 |
+
placeholder="Type your sentence here",
|
386 |
+
value=long_text_example, elem_id=f"tts-input")
|
387 |
+
language_dropdown_4 = gr.Dropdown(choices=language_options,
|
388 |
+
value='auto-detect',
|
389 |
+
label='language')
|
390 |
+
accent_dropdown_4 = gr.Dropdown(choices=language_options,
|
391 |
+
value='no-accent',
|
392 |
+
label='accent')
|
393 |
+
|
394 |
+
with gr.Column():
|
395 |
+
preset_dropdown_4 = gr.Dropdown(choices=preset_list, value=None, label='Voice preset')
|
396 |
+
prompt_file_4 = gr.File(file_count='single', file_types=['.npz'], interactive=True)
|
397 |
+
audio_output_4 = gr.Audio(label="Output Audio", elem_id="tts-audio")
|
398 |
+
btn_4 = gr.Button("Generate!")
|
399 |
+
btn_4.click(infer_long_text,
|
400 |
+
inputs=[textbox_4, preset_dropdown_4, prompt_file_4, language_dropdown_4,
|
401 |
+
accent_dropdown_4],
|
402 |
+
outputs=[audio_output_4])
|
403 |
+
with gr.Tab("Make prompt for voice clone"):
|
404 |
+
with gr.Row():
|
405 |
+
with gr.Column():
|
406 |
+
textbox2 = gr.TextArea(label="Prompt name",
|
407 |
+
placeholder="Name your prompt here",
|
408 |
+
value="prompt_1", elem_id=f"prompt-name")
|
409 |
+
textbox_transcript2 = gr.TextArea(label="Transcript",
|
410 |
+
placeholder="Write transcript here. (leave empty to use whisper)",
|
411 |
+
value="", elem_id=f"prompt-name")
|
412 |
+
upload_audio_prompt_2 = gr.Audio(label='uploaded audio prompt', source='upload', interactive=True)
|
413 |
+
record_audio_prompt_2 = gr.Audio(label='recorded audio prompt', source='microphone',
|
414 |
+
interactive=True)
|
415 |
+
with gr.Column():
|
416 |
+
text_output_2 = gr.Textbox(label="Message")
|
417 |
+
prompt_output_2 = gr.File(interactive=False)
|
418 |
+
btn_2 = gr.Button("Make!")
|
419 |
+
btn_2.click(make_npz_prompt,
|
420 |
+
inputs=[textbox2, upload_audio_prompt_2, record_audio_prompt_2, textbox_transcript2],
|
421 |
+
outputs=[text_output_2, prompt_output_2])
|
422 |
+
|
423 |
+
webbrowser.open("http://127.0.0.1:7860")
|
424 |
+
app.launch()
|
425 |
+
|
426 |
+
|
427 |
+
if __name__ == "__main__":
|
428 |
+
formatter = (
|
429 |
+
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
430 |
+
)
|
431 |
+
logging.basicConfig(format=formatter, level=logging.INFO)
|
432 |
+
main()
|
macros.py
ADDED
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
NUM_LAYERS = 12
|
2 |
+
NUM_HEAD = 16
|
3 |
+
N_DIM = 1024
|
4 |
+
PREFIX_MODE = 1
|
5 |
+
NUM_QUANTIZERS = 8
|
6 |
+
SAMPLE_RATE = 24000
|
7 |
+
|
8 |
+
lang2token = {
|
9 |
+
'zh': "[ZH]",
|
10 |
+
'ja': "[JA]",
|
11 |
+
"en": "[EN]",
|
12 |
+
"vi": "[VI]",
|
13 |
+
'mix': "",
|
14 |
+
}
|
15 |
+
|
16 |
+
lang2code = {
|
17 |
+
'zh': 0,
|
18 |
+
'ja': 1,
|
19 |
+
"en": 2,
|
20 |
+
"vi": 3
|
21 |
+
}
|
22 |
+
|
23 |
+
token2lang = {
|
24 |
+
'[ZH]': "zh",
|
25 |
+
'[JA]': "ja",
|
26 |
+
"[EN]": "en",
|
27 |
+
"[VI]": "vi",
|
28 |
+
"": "mix"
|
29 |
+
}
|
30 |
+
|
31 |
+
code2lang = {
|
32 |
+
0: 'zh',
|
33 |
+
1: 'ja',
|
34 |
+
2: "en",
|
35 |
+
3: "vi",
|
36 |
+
}
|
37 |
+
|
38 |
+
vi_code = 'vi'
|
39 |
+
zh_code = 'zh'
|
40 |
+
en_code = 'en'
|
41 |
+
ja_code = 'ja'
|
42 |
+
|
43 |
+
langdropdown2token = {
|
44 |
+
'English': "[EN]",
|
45 |
+
'Chinese': "[ZH]",
|
46 |
+
'Japanese': "[JA]",
|
47 |
+
'Vietnamese': "[VI]",
|
48 |
+
'Mix': "",
|
49 |
+
}
|
50 |
+
|
51 |
+
language_options = ['no-accent', 'English', 'Chinese', 'Japanese', 'Vietnamese']
|
makedata.ipynb
ADDED
The diff for this file is too large to render.
See raw diff
|
|
model-card.md
ADDED
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Model Card: VALL-E X
|
2 |
+
|
3 |
+
**Author**: [Songting](https://github.com/Plachtaa).<br>
|
4 |
+
<br>
|
5 |
+
This is the official codebase for running open-sourced VALL-E X.
|
6 |
+
|
7 |
+
The following is additional information about the models released here.
|
8 |
+
|
9 |
+
## Model Details
|
10 |
+
|
11 |
+
VALL-E X is a series of two transformer models that turn text into audio.
|
12 |
+
|
13 |
+
### Phoneme to acoustic tokens
|
14 |
+
- Input: IPAs converted from input text by a rule-based G2P tool.
|
15 |
+
- Output: tokens from the first codebook of the [EnCodec Codec](https://github.com/facebookresearch/encodec) from facebook
|
16 |
+
|
17 |
+
### Coarse to fine tokens
|
18 |
+
- Input: IPAs converted from input text by a rule-based G2P tool & the first codebook from EnCodec
|
19 |
+
- Output: 8 codebooks from EnCodec
|
20 |
+
|
21 |
+
### Architecture
|
22 |
+
| Model | Parameters | Attention | Output Vocab size |
|
23 |
+
|:------------------------:|:----------:|------------|:-----------------:|
|
24 |
+
| G2P tool | - | - | 69 |
|
25 |
+
| Phoneme to coarse tokens | 150 M | Causal | 1x 1,024 |
|
26 |
+
| Coarse to fine tokens | 150 M | Non-causal | 7x 1,024 |
|
27 |
+
|
28 |
+
### Release date
|
29 |
+
August 2023
|
30 |
+
|
31 |
+
## Broader Implications
|
32 |
+
We anticipate that this model's text to audio capabilities can be used to improve accessbility tools in a variety of languages.
|
33 |
+
Straightforward improvements will allow models to run faster than realtime, rendering them useful for applications such as virtual assistants.
|
models/__init__.py
ADDED
@@ -0,0 +1,136 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
|
3 |
+
import torch.nn as nn
|
4 |
+
# from icefall.utils import AttributeDict, str2bool
|
5 |
+
|
6 |
+
from .macros import (
|
7 |
+
NUM_AUDIO_TOKENS,
|
8 |
+
NUM_MEL_BINS,
|
9 |
+
NUM_SPEAKER_CLASSES,
|
10 |
+
NUM_TEXT_TOKENS,
|
11 |
+
SPEAKER_EMBEDDING_DIM,
|
12 |
+
)
|
13 |
+
from .transformer import Transformer
|
14 |
+
from .vallex import VALLE, VALLF
|
15 |
+
from .visualizer import visualize
|
16 |
+
|
17 |
+
|
18 |
+
def add_model_arguments(parser: argparse.ArgumentParser):
|
19 |
+
parser.add_argument(
|
20 |
+
"--model-name",
|
21 |
+
type=str,
|
22 |
+
default="VALL-E",
|
23 |
+
help="VALL-E, VALL-F, Transformer.",
|
24 |
+
)
|
25 |
+
parser.add_argument(
|
26 |
+
"--decoder-dim",
|
27 |
+
type=int,
|
28 |
+
default=1024,
|
29 |
+
help="Embedding dimension in the decoder model.",
|
30 |
+
)
|
31 |
+
parser.add_argument(
|
32 |
+
"--nhead",
|
33 |
+
type=int,
|
34 |
+
default=16,
|
35 |
+
help="Number of attention heads in the Decoder layers.",
|
36 |
+
)
|
37 |
+
parser.add_argument(
|
38 |
+
"--num-decoder-layers",
|
39 |
+
type=int,
|
40 |
+
default=12,
|
41 |
+
help="Number of Decoder layers.",
|
42 |
+
)
|
43 |
+
parser.add_argument(
|
44 |
+
"--scale-factor",
|
45 |
+
type=float,
|
46 |
+
default=1.0,
|
47 |
+
help="Model scale factor which will be assigned different meanings in different models.",
|
48 |
+
)
|
49 |
+
parser.add_argument(
|
50 |
+
"--norm-first",
|
51 |
+
type=bool,
|
52 |
+
default=True,
|
53 |
+
help="Pre or Post Normalization.",
|
54 |
+
)
|
55 |
+
parser.add_argument(
|
56 |
+
"--add-prenet",
|
57 |
+
type=bool,
|
58 |
+
default=False,
|
59 |
+
help="Whether add PreNet after Inputs.",
|
60 |
+
)
|
61 |
+
|
62 |
+
# VALL-E & F
|
63 |
+
parser.add_argument(
|
64 |
+
"--prefix-mode",
|
65 |
+
type=int,
|
66 |
+
default=1,
|
67 |
+
help="The mode for how to prefix VALL-E NAR Decoder, "
|
68 |
+
"0: no prefix, 1: 0 to random, 2: random to random, 4: chunk of pre or post utterance.",
|
69 |
+
)
|
70 |
+
parser.add_argument(
|
71 |
+
"--share-embedding",
|
72 |
+
type=bool,
|
73 |
+
default=True,
|
74 |
+
help="Share the parameters of the output projection layer with the parameters of the acoustic embedding.",
|
75 |
+
)
|
76 |
+
parser.add_argument(
|
77 |
+
"--prepend-bos",
|
78 |
+
type=bool,
|
79 |
+
default=False,
|
80 |
+
help="Whether prepend <BOS> to the acoustic tokens -> AR Decoder inputs.",
|
81 |
+
)
|
82 |
+
parser.add_argument(
|
83 |
+
"--num-quantizers",
|
84 |
+
type=int,
|
85 |
+
default=8,
|
86 |
+
help="Number of Audio/Semantic quantization layers.",
|
87 |
+
)
|
88 |
+
|
89 |
+
# Transformer
|
90 |
+
parser.add_argument(
|
91 |
+
"--scaling-xformers",
|
92 |
+
type=bool,
|
93 |
+
default=False,
|
94 |
+
help="Apply Reworked Conformer scaling on Transformers.",
|
95 |
+
)
|
96 |
+
|
97 |
+
|
98 |
+
def get_model(params) -> nn.Module:
|
99 |
+
if params.model_name.lower() in ["vall-f", "vallf"]:
|
100 |
+
model = VALLF(
|
101 |
+
params.decoder_dim,
|
102 |
+
params.nhead,
|
103 |
+
params.num_decoder_layers,
|
104 |
+
norm_first=params.norm_first,
|
105 |
+
add_prenet=params.add_prenet,
|
106 |
+
prefix_mode=params.prefix_mode,
|
107 |
+
share_embedding=params.share_embedding,
|
108 |
+
nar_scale_factor=params.scale_factor,
|
109 |
+
prepend_bos=params.prepend_bos,
|
110 |
+
num_quantizers=params.num_quantizers,
|
111 |
+
)
|
112 |
+
elif params.model_name.lower() in ["vall-e", "valle"]:
|
113 |
+
model = VALLE(
|
114 |
+
params.decoder_dim,
|
115 |
+
params.nhead,
|
116 |
+
params.num_decoder_layers,
|
117 |
+
norm_first=params.norm_first,
|
118 |
+
add_prenet=params.add_prenet,
|
119 |
+
prefix_mode=params.prefix_mode,
|
120 |
+
share_embedding=params.share_embedding,
|
121 |
+
nar_scale_factor=params.scale_factor,
|
122 |
+
prepend_bos=params.prepend_bos,
|
123 |
+
num_quantizers=params.num_quantizers,
|
124 |
+
)
|
125 |
+
else:
|
126 |
+
assert params.model_name in ["Transformer"]
|
127 |
+
model = Transformer(
|
128 |
+
params.decoder_dim,
|
129 |
+
params.nhead,
|
130 |
+
params.num_decoder_layers,
|
131 |
+
norm_first=params.norm_first,
|
132 |
+
add_prenet=params.add_prenet,
|
133 |
+
scaling_xformers=params.scaling_xformers,
|
134 |
+
)
|
135 |
+
|
136 |
+
return model
|
models/macros.py
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Text
|
2 |
+
NUM_TEXT_TOKENS = 2048
|
3 |
+
|
4 |
+
# Audio
|
5 |
+
NUM_AUDIO_TOKENS = 1024 # EnCodec RVQ bins
|
6 |
+
NUM_MEL_BINS = 100 # BigVGAN bigvgan_24khz_100band
|
7 |
+
|
8 |
+
|
9 |
+
# Speaker
|
10 |
+
NUM_SPEAKER_CLASSES = 4096
|
11 |
+
SPEAKER_EMBEDDING_DIM = 64
|
models/transformer.py
ADDED
@@ -0,0 +1,394 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023 (authors: Feiteng Li)
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
from functools import partial
|
16 |
+
from typing import Any, Dict, List, Tuple, Union
|
17 |
+
|
18 |
+
import torch
|
19 |
+
import torch.nn as nn
|
20 |
+
import torch.nn.functional as F
|
21 |
+
# from icefall.utils import make_pad_mask
|
22 |
+
# from torchmetrics.classification import BinaryAccuracy
|
23 |
+
|
24 |
+
from models.vallex import Transpose
|
25 |
+
from modules.embedding import SinePositionalEmbedding, TokenEmbedding
|
26 |
+
from modules.scaling import BalancedDoubleSwish, ScaledLinear
|
27 |
+
from modules.transformer import (
|
28 |
+
BalancedBasicNorm,
|
29 |
+
IdentityNorm,
|
30 |
+
TransformerDecoderLayer,
|
31 |
+
TransformerEncoder,
|
32 |
+
TransformerEncoderLayer,
|
33 |
+
)
|
34 |
+
|
35 |
+
from .macros import NUM_MEL_BINS, NUM_TEXT_TOKENS
|
36 |
+
from .visualizer import visualize
|
37 |
+
|
38 |
+
IdentityNorm = IdentityNorm
|
39 |
+
|
40 |
+
|
41 |
+
class Transformer(nn.Module):
|
42 |
+
"""It implements seq2seq Transformer TTS for debug(No StopPredictor and SpeakerEmbeding)
|
43 |
+
Neural Speech Synthesis with Transformer Network
|
44 |
+
https://arxiv.org/abs/1809.08895
|
45 |
+
"""
|
46 |
+
|
47 |
+
def __init__(
|
48 |
+
self,
|
49 |
+
d_model: int,
|
50 |
+
nhead: int,
|
51 |
+
num_layers: int,
|
52 |
+
norm_first: bool = True,
|
53 |
+
add_prenet: bool = False,
|
54 |
+
scaling_xformers: bool = False,
|
55 |
+
):
|
56 |
+
"""
|
57 |
+
Args:
|
58 |
+
d_model:
|
59 |
+
The number of expected features in the input (required).
|
60 |
+
nhead:
|
61 |
+
The number of heads in the multiheadattention models (required).
|
62 |
+
num_layers:
|
63 |
+
The number of sub-decoder-layers in the decoder (required).
|
64 |
+
"""
|
65 |
+
super().__init__()
|
66 |
+
self.text_embedding = TokenEmbedding(d_model, NUM_TEXT_TOKENS) # W_x
|
67 |
+
|
68 |
+
if add_prenet:
|
69 |
+
self.encoder_prenet = nn.Sequential(
|
70 |
+
Transpose(),
|
71 |
+
nn.Conv1d(d_model, d_model, kernel_size=5, padding="same"),
|
72 |
+
nn.BatchNorm1d(d_model),
|
73 |
+
nn.ReLU(),
|
74 |
+
nn.Dropout(0.5),
|
75 |
+
nn.Conv1d(d_model, d_model, kernel_size=5, padding="same"),
|
76 |
+
nn.BatchNorm1d(d_model),
|
77 |
+
nn.ReLU(),
|
78 |
+
nn.Dropout(0.5),
|
79 |
+
nn.Conv1d(d_model, d_model, kernel_size=5, padding="same"),
|
80 |
+
nn.BatchNorm1d(d_model),
|
81 |
+
nn.ReLU(),
|
82 |
+
nn.Dropout(0.5),
|
83 |
+
Transpose(),
|
84 |
+
nn.Linear(d_model, d_model),
|
85 |
+
)
|
86 |
+
|
87 |
+
self.decoder_prenet = nn.Sequential(
|
88 |
+
nn.Linear(NUM_MEL_BINS, 256),
|
89 |
+
nn.ReLU(),
|
90 |
+
nn.Dropout(0.5),
|
91 |
+
nn.Linear(256, 256),
|
92 |
+
nn.ReLU(),
|
93 |
+
nn.Dropout(0.5),
|
94 |
+
nn.Linear(256, d_model),
|
95 |
+
)
|
96 |
+
|
97 |
+
assert scaling_xformers is False # TODO: update this block
|
98 |
+
else:
|
99 |
+
self.encoder_prenet = nn.Identity()
|
100 |
+
if scaling_xformers:
|
101 |
+
self.decoder_prenet = ScaledLinear(NUM_MEL_BINS, d_model)
|
102 |
+
else:
|
103 |
+
self.decoder_prenet = nn.Linear(NUM_MEL_BINS, d_model)
|
104 |
+
|
105 |
+
self.encoder_position = SinePositionalEmbedding(
|
106 |
+
d_model,
|
107 |
+
dropout=0.1,
|
108 |
+
scale=False,
|
109 |
+
)
|
110 |
+
self.decoder_position = SinePositionalEmbedding(
|
111 |
+
d_model, dropout=0.1, scale=False
|
112 |
+
)
|
113 |
+
|
114 |
+
if scaling_xformers:
|
115 |
+
self.encoder = TransformerEncoder(
|
116 |
+
TransformerEncoderLayer(
|
117 |
+
d_model,
|
118 |
+
nhead,
|
119 |
+
dim_feedforward=d_model * 4,
|
120 |
+
dropout=0.1,
|
121 |
+
batch_first=True,
|
122 |
+
norm_first=norm_first,
|
123 |
+
linear1_self_attention_cls=ScaledLinear,
|
124 |
+
linear2_self_attention_cls=partial(
|
125 |
+
ScaledLinear, initial_scale=0.01
|
126 |
+
),
|
127 |
+
linear1_feedforward_cls=ScaledLinear,
|
128 |
+
linear2_feedforward_cls=partial(
|
129 |
+
ScaledLinear, initial_scale=0.01
|
130 |
+
),
|
131 |
+
activation=partial(
|
132 |
+
BalancedDoubleSwish,
|
133 |
+
channel_dim=-1,
|
134 |
+
max_abs=10.0,
|
135 |
+
min_prob=0.25,
|
136 |
+
),
|
137 |
+
layer_norm_cls=IdentityNorm,
|
138 |
+
),
|
139 |
+
num_layers=num_layers,
|
140 |
+
norm=BalancedBasicNorm(d_model) if norm_first else None,
|
141 |
+
)
|
142 |
+
|
143 |
+
self.decoder = nn.TransformerDecoder(
|
144 |
+
TransformerDecoderLayer(
|
145 |
+
d_model,
|
146 |
+
nhead,
|
147 |
+
dim_feedforward=d_model * 4,
|
148 |
+
dropout=0.1,
|
149 |
+
batch_first=True,
|
150 |
+
norm_first=norm_first,
|
151 |
+
linear1_self_attention_cls=ScaledLinear,
|
152 |
+
linear2_self_attention_cls=partial(
|
153 |
+
ScaledLinear, initial_scale=0.01
|
154 |
+
),
|
155 |
+
linear1_feedforward_cls=ScaledLinear,
|
156 |
+
linear2_feedforward_cls=partial(
|
157 |
+
ScaledLinear, initial_scale=0.01
|
158 |
+
),
|
159 |
+
activation=partial(
|
160 |
+
BalancedDoubleSwish,
|
161 |
+
channel_dim=-1,
|
162 |
+
max_abs=10.0,
|
163 |
+
min_prob=0.25,
|
164 |
+
),
|
165 |
+
layer_norm_cls=IdentityNorm,
|
166 |
+
),
|
167 |
+
num_layers=num_layers,
|
168 |
+
norm=BalancedBasicNorm(d_model) if norm_first else None,
|
169 |
+
)
|
170 |
+
|
171 |
+
self.predict_layer = ScaledLinear(d_model, NUM_MEL_BINS)
|
172 |
+
self.stop_layer = nn.Linear(d_model, 1)
|
173 |
+
else:
|
174 |
+
self.encoder = nn.TransformerEncoder(
|
175 |
+
nn.TransformerEncoderLayer(
|
176 |
+
d_model,
|
177 |
+
nhead,
|
178 |
+
dim_feedforward=d_model * 4,
|
179 |
+
activation=F.relu,
|
180 |
+
dropout=0.1,
|
181 |
+
batch_first=True,
|
182 |
+
norm_first=norm_first,
|
183 |
+
),
|
184 |
+
num_layers=num_layers,
|
185 |
+
norm=nn.LayerNorm(d_model) if norm_first else None,
|
186 |
+
)
|
187 |
+
|
188 |
+
self.decoder = nn.TransformerDecoder(
|
189 |
+
nn.TransformerDecoderLayer(
|
190 |
+
d_model,
|
191 |
+
nhead,
|
192 |
+
dim_feedforward=d_model * 4,
|
193 |
+
activation=F.relu,
|
194 |
+
dropout=0.1,
|
195 |
+
batch_first=True,
|
196 |
+
norm_first=norm_first,
|
197 |
+
),
|
198 |
+
num_layers=num_layers,
|
199 |
+
norm=nn.LayerNorm(d_model) if norm_first else None,
|
200 |
+
)
|
201 |
+
|
202 |
+
self.predict_layer = nn.Linear(d_model, NUM_MEL_BINS)
|
203 |
+
self.stop_layer = nn.Linear(d_model, 1)
|
204 |
+
|
205 |
+
self.stop_accuracy_metric = BinaryAccuracy(
|
206 |
+
threshold=0.5, multidim_average="global"
|
207 |
+
)
|
208 |
+
|
209 |
+
# self.apply(self._init_weights)
|
210 |
+
|
211 |
+
# def _init_weights(self, module):
|
212 |
+
# if isinstance(module, (nn.Linear)):
|
213 |
+
# module.weight.data.normal_(mean=0.0, std=0.02)
|
214 |
+
# if isinstance(module, nn.Linear) and module.bias is not None:
|
215 |
+
# module.bias.data.zero_()
|
216 |
+
# elif isinstance(module, nn.LayerNorm):
|
217 |
+
# module.bias.data.zero_()
|
218 |
+
# module.weight.data.fill_(1.0)
|
219 |
+
# elif isinstance(module, nn.Embedding):
|
220 |
+
# module.weight.data.normal_(mean=0.0, std=0.02)
|
221 |
+
|
222 |
+
def forward(
|
223 |
+
self,
|
224 |
+
x: torch.Tensor,
|
225 |
+
x_lens: torch.Tensor,
|
226 |
+
y: torch.Tensor,
|
227 |
+
y_lens: torch.Tensor,
|
228 |
+
reduction: str = "sum",
|
229 |
+
train_stage: int = 0,
|
230 |
+
**kwargs,
|
231 |
+
) -> Tuple[torch.Tensor, Union[torch.Tensor, None]]:
|
232 |
+
"""
|
233 |
+
Args:
|
234 |
+
x:
|
235 |
+
A 2-D tensor of shape (N, S).
|
236 |
+
x_lens:
|
237 |
+
A 1-D tensor of shape (N,). It contains the number of tokens in `x`
|
238 |
+
before padding.
|
239 |
+
y:
|
240 |
+
A 3-D tensor of shape (N, T, 8).
|
241 |
+
y_lens:
|
242 |
+
A 1-D tensor of shape (N,). It contains the number of tokens in `x`
|
243 |
+
before padding.
|
244 |
+
train_stage:
|
245 |
+
Not used in this model.
|
246 |
+
Returns:
|
247 |
+
Return the predicted audio code matrix, cross-entropy loss and Top-10 accuracy.
|
248 |
+
"""
|
249 |
+
del train_stage
|
250 |
+
|
251 |
+
assert x.ndim == 2, x.shape
|
252 |
+
assert x_lens.ndim == 1, x_lens.shape
|
253 |
+
assert y.ndim == 3, y.shape
|
254 |
+
assert y_lens.ndim == 1, y_lens.shape
|
255 |
+
|
256 |
+
assert torch.all(x_lens > 0)
|
257 |
+
|
258 |
+
# NOTE: x has been padded in TextTokenCollater
|
259 |
+
x_mask = make_pad_mask(x_lens).to(x.device)
|
260 |
+
|
261 |
+
x = self.text_embedding(x)
|
262 |
+
x = self.encoder_prenet(x)
|
263 |
+
x = self.encoder_position(x)
|
264 |
+
x = self.encoder(x, src_key_padding_mask=x_mask)
|
265 |
+
|
266 |
+
total_loss, metrics = 0.0, {}
|
267 |
+
|
268 |
+
y_mask = make_pad_mask(y_lens).to(y.device)
|
269 |
+
y_mask_float = y_mask.type(torch.float32)
|
270 |
+
data_mask = 1.0 - y_mask_float.unsqueeze(-1)
|
271 |
+
|
272 |
+
# Training
|
273 |
+
# AR Decoder
|
274 |
+
def pad_y(y):
|
275 |
+
y = F.pad(y, (0, 0, 1, 0, 0, 0), value=0).detach()
|
276 |
+
# inputs, targets
|
277 |
+
return y[:, :-1], y[:, 1:]
|
278 |
+
|
279 |
+
y, targets = pad_y(y * data_mask) # mask padding as zeros
|
280 |
+
|
281 |
+
y_emb = self.decoder_prenet(y)
|
282 |
+
y_pos = self.decoder_position(y_emb)
|
283 |
+
|
284 |
+
y_len = y_lens.max()
|
285 |
+
tgt_mask = torch.triu(
|
286 |
+
torch.ones(y_len, y_len, device=y.device, dtype=torch.bool),
|
287 |
+
diagonal=1,
|
288 |
+
)
|
289 |
+
y_dec = self.decoder(
|
290 |
+
y_pos,
|
291 |
+
x,
|
292 |
+
tgt_mask=tgt_mask,
|
293 |
+
memory_key_padding_mask=x_mask,
|
294 |
+
)
|
295 |
+
|
296 |
+
predict = self.predict_layer(y_dec)
|
297 |
+
# loss
|
298 |
+
total_loss = F.mse_loss(predict, targets, reduction=reduction)
|
299 |
+
|
300 |
+
logits = self.stop_layer(y_dec).squeeze(-1)
|
301 |
+
stop_loss = F.binary_cross_entropy_with_logits(
|
302 |
+
logits,
|
303 |
+
y_mask_float.detach(),
|
304 |
+
weight=1.0 + y_mask_float.detach() * 4.0,
|
305 |
+
reduction=reduction,
|
306 |
+
)
|
307 |
+
metrics["stop_loss"] = stop_loss.detach()
|
308 |
+
|
309 |
+
stop_accuracy = self.stop_accuracy_metric(
|
310 |
+
(torch.sigmoid(logits) >= 0.5).type(torch.int64),
|
311 |
+
y_mask.type(torch.int64),
|
312 |
+
)
|
313 |
+
# icefall MetricsTracker.norm_items()
|
314 |
+
metrics["stop_accuracy"] = stop_accuracy.item() * y_lens.sum().type(
|
315 |
+
torch.float32
|
316 |
+
)
|
317 |
+
|
318 |
+
return ((x, predict), total_loss + 100.0 * stop_loss, metrics)
|
319 |
+
|
320 |
+
def inference(
|
321 |
+
self,
|
322 |
+
x: torch.Tensor,
|
323 |
+
x_lens: torch.Tensor,
|
324 |
+
y: Any = None,
|
325 |
+
**kwargs,
|
326 |
+
) -> torch.Tensor:
|
327 |
+
"""
|
328 |
+
Args:
|
329 |
+
x:
|
330 |
+
A 2-D tensor of shape (1, S).
|
331 |
+
x_lens:
|
332 |
+
A 1-D tensor of shape (1,). It contains the number of tokens in `x`
|
333 |
+
before padding.
|
334 |
+
Returns:
|
335 |
+
Return the predicted audio code matrix and cross-entropy loss.
|
336 |
+
"""
|
337 |
+
assert x.ndim == 2, x.shape
|
338 |
+
assert x_lens.ndim == 1, x_lens.shape
|
339 |
+
|
340 |
+
assert torch.all(x_lens > 0)
|
341 |
+
|
342 |
+
x_mask = make_pad_mask(x_lens).to(x.device)
|
343 |
+
|
344 |
+
x = self.text_embedding(x)
|
345 |
+
x = self.encoder_prenet(x)
|
346 |
+
x = self.encoder_position(x)
|
347 |
+
x = self.encoder(x, src_key_padding_mask=x_mask)
|
348 |
+
|
349 |
+
x_mask = make_pad_mask(x_lens).to(x.device)
|
350 |
+
|
351 |
+
# AR Decoder
|
352 |
+
# TODO: Managing decoder steps avoid repetitive computation
|
353 |
+
y = torch.zeros(
|
354 |
+
[x.shape[0], 1, NUM_MEL_BINS], dtype=torch.float32, device=x.device
|
355 |
+
)
|
356 |
+
while True:
|
357 |
+
y_emb = self.decoder_prenet(y)
|
358 |
+
y_pos = self.decoder_position(y_emb)
|
359 |
+
|
360 |
+
tgt_mask = torch.triu(
|
361 |
+
torch.ones(
|
362 |
+
y.shape[1], y.shape[1], device=y.device, dtype=torch.bool
|
363 |
+
),
|
364 |
+
diagonal=1,
|
365 |
+
)
|
366 |
+
|
367 |
+
y_dec = self.decoder(
|
368 |
+
y_pos,
|
369 |
+
x,
|
370 |
+
tgt_mask=tgt_mask,
|
371 |
+
memory_mask=None,
|
372 |
+
memory_key_padding_mask=x_mask,
|
373 |
+
)
|
374 |
+
predict = self.predict_layer(y_dec[:, -1:])
|
375 |
+
|
376 |
+
logits = self.stop_layer(y_dec[:, -1:]) > 0 # sigmoid(0.0) = 0.5
|
377 |
+
if y.shape[1] > x_lens.max() * 10 or all(logits.cpu().numpy()):
|
378 |
+
print(
|
379 |
+
f"TransformerTTS EOS [Text {x_lens[0]} -> Audio {y.shape[1]}]"
|
380 |
+
)
|
381 |
+
break
|
382 |
+
|
383 |
+
y = torch.concat([y, predict], dim=1)
|
384 |
+
|
385 |
+
return y[:, 1:]
|
386 |
+
|
387 |
+
def visualize(
|
388 |
+
self,
|
389 |
+
predicts: Tuple[torch.Tensor],
|
390 |
+
batch: Dict[str, Union[List, torch.Tensor]],
|
391 |
+
output_dir: str,
|
392 |
+
limit: int = 4,
|
393 |
+
) -> None:
|
394 |
+
visualize(predicts, batch, output_dir, limit=limit)
|
models/vallex.py
ADDED
@@ -0,0 +1,1353 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023 (authors: Feiteng Li)
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
import random
|
16 |
+
from typing import Dict, Iterator, List, Tuple, Union
|
17 |
+
|
18 |
+
import numpy as np
|
19 |
+
import torch
|
20 |
+
import torch.nn as nn
|
21 |
+
import torch.nn.functional as F
|
22 |
+
|
23 |
+
from data.input_strategies import PromptedFeatures
|
24 |
+
from modules.embedding import SinePositionalEmbedding, TokenEmbedding
|
25 |
+
from modules.transformer import (
|
26 |
+
AdaptiveLayerNorm,
|
27 |
+
LayerNorm,
|
28 |
+
TransformerDecoderLayer,
|
29 |
+
TransformerEncoder,
|
30 |
+
TransformerEncoderLayer,
|
31 |
+
)
|
32 |
+
|
33 |
+
from .macros import NUM_AUDIO_TOKENS, NUM_TEXT_TOKENS
|
34 |
+
from .visualizer import visualize
|
35 |
+
from train_utils.utils import make_pad_mask
|
36 |
+
from torchmetrics.classification import MulticlassAccuracy
|
37 |
+
|
38 |
+
|
39 |
+
class Transpose(nn.Identity):
|
40 |
+
"""(N, T, D) -> (N, D, T)"""
|
41 |
+
|
42 |
+
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
43 |
+
return input.transpose(1, 2)
|
44 |
+
|
45 |
+
|
46 |
+
# NOTE: There are two ways to implement the model
|
47 |
+
# 1) [VALL-F] standard TransformerDecoder, use x as memory
|
48 |
+
# 2) [VALL-E] modified TransformerDecoder like GPT-x(e.g. causal TransformerEncoder),
|
49 |
+
# use x as the prefix of decoder inputs
|
50 |
+
class VALLF(nn.Module):
|
51 |
+
"""It implements https://arxiv.org/abs/2301.02111
|
52 |
+
"Neural Codec Language Models are Zero-Shot Text to Speech Synthesizers"
|
53 |
+
"""
|
54 |
+
|
55 |
+
def __init__(
|
56 |
+
self,
|
57 |
+
d_model: int,
|
58 |
+
nhead: int,
|
59 |
+
num_layers: int,
|
60 |
+
norm_first: bool = True,
|
61 |
+
add_prenet: bool = False,
|
62 |
+
decoder_cls: Union[
|
63 |
+
nn.TransformerDecoder, nn.TransformerEncoder
|
64 |
+
] = nn.TransformerDecoder,
|
65 |
+
decoder_layer_cls: Union[
|
66 |
+
TransformerDecoderLayer, TransformerEncoderLayer
|
67 |
+
] = TransformerDecoderLayer,
|
68 |
+
prefix_mode: int = 0,
|
69 |
+
share_embedding: bool = True,
|
70 |
+
nar_scale_factor: float = 1.0,
|
71 |
+
prepend_bos: bool = True,
|
72 |
+
num_quantizers: int = 8,
|
73 |
+
):
|
74 |
+
"""
|
75 |
+
Args:
|
76 |
+
d_model:
|
77 |
+
The number of expected features in the input (required).
|
78 |
+
nhead:
|
79 |
+
The number of heads in the multiheadattention models (required).
|
80 |
+
num_layers:
|
81 |
+
The number of sub-decoder-layers in the decoder (required).
|
82 |
+
"""
|
83 |
+
super().__init__()
|
84 |
+
nar_d_model = int(d_model * nar_scale_factor)
|
85 |
+
|
86 |
+
self.ar_text_embedding = TokenEmbedding(d_model, NUM_TEXT_TOKENS) # W_x
|
87 |
+
self.nar_text_embedding = TokenEmbedding(nar_d_model, NUM_TEXT_TOKENS)
|
88 |
+
|
89 |
+
# ID NUM_AUDIO_TOKENS -> PAD
|
90 |
+
# ID NUM_AUDIO_TOKENS + 1 -> BOS
|
91 |
+
self.ar_audio_prepend_bos = prepend_bos
|
92 |
+
self.ar_audio_embedding = TokenEmbedding(
|
93 |
+
d_model, NUM_AUDIO_TOKENS + 1 + int(prepend_bos)
|
94 |
+
)
|
95 |
+
|
96 |
+
# PreNet
|
97 |
+
if add_prenet:
|
98 |
+
self.ar_text_prenet = nn.Sequential(
|
99 |
+
Transpose(),
|
100 |
+
nn.Conv1d(d_model, d_model, kernel_size=5, padding="same"),
|
101 |
+
nn.BatchNorm1d(d_model),
|
102 |
+
nn.ReLU(),
|
103 |
+
nn.Dropout(0.5),
|
104 |
+
nn.Conv1d(d_model, d_model, kernel_size=5, padding="same"),
|
105 |
+
nn.BatchNorm1d(d_model),
|
106 |
+
nn.ReLU(),
|
107 |
+
nn.Dropout(0.5),
|
108 |
+
nn.Conv1d(d_model, d_model, kernel_size=5, padding="same"),
|
109 |
+
nn.BatchNorm1d(d_model),
|
110 |
+
nn.ReLU(),
|
111 |
+
nn.Dropout(0.5),
|
112 |
+
Transpose(),
|
113 |
+
nn.Linear(d_model, d_model),
|
114 |
+
)
|
115 |
+
|
116 |
+
self.ar_audio_prenet = nn.Sequential(
|
117 |
+
nn.Linear(d_model, 256),
|
118 |
+
nn.ReLU(),
|
119 |
+
nn.Dropout(0.25),
|
120 |
+
nn.Linear(256, 256),
|
121 |
+
nn.ReLU(),
|
122 |
+
nn.Dropout(0.25),
|
123 |
+
nn.Linear(256, d_model),
|
124 |
+
)
|
125 |
+
else:
|
126 |
+
self.ar_text_prenet = nn.Identity()
|
127 |
+
self.ar_audio_prenet = nn.Identity()
|
128 |
+
|
129 |
+
self.ar_text_position = SinePositionalEmbedding(
|
130 |
+
d_model,
|
131 |
+
dropout=0.1,
|
132 |
+
scale=False,
|
133 |
+
alpha=True,
|
134 |
+
)
|
135 |
+
self.ar_audio_position = SinePositionalEmbedding(
|
136 |
+
d_model,
|
137 |
+
dropout=0.1,
|
138 |
+
scale=False,
|
139 |
+
alpha=True,
|
140 |
+
)
|
141 |
+
|
142 |
+
self.ar_decoder = decoder_cls(
|
143 |
+
decoder_layer_cls(
|
144 |
+
d_model,
|
145 |
+
nhead,
|
146 |
+
dim_feedforward=d_model * 4,
|
147 |
+
dropout=0.1,
|
148 |
+
batch_first=True,
|
149 |
+
norm_first=norm_first,
|
150 |
+
),
|
151 |
+
num_layers=num_layers,
|
152 |
+
norm=LayerNorm(d_model) if norm_first else None,
|
153 |
+
)
|
154 |
+
self.ar_predict_layer = nn.Linear(
|
155 |
+
d_model, NUM_AUDIO_TOKENS + 1, bias=False
|
156 |
+
)
|
157 |
+
|
158 |
+
self.ar_accuracy_metric = MulticlassAccuracy(
|
159 |
+
NUM_AUDIO_TOKENS + 1,
|
160 |
+
top_k=10,
|
161 |
+
average="micro",
|
162 |
+
multidim_average="global",
|
163 |
+
ignore_index=NUM_AUDIO_TOKENS,
|
164 |
+
)
|
165 |
+
|
166 |
+
self.rng = random.Random(0)
|
167 |
+
self.num_heads = nhead
|
168 |
+
self.prefix_mode = prefix_mode
|
169 |
+
self.num_quantizers = num_quantizers
|
170 |
+
|
171 |
+
assert num_quantizers >= 1
|
172 |
+
if num_quantizers > 1:
|
173 |
+
self.nar_audio_embeddings = nn.ModuleList(
|
174 |
+
[TokenEmbedding(nar_d_model, NUM_AUDIO_TOKENS + 1)]
|
175 |
+
+ [
|
176 |
+
TokenEmbedding(nar_d_model, NUM_AUDIO_TOKENS)
|
177 |
+
for i in range(num_quantizers - 1)
|
178 |
+
]
|
179 |
+
) # W_a
|
180 |
+
|
181 |
+
# PreNet
|
182 |
+
if add_prenet:
|
183 |
+
self.nar_text_prenet = nn.Sequential(
|
184 |
+
Transpose(),
|
185 |
+
nn.Conv1d(
|
186 |
+
nar_d_model, nar_d_model, kernel_size=5, padding="same"
|
187 |
+
),
|
188 |
+
nn.BatchNorm1d(nar_d_model),
|
189 |
+
nn.ReLU(),
|
190 |
+
nn.Dropout(0.5),
|
191 |
+
nn.Conv1d(
|
192 |
+
nar_d_model, nar_d_model, kernel_size=5, padding="same"
|
193 |
+
),
|
194 |
+
nn.BatchNorm1d(nar_d_model),
|
195 |
+
nn.ReLU(),
|
196 |
+
nn.Dropout(0.5),
|
197 |
+
nn.Conv1d(
|
198 |
+
nar_d_model, nar_d_model, kernel_size=5, padding="same"
|
199 |
+
),
|
200 |
+
nn.BatchNorm1d(nar_d_model),
|
201 |
+
nn.ReLU(),
|
202 |
+
nn.Dropout(0.5),
|
203 |
+
Transpose(),
|
204 |
+
nn.Linear(nar_d_model, nar_d_model),
|
205 |
+
)
|
206 |
+
self.nar_audio_prenet = nn.Sequential(
|
207 |
+
nn.Linear(nar_d_model, 256),
|
208 |
+
nn.ReLU(),
|
209 |
+
nn.Dropout(0.25),
|
210 |
+
nn.Linear(256, 256),
|
211 |
+
nn.ReLU(),
|
212 |
+
nn.Dropout(0.25),
|
213 |
+
nn.Linear(256, nar_d_model),
|
214 |
+
)
|
215 |
+
else:
|
216 |
+
self.nar_text_prenet = nn.Identity()
|
217 |
+
self.nar_audio_prenet = nn.Identity()
|
218 |
+
|
219 |
+
self.nar_text_position = SinePositionalEmbedding(
|
220 |
+
nar_d_model,
|
221 |
+
dropout=0.0,
|
222 |
+
scale=False,
|
223 |
+
alpha=False,
|
224 |
+
)
|
225 |
+
self.nar_audio_position = SinePositionalEmbedding(
|
226 |
+
nar_d_model,
|
227 |
+
dropout=0.1,
|
228 |
+
scale=False,
|
229 |
+
alpha=False,
|
230 |
+
)
|
231 |
+
|
232 |
+
self.nar_decoder = decoder_cls(
|
233 |
+
decoder_layer_cls(
|
234 |
+
nar_d_model,
|
235 |
+
int(nhead * nar_scale_factor),
|
236 |
+
dim_feedforward=nar_d_model * 4,
|
237 |
+
dropout=0.1,
|
238 |
+
batch_first=True,
|
239 |
+
norm_first=norm_first,
|
240 |
+
adaptive_layer_norm=True,
|
241 |
+
),
|
242 |
+
num_layers=int(num_layers * nar_scale_factor),
|
243 |
+
norm=AdaptiveLayerNorm(
|
244 |
+
nar_d_model, norm=nn.LayerNorm(nar_d_model)
|
245 |
+
)
|
246 |
+
if norm_first
|
247 |
+
else None,
|
248 |
+
)
|
249 |
+
self.nar_predict_layers = nn.ModuleList(
|
250 |
+
[
|
251 |
+
nn.Linear(nar_d_model, NUM_AUDIO_TOKENS, bias=False)
|
252 |
+
for i in range(num_quantizers - 1)
|
253 |
+
]
|
254 |
+
)
|
255 |
+
self.nar_stage_embeddings = nn.ModuleList(
|
256 |
+
[
|
257 |
+
TokenEmbedding(nar_d_model, 1)
|
258 |
+
for i in range(num_quantizers - 1)
|
259 |
+
]
|
260 |
+
)
|
261 |
+
|
262 |
+
if share_embedding:
|
263 |
+
# We share the parameters of the output projection layer with the parameters of the acoustic embedding Wa
|
264 |
+
# NOTE(Feiteng): In the experiment, this undermines accuracy
|
265 |
+
# self.ar_predict_layer.weight = self.ar_audio_embedding.weight
|
266 |
+
|
267 |
+
# We also share the parameters of the acoustic embedding layer and the output prediction layer,
|
268 |
+
# which means the weights of the j-th prediction layer are the same as the (j + 1)-th acoustic embedding layer.
|
269 |
+
for j in range(0, num_quantizers - 2):
|
270 |
+
self.nar_predict_layers[
|
271 |
+
j
|
272 |
+
].weight = self.nar_audio_embeddings[j + 2].weight
|
273 |
+
|
274 |
+
self.nar_accuracy_metric = MulticlassAccuracy(
|
275 |
+
NUM_AUDIO_TOKENS + 1,
|
276 |
+
top_k=10,
|
277 |
+
average="micro",
|
278 |
+
multidim_average="global",
|
279 |
+
ignore_index=NUM_AUDIO_TOKENS,
|
280 |
+
)
|
281 |
+
|
282 |
+
def stage_parameters(self, stage: int = 1) -> Iterator[nn.Parameter]:
|
283 |
+
assert stage > 0
|
284 |
+
if stage == 1:
|
285 |
+
for name, param in self.named_parameters():
|
286 |
+
if name.startswith("ar_"):
|
287 |
+
print(f" AR parameter: {name}")
|
288 |
+
yield param
|
289 |
+
|
290 |
+
if stage == 2:
|
291 |
+
for name, param in self.named_parameters():
|
292 |
+
if name.startswith("nar_"):
|
293 |
+
print(f"NAR parameter: {name}")
|
294 |
+
yield param
|
295 |
+
|
296 |
+
def stage_named_parameters(
|
297 |
+
self, stage: int = 1
|
298 |
+
) -> Iterator[Tuple[str, nn.Parameter]]:
|
299 |
+
assert stage > 0
|
300 |
+
if stage == 1:
|
301 |
+
for pair in self.named_parameters():
|
302 |
+
if pair[0].startswith("ar_"):
|
303 |
+
yield pair
|
304 |
+
|
305 |
+
if stage == 2:
|
306 |
+
for pair in self.named_parameters():
|
307 |
+
if pair[0].startswith("nar_"):
|
308 |
+
yield pair
|
309 |
+
|
310 |
+
def pad_y_eos(self, y, y_mask_int, eos_id):
|
311 |
+
targets = F.pad(y, (0, 1), value=0) + eos_id * F.pad(
|
312 |
+
y_mask_int, (0, 1), value=1
|
313 |
+
)
|
314 |
+
# inputs, targets
|
315 |
+
if self.ar_audio_prepend_bos:
|
316 |
+
return (
|
317 |
+
F.pad(targets[:, :-1], (1, 0), value=NUM_AUDIO_TOKENS + 1),
|
318 |
+
targets,
|
319 |
+
)
|
320 |
+
|
321 |
+
return targets[:, :-1], targets[:, 1:]
|
322 |
+
|
323 |
+
def _prepare_prompts(self, y, y_lens, codes, nar_stage, y_prompts_codes, prefix_mode):
|
324 |
+
# 5.1 For the NAR acoustic prompt tokens, we select a random segment waveform of 3 seconds
|
325 |
+
# from the same utterance.
|
326 |
+
# We implement this differently.
|
327 |
+
if prefix_mode == 0:
|
328 |
+
# no prefix
|
329 |
+
prefix_len = 0
|
330 |
+
y_emb = self.nar_audio_embeddings[0](y)
|
331 |
+
for j in range(1, nar_stage):
|
332 |
+
# Formula (4) (5)
|
333 |
+
y_emb = y_emb + self.nar_audio_embeddings[j](codes[..., j])
|
334 |
+
elif prefix_mode == 1:
|
335 |
+
# prefix at begining
|
336 |
+
int_low = (0.25 * y_lens.min()).type(torch.int64).item()
|
337 |
+
prefix_len = torch.randint(0, int_low * 2, size=()).item()
|
338 |
+
prefix_len = min(prefix_len, 225) # 24000/320 * 3s = 225 frames
|
339 |
+
|
340 |
+
y_prompts = self.nar_audio_embeddings[0](y[:, :prefix_len])
|
341 |
+
y_emb = self.nar_audio_embeddings[0](y[:, prefix_len:])
|
342 |
+
for j in range(1, self.num_quantizers):
|
343 |
+
y_prompts += self.nar_audio_embeddings[j](
|
344 |
+
codes[:, :prefix_len, j]
|
345 |
+
)
|
346 |
+
if j < nar_stage:
|
347 |
+
y_emb += self.nar_audio_embeddings[j](
|
348 |
+
codes[:, prefix_len:, j]
|
349 |
+
)
|
350 |
+
y_emb = torch.concat([y_prompts, y_emb], axis=1)
|
351 |
+
elif prefix_mode in [2, 4]:
|
352 |
+
if prefix_mode == 2:
|
353 |
+
# random prefix
|
354 |
+
prefix_len = min(225, int(0.25 * y_lens.min().item()))
|
355 |
+
|
356 |
+
y_prompts_codes = []
|
357 |
+
for b in range(codes.shape[0]):
|
358 |
+
start = self.rng.randint(0, y_lens[b].item() - prefix_len)
|
359 |
+
y_prompts_codes.append(
|
360 |
+
torch.clone(codes[b, start : start + prefix_len])
|
361 |
+
)
|
362 |
+
codes[
|
363 |
+
b, start : start + prefix_len, nar_stage
|
364 |
+
] = NUM_AUDIO_TOKENS
|
365 |
+
y_prompts_codes = torch.stack(y_prompts_codes, dim=0)
|
366 |
+
else:
|
367 |
+
prefix_len = y_prompts_codes.shape[1]
|
368 |
+
|
369 |
+
y_prompts = self.nar_audio_embeddings[0](y_prompts_codes[..., 0])
|
370 |
+
y_emb = self.nar_audio_embeddings[0](y)
|
371 |
+
for j in range(1, self.num_quantizers):
|
372 |
+
y_prompts += self.nar_audio_embeddings[j](
|
373 |
+
y_prompts_codes[..., j]
|
374 |
+
)
|
375 |
+
if j < nar_stage:
|
376 |
+
y_emb += self.nar_audio_embeddings[j](codes[..., j])
|
377 |
+
y_emb = torch.concat([y_prompts, y_emb], axis=1)
|
378 |
+
else:
|
379 |
+
raise ValueError
|
380 |
+
|
381 |
+
return y_emb, prefix_len
|
382 |
+
|
383 |
+
def forward(
|
384 |
+
self,
|
385 |
+
x: torch.Tensor,
|
386 |
+
x_lens: torch.Tensor,
|
387 |
+
y: Union[torch.Tensor, PromptedFeatures],
|
388 |
+
y_lens: Union[torch.Tensor, PromptedFeatures],
|
389 |
+
reduction: str = "sum",
|
390 |
+
train_stage: int = 0,
|
391 |
+
**kwargs,
|
392 |
+
) -> Tuple[torch.Tensor, Union[torch.Tensor, None]]:
|
393 |
+
"""
|
394 |
+
Args:
|
395 |
+
x:
|
396 |
+
A 2-D tensor of shape (N, S).
|
397 |
+
x_lens:
|
398 |
+
A 1-D tensor of shape (N,). It contains the number of tokens in `x`
|
399 |
+
before padding.
|
400 |
+
y:
|
401 |
+
A 3-D tensor of shape (N, T, 8).
|
402 |
+
y_lens:
|
403 |
+
A 1-D tensor of shape (N,). It contains the number of tokens in `x`
|
404 |
+
before padding.
|
405 |
+
train_stage:
|
406 |
+
0: AR & NAR modules, 1: AR modules, 2: NAR modules
|
407 |
+
Returns:
|
408 |
+
Return the predicted audio code matrix, cross-entropy loss and Top-10 accuracy.
|
409 |
+
"""
|
410 |
+
assert x.ndim == 2, x.shape
|
411 |
+
assert x_lens.ndim == 1, x_lens.shape
|
412 |
+
|
413 |
+
y_prompts_codes = None
|
414 |
+
if isinstance(y, PromptedFeatures):
|
415 |
+
y_prompts_codes, y = y.data
|
416 |
+
prompts_len, y_lens = y_lens.data
|
417 |
+
assert prompts_len.min() == prompts_len.max()
|
418 |
+
assert self.prefix_mode == 4
|
419 |
+
y_prompts_codes = y_prompts_codes.type(torch.int64)
|
420 |
+
|
421 |
+
assert y.ndim == 3, y.shape
|
422 |
+
assert y_lens.ndim == 1, y_lens.shape
|
423 |
+
|
424 |
+
# NOTE: x has been padded in TextTokenCollater
|
425 |
+
x_mask = make_pad_mask(x_lens).to(x.device)
|
426 |
+
|
427 |
+
text = x
|
428 |
+
x = self.ar_text_embedding(text)
|
429 |
+
x = self.ar_text_prenet(x)
|
430 |
+
x = self.ar_text_position(x)
|
431 |
+
|
432 |
+
total_loss, metrics = 0.0, {}
|
433 |
+
|
434 |
+
y_mask = make_pad_mask(y_lens).to(y.device)
|
435 |
+
y_mask_int = y_mask.type(torch.int64)
|
436 |
+
|
437 |
+
codes = y.type(torch.int64) * (1 - y_mask_int.unsqueeze(dim=-1))
|
438 |
+
|
439 |
+
# Training
|
440 |
+
# AR Decoder
|
441 |
+
y, targets = self.pad_y_eos(
|
442 |
+
codes[..., 0], y_mask_int, eos_id=NUM_AUDIO_TOKENS
|
443 |
+
)
|
444 |
+
|
445 |
+
if train_stage in [0, 1]:
|
446 |
+
y_emb = self.ar_audio_embedding(y)
|
447 |
+
y_emb = self.ar_audio_prenet(y_emb)
|
448 |
+
y_pos = self.ar_audio_position(y_emb)
|
449 |
+
|
450 |
+
ar_y_mask = y_mask
|
451 |
+
if self.ar_audio_prepend_bos:
|
452 |
+
ar_y_mask = F.pad(y_mask, (1, 0), value=False)
|
453 |
+
|
454 |
+
y_len = y_lens.max() + int(self.ar_audio_prepend_bos)
|
455 |
+
tgt_mask = torch.triu(
|
456 |
+
torch.ones(y_len, y_len, device=y.device, dtype=torch.bool),
|
457 |
+
diagonal=1,
|
458 |
+
)
|
459 |
+
y_dec, _ = self.ar_decoder(
|
460 |
+
(y_pos, None),
|
461 |
+
x,
|
462 |
+
tgt_mask=tgt_mask,
|
463 |
+
tgt_key_padding_mask=ar_y_mask,
|
464 |
+
memory_mask=None,
|
465 |
+
memory_key_padding_mask=x_mask,
|
466 |
+
)
|
467 |
+
logits = self.ar_predict_layer(y_dec).permute(0, 2, 1)
|
468 |
+
# loss
|
469 |
+
total_loss = F.cross_entropy(logits, targets, reduction=reduction)
|
470 |
+
metrics["ArTop10Accuracy"] = self.ar_accuracy_metric(
|
471 |
+
logits.detach(), targets
|
472 |
+
).item() * y_lens.sum().type(torch.float32)
|
473 |
+
|
474 |
+
if self.num_quantizers == 1:
|
475 |
+
return ((x, codes), total_loss, metrics)
|
476 |
+
|
477 |
+
# Non-AR Decoders
|
478 |
+
if self.ar_audio_prepend_bos:
|
479 |
+
y = y[:, 1:]
|
480 |
+
|
481 |
+
if train_stage in [0, 2]:
|
482 |
+
num_nar_layers = self.num_quantizers - 1
|
483 |
+
nar_stage = self.rng.choices(
|
484 |
+
[_k for _k in range(1, self.num_quantizers)],
|
485 |
+
weights=[1.0 / num_nar_layers] * num_nar_layers,
|
486 |
+
k=1,
|
487 |
+
)[0]
|
488 |
+
|
489 |
+
x = self.nar_text_embedding(text)
|
490 |
+
x = self.nar_text_prenet(x)
|
491 |
+
x = self.nar_text_position(x)
|
492 |
+
|
493 |
+
y_emb, prefix_len = self._prepare_prompts(
|
494 |
+
y, y_lens, codes, nar_stage, y_prompts_codes, self.prefix_mode
|
495 |
+
)
|
496 |
+
|
497 |
+
y_len = y_lens.max()
|
498 |
+
targets = codes[..., nar_stage] + NUM_AUDIO_TOKENS * y_mask_int
|
499 |
+
if self.prefix_mode in [2, 4]:
|
500 |
+
targets = targets
|
501 |
+
y_mask = F.pad(y_mask, (y_emb.shape[1] - y_len, 0), value=False)
|
502 |
+
elif self.prefix_mode == 1:
|
503 |
+
targets = targets[:, prefix_len:]
|
504 |
+
else:
|
505 |
+
assert prefix_len == 0
|
506 |
+
|
507 |
+
y_pos = self.nar_audio_prenet(y_emb)
|
508 |
+
y_pos = self.nar_audio_position(y_pos)
|
509 |
+
|
510 |
+
y_dec, _ = self.nar_decoder(
|
511 |
+
(y_pos, self.nar_stage_embeddings[nar_stage - 1].weight),
|
512 |
+
x,
|
513 |
+
tgt_mask=None,
|
514 |
+
tgt_key_padding_mask=y_mask,
|
515 |
+
memory_mask=None,
|
516 |
+
memory_key_padding_mask=x_mask,
|
517 |
+
)
|
518 |
+
if self.prefix_mode != 0:
|
519 |
+
y_dec = y_dec[:, prefix_len:]
|
520 |
+
if self.prefix_mode == 4:
|
521 |
+
prefix_len = 0 # reset for Top10Accuracy metric
|
522 |
+
|
523 |
+
logits = self.nar_predict_layers[nar_stage - 1](y_dec).permute(
|
524 |
+
0, 2, 1
|
525 |
+
)
|
526 |
+
# loss
|
527 |
+
total_length = (y_lens).sum().type(torch.float32)
|
528 |
+
total_loss += (
|
529 |
+
F.cross_entropy(
|
530 |
+
logits,
|
531 |
+
targets,
|
532 |
+
ignore_index=NUM_AUDIO_TOKENS,
|
533 |
+
reduction=reduction,
|
534 |
+
)
|
535 |
+
* (total_length / (total_length - prefix_len * x.shape[0]))
|
536 |
+
)
|
537 |
+
metrics["NarTop10Accuracy"] = (
|
538 |
+
self.nar_accuracy_metric(
|
539 |
+
F.pad(
|
540 |
+
logits.detach(),
|
541 |
+
(0, 0, 0, 1, 0, 0),
|
542 |
+
value=logits.min().cpu().item(),
|
543 |
+
),
|
544 |
+
targets,
|
545 |
+
).item()
|
546 |
+
* total_length
|
547 |
+
)
|
548 |
+
|
549 |
+
if train_stage == 0:
|
550 |
+
total_loss = total_loss / 2.0
|
551 |
+
print("total_loss:", total_loss)
|
552 |
+
|
553 |
+
return ((x, codes), total_loss, metrics)
|
554 |
+
|
555 |
+
def inference(
|
556 |
+
self,
|
557 |
+
x: torch.Tensor,
|
558 |
+
x_lens: torch.Tensor,
|
559 |
+
y: torch.Tensor,
|
560 |
+
enroll_x_lens: Union[torch.Tensor, None] = None,
|
561 |
+
top_k: int = -100,
|
562 |
+
temperature: float = 1.0,
|
563 |
+
) -> torch.Tensor:
|
564 |
+
"""
|
565 |
+
Args:
|
566 |
+
x:
|
567 |
+
A 2-D tensor of shape (1, S).
|
568 |
+
x_lens:
|
569 |
+
A 1-D tensor of shape (1,). It contains the number of tokens in `x`
|
570 |
+
before padding.
|
571 |
+
y:
|
572 |
+
A 3-D tensor of shape (1, T, 8).
|
573 |
+
top_k: (`optional`) int
|
574 |
+
The number of highest probability tokens to keep for top-k-filtering. Default to -100.
|
575 |
+
temperature: (`optional`) float
|
576 |
+
The value used to module the next token probabilities. Must be strictly positive. Default to 1.0.
|
577 |
+
Returns:
|
578 |
+
Return the predicted audio code matrix and cross-entropy loss.
|
579 |
+
"""
|
580 |
+
assert x.ndim == 2, x.shape
|
581 |
+
assert x_lens.ndim == 1, x_lens.shape
|
582 |
+
assert y.ndim == 3, y.shape
|
583 |
+
assert y.shape[0] == 1, y.shape
|
584 |
+
|
585 |
+
assert torch.all(x_lens > 0)
|
586 |
+
|
587 |
+
text = x
|
588 |
+
x = self.ar_text_embedding(text)
|
589 |
+
x = self.ar_text_prenet(x)
|
590 |
+
x = self.ar_text_position(x)
|
591 |
+
# NOTE: x has been padded in TextTokenCollater
|
592 |
+
x_mask = make_pad_mask(x_lens).to(x.device)
|
593 |
+
|
594 |
+
prompts = y
|
595 |
+
prefix_len = y.shape[1]
|
596 |
+
|
597 |
+
# AR Decoder
|
598 |
+
# TODO: Managing decoder steps avoid repetitive computation
|
599 |
+
y = prompts[..., 0]
|
600 |
+
if self.ar_audio_prepend_bos:
|
601 |
+
y = F.pad(y, (1, 0), value=NUM_AUDIO_TOKENS + 1)
|
602 |
+
|
603 |
+
while True:
|
604 |
+
y_emb = self.ar_audio_embedding(y)
|
605 |
+
y_emb = self.ar_audio_prenet(y_emb)
|
606 |
+
y_pos = self.ar_audio_position(y_emb)
|
607 |
+
|
608 |
+
tgt_mask = torch.triu(
|
609 |
+
torch.ones(
|
610 |
+
y.shape[1], y.shape[1], device=y.device, dtype=torch.bool
|
611 |
+
),
|
612 |
+
diagonal=1,
|
613 |
+
)
|
614 |
+
|
615 |
+
y_dec, _ = self.ar_decoder(
|
616 |
+
(y_pos, None),
|
617 |
+
x,
|
618 |
+
tgt_mask=tgt_mask,
|
619 |
+
memory_mask=None,
|
620 |
+
memory_key_padding_mask=x_mask,
|
621 |
+
)
|
622 |
+
logits = self.ar_predict_layer(y_dec[:, -1])
|
623 |
+
samples = topk_sampling(
|
624 |
+
logits, top_k=top_k, top_p=1.0, temperature=temperature
|
625 |
+
)
|
626 |
+
|
627 |
+
if (
|
628 |
+
torch.argmax(logits, dim=-1)[0] == NUM_AUDIO_TOKENS
|
629 |
+
or samples[0, 0] == NUM_AUDIO_TOKENS
|
630 |
+
or (y.shape[1] - prefix_len) > x_lens.max() * 16
|
631 |
+
):
|
632 |
+
if prompts.shape[1] == y.shape[1]:
|
633 |
+
raise SyntaxError(
|
634 |
+
"well trained model shouldn't reach here."
|
635 |
+
)
|
636 |
+
|
637 |
+
print(f"VALL-F EOS [{prefix_len} -> {y.shape[1]}]")
|
638 |
+
break
|
639 |
+
|
640 |
+
y = torch.concat([y, samples], dim=1)
|
641 |
+
|
642 |
+
codes = [y[:, prefix_len + int(self.ar_audio_prepend_bos) :]]
|
643 |
+
if self.num_quantizers == 1:
|
644 |
+
return torch.stack(codes, dim=-1)
|
645 |
+
|
646 |
+
# Non-AR Decoders
|
647 |
+
y_emb = self.nar_audio_embeddings[0](
|
648 |
+
y[:, int(self.ar_audio_prepend_bos) :]
|
649 |
+
)
|
650 |
+
if self.prefix_mode in [2, 4]: # Exclude enrolled_phonemes
|
651 |
+
enrolled_len = enroll_x_lens.max().item()
|
652 |
+
# SOS + Synthesis Text + EOS
|
653 |
+
text = torch.concat(
|
654 |
+
[
|
655 |
+
text[:, :1],
|
656 |
+
text[:, enrolled_len - 1 :],
|
657 |
+
],
|
658 |
+
dim=1,
|
659 |
+
)
|
660 |
+
assert text.shape[0] == 1
|
661 |
+
|
662 |
+
x = self.nar_text_embedding(text)
|
663 |
+
x = self.nar_text_prenet(x)
|
664 |
+
x = self.nar_text_position(x)
|
665 |
+
|
666 |
+
if self.prefix_mode != 0:
|
667 |
+
for j in range(1, self.num_quantizers):
|
668 |
+
y_emb[:, :prefix_len] += self.nar_audio_embeddings[j](
|
669 |
+
prompts[..., j]
|
670 |
+
)
|
671 |
+
|
672 |
+
for i, (predict_layer, embedding_layer) in enumerate(
|
673 |
+
zip(
|
674 |
+
self.nar_predict_layers,
|
675 |
+
self.nar_audio_embeddings[1:],
|
676 |
+
)
|
677 |
+
):
|
678 |
+
y_pos = self.nar_audio_prenet(y_emb)
|
679 |
+
y_pos = self.nar_audio_position(y_pos)
|
680 |
+
y_dec, _ = self.nar_decoder(
|
681 |
+
(y_pos, self.nar_stage_embeddings[i].weight),
|
682 |
+
x,
|
683 |
+
tgt_mask=None,
|
684 |
+
memory_mask=None,
|
685 |
+
memory_key_padding_mask=None,
|
686 |
+
)
|
687 |
+
logits = predict_layer(y_dec[:, prefix_len:])
|
688 |
+
samples = torch.argmax(logits, dim=-1)
|
689 |
+
codes.append(samples)
|
690 |
+
# Formula (4) (5)
|
691 |
+
if i < 6:
|
692 |
+
if self.prefix_mode == 0:
|
693 |
+
y_emb[:, :prefix_len] += embedding_layer(
|
694 |
+
prompts[..., i + 1]
|
695 |
+
)
|
696 |
+
y_emb[:, prefix_len:] += embedding_layer(samples)
|
697 |
+
|
698 |
+
assert len(codes) == self.num_quantizers
|
699 |
+
return torch.stack(codes, dim=-1)
|
700 |
+
|
701 |
+
def visualize(
|
702 |
+
self,
|
703 |
+
predicts: Tuple[torch.Tensor],
|
704 |
+
batch: Dict[str, Union[List, torch.Tensor]],
|
705 |
+
output_dir: str,
|
706 |
+
limit: int = 4,
|
707 |
+
) -> None:
|
708 |
+
visualize(predicts, batch, output_dir, limit=limit)
|
709 |
+
|
710 |
+
|
711 |
+
class VALLE(VALLF):
|
712 |
+
"""It implements https://arxiv.org/abs/2301.02111
|
713 |
+
"Neural Codec Language Models are Zero-Shot Text to Speech Synthesizers"
|
714 |
+
"""
|
715 |
+
|
716 |
+
def __init__(
|
717 |
+
self,
|
718 |
+
d_model: int,
|
719 |
+
nhead: int,
|
720 |
+
num_layers: int,
|
721 |
+
norm_first: bool = True,
|
722 |
+
add_prenet: bool = False,
|
723 |
+
prefix_mode: int = 0,
|
724 |
+
share_embedding: bool = True,
|
725 |
+
nar_scale_factor: float = 1.0,
|
726 |
+
**kwargs,
|
727 |
+
):
|
728 |
+
"""
|
729 |
+
Args:
|
730 |
+
d_model:
|
731 |
+
The number of expected features in the input (required).
|
732 |
+
nhead:
|
733 |
+
The number of heads in the multiheadattention models (required).
|
734 |
+
num_layers:
|
735 |
+
The number of sub-decoder-layers in the decoder (required).
|
736 |
+
"""
|
737 |
+
super(VALLE, self).__init__(
|
738 |
+
d_model,
|
739 |
+
nhead,
|
740 |
+
num_layers,
|
741 |
+
norm_first=norm_first,
|
742 |
+
add_prenet=add_prenet,
|
743 |
+
decoder_cls=TransformerEncoder,
|
744 |
+
decoder_layer_cls=TransformerEncoderLayer,
|
745 |
+
prefix_mode=prefix_mode,
|
746 |
+
share_embedding=share_embedding,
|
747 |
+
nar_scale_factor=nar_scale_factor,
|
748 |
+
**kwargs,
|
749 |
+
)
|
750 |
+
self.language_ID = {
|
751 |
+
'en': 0,
|
752 |
+
'zh': 1,
|
753 |
+
'ja': 2,
|
754 |
+
'vi': 3
|
755 |
+
}
|
756 |
+
self.ar_language_embedding = TokenEmbedding(d_model, 3)
|
757 |
+
self.nar_language_embedding = TokenEmbedding(d_model, 3)
|
758 |
+
|
759 |
+
def forward(
|
760 |
+
self,
|
761 |
+
x: torch.Tensor,
|
762 |
+
x_lens: torch.Tensor,
|
763 |
+
y: Union[torch.Tensor, PromptedFeatures],
|
764 |
+
y_lens: Union[torch.Tensor, PromptedFeatures],
|
765 |
+
reduction: str = "sum",
|
766 |
+
train_stage: int = 0,
|
767 |
+
**kwargs,
|
768 |
+
) -> Tuple[torch.Tensor, Union[torch.Tensor, None]]:
|
769 |
+
"""
|
770 |
+
Args:
|
771 |
+
x:
|
772 |
+
A 2-D tensor of shape (N, S).
|
773 |
+
x_lens:
|
774 |
+
A 1-D tensor of shape (N,). It contains the number of tokens in `x`
|
775 |
+
before padding.
|
776 |
+
y:
|
777 |
+
A 3-D tensor of shape (N, T, 8).
|
778 |
+
y_lens:
|
779 |
+
A 1-D tensor of shape (N,). It contains the number of tokens in `x`
|
780 |
+
before padding.
|
781 |
+
train_stage:
|
782 |
+
0: AR & NAR modules, 1: AR modules, 2: NAR modules
|
783 |
+
Returns:
|
784 |
+
Return the predicted audio code matrix, cross-entropy loss and Top-10 accuracy.
|
785 |
+
"""
|
786 |
+
assert x.ndim == 2, x.shape
|
787 |
+
assert x_lens.ndim == 1, x_lens.shape
|
788 |
+
|
789 |
+
y_prompts_codes = None
|
790 |
+
if isinstance(y, PromptedFeatures):
|
791 |
+
y_prompts_codes, y = y.data
|
792 |
+
prompts_len, y_lens = y_lens.data
|
793 |
+
assert prompts_len.min() == prompts_len.max()
|
794 |
+
assert self.prefix_mode == 4
|
795 |
+
y_prompts_codes = y_prompts_codes.type(torch.int64)
|
796 |
+
|
797 |
+
assert y.ndim == 3, y.shape
|
798 |
+
assert y_lens.ndim == 1, y_lens.shape
|
799 |
+
|
800 |
+
# NOTE: x has been padded in TextTokenCollater
|
801 |
+
x_mask = make_pad_mask(x_lens).to(x.device)
|
802 |
+
y_mask = make_pad_mask(y_lens).to(y.device)
|
803 |
+
y_mask_int = y_mask.type(torch.int64)
|
804 |
+
|
805 |
+
text = x
|
806 |
+
codes = y.type(torch.int64) * (1 - y_mask_int.unsqueeze(dim=-1))
|
807 |
+
|
808 |
+
y, targets = self.pad_y_eos(
|
809 |
+
codes[..., 0], y_mask_int, eos_id=NUM_AUDIO_TOKENS
|
810 |
+
)
|
811 |
+
|
812 |
+
x_len = x_lens.max()
|
813 |
+
|
814 |
+
metrics = {}
|
815 |
+
total_loss = 0.0
|
816 |
+
|
817 |
+
xy_padding_mask = torch.concat([x_mask, y_mask], dim=1)
|
818 |
+
if self.ar_audio_prepend_bos:
|
819 |
+
ar_xy_padding_mask = torch.concat(
|
820 |
+
[x_mask, F.pad(y_mask, (1, 0), value=False)], dim=1
|
821 |
+
)
|
822 |
+
else:
|
823 |
+
ar_xy_padding_mask = xy_padding_mask
|
824 |
+
# AR Decoder
|
825 |
+
if train_stage in [0, 1]:
|
826 |
+
x = self.ar_text_embedding(text)
|
827 |
+
x = self.ar_text_prenet(x)
|
828 |
+
x = self.ar_text_position(x)
|
829 |
+
|
830 |
+
y_len = y_lens.max() + int(self.ar_audio_prepend_bos)
|
831 |
+
|
832 |
+
x_attn_mask = F.pad(
|
833 |
+
torch.zeros((x_len, x_len), dtype=torch.bool, device=x.device),
|
834 |
+
(0, y_len),
|
835 |
+
value=True,
|
836 |
+
)
|
837 |
+
y_attn_mask = F.pad(
|
838 |
+
torch.triu(
|
839 |
+
torch.ones(y_len, y_len, dtype=torch.bool, device=x.device),
|
840 |
+
diagonal=1,
|
841 |
+
),
|
842 |
+
(x_len, 0),
|
843 |
+
value=False,
|
844 |
+
)
|
845 |
+
xy_attn_mask = torch.concat([x_attn_mask, y_attn_mask], dim=0)
|
846 |
+
|
847 |
+
# merge key padding and attention masks
|
848 |
+
bsz, src_len = x.shape[0], x_len + y_len
|
849 |
+
_xy_padding_mask = (
|
850 |
+
ar_xy_padding_mask.view(bsz, 1, 1, src_len)
|
851 |
+
.expand(-1, self.num_heads, -1, -1)
|
852 |
+
.reshape(bsz * self.num_heads, 1, src_len)
|
853 |
+
)
|
854 |
+
xy_attn_mask = xy_attn_mask.logical_or(_xy_padding_mask)
|
855 |
+
|
856 |
+
new_attn_mask = torch.zeros_like(xy_attn_mask, dtype=x.dtype)
|
857 |
+
new_attn_mask.masked_fill_(xy_attn_mask, float("-inf"))
|
858 |
+
xy_attn_mask = new_attn_mask
|
859 |
+
|
860 |
+
y_emb = self.ar_audio_embedding(y)
|
861 |
+
y_emb = self.ar_audio_prenet(y_emb)
|
862 |
+
y_pos = self.ar_audio_position(y_emb)
|
863 |
+
|
864 |
+
xy_pos = torch.concat([x, y_pos], dim=1)
|
865 |
+
|
866 |
+
xy_dec, _ = self.ar_decoder(
|
867 |
+
(xy_pos, None),
|
868 |
+
mask=xy_attn_mask,
|
869 |
+
# src_key_padding_mask=xy_padding_mask,
|
870 |
+
# is_causal=True,
|
871 |
+
)
|
872 |
+
logits = self.ar_predict_layer(xy_dec[:, x_len:]).permute(0, 2, 1)
|
873 |
+
# loss
|
874 |
+
total_loss = F.cross_entropy(logits, targets, reduction=reduction)
|
875 |
+
|
876 |
+
metrics["ArTop10Accuracy"] = self.ar_accuracy_metric(
|
877 |
+
logits.detach(), targets
|
878 |
+
).item() * y_lens.sum().type(torch.float32)
|
879 |
+
|
880 |
+
if self.num_quantizers == 1:
|
881 |
+
return ((x, codes), total_loss, metrics)
|
882 |
+
|
883 |
+
# Non-AR Decoders
|
884 |
+
if self.ar_audio_prepend_bos:
|
885 |
+
y = y[:, 1:]
|
886 |
+
if train_stage in [0, 2]:
|
887 |
+
num_nar_layers = self.num_quantizers - 1
|
888 |
+
nar_stage = self.rng.choices(
|
889 |
+
[_k for _k in range(1, self.num_quantizers)],
|
890 |
+
weights=[1.0 / num_nar_layers] * num_nar_layers,
|
891 |
+
k=1,
|
892 |
+
)[0]
|
893 |
+
|
894 |
+
x = self.nar_text_embedding(text)
|
895 |
+
x = self.nar_text_prenet(x)
|
896 |
+
x = self.nar_text_position(x)
|
897 |
+
|
898 |
+
y_emb, prefix_len = self._prepare_prompts(
|
899 |
+
y, y_lens, codes, nar_stage, y_prompts_codes, self.prefix_mode
|
900 |
+
)
|
901 |
+
|
902 |
+
y_len = y_lens.max()
|
903 |
+
targets = codes[..., nar_stage] + NUM_AUDIO_TOKENS * y_mask_int
|
904 |
+
if self.prefix_mode in [2, 4]:
|
905 |
+
xy_padding_mask = torch.concat(
|
906 |
+
[
|
907 |
+
x_mask,
|
908 |
+
F.pad(y_mask, (y_emb.shape[1] - y_len, 0), value=False),
|
909 |
+
],
|
910 |
+
dim=1,
|
911 |
+
)
|
912 |
+
elif self.prefix_mode == 1:
|
913 |
+
targets = targets[:, prefix_len:]
|
914 |
+
|
915 |
+
y_pos = self.nar_audio_prenet(y_emb)
|
916 |
+
y_pos = self.nar_audio_position(y_pos)
|
917 |
+
xy_pos = torch.concat([x, y_pos], dim=1)
|
918 |
+
xy_dec, _ = self.nar_decoder(
|
919 |
+
(xy_pos, self.nar_stage_embeddings[nar_stage - 1].weight),
|
920 |
+
src_key_padding_mask=xy_padding_mask,
|
921 |
+
# is_causal=False,
|
922 |
+
)
|
923 |
+
xy_dec = xy_dec[:, x_lens.max() + prefix_len :]
|
924 |
+
if self.prefix_mode == 4:
|
925 |
+
prefix_len = 0 # reset for Top10Accuracy metric
|
926 |
+
logits = self.nar_predict_layers[nar_stage - 1](xy_dec).permute(
|
927 |
+
0, 2, 1
|
928 |
+
)
|
929 |
+
|
930 |
+
# loss
|
931 |
+
total_length = (y_lens).sum().type(torch.float32)
|
932 |
+
total_loss += (
|
933 |
+
F.cross_entropy(
|
934 |
+
logits,
|
935 |
+
targets,
|
936 |
+
ignore_index=NUM_AUDIO_TOKENS,
|
937 |
+
reduction=reduction,
|
938 |
+
)
|
939 |
+
* (total_length / (total_length - prefix_len * x.shape[0]))
|
940 |
+
)
|
941 |
+
metrics["NarTop10Accuracy"] = (
|
942 |
+
self.nar_accuracy_metric(
|
943 |
+
F.pad(
|
944 |
+
logits.detach(),
|
945 |
+
(0, 0, 0, 1, 0, 0),
|
946 |
+
value=logits.min().cpu().item(),
|
947 |
+
),
|
948 |
+
targets,
|
949 |
+
).item()
|
950 |
+
* total_length
|
951 |
+
)
|
952 |
+
|
953 |
+
if train_stage == 0:
|
954 |
+
total_loss = total_loss / 2.0
|
955 |
+
|
956 |
+
return ((x, codes), total_loss, metrics)
|
957 |
+
|
958 |
+
def inference(
|
959 |
+
self,
|
960 |
+
x: torch.Tensor,
|
961 |
+
x_lens: torch.Tensor,
|
962 |
+
y: torch.Tensor,
|
963 |
+
enroll_x_lens: torch.Tensor,
|
964 |
+
top_k: int = -100,
|
965 |
+
temperature: float = 1.0,
|
966 |
+
prompt_language: str = None,
|
967 |
+
text_language: str = None,
|
968 |
+
best_of: int = 1,
|
969 |
+
length_penalty: float = 1.0,
|
970 |
+
return_worst: bool = False,
|
971 |
+
) -> torch.Tensor:
|
972 |
+
"""
|
973 |
+
Args:
|
974 |
+
x:
|
975 |
+
A 2-D tensor of shape (1, S).
|
976 |
+
x_lens:
|
977 |
+
A 1-D tensor of shape (1,). It contains the number of tokens in `x`
|
978 |
+
before padding.
|
979 |
+
y:
|
980 |
+
A 3-D tensor of shape (1, T, 8).
|
981 |
+
top_k: (`optional`) int
|
982 |
+
The number of highest probability tokens to keep for top-k-filtering. Default to -100.
|
983 |
+
temperature: (`optional`) float
|
984 |
+
The value used to module the next token probabilities. Must be strictly positive. Default to 1.0.
|
985 |
+
Returns:
|
986 |
+
Return the predicted audio code matrix.
|
987 |
+
"""
|
988 |
+
assert x.ndim == 2, x.shape
|
989 |
+
assert x_lens.ndim == 1, x_lens.shape
|
990 |
+
assert y.ndim == 3, y.shape
|
991 |
+
assert y.shape[0] == 1, y.shape
|
992 |
+
|
993 |
+
assert torch.all(x_lens > 0)
|
994 |
+
|
995 |
+
# NOTE: x has been padded in TextTokenCollater
|
996 |
+
text = x
|
997 |
+
x = self.ar_text_embedding(text)
|
998 |
+
# Add language embedding
|
999 |
+
prompt_language_id = torch.LongTensor(np.array([self.language_ID[prompt_language]])).to(x.device)
|
1000 |
+
if isinstance(text_language, str):
|
1001 |
+
text_language_id = torch.LongTensor(np.array([self.language_ID[text_language]])).to(x.device)
|
1002 |
+
elif isinstance(text_language, List):
|
1003 |
+
text_language_id = torch.LongTensor(np.array([self.language_ID[tl] for tl in text_language])).to(x.device)
|
1004 |
+
x[:, :enroll_x_lens, :] += self.ar_language_embedding(prompt_language_id)
|
1005 |
+
x[:, enroll_x_lens:, :] += self.ar_language_embedding(text_language_id)
|
1006 |
+
x = self.ar_text_prenet(x)
|
1007 |
+
x = self.ar_text_position(x)
|
1008 |
+
|
1009 |
+
text_len = x_lens.max()
|
1010 |
+
prompts = y
|
1011 |
+
prefix_len = y.shape[1]
|
1012 |
+
|
1013 |
+
# AR Decoder
|
1014 |
+
# TODO: Managing decoder steps avoid repetitive computation
|
1015 |
+
y = prompts[..., 0]
|
1016 |
+
if self.ar_audio_prepend_bos:
|
1017 |
+
y = F.pad(y, (1, 0), value=NUM_AUDIO_TOKENS + 1)
|
1018 |
+
|
1019 |
+
x_len = x_lens.max()
|
1020 |
+
x_attn_mask = torch.zeros((x_len, x_len), dtype=torch.bool)
|
1021 |
+
|
1022 |
+
kv_cache = None
|
1023 |
+
use_kv_caching = True
|
1024 |
+
|
1025 |
+
sum_logprobs = torch.zeros(best_of, device=y.device) # implement batch decoding here
|
1026 |
+
x = x.repeat(best_of, 1, 1)
|
1027 |
+
y = y.repeat(best_of, 1)
|
1028 |
+
while True:
|
1029 |
+
y_emb = self.ar_audio_embedding(y)
|
1030 |
+
y_emb = self.ar_audio_prenet(y_emb)
|
1031 |
+
y_pos = self.ar_audio_position(y_emb)
|
1032 |
+
xy_pos = torch.concat([x, y_pos], dim=1)
|
1033 |
+
|
1034 |
+
y_len = y.shape[1]
|
1035 |
+
x_attn_mask_pad = F.pad(
|
1036 |
+
x_attn_mask,
|
1037 |
+
(0, y_len),
|
1038 |
+
value=True,
|
1039 |
+
)
|
1040 |
+
y_attn_mask = F.pad(
|
1041 |
+
torch.triu(
|
1042 |
+
torch.ones(y_len, y_len, dtype=torch.bool), diagonal=1
|
1043 |
+
),
|
1044 |
+
(x_len, 0),
|
1045 |
+
value=False,
|
1046 |
+
)
|
1047 |
+
xy_attn_mask = torch.concat(
|
1048 |
+
[x_attn_mask_pad, y_attn_mask], dim=0
|
1049 |
+
).to(y.device)
|
1050 |
+
|
1051 |
+
|
1052 |
+
if use_kv_caching and kv_cache is not None:
|
1053 |
+
xy_pos = xy_pos[:, [-1]]
|
1054 |
+
else:
|
1055 |
+
pass
|
1056 |
+
|
1057 |
+
xy_dec, kv_cache = self.ar_decoder.infer(
|
1058 |
+
xy_pos,
|
1059 |
+
mask=xy_attn_mask,
|
1060 |
+
past_kv=kv_cache,
|
1061 |
+
use_cache=use_kv_caching,
|
1062 |
+
)
|
1063 |
+
# xy_dec, _ = self.ar_decoder(
|
1064 |
+
# (xy_pos, None),
|
1065 |
+
# mask=xy_attn_mask,
|
1066 |
+
# )
|
1067 |
+
|
1068 |
+
logits = self.ar_predict_layer(xy_dec[:, -1])
|
1069 |
+
samples, current_logprobs = topk_sampling(
|
1070 |
+
logits, top_k=top_k, top_p=1, temperature=temperature
|
1071 |
+
)
|
1072 |
+
sum_logprobs += current_logprobs * (y[:, -1] != NUM_AUDIO_TOKENS)
|
1073 |
+
samples[y[:, -1] == NUM_AUDIO_TOKENS] = NUM_AUDIO_TOKENS
|
1074 |
+
completed = (samples[:, -1] == NUM_AUDIO_TOKENS).all()
|
1075 |
+
if (
|
1076 |
+
completed
|
1077 |
+
or (y.shape[1] - prompts.shape[1]) > x_lens.max() * 16
|
1078 |
+
):
|
1079 |
+
if prompts.shape[1] == y.shape[1]:
|
1080 |
+
raise SyntaxError(
|
1081 |
+
"well trained model shouldn't reach here."
|
1082 |
+
)
|
1083 |
+
lengths = torch.sum(y != NUM_AUDIO_TOKENS, dim=1)
|
1084 |
+
avg_logprobs = sum_logprobs / lengths ** length_penalty
|
1085 |
+
# choose the best beam according to sum_logprobs
|
1086 |
+
best_beam = y[torch.argmax(avg_logprobs), :]
|
1087 |
+
worst_beam = y[torch.argmin(avg_logprobs), :]
|
1088 |
+
# strip all eos tokens
|
1089 |
+
best_beam = best_beam[best_beam != NUM_AUDIO_TOKENS]
|
1090 |
+
worst_beam = worst_beam[worst_beam != NUM_AUDIO_TOKENS]
|
1091 |
+
if return_worst:
|
1092 |
+
y = worst_beam.unsqueeze(0)
|
1093 |
+
else:
|
1094 |
+
y = best_beam.unsqueeze(0)
|
1095 |
+
print(f"VALL-E EOS [{prompts.shape[1]} -> {y.shape[1]}]")
|
1096 |
+
break
|
1097 |
+
|
1098 |
+
y = torch.concat([y, samples], dim=1)
|
1099 |
+
|
1100 |
+
codes = [y[:, prefix_len + int(self.ar_audio_prepend_bos) :]]
|
1101 |
+
if self.num_quantizers == 1:
|
1102 |
+
return torch.stack(codes, dim=-1)
|
1103 |
+
|
1104 |
+
# Non-AR Decoders
|
1105 |
+
y_emb = self.nar_audio_embeddings[0](
|
1106 |
+
y[:, int(self.ar_audio_prepend_bos) :]
|
1107 |
+
)
|
1108 |
+
|
1109 |
+
if self.prefix_mode in [2, 4]: # Exclude enrolled_phonemes
|
1110 |
+
enrolled_len = enroll_x_lens.max().item()
|
1111 |
+
# SOS + Synthesis Text + EOS
|
1112 |
+
text = torch.concat(
|
1113 |
+
[
|
1114 |
+
text[:, :1],
|
1115 |
+
text[:, enrolled_len - 1 :],
|
1116 |
+
],
|
1117 |
+
dim=1,
|
1118 |
+
)
|
1119 |
+
text_len = text_len - (enrolled_len - 2)
|
1120 |
+
assert text.shape[0] == 1
|
1121 |
+
|
1122 |
+
x = self.nar_text_embedding(text)
|
1123 |
+
# Add language embedding
|
1124 |
+
prompt_language_id = torch.LongTensor(np.array([self.language_ID[prompt_language]])).to(x.device)
|
1125 |
+
if isinstance(text_language, str):
|
1126 |
+
text_language_id = torch.LongTensor(np.array([self.language_ID[text_language]])).to(x.device)
|
1127 |
+
elif isinstance(text_language, List):
|
1128 |
+
text_language_id = torch.LongTensor(np.array([self.language_ID[tl] for tl in text_language])).to(x.device)
|
1129 |
+
x[:, :enroll_x_lens, :] += self.nar_language_embedding(prompt_language_id)
|
1130 |
+
x[:, enroll_x_lens:, :] += self.nar_language_embedding(text_language_id)
|
1131 |
+
x = self.nar_text_prenet(x)
|
1132 |
+
x = self.nar_text_position(x)
|
1133 |
+
|
1134 |
+
if self.prefix_mode == 0:
|
1135 |
+
for i, (predict_layer, embedding_layer) in enumerate(
|
1136 |
+
zip(
|
1137 |
+
self.nar_predict_layers,
|
1138 |
+
self.nar_audio_embeddings[1:],
|
1139 |
+
)
|
1140 |
+
):
|
1141 |
+
y_pos = self.nar_audio_prenet(y_emb)
|
1142 |
+
y_pos = self.nar_audio_position(y_pos)
|
1143 |
+
xy_pos = torch.concat([x, y_pos], dim=1)
|
1144 |
+
|
1145 |
+
xy_dec, _ = self.nar_decoder(
|
1146 |
+
(xy_pos, self.nar_stage_embeddings[i].weight)
|
1147 |
+
)
|
1148 |
+
logits = predict_layer(xy_dec[:, text_len + prefix_len :])
|
1149 |
+
|
1150 |
+
samples = torch.argmax(logits, dim=-1)
|
1151 |
+
codes.append(samples)
|
1152 |
+
|
1153 |
+
if i < self.num_quantizers - 2:
|
1154 |
+
y_emb[:, :prefix_len] += embedding_layer(
|
1155 |
+
prompts[..., i + 1]
|
1156 |
+
)
|
1157 |
+
y_emb[:, prefix_len:] += embedding_layer(samples)
|
1158 |
+
else:
|
1159 |
+
for j in range(1, self.num_quantizers):
|
1160 |
+
y_emb[:, :prefix_len] += self.nar_audio_embeddings[j](
|
1161 |
+
prompts[..., j]
|
1162 |
+
)
|
1163 |
+
|
1164 |
+
for i, (predict_layer, embedding_layer) in enumerate(
|
1165 |
+
zip(
|
1166 |
+
self.nar_predict_layers,
|
1167 |
+
self.nar_audio_embeddings[1:],
|
1168 |
+
)
|
1169 |
+
):
|
1170 |
+
y_pos = self.nar_audio_prenet(y_emb)
|
1171 |
+
y_pos = self.nar_audio_position(y_pos)
|
1172 |
+
xy_pos = torch.concat([x, y_pos], dim=1)
|
1173 |
+
|
1174 |
+
xy_dec, _ = self.nar_decoder(
|
1175 |
+
(xy_pos, self.nar_stage_embeddings[i].weight)
|
1176 |
+
)
|
1177 |
+
logits = predict_layer(xy_dec[:, text_len + prefix_len :])
|
1178 |
+
|
1179 |
+
samples = torch.argmax(logits, dim=-1)
|
1180 |
+
codes.append(samples)
|
1181 |
+
|
1182 |
+
if i < self.num_quantizers - 2:
|
1183 |
+
y_emb[:, prefix_len:] += embedding_layer(samples)
|
1184 |
+
|
1185 |
+
assert len(codes) == self.num_quantizers
|
1186 |
+
return torch.stack(codes, dim=-1)
|
1187 |
+
|
1188 |
+
def continual(
|
1189 |
+
self,
|
1190 |
+
x: torch.Tensor,
|
1191 |
+
x_lens: torch.Tensor,
|
1192 |
+
y: torch.Tensor,
|
1193 |
+
) -> torch.Tensor:
|
1194 |
+
"""
|
1195 |
+
Args:
|
1196 |
+
x:
|
1197 |
+
A 2-D tensor of shape (1, S).
|
1198 |
+
x_lens:
|
1199 |
+
A 1-D tensor of shape (1,). It contains the number of tokens in `x`
|
1200 |
+
before padding.
|
1201 |
+
y:
|
1202 |
+
A 3-D tensor of shape (1, T, 8).
|
1203 |
+
Returns:
|
1204 |
+
Return the predicted audio code matrix.
|
1205 |
+
"""
|
1206 |
+
assert x.ndim == 2, x.shape
|
1207 |
+
assert x_lens.ndim == 1, x_lens.shape
|
1208 |
+
assert y.ndim == 3, y.shape
|
1209 |
+
assert y.shape[0] == 1, y.shape
|
1210 |
+
|
1211 |
+
assert torch.all(x_lens > 0)
|
1212 |
+
assert self.num_quantizers == 8
|
1213 |
+
|
1214 |
+
# NOTE: x has been padded in TextTokenCollater
|
1215 |
+
text = x
|
1216 |
+
x = self.ar_text_embedding(text)
|
1217 |
+
x = self.ar_text_prenet(x)
|
1218 |
+
x = self.ar_text_position(x)
|
1219 |
+
|
1220 |
+
text_len = x_lens.max()
|
1221 |
+
|
1222 |
+
prefix_len = min(int(y.shape[1] * 0.5), 3 * 75)
|
1223 |
+
|
1224 |
+
# AR Decoder
|
1225 |
+
prompts = y[:, :prefix_len]
|
1226 |
+
|
1227 |
+
codes = [y[:, prefix_len:, 0]]
|
1228 |
+
# Non-AR Decoders
|
1229 |
+
x = self.nar_text_embedding(text)
|
1230 |
+
x = self.nar_text_prenet(x)
|
1231 |
+
x = self.nar_text_position(x)
|
1232 |
+
|
1233 |
+
y_emb = self.nar_audio_embeddings[0](y[..., 0])
|
1234 |
+
|
1235 |
+
if self.prefix_mode == 0:
|
1236 |
+
for i, (predict_layer, embedding_layer) in enumerate(
|
1237 |
+
zip(
|
1238 |
+
self.nar_predict_layers,
|
1239 |
+
self.nar_audio_embeddings[1:],
|
1240 |
+
)
|
1241 |
+
):
|
1242 |
+
y_pos = self.nar_audio_position(y_emb)
|
1243 |
+
y_pos = self.nar_audio_prenet(y_pos)
|
1244 |
+
xy_pos = torch.concat([x, y_pos], dim=1)
|
1245 |
+
|
1246 |
+
xy_dec, _ = self.nar_decoder(
|
1247 |
+
(xy_pos, self.nar_stage_embeddings[i].weight)
|
1248 |
+
)
|
1249 |
+
logits = predict_layer(xy_dec[:, text_len + prefix_len :])
|
1250 |
+
|
1251 |
+
samples = torch.argmax(logits, dim=-1)
|
1252 |
+
codes.append(samples)
|
1253 |
+
|
1254 |
+
if i < 6:
|
1255 |
+
y_emb[:, :prefix_len] += embedding_layer(
|
1256 |
+
prompts[..., i + 1]
|
1257 |
+
)
|
1258 |
+
y_emb[:, prefix_len:] += embedding_layer(samples)
|
1259 |
+
else:
|
1260 |
+
for j in range(1, 8):
|
1261 |
+
y_emb[:, :prefix_len] += self.nar_audio_embeddings[j](
|
1262 |
+
prompts[..., j]
|
1263 |
+
)
|
1264 |
+
|
1265 |
+
for i, (predict_layer, embedding_layer) in enumerate(
|
1266 |
+
zip(
|
1267 |
+
self.nar_predict_layers,
|
1268 |
+
self.nar_audio_embeddings[1:],
|
1269 |
+
)
|
1270 |
+
):
|
1271 |
+
y_pos = self.nar_audio_prenet(y_emb)
|
1272 |
+
y_pos = self.nar_audio_position(y_pos)
|
1273 |
+
xy_pos = torch.concat([x, y_pos], dim=1)
|
1274 |
+
|
1275 |
+
xy_dec, _ = self.nar_decoder(
|
1276 |
+
(xy_pos, self.nar_stage_embeddings[i].weight)
|
1277 |
+
)
|
1278 |
+
logits = predict_layer(xy_dec[:, text_len + prefix_len :])
|
1279 |
+
|
1280 |
+
samples = torch.argmax(logits, dim=-1)
|
1281 |
+
codes.append(samples)
|
1282 |
+
|
1283 |
+
if i < 6:
|
1284 |
+
y_emb[:, prefix_len:] += embedding_layer(samples)
|
1285 |
+
|
1286 |
+
assert len(codes) == 8
|
1287 |
+
return torch.stack(codes, dim=-1)
|
1288 |
+
|
1289 |
+
|
1290 |
+
# https://github.com/microsoft/unilm/blob/master/xtune/src/transformers/modeling_utils.py
|
1291 |
+
def top_k_top_p_filtering(
|
1292 |
+
logits, top_k=0, top_p=1.0, filter_value=-float("Inf"), min_tokens_to_keep=1
|
1293 |
+
):
|
1294 |
+
"""Filter a distribution of logits using top-k and/or nucleus (top-p) filtering
|
1295 |
+
Args:
|
1296 |
+
logits: logits distribution shape (batch size, vocabulary size)
|
1297 |
+
if top_k > 0: keep only top k tokens with highest probability (top-k filtering).
|
1298 |
+
if top_p < 1.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering).
|
1299 |
+
Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751)
|
1300 |
+
Make sure we keep at least min_tokens_to_keep per batch example in the output
|
1301 |
+
From: https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317
|
1302 |
+
"""
|
1303 |
+
if top_k > 0:
|
1304 |
+
top_k = min(
|
1305 |
+
max(top_k, min_tokens_to_keep), logits.size(-1)
|
1306 |
+
) # Safety check
|
1307 |
+
# Remove all tokens with a probability less than the last token of the top-k
|
1308 |
+
indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
|
1309 |
+
logits[indices_to_remove] = filter_value
|
1310 |
+
|
1311 |
+
if top_p < 1.0:
|
1312 |
+
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
|
1313 |
+
cumulative_probs = torch.cumsum(
|
1314 |
+
F.softmax(sorted_logits, dim=-1), dim=-1
|
1315 |
+
)
|
1316 |
+
|
1317 |
+
# Remove tokens with cumulative probability above the threshold (token with 0 are kept)
|
1318 |
+
sorted_indices_to_remove = cumulative_probs > top_p
|
1319 |
+
if min_tokens_to_keep > 1:
|
1320 |
+
# Keep at least min_tokens_to_keep (set to min_tokens_to_keep-1 because we add the first one below)
|
1321 |
+
sorted_indices_to_remove[..., :min_tokens_to_keep] = 0
|
1322 |
+
# Shift the indices to the right to keep also the first token above the threshold
|
1323 |
+
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[
|
1324 |
+
..., :-1
|
1325 |
+
].clone()
|
1326 |
+
sorted_indices_to_remove[..., 0] = 0
|
1327 |
+
|
1328 |
+
# scatter sorted tensors to original indexing
|
1329 |
+
indices_to_remove = sorted_indices_to_remove.scatter(
|
1330 |
+
1, sorted_indices, sorted_indices_to_remove
|
1331 |
+
)
|
1332 |
+
logits[indices_to_remove] = filter_value
|
1333 |
+
return logits
|
1334 |
+
|
1335 |
+
|
1336 |
+
def topk_sampling(logits, top_k=10, top_p=1.0, temperature=1.0):
|
1337 |
+
# temperature: (`optional`) float
|
1338 |
+
# The value used to module the next token probabilities. Must be strictly positive. Default to 1.0.
|
1339 |
+
# top_k: (`optional`) int
|
1340 |
+
# The number of highest probability vocabulary tokens to keep for top-k-filtering. Between 1 and infinity. Default to 50.
|
1341 |
+
# top_p: (`optional`) float
|
1342 |
+
# The cumulative probability of parameter highest probability vocabulary tokens to keep for nucleus sampling. Must be between 0 and 1. Default to 1.
|
1343 |
+
|
1344 |
+
# Temperature (higher temperature => more likely to sample low probability tokens)
|
1345 |
+
if temperature != 1.0:
|
1346 |
+
logits = logits / temperature
|
1347 |
+
# Top-p/top-k filtering
|
1348 |
+
logits = top_k_top_p_filtering(logits, top_k=top_k, top_p=top_p)
|
1349 |
+
# Sample
|
1350 |
+
token = torch.multinomial(F.softmax(logits, dim=-1), num_samples=1)
|
1351 |
+
logprobs = F.log_softmax(logits.float(), dim=-1)
|
1352 |
+
current_logprobs = logprobs[torch.arange(logprobs.shape[0]), token.squeeze(1)]
|
1353 |
+
return token, current_logprobs
|
models/visualizer.py
ADDED
@@ -0,0 +1,106 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
# Copyright 2023 (authors: Feiteng Li)
|
3 |
+
#
|
4 |
+
# See ../../../../LICENSE for clarification regarding multiple authors
|
5 |
+
#
|
6 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
7 |
+
# you may not use this file except in compliance with the License.
|
8 |
+
# You may obtain a copy of the License at
|
9 |
+
#
|
10 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
11 |
+
#
|
12 |
+
# Unless required by applicable law or agreed to in writing, software
|
13 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
14 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
15 |
+
# See the License for the specific language governing permissions and
|
16 |
+
# limitations under the License.
|
17 |
+
|
18 |
+
|
19 |
+
from typing import Dict, List, Tuple, Union
|
20 |
+
|
21 |
+
import matplotlib.pyplot as plt
|
22 |
+
import numpy as np
|
23 |
+
import torch
|
24 |
+
|
25 |
+
|
26 |
+
def visualize(
|
27 |
+
predicts: Tuple[torch.Tensor],
|
28 |
+
batch: Dict[str, Union[List, torch.Tensor]],
|
29 |
+
output_dir: str,
|
30 |
+
limit: int = 4,
|
31 |
+
) -> None:
|
32 |
+
text_tokens = batch["text_tokens"].to("cpu").detach().numpy()
|
33 |
+
text_tokens_lens = batch["text_tokens_lens"].to("cpu").detach().numpy()
|
34 |
+
audio_features = batch["audio_features"].to("cpu").detach().numpy()
|
35 |
+
audio_features_lens = (
|
36 |
+
batch["audio_features_lens"].to("cpu").detach().numpy()
|
37 |
+
)
|
38 |
+
assert text_tokens.ndim == 2
|
39 |
+
|
40 |
+
utt_ids, texts = batch["utt_id"], batch["text"]
|
41 |
+
|
42 |
+
encoder_outputs = predicts[0].to("cpu").type(torch.float32).detach().numpy()
|
43 |
+
decoder_outputs = predicts[1]
|
44 |
+
if isinstance(decoder_outputs, list):
|
45 |
+
decoder_outputs = decoder_outputs[-1]
|
46 |
+
decoder_outputs = (
|
47 |
+
decoder_outputs.to("cpu").type(torch.float32).detach().numpy()
|
48 |
+
)
|
49 |
+
|
50 |
+
vmin, vmax = 0, 1024 # Encodec
|
51 |
+
if decoder_outputs.dtype == np.float32:
|
52 |
+
vmin, vmax = -6, 0 # Fbank
|
53 |
+
|
54 |
+
num_figures = 3
|
55 |
+
for b, (utt_id, text) in enumerate(zip(utt_ids[:limit], texts[:limit])):
|
56 |
+
_ = plt.figure(figsize=(14, 8 * num_figures))
|
57 |
+
|
58 |
+
S = text_tokens_lens[b]
|
59 |
+
T = audio_features_lens[b]
|
60 |
+
|
61 |
+
# encoder
|
62 |
+
plt.subplot(num_figures, 1, 1)
|
63 |
+
plt.title(f"Text: {text}")
|
64 |
+
plt.imshow(
|
65 |
+
X=np.transpose(encoder_outputs[b]),
|
66 |
+
cmap=plt.get_cmap("jet"),
|
67 |
+
aspect="auto",
|
68 |
+
interpolation="nearest",
|
69 |
+
)
|
70 |
+
plt.gca().invert_yaxis()
|
71 |
+
plt.axvline(x=S - 0.4, linewidth=2, color="r")
|
72 |
+
plt.xlabel("Encoder Output")
|
73 |
+
plt.colorbar()
|
74 |
+
|
75 |
+
# decoder
|
76 |
+
plt.subplot(num_figures, 1, 2)
|
77 |
+
plt.imshow(
|
78 |
+
X=np.transpose(decoder_outputs[b]),
|
79 |
+
cmap=plt.get_cmap("jet"),
|
80 |
+
aspect="auto",
|
81 |
+
interpolation="nearest",
|
82 |
+
vmin=vmin,
|
83 |
+
vmax=vmax,
|
84 |
+
)
|
85 |
+
plt.gca().invert_yaxis()
|
86 |
+
plt.axvline(x=T - 0.4, linewidth=2, color="r")
|
87 |
+
plt.xlabel("Decoder Output")
|
88 |
+
plt.colorbar()
|
89 |
+
|
90 |
+
# target
|
91 |
+
plt.subplot(num_figures, 1, 3)
|
92 |
+
plt.imshow(
|
93 |
+
X=np.transpose(audio_features[b]),
|
94 |
+
cmap=plt.get_cmap("jet"),
|
95 |
+
aspect="auto",
|
96 |
+
interpolation="nearest",
|
97 |
+
vmin=vmin,
|
98 |
+
vmax=vmax,
|
99 |
+
)
|
100 |
+
plt.gca().invert_yaxis()
|
101 |
+
plt.axvline(x=T - 0.4, linewidth=2, color="r")
|
102 |
+
plt.xlabel("Decoder Target")
|
103 |
+
plt.colorbar()
|
104 |
+
|
105 |
+
plt.savefig(f"{output_dir}/{utt_id}.png")
|
106 |
+
plt.close()
|
modules/__init__.py
ADDED
File without changes
|
modules/activation.py
ADDED
@@ -0,0 +1,612 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Optional, Tuple, List
|
2 |
+
import math
|
3 |
+
|
4 |
+
import torch
|
5 |
+
from torch import Tensor
|
6 |
+
from torch.nn import Linear, Module
|
7 |
+
from torch.nn import functional as F
|
8 |
+
from torch.nn.init import constant_, xavier_normal_, xavier_uniform_
|
9 |
+
from torch.nn.modules.linear import NonDynamicallyQuantizableLinear
|
10 |
+
from torch.nn.parameter import Parameter
|
11 |
+
|
12 |
+
def _in_projection_packed(
|
13 |
+
q: Tensor,
|
14 |
+
k: Tensor,
|
15 |
+
v: Tensor,
|
16 |
+
w: Tensor,
|
17 |
+
b: Optional[Tensor] = None,
|
18 |
+
) -> List[Tensor]:
|
19 |
+
r"""
|
20 |
+
Performs the in-projection step of the attention operation, using packed weights.
|
21 |
+
Output is a triple containing projection tensors for query, key and value.
|
22 |
+
|
23 |
+
Args:
|
24 |
+
q, k, v: query, key and value tensors to be projected. For self-attention,
|
25 |
+
these are typically the same tensor; for encoder-decoder attention,
|
26 |
+
k and v are typically the same tensor. (We take advantage of these
|
27 |
+
identities for performance if they are present.) Regardless, q, k and v
|
28 |
+
must share a common embedding dimension; otherwise their shapes may vary.
|
29 |
+
w: projection weights for q, k and v, packed into a single tensor. Weights
|
30 |
+
are packed along dimension 0, in q, k, v order.
|
31 |
+
b: optional projection biases for q, k and v, packed into a single tensor
|
32 |
+
in q, k, v order.
|
33 |
+
|
34 |
+
Shape:
|
35 |
+
Inputs:
|
36 |
+
- q: :math:`(..., E)` where E is the embedding dimension
|
37 |
+
- k: :math:`(..., E)` where E is the embedding dimension
|
38 |
+
- v: :math:`(..., E)` where E is the embedding dimension
|
39 |
+
- w: :math:`(E * 3, E)` where E is the embedding dimension
|
40 |
+
- b: :math:`E * 3` where E is the embedding dimension
|
41 |
+
|
42 |
+
Output:
|
43 |
+
- in output list :math:`[q', k', v']`, each output tensor will have the
|
44 |
+
same shape as the corresponding input tensor.
|
45 |
+
"""
|
46 |
+
E = q.size(-1)
|
47 |
+
if k is v:
|
48 |
+
if q is k:
|
49 |
+
# self-attention
|
50 |
+
return F.linear(q, w, b).chunk(3, dim=-1)
|
51 |
+
else:
|
52 |
+
# encoder-decoder attention
|
53 |
+
w_q, w_kv = w.split([E, E * 2])
|
54 |
+
if b is None:
|
55 |
+
b_q = b_kv = None
|
56 |
+
else:
|
57 |
+
b_q, b_kv = b.split([E, E * 2])
|
58 |
+
return (F.linear(q, w_q, b_q),) + F.linear(k, w_kv, b_kv).chunk(2, dim=-1)
|
59 |
+
else:
|
60 |
+
w_q, w_k, w_v = w.chunk(3)
|
61 |
+
if b is None:
|
62 |
+
b_q = b_k = b_v = None
|
63 |
+
else:
|
64 |
+
b_q, b_k, b_v = b.chunk(3)
|
65 |
+
return F.linear(q, w_q, b_q), F.linear(k, w_k, b_k), F.linear(v, w_v, b_v)
|
66 |
+
|
67 |
+
def _scaled_dot_product_attention(
|
68 |
+
q: Tensor,
|
69 |
+
k: Tensor,
|
70 |
+
v: Tensor,
|
71 |
+
attn_mask: Optional[Tensor] = None,
|
72 |
+
dropout_p: float = 0.0,
|
73 |
+
) -> Tuple[Tensor, Tensor]:
|
74 |
+
r"""
|
75 |
+
Computes scaled dot product attention on query, key and value tensors, using
|
76 |
+
an optional attention mask if passed, and applying dropout if a probability
|
77 |
+
greater than 0.0 is specified.
|
78 |
+
Returns a tensor pair containing attended values and attention weights.
|
79 |
+
|
80 |
+
Args:
|
81 |
+
q, k, v: query, key and value tensors. See Shape section for shape details.
|
82 |
+
attn_mask: optional tensor containing mask values to be added to calculated
|
83 |
+
attention. May be 2D or 3D; see Shape section for details.
|
84 |
+
dropout_p: dropout probability. If greater than 0.0, dropout is applied.
|
85 |
+
|
86 |
+
Shape:
|
87 |
+
- q: :math:`(B, Nt, E)` where B is batch size, Nt is the target sequence length,
|
88 |
+
and E is embedding dimension.
|
89 |
+
- key: :math:`(B, Ns, E)` where B is batch size, Ns is the source sequence length,
|
90 |
+
and E is embedding dimension.
|
91 |
+
- value: :math:`(B, Ns, E)` where B is batch size, Ns is the source sequence length,
|
92 |
+
and E is embedding dimension.
|
93 |
+
- attn_mask: either a 3D tensor of shape :math:`(B, Nt, Ns)` or a 2D tensor of
|
94 |
+
shape :math:`(Nt, Ns)`.
|
95 |
+
|
96 |
+
- Output: attention values have shape :math:`(B, Nt, E)`; attention weights
|
97 |
+
have shape :math:`(B, Nt, Ns)`
|
98 |
+
"""
|
99 |
+
B, Nt, E = q.shape
|
100 |
+
q = q / math.sqrt(E)
|
101 |
+
# (B, Nt, E) x (B, E, Ns) -> (B, Nt, Ns)
|
102 |
+
if attn_mask is not None:
|
103 |
+
attn = torch.baddbmm(attn_mask, q, k.transpose(-2, -1))
|
104 |
+
else:
|
105 |
+
attn = torch.bmm(q, k.transpose(-2, -1))
|
106 |
+
|
107 |
+
attn = F.softmax(attn, dim=-1)
|
108 |
+
if dropout_p > 0.0:
|
109 |
+
attn = F.dropout(attn, p=dropout_p)
|
110 |
+
# (B, Nt, Ns) x (B, Ns, E) -> (B, Nt, E)
|
111 |
+
output = torch.bmm(attn, v)
|
112 |
+
return output, attn
|
113 |
+
|
114 |
+
def multi_head_attention_forward(
|
115 |
+
x,
|
116 |
+
ipw,
|
117 |
+
ipb,
|
118 |
+
opw,
|
119 |
+
opb,
|
120 |
+
n_head,
|
121 |
+
attn_mask,
|
122 |
+
past_kv=None,
|
123 |
+
use_cache=False,
|
124 |
+
):
|
125 |
+
# x = x.transpose(1, 0)
|
126 |
+
# tgt_len, bsz, embed_dim = x.shape
|
127 |
+
# head_dim = embed_dim // n_head
|
128 |
+
# q, k, v = _in_projection_packed(x, x, x, ipw, ipb)
|
129 |
+
# q = q.contiguous().view(tgt_len, bsz * n_head, head_dim).transpose(0, 1)
|
130 |
+
# k = k.contiguous().view(k.shape[0], bsz * n_head, head_dim).transpose(0, 1)
|
131 |
+
# v = v.contiguous().view(v.shape[0], bsz * n_head, head_dim).transpose(0, 1)
|
132 |
+
|
133 |
+
# new_attn_mask = torch.zeros_like(attn_mask, dtype=q.dtype)
|
134 |
+
# new_attn_mask.masked_fill_(attn_mask, float("-inf"))
|
135 |
+
# attn_mask = new_attn_mask
|
136 |
+
#
|
137 |
+
# attn_output, attn_output_weights = _scaled_dot_product_attention(q, k, v, attn_mask, 0.0)
|
138 |
+
# attn_output = attn_output.transpose(0, 1).contiguous().view(tgt_len * bsz, embed_dim)
|
139 |
+
# attn_output = torch._C._nn.linear(attn_output, opw, opb)
|
140 |
+
# attn_output = attn_output.view(tgt_len, bsz, attn_output.size(1))
|
141 |
+
|
142 |
+
B, T, C = x.size()
|
143 |
+
|
144 |
+
q, k, v = torch._C._nn.linear(x, ipw, ipb).chunk(3, dim=-1)
|
145 |
+
k = k.view(B, T, n_head, C // n_head).transpose(1, 2) # (B, nh, T, hs)
|
146 |
+
q = q.view(B, T, n_head, C // n_head).transpose(1, 2) # (B, nh, T, hs)
|
147 |
+
v = v.view(B, T, n_head, C // n_head).transpose(1, 2) # (B, nh, T, hs)
|
148 |
+
if past_kv is not None:
|
149 |
+
past_key = past_kv[0]
|
150 |
+
past_value = past_kv[1]
|
151 |
+
k = torch.cat((past_key, k), dim=-2)
|
152 |
+
v = torch.cat((past_value, v), dim=-2)
|
153 |
+
|
154 |
+
FULL_T = k.shape[-2]
|
155 |
+
|
156 |
+
if use_cache is True:
|
157 |
+
present = (k, v)
|
158 |
+
else:
|
159 |
+
present = None
|
160 |
+
|
161 |
+
att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
|
162 |
+
att = att.masked_fill(attn_mask[FULL_T - T:FULL_T, :FULL_T], float('-inf'))
|
163 |
+
att = F.softmax(att, dim=-1)
|
164 |
+
y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
|
165 |
+
y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side
|
166 |
+
y = torch._C._nn.linear(y, opw, opb)
|
167 |
+
return (y, present)
|
168 |
+
|
169 |
+
|
170 |
+
class MultiheadAttention(Module):
|
171 |
+
r"""Allows the model to jointly attend to information
|
172 |
+
from different representation subspaces as described in the paper:
|
173 |
+
`Attention Is All You Need <https://arxiv.org/abs/1706.03762>`_.
|
174 |
+
|
175 |
+
Multi-Head Attention is defined as:
|
176 |
+
|
177 |
+
.. math::
|
178 |
+
\text{MultiHead}(Q, K, V) = \text{Concat}(head_1,\dots,head_h)W^O
|
179 |
+
|
180 |
+
where :math:`head_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V)`.
|
181 |
+
|
182 |
+
``forward()`` will use a special optimized implementation if all of the following
|
183 |
+
conditions are met:
|
184 |
+
|
185 |
+
- self attention is being computed (i.e., ``query``, ``key``, and ``value`` are the same tensor. This
|
186 |
+
restriction will be loosened in the future.)
|
187 |
+
- Either autograd is disabled (using ``torch.inference_mode`` or ``torch.no_grad``) or no tensor argument ``requires_grad``
|
188 |
+
- training is disabled (using ``.eval()``)
|
189 |
+
- dropout is 0
|
190 |
+
- ``add_bias_kv`` is ``False``
|
191 |
+
- ``add_zero_attn`` is ``False``
|
192 |
+
- ``batch_first`` is ``True`` and the input is batched
|
193 |
+
- ``kdim`` and ``vdim`` are equal to ``embed_dim``
|
194 |
+
- at most one of ``key_padding_mask`` or ``attn_mask`` is passed
|
195 |
+
- if a `NestedTensor <https://pytorch.org/docs/stable/nested.html>`_ is passed, neither ``key_padding_mask``
|
196 |
+
nor ``attn_mask`` is passed
|
197 |
+
|
198 |
+
If the optimized implementation is in use, a
|
199 |
+
`NestedTensor <https://pytorch.org/docs/stable/nested.html>`_ can be passed for
|
200 |
+
``query``/``key``/``value`` to represent padding more efficiently than using a
|
201 |
+
padding mask. In this case, a `NestedTensor <https://pytorch.org/docs/stable/nested.html>`_
|
202 |
+
will be returned, and an additional speedup proportional to the fraction of the input
|
203 |
+
that is padding can be expected.
|
204 |
+
|
205 |
+
Args:
|
206 |
+
embed_dim: Total dimension of the model.
|
207 |
+
num_heads: Number of parallel attention heads. Note that ``embed_dim`` will be split
|
208 |
+
across ``num_heads`` (i.e. each head will have dimension ``embed_dim // num_heads``).
|
209 |
+
dropout: Dropout probability on ``attn_output_weights``. Default: ``0.0`` (no dropout).
|
210 |
+
bias: If specified, adds bias to input / output projection layers. Default: ``True``.
|
211 |
+
add_bias_kv: If specified, adds bias to the key and value sequences at dim=0. Default: ``False``.
|
212 |
+
add_zero_attn: If specified, adds a new batch of zeros to the key and value sequences at dim=1.
|
213 |
+
Default: ``False``.
|
214 |
+
kdim: Total number of features for keys. Default: ``None`` (uses ``kdim=embed_dim``).
|
215 |
+
vdim: Total number of features for values. Default: ``None`` (uses ``vdim=embed_dim``).
|
216 |
+
batch_first: If ``True``, then the input and output tensors are provided
|
217 |
+
as (batch, seq, feature). Default: ``False`` (seq, batch, feature).
|
218 |
+
|
219 |
+
Examples::
|
220 |
+
|
221 |
+
>>> # xdoctest: +SKIP
|
222 |
+
>>> multihead_attn = nn.MultiheadAttention(embed_dim, num_heads)
|
223 |
+
>>> attn_output, attn_output_weights = multihead_attn(query, key, value)
|
224 |
+
|
225 |
+
"""
|
226 |
+
__constants__ = ["batch_first"]
|
227 |
+
bias_k: Optional[torch.Tensor]
|
228 |
+
bias_v: Optional[torch.Tensor]
|
229 |
+
|
230 |
+
def __init__(
|
231 |
+
self,
|
232 |
+
embed_dim,
|
233 |
+
num_heads,
|
234 |
+
dropout=0.0,
|
235 |
+
bias=True,
|
236 |
+
add_bias_kv=False,
|
237 |
+
add_zero_attn=False,
|
238 |
+
kdim=None,
|
239 |
+
vdim=None,
|
240 |
+
batch_first=False,
|
241 |
+
linear1_cls=Linear,
|
242 |
+
linear2_cls=Linear,
|
243 |
+
device=None,
|
244 |
+
dtype=None,
|
245 |
+
) -> None:
|
246 |
+
factory_kwargs = {"device": device, "dtype": dtype}
|
247 |
+
super(MultiheadAttention, self).__init__()
|
248 |
+
self.embed_dim = embed_dim
|
249 |
+
self.kdim = kdim if kdim is not None else embed_dim
|
250 |
+
self.vdim = vdim if vdim is not None else embed_dim
|
251 |
+
self._qkv_same_embed_dim = (
|
252 |
+
self.kdim == embed_dim and self.vdim == embed_dim
|
253 |
+
)
|
254 |
+
|
255 |
+
self.num_heads = num_heads
|
256 |
+
self.dropout = dropout
|
257 |
+
self.batch_first = batch_first
|
258 |
+
self.head_dim = embed_dim // num_heads
|
259 |
+
assert (
|
260 |
+
self.head_dim * num_heads == self.embed_dim
|
261 |
+
), "embed_dim must be divisible by num_heads"
|
262 |
+
|
263 |
+
if add_bias_kv:
|
264 |
+
self.bias_k = Parameter(
|
265 |
+
torch.empty((1, 1, embed_dim), **factory_kwargs)
|
266 |
+
)
|
267 |
+
self.bias_v = Parameter(
|
268 |
+
torch.empty((1, 1, embed_dim), **factory_kwargs)
|
269 |
+
)
|
270 |
+
else:
|
271 |
+
self.bias_k = self.bias_v = None
|
272 |
+
|
273 |
+
if linear1_cls == Linear:
|
274 |
+
if not self._qkv_same_embed_dim:
|
275 |
+
self.q_proj_weight = Parameter(
|
276 |
+
torch.empty((embed_dim, embed_dim), **factory_kwargs)
|
277 |
+
)
|
278 |
+
self.k_proj_weight = Parameter(
|
279 |
+
torch.empty((embed_dim, self.kdim), **factory_kwargs)
|
280 |
+
)
|
281 |
+
self.v_proj_weight = Parameter(
|
282 |
+
torch.empty((embed_dim, self.vdim), **factory_kwargs)
|
283 |
+
)
|
284 |
+
self.register_parameter("in_proj_weight", None)
|
285 |
+
else:
|
286 |
+
self.in_proj_weight = Parameter(
|
287 |
+
torch.empty((3 * embed_dim, embed_dim), **factory_kwargs)
|
288 |
+
)
|
289 |
+
self.register_parameter("q_proj_weight", None)
|
290 |
+
self.register_parameter("k_proj_weight", None)
|
291 |
+
self.register_parameter("v_proj_weight", None)
|
292 |
+
|
293 |
+
if bias:
|
294 |
+
self.in_proj_bias = Parameter(
|
295 |
+
torch.empty(3 * embed_dim, **factory_kwargs)
|
296 |
+
)
|
297 |
+
else:
|
298 |
+
self.register_parameter("in_proj_bias", None)
|
299 |
+
self.out_proj = NonDynamicallyQuantizableLinear(
|
300 |
+
embed_dim, embed_dim, bias=bias, **factory_kwargs
|
301 |
+
)
|
302 |
+
|
303 |
+
self._reset_parameters()
|
304 |
+
else:
|
305 |
+
if not self._qkv_same_embed_dim:
|
306 |
+
raise NotImplementedError
|
307 |
+
else:
|
308 |
+
self.in_proj_linear = linear1_cls(
|
309 |
+
embed_dim, 3 * embed_dim, bias=bias, **factory_kwargs
|
310 |
+
)
|
311 |
+
self.in_proj_weight = self.in_proj_linear.weight
|
312 |
+
|
313 |
+
self.register_parameter("q_proj_weight", None)
|
314 |
+
self.register_parameter("k_proj_weight", None)
|
315 |
+
self.register_parameter("v_proj_weight", None)
|
316 |
+
|
317 |
+
if bias:
|
318 |
+
self.in_proj_bias = self.in_proj_linear.bias
|
319 |
+
else:
|
320 |
+
self.register_parameter("in_proj_bias", None)
|
321 |
+
|
322 |
+
self.out_proj = linear2_cls(
|
323 |
+
embed_dim, embed_dim, bias=bias, **factory_kwargs
|
324 |
+
)
|
325 |
+
|
326 |
+
if self.bias_k is not None:
|
327 |
+
xavier_normal_(self.bias_k)
|
328 |
+
if self.bias_v is not None:
|
329 |
+
xavier_normal_(self.bias_v)
|
330 |
+
|
331 |
+
self.add_zero_attn = add_zero_attn
|
332 |
+
|
333 |
+
def _reset_parameters(self):
|
334 |
+
if self._qkv_same_embed_dim:
|
335 |
+
xavier_uniform_(self.in_proj_weight)
|
336 |
+
else:
|
337 |
+
xavier_uniform_(self.q_proj_weight)
|
338 |
+
xavier_uniform_(self.k_proj_weight)
|
339 |
+
xavier_uniform_(self.v_proj_weight)
|
340 |
+
|
341 |
+
if self.in_proj_bias is not None:
|
342 |
+
constant_(self.in_proj_bias, 0.0)
|
343 |
+
constant_(self.out_proj.bias, 0.0)
|
344 |
+
|
345 |
+
if self.bias_k is not None:
|
346 |
+
xavier_normal_(self.bias_k)
|
347 |
+
if self.bias_v is not None:
|
348 |
+
xavier_normal_(self.bias_v)
|
349 |
+
|
350 |
+
def __setstate__(self, state):
|
351 |
+
# Support loading old MultiheadAttention checkpoints generated by v1.1.0
|
352 |
+
if "_qkv_same_embed_dim" not in state:
|
353 |
+
state["_qkv_same_embed_dim"] = True
|
354 |
+
|
355 |
+
super(MultiheadAttention, self).__setstate__(state)
|
356 |
+
|
357 |
+
def forward(
|
358 |
+
self,
|
359 |
+
query: Tensor,
|
360 |
+
key: Tensor,
|
361 |
+
value: Tensor,
|
362 |
+
key_padding_mask: Optional[Tensor] = None,
|
363 |
+
need_weights: bool = True,
|
364 |
+
attn_mask: Optional[Tensor] = None,
|
365 |
+
average_attn_weights: bool = True,
|
366 |
+
) -> Tuple[Tensor, Optional[Tensor]]:
|
367 |
+
r"""
|
368 |
+
Args:
|
369 |
+
query: Query embeddings of shape :math:`(L, E_q)` for unbatched input, :math:`(L, N, E_q)` when ``batch_first=False``
|
370 |
+
or :math:`(N, L, E_q)` when ``batch_first=True``, where :math:`L` is the target sequence length,
|
371 |
+
:math:`N` is the batch size, and :math:`E_q` is the query embedding dimension ``embed_dim``.
|
372 |
+
Queries are compared against key-value pairs to produce the output.
|
373 |
+
See "Attention Is All You Need" for more details.
|
374 |
+
key: Key embeddings of shape :math:`(S, E_k)` for unbatched input, :math:`(S, N, E_k)` when ``batch_first=False``
|
375 |
+
or :math:`(N, S, E_k)` when ``batch_first=True``, where :math:`S` is the source sequence length,
|
376 |
+
:math:`N` is the batch size, and :math:`E_k` is the key embedding dimension ``kdim``.
|
377 |
+
See "Attention Is All You Need" for more details.
|
378 |
+
value: Value embeddings of shape :math:`(S, E_v)` for unbatched input, :math:`(S, N, E_v)` when
|
379 |
+
``batch_first=False`` or :math:`(N, S, E_v)` when ``batch_first=True``, where :math:`S` is the source
|
380 |
+
sequence length, :math:`N` is the batch size, and :math:`E_v` is the value embedding dimension ``vdim``.
|
381 |
+
See "Attention Is All You Need" for more details.
|
382 |
+
key_padding_mask: If specified, a mask of shape :math:`(N, S)` indicating which elements within ``key``
|
383 |
+
to ignore for the purpose of attention (i.e. treat as "padding"). For unbatched `query`, shape should be :math:`(S)`.
|
384 |
+
Binary and byte masks are supported.
|
385 |
+
For a binary mask, a ``True`` value indicates that the corresponding ``key`` value will be ignored for
|
386 |
+
the purpose of attention. For a float mask, it will be directly added to the corresponding ``key`` value.
|
387 |
+
need_weights: If specified, returns ``attn_output_weights`` in addition to ``attn_outputs``.
|
388 |
+
Default: ``True``.
|
389 |
+
attn_mask: If specified, a 2D or 3D mask preventing attention to certain positions. Must be of shape
|
390 |
+
:math:`(L, S)` or :math:`(N\cdot\text{num\_heads}, L, S)`, where :math:`N` is the batch size,
|
391 |
+
:math:`L` is the target sequence length, and :math:`S` is the source sequence length. A 2D mask will be
|
392 |
+
broadcasted across the batch while a 3D mask allows for a different mask for each entry in the batch.
|
393 |
+
Binary, byte, and float masks are supported. For a binary mask, a ``True`` value indicates that the
|
394 |
+
corresponding position is not allowed to attend. For a byte mask, a non-zero value indicates that the
|
395 |
+
corresponding position is not allowed to attend. For a float mask, the mask values will be added to
|
396 |
+
the attention weight.
|
397 |
+
average_attn_weights: If true, indicates that the returned ``attn_weights`` should be averaged across
|
398 |
+
heads. Otherwise, ``attn_weights`` are provided separately per head. Note that this flag only has an
|
399 |
+
effect when ``need_weights=True``. Default: ``True`` (i.e. average weights across heads)
|
400 |
+
|
401 |
+
Outputs:
|
402 |
+
- **attn_output** - Attention outputs of shape :math:`(L, E)` when input is unbatched,
|
403 |
+
:math:`(L, N, E)` when ``batch_first=False`` or :math:`(N, L, E)` when ``batch_first=True``,
|
404 |
+
where :math:`L` is the target sequence length, :math:`N` is the batch size, and :math:`E` is the
|
405 |
+
embedding dimension ``embed_dim``.
|
406 |
+
- **attn_output_weights** - Only returned when ``need_weights=True``. If ``average_attn_weights=True``,
|
407 |
+
returns attention weights averaged across heads of shape :math:`(L, S)` when input is unbatched or
|
408 |
+
:math:`(N, L, S)`, where :math:`N` is the batch size, :math:`L` is the target sequence length, and
|
409 |
+
:math:`S` is the source sequence length. If ``average_attn_weights=False``, returns attention weights per
|
410 |
+
head of shape :math:`(\text{num\_heads}, L, S)` when input is unbatched or :math:`(N, \text{num\_heads}, L, S)`.
|
411 |
+
|
412 |
+
.. note::
|
413 |
+
`batch_first` argument is ignored for unbatched inputs.
|
414 |
+
"""
|
415 |
+
is_batched = query.dim() == 3
|
416 |
+
if key_padding_mask is not None:
|
417 |
+
_kpm_dtype = key_padding_mask.dtype
|
418 |
+
if _kpm_dtype != torch.bool and not torch.is_floating_point(
|
419 |
+
key_padding_mask
|
420 |
+
):
|
421 |
+
raise AssertionError(
|
422 |
+
"only bool and floating types of key_padding_mask are supported"
|
423 |
+
)
|
424 |
+
why_not_fast_path = ""
|
425 |
+
if not is_batched:
|
426 |
+
why_not_fast_path = f"input not batched; expected query.dim() of 3 but got {query.dim()}"
|
427 |
+
elif query is not key or key is not value:
|
428 |
+
# When lifting this restriction, don't forget to either
|
429 |
+
# enforce that the dtypes all match or test cases where
|
430 |
+
# they don't!
|
431 |
+
why_not_fast_path = "non-self attention was used (query, key, and value are not the same Tensor)"
|
432 |
+
elif (
|
433 |
+
self.in_proj_bias is not None
|
434 |
+
and query.dtype != self.in_proj_bias.dtype
|
435 |
+
):
|
436 |
+
why_not_fast_path = f"dtypes of query ({query.dtype}) and self.in_proj_bias ({self.in_proj_bias.dtype}) don't match"
|
437 |
+
elif (
|
438 |
+
self.in_proj_weight is not None
|
439 |
+
and query.dtype != self.in_proj_weight.dtype
|
440 |
+
):
|
441 |
+
# this case will fail anyway, but at least they'll get a useful error message.
|
442 |
+
why_not_fast_path = f"dtypes of query ({query.dtype}) and self.in_proj_weight ({self.in_proj_weight.dtype}) don't match"
|
443 |
+
elif self.training:
|
444 |
+
why_not_fast_path = "training is enabled"
|
445 |
+
elif not self.batch_first:
|
446 |
+
why_not_fast_path = "batch_first was not True"
|
447 |
+
elif self.bias_k is not None:
|
448 |
+
why_not_fast_path = "self.bias_k was not None"
|
449 |
+
elif self.bias_v is not None:
|
450 |
+
why_not_fast_path = "self.bias_v was not None"
|
451 |
+
elif self.dropout:
|
452 |
+
why_not_fast_path = f"dropout was {self.dropout}, required zero"
|
453 |
+
elif self.add_zero_attn:
|
454 |
+
why_not_fast_path = "add_zero_attn was enabled"
|
455 |
+
elif not self._qkv_same_embed_dim:
|
456 |
+
why_not_fast_path = "_qkv_same_embed_dim was not True"
|
457 |
+
elif attn_mask is not None:
|
458 |
+
why_not_fast_path = "attn_mask was not None"
|
459 |
+
elif query.is_nested and key_padding_mask is not None:
|
460 |
+
why_not_fast_path = (
|
461 |
+
"key_padding_mask is not supported with NestedTensor input"
|
462 |
+
)
|
463 |
+
elif self.num_heads % 2 == 1:
|
464 |
+
why_not_fast_path = "num_heads is odd"
|
465 |
+
elif torch.is_autocast_enabled():
|
466 |
+
why_not_fast_path = "autocast is enabled"
|
467 |
+
|
468 |
+
if not why_not_fast_path:
|
469 |
+
tensor_args = (
|
470 |
+
query,
|
471 |
+
key,
|
472 |
+
value,
|
473 |
+
self.in_proj_weight,
|
474 |
+
self.in_proj_bias,
|
475 |
+
self.out_proj.weight,
|
476 |
+
self.out_proj.bias,
|
477 |
+
)
|
478 |
+
# We have to use list comprehensions below because TorchScript does not support
|
479 |
+
# generator expressions.
|
480 |
+
if torch.overrides.has_torch_function(tensor_args):
|
481 |
+
why_not_fast_path = "some Tensor argument has_torch_function"
|
482 |
+
elif not all(
|
483 |
+
[
|
484 |
+
(x is None or x.is_cuda or "cpu" in str(x.device))
|
485 |
+
for x in tensor_args
|
486 |
+
]
|
487 |
+
):
|
488 |
+
why_not_fast_path = (
|
489 |
+
"some Tensor argument is neither CUDA nor CPU"
|
490 |
+
)
|
491 |
+
elif torch.is_grad_enabled() and any(
|
492 |
+
[x is not None and x.requires_grad for x in tensor_args]
|
493 |
+
):
|
494 |
+
why_not_fast_path = (
|
495 |
+
"grad is enabled and at least one of query or the "
|
496 |
+
"input/output projection weights or biases requires_grad"
|
497 |
+
)
|
498 |
+
if not why_not_fast_path:
|
499 |
+
return torch._native_multi_head_attention(
|
500 |
+
query,
|
501 |
+
key,
|
502 |
+
value,
|
503 |
+
self.embed_dim,
|
504 |
+
self.num_heads,
|
505 |
+
self.in_proj_weight,
|
506 |
+
self.in_proj_bias,
|
507 |
+
self.out_proj.weight,
|
508 |
+
self.out_proj.bias,
|
509 |
+
key_padding_mask
|
510 |
+
if key_padding_mask is not None
|
511 |
+
else attn_mask,
|
512 |
+
need_weights,
|
513 |
+
average_attn_weights,
|
514 |
+
1
|
515 |
+
if key_padding_mask is not None
|
516 |
+
else 0
|
517 |
+
if attn_mask is not None
|
518 |
+
else None,
|
519 |
+
)
|
520 |
+
|
521 |
+
any_nested = query.is_nested or key.is_nested or value.is_nested
|
522 |
+
assert not any_nested, (
|
523 |
+
"MultiheadAttention does not support NestedTensor outside of its fast path. "
|
524 |
+
+ f"The fast path was not hit because {why_not_fast_path}"
|
525 |
+
)
|
526 |
+
|
527 |
+
if self.batch_first and is_batched:
|
528 |
+
# make sure that the transpose op does not affect the "is" property
|
529 |
+
if key is value:
|
530 |
+
if query is key:
|
531 |
+
query = key = value = query.transpose(1, 0)
|
532 |
+
else:
|
533 |
+
query, key = [x.transpose(1, 0) for x in (query, key)]
|
534 |
+
value = key
|
535 |
+
else:
|
536 |
+
query, key, value = [
|
537 |
+
x.transpose(1, 0) for x in (query, key, value)
|
538 |
+
]
|
539 |
+
|
540 |
+
if not self._qkv_same_embed_dim:
|
541 |
+
attn_output, attn_output_weights = F.multi_head_attention_forward(
|
542 |
+
query,
|
543 |
+
key,
|
544 |
+
value,
|
545 |
+
self.embed_dim,
|
546 |
+
self.num_heads,
|
547 |
+
self.in_proj_weight,
|
548 |
+
self.in_proj_bias,
|
549 |
+
self.bias_k,
|
550 |
+
self.bias_v,
|
551 |
+
self.add_zero_attn,
|
552 |
+
self.dropout,
|
553 |
+
self.out_proj.weight,
|
554 |
+
self.out_proj.bias,
|
555 |
+
training=self.training,
|
556 |
+
key_padding_mask=key_padding_mask,
|
557 |
+
need_weights=need_weights,
|
558 |
+
attn_mask=attn_mask,
|
559 |
+
use_separate_proj_weight=True,
|
560 |
+
q_proj_weight=self.q_proj_weight,
|
561 |
+
k_proj_weight=self.k_proj_weight,
|
562 |
+
v_proj_weight=self.v_proj_weight,
|
563 |
+
average_attn_weights=average_attn_weights,
|
564 |
+
)
|
565 |
+
else:
|
566 |
+
attn_output, attn_output_weights = F.multi_head_attention_forward(
|
567 |
+
query,
|
568 |
+
key,
|
569 |
+
value,
|
570 |
+
self.embed_dim,
|
571 |
+
self.num_heads,
|
572 |
+
self.in_proj_weight,
|
573 |
+
self.in_proj_bias,
|
574 |
+
self.bias_k,
|
575 |
+
self.bias_v,
|
576 |
+
self.add_zero_attn,
|
577 |
+
self.dropout,
|
578 |
+
self.out_proj.weight,
|
579 |
+
self.out_proj.bias,
|
580 |
+
training=self.training,
|
581 |
+
key_padding_mask=key_padding_mask,
|
582 |
+
need_weights=need_weights,
|
583 |
+
attn_mask=attn_mask,
|
584 |
+
average_attn_weights=average_attn_weights,
|
585 |
+
)
|
586 |
+
if self.batch_first and is_batched:
|
587 |
+
return attn_output.transpose(1, 0), attn_output_weights
|
588 |
+
else:
|
589 |
+
return attn_output, attn_output_weights
|
590 |
+
|
591 |
+
def infer(self,
|
592 |
+
x: Tensor,
|
593 |
+
key_padding_mask: Optional[Tensor] = None,
|
594 |
+
need_weights: bool = True,
|
595 |
+
attn_mask: Optional[Tensor] = None,
|
596 |
+
average_attn_weights: bool = True,
|
597 |
+
past_kv = None,
|
598 |
+
use_cache = False
|
599 |
+
):
|
600 |
+
# x = x.transpose(1, 0)
|
601 |
+
y, kv = multi_head_attention_forward(
|
602 |
+
x=x,
|
603 |
+
ipw=self.in_proj_weight,
|
604 |
+
ipb=self.in_proj_bias,
|
605 |
+
opw=self.out_proj.weight,
|
606 |
+
opb=self.out_proj.bias,
|
607 |
+
n_head=self.num_heads,
|
608 |
+
attn_mask=attn_mask,
|
609 |
+
past_kv=past_kv,
|
610 |
+
use_cache=use_cache,
|
611 |
+
)
|
612 |
+
return (y, kv)
|
modules/embedding.py
ADDED
@@ -0,0 +1,97 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023 (authors: Feiteng Li)
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
import math
|
16 |
+
|
17 |
+
import torch
|
18 |
+
import torch.nn as nn
|
19 |
+
|
20 |
+
|
21 |
+
class TokenEmbedding(nn.Module):
|
22 |
+
def __init__(
|
23 |
+
self,
|
24 |
+
dim_model: int,
|
25 |
+
vocab_size: int,
|
26 |
+
dropout: float = 0.0,
|
27 |
+
):
|
28 |
+
super().__init__()
|
29 |
+
|
30 |
+
self.vocab_size = vocab_size
|
31 |
+
self.dim_model = dim_model
|
32 |
+
|
33 |
+
self.dropout = torch.nn.Dropout(p=dropout)
|
34 |
+
self.word_embeddings = nn.Embedding(self.vocab_size, self.dim_model)
|
35 |
+
|
36 |
+
@property
|
37 |
+
def weight(self) -> torch.Tensor:
|
38 |
+
return self.word_embeddings.weight
|
39 |
+
|
40 |
+
def embedding(self, index: int) -> torch.Tensor:
|
41 |
+
return self.word_embeddings.weight[index : index + 1]
|
42 |
+
|
43 |
+
def forward(self, x: torch.Tensor):
|
44 |
+
X = self.word_embeddings(x)
|
45 |
+
X = self.dropout(X)
|
46 |
+
|
47 |
+
return X
|
48 |
+
|
49 |
+
|
50 |
+
class SinePositionalEmbedding(nn.Module):
|
51 |
+
def __init__(
|
52 |
+
self,
|
53 |
+
dim_model: int,
|
54 |
+
dropout: float = 0.0,
|
55 |
+
scale: bool = False,
|
56 |
+
alpha: bool = False,
|
57 |
+
):
|
58 |
+
super().__init__()
|
59 |
+
self.dim_model = dim_model
|
60 |
+
self.x_scale = math.sqrt(dim_model) if scale else 1.0
|
61 |
+
self.alpha = nn.Parameter(torch.ones(1), requires_grad=alpha)
|
62 |
+
self.dropout = torch.nn.Dropout(p=dropout)
|
63 |
+
|
64 |
+
self.reverse = False
|
65 |
+
self.pe = None
|
66 |
+
self.extend_pe(torch.tensor(0.0).expand(1, 4000))
|
67 |
+
|
68 |
+
def extend_pe(self, x):
|
69 |
+
"""Reset the positional encodings."""
|
70 |
+
if self.pe is not None:
|
71 |
+
if self.pe.size(1) >= x.size(1):
|
72 |
+
if self.pe.dtype != x.dtype or self.pe.device != x.device:
|
73 |
+
self.pe = self.pe.to(dtype=x.dtype, device=x.device)
|
74 |
+
return
|
75 |
+
pe = torch.zeros(x.size(1), self.dim_model)
|
76 |
+
if self.reverse:
|
77 |
+
position = torch.arange(
|
78 |
+
x.size(1) - 1, -1, -1.0, dtype=torch.float32
|
79 |
+
).unsqueeze(1)
|
80 |
+
else:
|
81 |
+
position = torch.arange(
|
82 |
+
0, x.size(1), dtype=torch.float32
|
83 |
+
).unsqueeze(1)
|
84 |
+
div_term = torch.exp(
|
85 |
+
torch.arange(0, self.dim_model, 2, dtype=torch.float32)
|
86 |
+
* -(math.log(10000.0) / self.dim_model)
|
87 |
+
)
|
88 |
+
pe[:, 0::2] = torch.sin(position * div_term)
|
89 |
+
pe[:, 1::2] = torch.cos(position * div_term)
|
90 |
+
pe = pe.unsqueeze(0)
|
91 |
+
self.pe = pe.to(device=x.device, dtype=x.dtype).detach()
|
92 |
+
|
93 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
94 |
+
self.extend_pe(x)
|
95 |
+
output = x.unsqueeze(-1) if x.ndim == 2 else x
|
96 |
+
output = output * self.x_scale + self.alpha * self.pe[:, : x.size(1)]
|
97 |
+
return self.dropout(output)
|
modules/optim.py
ADDED
@@ -0,0 +1,1105 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2022 Xiaomi Corp. (authors: Daniel Povey)
|
2 |
+
#
|
3 |
+
# See ../LICENSE for clarification regarding multiple authors
|
4 |
+
#
|
5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
6 |
+
# you may not use this file except in compliance with the License.
|
7 |
+
# You may obtain a copy of the License at
|
8 |
+
#
|
9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
10 |
+
#
|
11 |
+
# Unless required by applicable law or agreed to in writing, software
|
12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
14 |
+
# See the License for the specific language governing permissions and
|
15 |
+
# limitations under the License.
|
16 |
+
|
17 |
+
import contextlib
|
18 |
+
import logging
|
19 |
+
import random
|
20 |
+
from collections import defaultdict
|
21 |
+
from typing import List, Optional, Tuple, Union
|
22 |
+
|
23 |
+
import torch
|
24 |
+
from lhotse.utils import fix_random_seed
|
25 |
+
from torch import Tensor
|
26 |
+
from torch.optim import Optimizer
|
27 |
+
|
28 |
+
|
29 |
+
class BatchedOptimizer(Optimizer):
|
30 |
+
"""
|
31 |
+
This class adds to class Optimizer the capability to optimize parameters in batches:
|
32 |
+
it will stack the parameters and their grads for you so the optimizer can work
|
33 |
+
on tensors with an extra leading dimension. This is intended for speed with GPUs,
|
34 |
+
as it reduces the number of kernels launched in the optimizer.
|
35 |
+
|
36 |
+
Args:
|
37 |
+
params:
|
38 |
+
"""
|
39 |
+
|
40 |
+
def __init__(self, params, defaults):
|
41 |
+
super(BatchedOptimizer, self).__init__(params, defaults)
|
42 |
+
|
43 |
+
@contextlib.contextmanager
|
44 |
+
def batched_params(self, param_group, group_params_names):
|
45 |
+
"""
|
46 |
+
This function returns (technically, yields) a list of
|
47 |
+
of tuples (p, state), where
|
48 |
+
p is a `fake` parameter that is stacked (over axis 0) from real parameters
|
49 |
+
that share the same shape, and its gradient is also stacked;
|
50 |
+
`state` is the state corresponding to this batch of parameters
|
51 |
+
(it will be physically located in the "state" for one of the real
|
52 |
+
parameters, the last one that has any particular shape and dtype).
|
53 |
+
|
54 |
+
This function is decorated as a context manager so that it can
|
55 |
+
write parameters back to their "real" locations.
|
56 |
+
|
57 |
+
The idea is, instead of doing:
|
58 |
+
<code>
|
59 |
+
for p in group["params"]:
|
60 |
+
state = self.state[p]
|
61 |
+
...
|
62 |
+
</code>
|
63 |
+
you can do:
|
64 |
+
<code>
|
65 |
+
with self.batched_params(group["params"]) as batches:
|
66 |
+
for p, state, p_names in batches:
|
67 |
+
...
|
68 |
+
</code>
|
69 |
+
|
70 |
+
Args:
|
71 |
+
group: a parameter group, which is a list of parameters; should be
|
72 |
+
one of self.param_groups.
|
73 |
+
group_params_names: name for each parameter in group,
|
74 |
+
which is List[str].
|
75 |
+
"""
|
76 |
+
batches = defaultdict(
|
77 |
+
list
|
78 |
+
) # `batches` maps from tuple (dtype_as_str,*shape) to list of nn.Parameter
|
79 |
+
batches_names = defaultdict(
|
80 |
+
list
|
81 |
+
) # `batches` maps from tuple (dtype_as_str,*shape) to list of str
|
82 |
+
|
83 |
+
assert len(param_group) == len(group_params_names)
|
84 |
+
for p, named_p in zip(param_group, group_params_names):
|
85 |
+
key = (str(p.dtype), *p.shape)
|
86 |
+
batches[key].append(p)
|
87 |
+
batches_names[key].append(named_p)
|
88 |
+
|
89 |
+
batches_names_keys = list(batches_names.keys())
|
90 |
+
sorted_idx = sorted(
|
91 |
+
range(len(batches_names)), key=lambda i: batches_names_keys[i]
|
92 |
+
)
|
93 |
+
batches_names = [
|
94 |
+
batches_names[batches_names_keys[idx]] for idx in sorted_idx
|
95 |
+
]
|
96 |
+
batches = [batches[batches_names_keys[idx]] for idx in sorted_idx]
|
97 |
+
|
98 |
+
stacked_params_dict = dict()
|
99 |
+
|
100 |
+
# turn batches into a list, in deterministic order.
|
101 |
+
# tuples will contain tuples of (stacked_param, state, stacked_params_names),
|
102 |
+
# one for each batch in `batches`.
|
103 |
+
tuples = []
|
104 |
+
|
105 |
+
for batch, batch_names in zip(batches, batches_names):
|
106 |
+
p = batch[0]
|
107 |
+
# we arbitrarily store the state in the
|
108 |
+
# state corresponding to the 1st parameter in the
|
109 |
+
# group. class Optimizer will take care of saving/loading state.
|
110 |
+
state = self.state[p]
|
111 |
+
p_stacked = torch.stack(batch)
|
112 |
+
grad = torch.stack(
|
113 |
+
[
|
114 |
+
torch.zeros_like(p) if p.grad is None else p.grad
|
115 |
+
for p in batch
|
116 |
+
]
|
117 |
+
)
|
118 |
+
p_stacked.grad = grad
|
119 |
+
stacked_params_dict[key] = p_stacked
|
120 |
+
tuples.append((p_stacked, state, batch_names))
|
121 |
+
|
122 |
+
yield tuples # <-- calling code will do the actual optimization here!
|
123 |
+
|
124 |
+
for ((stacked_params, _state, _names), batch) in zip(tuples, batches):
|
125 |
+
for i, p in enumerate(batch): # batch is list of Parameter
|
126 |
+
p.copy_(stacked_params[i])
|
127 |
+
|
128 |
+
|
129 |
+
class ScaledAdam(BatchedOptimizer):
|
130 |
+
"""
|
131 |
+
Implements 'Scaled Adam', a variant of Adam where we scale each parameter's update
|
132 |
+
proportional to the norm of that parameter; and also learn the scale of the parameter,
|
133 |
+
in log space, subject to upper and lower limits (as if we had factored each parameter as
|
134 |
+
param = underlying_param * log_scale.exp())
|
135 |
+
|
136 |
+
|
137 |
+
Args:
|
138 |
+
params: The parameters or param_groups to optimize (like other Optimizer subclasses)
|
139 |
+
lr: The learning rate. We will typically use a learning rate schedule that starts
|
140 |
+
at 0.03 and decreases over time, i.e. much higher than other common
|
141 |
+
optimizers.
|
142 |
+
clipping_scale: (e.g. 2.0)
|
143 |
+
A scale for gradient-clipping: if specified, the normalized gradients
|
144 |
+
over the whole model will be clipped to have 2-norm equal to
|
145 |
+
`clipping_scale` times the median 2-norm over the most recent period
|
146 |
+
of `clipping_update_period` minibatches. By "normalized gradients",
|
147 |
+
we mean after multiplying by the rms parameter value for this tensor
|
148 |
+
[for non-scalars]; this is appropriate because our update is scaled
|
149 |
+
by this quantity.
|
150 |
+
betas: beta1,beta2 are momentum constants for regular momentum, and moving sum-sq grad.
|
151 |
+
Must satisfy 0 < beta <= beta2 < 1.
|
152 |
+
scalar_lr_scale: A scaling factor on the learning rate, that we use to update the
|
153 |
+
scale of each parameter tensor and scalar parameters of the mode..
|
154 |
+
If each parameter were decomposed
|
155 |
+
as p * p_scale.exp(), where (p**2).mean().sqrt() == 1.0, scalar_lr_scale
|
156 |
+
would be a the scaling factor on the learning rate of p_scale.
|
157 |
+
eps: A general-purpose epsilon to prevent division by zero
|
158 |
+
param_min_rms: Minimum root-mean-square value of parameter tensor, for purposes of
|
159 |
+
learning the scale on the parameters (we'll constrain the rms of each non-scalar
|
160 |
+
parameter tensor to be >= this value)
|
161 |
+
param_max_rms: Maximum root-mean-square value of parameter tensor, for purposes of
|
162 |
+
learning the scale on the parameters (we'll constrain the rms of each non-scalar
|
163 |
+
parameter tensor to be <= this value)
|
164 |
+
scalar_max: Maximum absolute value for scalar parameters (applicable if your
|
165 |
+
model has any parameters with numel() == 1).
|
166 |
+
size_update_period: The periodicity, in steps, with which we update the size (scale)
|
167 |
+
of the parameter tensor. This is provided to save a little time
|
168 |
+
in the update.
|
169 |
+
clipping_update_period: if clipping_scale is specified, this is the period
|
170 |
+
"""
|
171 |
+
|
172 |
+
def __init__(
|
173 |
+
self,
|
174 |
+
params,
|
175 |
+
lr=3e-02,
|
176 |
+
clipping_scale=None,
|
177 |
+
betas=(0.9, 0.98),
|
178 |
+
scalar_lr_scale=0.1,
|
179 |
+
eps=1.0e-08,
|
180 |
+
param_min_rms=1.0e-05,
|
181 |
+
param_max_rms=3.0,
|
182 |
+
scalar_max=10.0,
|
183 |
+
size_update_period=4,
|
184 |
+
clipping_update_period=100,
|
185 |
+
parameters_names=None,
|
186 |
+
show_dominant_parameters=True,
|
187 |
+
):
|
188 |
+
|
189 |
+
assert parameters_names is not None, (
|
190 |
+
"Please prepare parameters_names,"
|
191 |
+
"which is a List[List[str]]. Each List[str] is for a group"
|
192 |
+
"and each str is for a parameter"
|
193 |
+
)
|
194 |
+
defaults = dict(
|
195 |
+
lr=lr,
|
196 |
+
clipping_scale=clipping_scale,
|
197 |
+
betas=betas,
|
198 |
+
scalar_lr_scale=scalar_lr_scale,
|
199 |
+
eps=eps,
|
200 |
+
param_min_rms=param_min_rms,
|
201 |
+
param_max_rms=param_max_rms,
|
202 |
+
scalar_max=scalar_max,
|
203 |
+
size_update_period=size_update_period,
|
204 |
+
clipping_update_period=clipping_update_period,
|
205 |
+
)
|
206 |
+
|
207 |
+
super(ScaledAdam, self).__init__(params, defaults)
|
208 |
+
assert len(self.param_groups) == len(parameters_names)
|
209 |
+
self.parameters_names = parameters_names
|
210 |
+
self.show_dominant_parameters = show_dominant_parameters
|
211 |
+
|
212 |
+
def __setstate__(self, state):
|
213 |
+
super(ScaledAdam, self).__setstate__(state)
|
214 |
+
|
215 |
+
@torch.no_grad()
|
216 |
+
def step(self, closure=None):
|
217 |
+
"""Performs a single optimization step.
|
218 |
+
|
219 |
+
Arguments:
|
220 |
+
closure (callable, optional): A closure that reevaluates the model
|
221 |
+
and returns the loss.
|
222 |
+
"""
|
223 |
+
loss = None
|
224 |
+
if closure is not None:
|
225 |
+
with torch.enable_grad():
|
226 |
+
loss = closure()
|
227 |
+
|
228 |
+
batch = True
|
229 |
+
|
230 |
+
for group, group_params_names in zip(
|
231 |
+
self.param_groups, self.parameters_names
|
232 |
+
):
|
233 |
+
|
234 |
+
with self.batched_params(
|
235 |
+
group["params"], group_params_names
|
236 |
+
) as batches:
|
237 |
+
|
238 |
+
# batches is list of pairs (stacked_param, state). stacked_param is like
|
239 |
+
# a regular parameter, and will have a .grad, but the 1st dim corresponds to
|
240 |
+
# a stacking dim, it is not a real dim.
|
241 |
+
|
242 |
+
if (
|
243 |
+
len(batches[0][1]) == 0
|
244 |
+
): # if len(first state) == 0: not yet initialized
|
245 |
+
clipping_scale = 1
|
246 |
+
else:
|
247 |
+
clipping_scale = self._get_clipping_scale(group, batches)
|
248 |
+
|
249 |
+
for p, state, _ in batches:
|
250 |
+
# Perform optimization step.
|
251 |
+
# grad is not going to be None, we handled that when creating the batches.
|
252 |
+
grad = p.grad
|
253 |
+
if grad.is_sparse:
|
254 |
+
raise RuntimeError(
|
255 |
+
"ScaledAdam optimizer does not support sparse gradients"
|
256 |
+
)
|
257 |
+
# State initialization
|
258 |
+
if len(state) == 0:
|
259 |
+
self._init_state(group, p, state)
|
260 |
+
|
261 |
+
self._step_one_batch(group, p, state, clipping_scale)
|
262 |
+
|
263 |
+
return loss
|
264 |
+
|
265 |
+
def _init_state(self, group: dict, p: Tensor, state: dict):
|
266 |
+
"""
|
267 |
+
Initializes state dict for parameter 'p'. Assumes that dim 0 of tensor p
|
268 |
+
is actually the batch dimension, corresponding to batched-together
|
269 |
+
parameters of a given shape.
|
270 |
+
|
271 |
+
|
272 |
+
Args:
|
273 |
+
group: Dict to look up configuration values.
|
274 |
+
p: The parameter that we are initializing the state for
|
275 |
+
state: Dict from string to whatever state we are initializing
|
276 |
+
"""
|
277 |
+
size_update_period = group["size_update_period"]
|
278 |
+
|
279 |
+
state["step"] = 0
|
280 |
+
|
281 |
+
kwargs = {"device": p.device, "dtype": p.dtype}
|
282 |
+
|
283 |
+
# 'delta' implements conventional momentum. There are
|
284 |
+
# several different kinds of update going on, so rather than
|
285 |
+
# compute "exp_avg" like in Adam, we store and decay a
|
286 |
+
# parameter-change "delta", which combines all forms of
|
287 |
+
# update. this is equivalent to how it's done in Adam,
|
288 |
+
# except for the first few steps.
|
289 |
+
state["delta"] = torch.zeros_like(
|
290 |
+
p, memory_format=torch.preserve_format
|
291 |
+
)
|
292 |
+
|
293 |
+
batch_size = p.shape[0]
|
294 |
+
numel = p.numel() // batch_size
|
295 |
+
numel = p.numel()
|
296 |
+
|
297 |
+
if numel > 1:
|
298 |
+
# "param_rms" just periodically records the scalar root-mean-square value of
|
299 |
+
# the parameter tensor.
|
300 |
+
# it has a shape like (batch_size, 1, 1, 1, 1)
|
301 |
+
param_rms = (
|
302 |
+
(p ** 2).mean(dim=list(range(1, p.ndim)), keepdim=True).sqrt()
|
303 |
+
)
|
304 |
+
state["param_rms"] = param_rms
|
305 |
+
|
306 |
+
state["scale_exp_avg_sq"] = torch.zeros_like(param_rms)
|
307 |
+
state["scale_grads"] = torch.zeros(
|
308 |
+
size_update_period, *param_rms.shape, **kwargs
|
309 |
+
)
|
310 |
+
|
311 |
+
# exp_avg_sq is the weighted sum of scaled gradients. as in Adam.
|
312 |
+
state["exp_avg_sq"] = torch.zeros_like(
|
313 |
+
p, memory_format=torch.preserve_format
|
314 |
+
)
|
315 |
+
|
316 |
+
def _get_clipping_scale(
|
317 |
+
self, group: dict, tuples: List[Tuple[Tensor, dict, List[str]]]
|
318 |
+
) -> float:
|
319 |
+
"""
|
320 |
+
Returns a scalar factor <= 1.0 that dictates gradient clipping, i.e. we will scale the gradients
|
321 |
+
by this amount before applying the rest of the update.
|
322 |
+
|
323 |
+
Args:
|
324 |
+
group: the parameter group, an item in self.param_groups
|
325 |
+
tuples: a list of tuples of (param, state, param_names)
|
326 |
+
where param is a batched set of parameters,
|
327 |
+
with a .grad (1st dim is batch dim)
|
328 |
+
and state is the state-dict where optimization parameters are kept.
|
329 |
+
param_names is a List[str] while each str is name for a parameter
|
330 |
+
in batched set of parameters "param".
|
331 |
+
"""
|
332 |
+
assert len(tuples) >= 1
|
333 |
+
clipping_scale = group["clipping_scale"]
|
334 |
+
(first_p, first_state, _) = tuples[0]
|
335 |
+
step = first_state["step"]
|
336 |
+
if clipping_scale is None or step == 0:
|
337 |
+
# no clipping. return early on step == 0 because the other
|
338 |
+
# parameters' state won't have been initialized yet.
|
339 |
+
return 1.0
|
340 |
+
clipping_update_period = group["clipping_update_period"]
|
341 |
+
|
342 |
+
tot_sumsq = torch.tensor(0.0, device=first_p.device)
|
343 |
+
for (p, state, param_names) in tuples:
|
344 |
+
grad = p.grad
|
345 |
+
if grad.is_sparse:
|
346 |
+
raise RuntimeError(
|
347 |
+
"ScaledAdam optimizer does not support sparse gradients"
|
348 |
+
)
|
349 |
+
if p.numel() == p.shape[0]: # a batch of scalars
|
350 |
+
tot_sumsq += (
|
351 |
+
grad ** 2
|
352 |
+
).sum() # sum() to change shape [1] to []
|
353 |
+
else:
|
354 |
+
tot_sumsq += ((grad * state["param_rms"]) ** 2).sum()
|
355 |
+
|
356 |
+
tot_norm = tot_sumsq.sqrt()
|
357 |
+
if "model_norms" not in first_state:
|
358 |
+
first_state["model_norms"] = torch.zeros(
|
359 |
+
clipping_update_period, device=p.device
|
360 |
+
)
|
361 |
+
first_state["model_norms"][step % clipping_update_period] = tot_norm
|
362 |
+
|
363 |
+
if step % clipping_update_period == 0:
|
364 |
+
# Print some stats.
|
365 |
+
# We don't reach here if step == 0 because we would have returned
|
366 |
+
# above.
|
367 |
+
sorted_norms = first_state["model_norms"].sort()[0].to("cpu")
|
368 |
+
quartiles = []
|
369 |
+
for n in range(0, 5):
|
370 |
+
index = min(
|
371 |
+
clipping_update_period - 1,
|
372 |
+
(clipping_update_period // 4) * n,
|
373 |
+
)
|
374 |
+
quartiles.append(sorted_norms[index].item())
|
375 |
+
|
376 |
+
median = quartiles[2]
|
377 |
+
threshold = clipping_scale * median
|
378 |
+
first_state["model_norm_threshold"] = threshold
|
379 |
+
percent_clipped = (
|
380 |
+
first_state["num_clipped"] * 100.0 / clipping_update_period
|
381 |
+
if "num_clipped" in first_state
|
382 |
+
else 0.0
|
383 |
+
)
|
384 |
+
first_state["num_clipped"] = 0
|
385 |
+
quartiles = " ".join(["%.3e" % x for x in quartiles])
|
386 |
+
logging.info(
|
387 |
+
f"Clipping_scale={clipping_scale}, grad-norm quartiles {quartiles}, "
|
388 |
+
f"threshold={threshold:.3e}, percent-clipped={percent_clipped:.1f}"
|
389 |
+
)
|
390 |
+
|
391 |
+
if step < clipping_update_period:
|
392 |
+
return 1.0 # We have not yet estimated a norm to clip to.
|
393 |
+
else:
|
394 |
+
try:
|
395 |
+
model_norm_threshold = first_state["model_norm_threshold"]
|
396 |
+
except KeyError:
|
397 |
+
logging.info(
|
398 |
+
"Warning: model_norm_threshold not in state: possibly "
|
399 |
+
"you changed config when restarting, adding clipping_scale option?"
|
400 |
+
)
|
401 |
+
return 1.0
|
402 |
+
ans = min(1.0, (model_norm_threshold / (tot_norm + 1.0e-20)).item())
|
403 |
+
if ans < 1.0:
|
404 |
+
first_state["num_clipped"] += 1
|
405 |
+
if ans < 0.1:
|
406 |
+
logging.warn(
|
407 |
+
f"Scaling gradients by {ans}, model_norm_threshold={model_norm_threshold}"
|
408 |
+
)
|
409 |
+
if self.show_dominant_parameters:
|
410 |
+
assert p.shape[0] == len(param_names)
|
411 |
+
self._show_gradient_dominating_parameter(tuples, tot_sumsq)
|
412 |
+
return ans
|
413 |
+
|
414 |
+
def _show_gradient_dominating_parameter(
|
415 |
+
self, tuples: List[Tuple[Tensor, dict, List[str]]], tot_sumsq: Tensor
|
416 |
+
):
|
417 |
+
"""
|
418 |
+
Show information of parameter wihch dominanting tot_sumsq.
|
419 |
+
|
420 |
+
Args:
|
421 |
+
tuples: a list of tuples of (param, state, param_names)
|
422 |
+
where param is a batched set of parameters,
|
423 |
+
with a .grad (1st dim is batch dim)
|
424 |
+
and state is the state-dict where optimization parameters are kept.
|
425 |
+
param_names is a List[str] while each str is name for a parameter
|
426 |
+
in batched set of parameters "param".
|
427 |
+
tot_sumsq: sumsq of all parameters. Though it's could be calculated
|
428 |
+
from tuples, we still pass it to save some time.
|
429 |
+
"""
|
430 |
+
all_sumsq_orig = {}
|
431 |
+
for (p, state, batch_param_names) in tuples:
|
432 |
+
# p is a stacked batch parameters.
|
433 |
+
batch_grad = p.grad
|
434 |
+
if p.numel() == p.shape[0]: # a batch of scalars
|
435 |
+
batch_sumsq_orig = batch_grad ** 2
|
436 |
+
# Dummpy values used by following `zip` statement.
|
437 |
+
batch_rms_orig = torch.ones(p.shape[0])
|
438 |
+
else:
|
439 |
+
batch_rms_orig = state["param_rms"]
|
440 |
+
batch_sumsq_orig = ((batch_grad * batch_rms_orig) ** 2).sum(
|
441 |
+
dim=list(range(1, batch_grad.ndim))
|
442 |
+
)
|
443 |
+
|
444 |
+
for name, sumsq_orig, rms, grad in zip(
|
445 |
+
batch_param_names, batch_sumsq_orig, batch_rms_orig, batch_grad
|
446 |
+
):
|
447 |
+
|
448 |
+
proportion_orig = sumsq_orig / tot_sumsq
|
449 |
+
all_sumsq_orig[name] = (proportion_orig, sumsq_orig, rms, grad)
|
450 |
+
|
451 |
+
assert torch.isclose(
|
452 |
+
sum([value[0] for value in all_sumsq_orig.values()]).cpu(),
|
453 |
+
torch.tensor(1.0),
|
454 |
+
)
|
455 |
+
sorted_by_proportion = {
|
456 |
+
k: v
|
457 |
+
for k, v in sorted(
|
458 |
+
all_sumsq_orig.items(),
|
459 |
+
key=lambda item: item[1][0],
|
460 |
+
reverse=True,
|
461 |
+
)
|
462 |
+
}
|
463 |
+
dominant_param_name = next(iter(sorted_by_proportion))
|
464 |
+
(
|
465 |
+
dominant_proportion,
|
466 |
+
dominant_sumsq,
|
467 |
+
dominant_rms,
|
468 |
+
dominant_grad,
|
469 |
+
) = sorted_by_proportion[dominant_param_name]
|
470 |
+
logging.info(
|
471 |
+
f"Parameter Dominanting tot_sumsq {dominant_param_name}"
|
472 |
+
f" with proportion {dominant_proportion:.2f},"
|
473 |
+
f" where dominant_sumsq=(grad_sumsq*orig_rms_sq)"
|
474 |
+
f"={dominant_sumsq:.3e},"
|
475 |
+
f" grad_sumsq = {(dominant_grad**2).sum():.3e},"
|
476 |
+
f" orig_rms_sq={(dominant_rms**2).item():.3e}"
|
477 |
+
)
|
478 |
+
|
479 |
+
def _step_one_batch(
|
480 |
+
self, group: dict, p: Tensor, state: dict, clipping_scale: float
|
481 |
+
):
|
482 |
+
"""
|
483 |
+
Do the step for one parameter, which is actually going to be a batch of
|
484 |
+
`real` parameters, with dim 0 as the batch dim.
|
485 |
+
Args:
|
486 |
+
group: dict to look up configuration values
|
487 |
+
p: parameter to update (actually multiple parameters stacked together
|
488 |
+
as a batch)
|
489 |
+
state: state-dict for p, to look up the optimizer state
|
490 |
+
"""
|
491 |
+
lr = group["lr"]
|
492 |
+
size_update_period = group["size_update_period"]
|
493 |
+
beta1 = group["betas"][0]
|
494 |
+
|
495 |
+
grad = p.grad
|
496 |
+
if clipping_scale != 1.0:
|
497 |
+
grad = grad * clipping_scale
|
498 |
+
step = state["step"]
|
499 |
+
delta = state["delta"]
|
500 |
+
|
501 |
+
delta.mul_(beta1)
|
502 |
+
batch_size = p.shape[0]
|
503 |
+
numel = p.numel() // batch_size
|
504 |
+
if numel > 1:
|
505 |
+
# Update the size/scale of p, and set param_rms
|
506 |
+
scale_grads = state["scale_grads"]
|
507 |
+
scale_grads[step % size_update_period] = (p * grad).sum(
|
508 |
+
dim=list(range(1, p.ndim)), keepdim=True
|
509 |
+
)
|
510 |
+
if step % size_update_period == size_update_period - 1:
|
511 |
+
param_rms = state["param_rms"] # shape: (batch_size, 1, 1, ..)
|
512 |
+
param_rms.copy_(
|
513 |
+
(p ** 2)
|
514 |
+
.mean(dim=list(range(1, p.ndim)), keepdim=True)
|
515 |
+
.sqrt()
|
516 |
+
)
|
517 |
+
if step > 0:
|
518 |
+
# self._size_update() learns the overall scale on the
|
519 |
+
# parameter, by shrinking or expanding it.
|
520 |
+
self._size_update(group, scale_grads, p, state)
|
521 |
+
|
522 |
+
if numel == 1:
|
523 |
+
# For parameters with 1 element we just use regular Adam.
|
524 |
+
# Updates delta.
|
525 |
+
self._step_scalar(group, p, state)
|
526 |
+
else:
|
527 |
+
self._step(group, p, state)
|
528 |
+
|
529 |
+
state["step"] = step + 1
|
530 |
+
|
531 |
+
def _size_update(
|
532 |
+
self, group: dict, scale_grads: Tensor, p: Tensor, state: dict
|
533 |
+
) -> None:
|
534 |
+
"""
|
535 |
+
Called only where p.numel() > 1, this updates the scale of the parameter.
|
536 |
+
If we imagine: p = underlying_param * scale.exp(), and we are doing
|
537 |
+
gradient descent on underlying param and on scale, this function does the update
|
538 |
+
on `scale`.
|
539 |
+
|
540 |
+
Args:
|
541 |
+
group: dict to look up configuration values
|
542 |
+
scale_grads: a tensor of shape (size_update_period, batch_size, 1, 1,...) containing
|
543 |
+
grads w.r.t. the scales.
|
544 |
+
p: The parameter to update
|
545 |
+
state: The state-dict of p
|
546 |
+
"""
|
547 |
+
|
548 |
+
param_rms = state["param_rms"]
|
549 |
+
beta1, beta2 = group["betas"]
|
550 |
+
size_lr = group["lr"] * group["scalar_lr_scale"]
|
551 |
+
param_min_rms = group["param_min_rms"]
|
552 |
+
param_max_rms = group["param_max_rms"]
|
553 |
+
eps = group["eps"]
|
554 |
+
step = state["step"]
|
555 |
+
batch_size = p.shape[0]
|
556 |
+
|
557 |
+
size_update_period = scale_grads.shape[0]
|
558 |
+
# correct beta2 for the size update period: we will have
|
559 |
+
# faster decay at this level.
|
560 |
+
beta2_corr = beta2 ** size_update_period
|
561 |
+
|
562 |
+
scale_exp_avg_sq = state[
|
563 |
+
"scale_exp_avg_sq"
|
564 |
+
] # shape: (batch_size, 1, 1, ..)
|
565 |
+
scale_exp_avg_sq.mul_(beta2_corr).add_(
|
566 |
+
(scale_grads ** 2).mean(
|
567 |
+
dim=0
|
568 |
+
), # mean over dim `size_update_period`
|
569 |
+
alpha=1 - beta2_corr,
|
570 |
+
) # shape is (batch_size, 1, 1, ...)
|
571 |
+
|
572 |
+
# The 1st time we reach here is when size_step == 1.
|
573 |
+
size_step = (step + 1) // size_update_period
|
574 |
+
bias_correction2 = 1 - beta2_corr ** size_step
|
575 |
+
# we don't bother with bias_correction1; this will help prevent divergence
|
576 |
+
# at the start of training.
|
577 |
+
|
578 |
+
denom = scale_exp_avg_sq.sqrt() + eps
|
579 |
+
|
580 |
+
scale_step = (
|
581 |
+
-size_lr
|
582 |
+
* (bias_correction2 ** 0.5)
|
583 |
+
* scale_grads.sum(dim=0)
|
584 |
+
/ denom
|
585 |
+
)
|
586 |
+
|
587 |
+
is_too_small = param_rms < param_min_rms
|
588 |
+
is_too_large = param_rms > param_max_rms
|
589 |
+
|
590 |
+
# when the param gets too small, just don't shrink it any further.
|
591 |
+
scale_step.masked_fill_(is_too_small, 0.0)
|
592 |
+
# when it gets too large, stop it from getting any larger.
|
593 |
+
scale_step.masked_fill_(is_too_large, -size_lr * size_update_period)
|
594 |
+
delta = state["delta"]
|
595 |
+
# the factor of (1-beta1) relates to momentum.
|
596 |
+
delta.add_(p * scale_step, alpha=(1 - beta1))
|
597 |
+
|
598 |
+
def _step(self, group: dict, p: Tensor, state: dict):
|
599 |
+
"""
|
600 |
+
This function does the core update of self.step(), in the case where the members of
|
601 |
+
the batch have more than 1 element.
|
602 |
+
|
603 |
+
Args:
|
604 |
+
group: A dict which will be used to look up configuration values
|
605 |
+
p: The parameter to be updated
|
606 |
+
grad: The grad of p
|
607 |
+
state: The state-dict corresponding to parameter p
|
608 |
+
|
609 |
+
This function modifies p.
|
610 |
+
"""
|
611 |
+
grad = p.grad
|
612 |
+
lr = group["lr"]
|
613 |
+
beta1, beta2 = group["betas"]
|
614 |
+
eps = group["eps"]
|
615 |
+
param_min_rms = group["param_min_rms"]
|
616 |
+
step = state["step"]
|
617 |
+
|
618 |
+
exp_avg_sq = state["exp_avg_sq"]
|
619 |
+
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=(1 - beta2))
|
620 |
+
|
621 |
+
this_step = state["step"] - (
|
622 |
+
state["zero_step"] if "zero_step" in state else 0
|
623 |
+
)
|
624 |
+
bias_correction2 = 1 - beta2 ** (this_step + 1)
|
625 |
+
if bias_correction2 < 0.99:
|
626 |
+
# note: not in-place.
|
627 |
+
exp_avg_sq = exp_avg_sq * (1.0 / bias_correction2)
|
628 |
+
|
629 |
+
denom = exp_avg_sq.sqrt()
|
630 |
+
denom += eps
|
631 |
+
grad = grad / denom
|
632 |
+
|
633 |
+
alpha = -lr * (1 - beta1) * state["param_rms"].clamp(min=param_min_rms)
|
634 |
+
|
635 |
+
delta = state["delta"]
|
636 |
+
delta.add_(grad * alpha)
|
637 |
+
p.add_(delta)
|
638 |
+
|
639 |
+
def _step_scalar(self, group: dict, p: Tensor, state: dict):
|
640 |
+
"""
|
641 |
+
A simplified form of the core update for scalar tensors, where we cannot get a good
|
642 |
+
estimate of the parameter rms.
|
643 |
+
"""
|
644 |
+
beta1, beta2 = group["betas"]
|
645 |
+
scalar_max = group["scalar_max"]
|
646 |
+
eps = group["eps"]
|
647 |
+
lr = group["lr"] * group["scalar_lr_scale"]
|
648 |
+
grad = p.grad
|
649 |
+
|
650 |
+
exp_avg_sq = state["exp_avg_sq"] # shape: (batch_size,)
|
651 |
+
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
|
652 |
+
|
653 |
+
# bias_correction2 is like in Adam. Don't bother with bias_correction1;
|
654 |
+
# slower update at the start will help stability anyway.
|
655 |
+
bias_correction2 = 1 - beta2 ** (state["step"] + 1)
|
656 |
+
denom = (exp_avg_sq / bias_correction2).sqrt() + eps
|
657 |
+
|
658 |
+
delta = state["delta"]
|
659 |
+
delta.add_(grad / denom, alpha=-lr * (1 - beta1))
|
660 |
+
p.clamp_(min=-scalar_max, max=scalar_max)
|
661 |
+
p.add_(delta)
|
662 |
+
|
663 |
+
|
664 |
+
class LRScheduler(object):
|
665 |
+
"""
|
666 |
+
Base-class for learning rate schedulers where the learning-rate depends on both the
|
667 |
+
batch and the epoch.
|
668 |
+
"""
|
669 |
+
|
670 |
+
def __init__(self, optimizer: Optimizer, verbose: bool = False):
|
671 |
+
# Attach optimizer
|
672 |
+
if not isinstance(optimizer, Optimizer):
|
673 |
+
raise TypeError(
|
674 |
+
"{} is not an Optimizer".format(type(optimizer).__name__)
|
675 |
+
)
|
676 |
+
self.optimizer = optimizer
|
677 |
+
self.verbose = verbose
|
678 |
+
|
679 |
+
for group in optimizer.param_groups:
|
680 |
+
group.setdefault("base_lr", group["lr"])
|
681 |
+
|
682 |
+
self.base_lrs = [group["base_lr"] for group in optimizer.param_groups]
|
683 |
+
|
684 |
+
self.epoch = 0
|
685 |
+
self.batch = 0
|
686 |
+
|
687 |
+
def state_dict(self):
|
688 |
+
"""Returns the state of the scheduler as a :class:`dict`.
|
689 |
+
|
690 |
+
It contains an entry for every variable in self.__dict__ which
|
691 |
+
is not the optimizer.
|
692 |
+
"""
|
693 |
+
return {
|
694 |
+
"base_lrs": self.base_lrs,
|
695 |
+
"epoch": self.epoch,
|
696 |
+
"batch": self.batch,
|
697 |
+
}
|
698 |
+
|
699 |
+
def load_state_dict(self, state_dict):
|
700 |
+
"""Loads the schedulers state.
|
701 |
+
|
702 |
+
Args:
|
703 |
+
state_dict (dict): scheduler state. Should be an object returned
|
704 |
+
from a call to :meth:`state_dict`.
|
705 |
+
"""
|
706 |
+
self.__dict__.update(state_dict)
|
707 |
+
|
708 |
+
def get_last_lr(self) -> List[float]:
|
709 |
+
"""Return last computed learning rate by current scheduler. Will be a list of float."""
|
710 |
+
return self._last_lr
|
711 |
+
|
712 |
+
def get_lr(self):
|
713 |
+
# Compute list of learning rates from self.epoch and self.batch and
|
714 |
+
# self.base_lrs; this must be overloaded by the user.
|
715 |
+
# e.g. return [some_formula(self.batch, self.epoch, base_lr) for base_lr in self.base_lrs ]
|
716 |
+
raise NotImplementedError
|
717 |
+
|
718 |
+
def step_batch(self, batch: Optional[int] = None) -> None:
|
719 |
+
# Step the batch index, or just set it. If `batch` is specified, it
|
720 |
+
# must be the batch index from the start of training, i.e. summed over
|
721 |
+
# all epochs.
|
722 |
+
# You can call this in any order; if you don't provide 'batch', it should
|
723 |
+
# of course be called once per batch.
|
724 |
+
if batch is not None:
|
725 |
+
self.batch = batch
|
726 |
+
else:
|
727 |
+
self.batch = self.batch + 1
|
728 |
+
self._set_lrs()
|
729 |
+
|
730 |
+
def step_epoch(self, epoch: Optional[int] = None):
|
731 |
+
# Step the epoch index, or just set it. If you provide the 'epoch' arg,
|
732 |
+
# you should call this at the start of the epoch; if you don't provide the 'epoch'
|
733 |
+
# arg, you should call it at the end of the epoch.
|
734 |
+
if epoch is not None:
|
735 |
+
self.epoch = epoch
|
736 |
+
else:
|
737 |
+
self.epoch = self.epoch + 1
|
738 |
+
self._set_lrs()
|
739 |
+
|
740 |
+
def _set_lrs(self):
|
741 |
+
values = self.get_lr()
|
742 |
+
assert len(values) == len(self.optimizer.param_groups)
|
743 |
+
|
744 |
+
for i, data in enumerate(zip(self.optimizer.param_groups, values)):
|
745 |
+
param_group, lr = data
|
746 |
+
param_group["lr"] = lr
|
747 |
+
self.print_lr(self.verbose, i, lr)
|
748 |
+
self._last_lr = [group["lr"] for group in self.optimizer.param_groups]
|
749 |
+
|
750 |
+
def print_lr(self, is_verbose, group, lr):
|
751 |
+
"""Display the current learning rate."""
|
752 |
+
if is_verbose:
|
753 |
+
logging.info(
|
754 |
+
f"Epoch={self.epoch}, batch={self.batch}: adjusting learning rate"
|
755 |
+
f" of group {group} to {lr:.4e}."
|
756 |
+
)
|
757 |
+
|
758 |
+
|
759 |
+
class Eden(LRScheduler):
|
760 |
+
"""
|
761 |
+
Eden scheduler.
|
762 |
+
The basic formula (before warmup) is:
|
763 |
+
lr = base_lr * (((batch**2 + lr_batches**2) / lr_batches**2) ** -0.25 *
|
764 |
+
(((epoch**2 + lr_epochs**2) / lr_epochs**2) ** -0.25)) * warmup
|
765 |
+
where `warmup` increases from linearly 0.5 to 1 over `warmup_batches` batches
|
766 |
+
and then stays constant at 1.
|
767 |
+
|
768 |
+
|
769 |
+
E.g. suggest base_lr = 0.04 (passed to optimizer) if used with ScaledAdam
|
770 |
+
|
771 |
+
Args:
|
772 |
+
optimizer: the optimizer to change the learning rates on
|
773 |
+
lr_batches: the number of batches after which we start significantly
|
774 |
+
decreasing the learning rate, suggest 5000.
|
775 |
+
lr_epochs: the number of epochs after which we start significantly
|
776 |
+
decreasing the learning rate, suggest 6 if you plan to do e.g.
|
777 |
+
20 to 40 epochs, but may need smaller number if dataset is huge
|
778 |
+
and you will do few epochs.
|
779 |
+
"""
|
780 |
+
|
781 |
+
def __init__(
|
782 |
+
self,
|
783 |
+
optimizer: Optimizer,
|
784 |
+
lr_batches: Union[int, float],
|
785 |
+
lr_epochs: Union[int, float],
|
786 |
+
warmup_batches: Union[int, float] = 500.0,
|
787 |
+
verbose: bool = False,
|
788 |
+
):
|
789 |
+
super(Eden, self).__init__(optimizer, verbose)
|
790 |
+
self.lr_batches = lr_batches
|
791 |
+
self.lr_epochs = lr_epochs
|
792 |
+
self.warmup_batches = warmup_batches
|
793 |
+
|
794 |
+
def get_lr(self):
|
795 |
+
factor = (
|
796 |
+
(self.batch ** 2 + self.lr_batches ** 2) / self.lr_batches ** 2
|
797 |
+
) ** -0.25 * (
|
798 |
+
((self.epoch ** 2 + self.lr_epochs ** 2) / self.lr_epochs ** 2)
|
799 |
+
** -0.25
|
800 |
+
)
|
801 |
+
warmup_factor = (
|
802 |
+
1.0
|
803 |
+
if self.batch >= self.warmup_batches
|
804 |
+
else 0.5 + 0.5 * (self.batch / self.warmup_batches)
|
805 |
+
)
|
806 |
+
|
807 |
+
return [x * factor * warmup_factor for x in self.base_lrs]
|
808 |
+
|
809 |
+
|
810 |
+
def _test_eden():
|
811 |
+
m = torch.nn.Linear(100, 100)
|
812 |
+
optim = ScaledAdam(m.parameters(), lr=0.03)
|
813 |
+
|
814 |
+
scheduler = Eden(optim, lr_batches=100, lr_epochs=2, verbose=True)
|
815 |
+
|
816 |
+
for epoch in range(10):
|
817 |
+
scheduler.step_epoch(epoch) # sets epoch to `epoch`
|
818 |
+
|
819 |
+
for step in range(20):
|
820 |
+
x = torch.randn(200, 100).detach()
|
821 |
+
x.requires_grad = True
|
822 |
+
y = m(x)
|
823 |
+
dy = torch.randn(200, 100).detach()
|
824 |
+
f = (y * dy).sum()
|
825 |
+
f.backward()
|
826 |
+
|
827 |
+
optim.step()
|
828 |
+
scheduler.step_batch()
|
829 |
+
optim.zero_grad()
|
830 |
+
|
831 |
+
logging.info(f"last lr = {scheduler.get_last_lr()}")
|
832 |
+
logging.info(f"state dict = {scheduler.state_dict()}")
|
833 |
+
|
834 |
+
|
835 |
+
# This is included mostly as a baseline for ScaledAdam.
|
836 |
+
class Eve(Optimizer):
|
837 |
+
"""
|
838 |
+
Implements Eve algorithm. This is a modified version of AdamW with a special
|
839 |
+
way of setting the weight-decay / shrinkage-factor, which is designed to make the
|
840 |
+
rms of the parameters approach a particular target_rms (default: 0.1). This is
|
841 |
+
for use with networks with 'scaled' versions of modules (see scaling.py), which
|
842 |
+
will be close to invariant to the absolute scale on the parameter matrix.
|
843 |
+
|
844 |
+
The original Adam algorithm was proposed in `Adam: A Method for Stochastic Optimization`_.
|
845 |
+
The AdamW variant was proposed in `Decoupled Weight Decay Regularization`_.
|
846 |
+
Eve is unpublished so far.
|
847 |
+
|
848 |
+
Arguments:
|
849 |
+
params (iterable): iterable of parameters to optimize or dicts defining
|
850 |
+
parameter groups
|
851 |
+
lr (float, optional): learning rate (default: 1e-3)
|
852 |
+
betas (Tuple[float, float], optional): coefficients used for computing
|
853 |
+
running averages of gradient and its square (default: (0.9, 0.999))
|
854 |
+
eps (float, optional): term added to the denominator to improve
|
855 |
+
numerical stability (default: 1e-8)
|
856 |
+
weight_decay (float, optional): weight decay coefficient (default: 3e-4;
|
857 |
+
this value means that the weight would decay significantly after
|
858 |
+
about 3k minibatches. Is not multiplied by learning rate, but
|
859 |
+
is conditional on RMS-value of parameter being > target_rms.
|
860 |
+
target_rms (float, optional): target root-mean-square value of
|
861 |
+
parameters, if they fall below this we will stop applying weight decay.
|
862 |
+
|
863 |
+
|
864 |
+
.. _Adam: A Method for Stochastic Optimization:
|
865 |
+
https://arxiv.org/abs/1412.6980
|
866 |
+
.. _Decoupled Weight Decay Regularization:
|
867 |
+
https://arxiv.org/abs/1711.05101
|
868 |
+
.. _On the Convergence of Adam and Beyond:
|
869 |
+
https://openreview.net/forum?id=ryQu7f-RZ
|
870 |
+
"""
|
871 |
+
|
872 |
+
def __init__(
|
873 |
+
self,
|
874 |
+
params,
|
875 |
+
lr=1e-3,
|
876 |
+
betas=(0.9, 0.98),
|
877 |
+
eps=1e-8,
|
878 |
+
weight_decay=1e-3,
|
879 |
+
target_rms=0.1,
|
880 |
+
):
|
881 |
+
if not 0.0 <= lr:
|
882 |
+
raise ValueError("Invalid learning rate: {}".format(lr))
|
883 |
+
if not 0.0 <= eps:
|
884 |
+
raise ValueError("Invalid epsilon value: {}".format(eps))
|
885 |
+
if not 0.0 <= betas[0] < 1.0:
|
886 |
+
raise ValueError(
|
887 |
+
"Invalid beta parameter at index 0: {}".format(betas[0])
|
888 |
+
)
|
889 |
+
if not 0.0 <= betas[1] < 1.0:
|
890 |
+
raise ValueError(
|
891 |
+
"Invalid beta parameter at index 1: {}".format(betas[1])
|
892 |
+
)
|
893 |
+
if not 0 <= weight_decay <= 0.1:
|
894 |
+
raise ValueError(
|
895 |
+
"Invalid weight_decay value: {}".format(weight_decay)
|
896 |
+
)
|
897 |
+
if not 0 < target_rms <= 10.0:
|
898 |
+
raise ValueError("Invalid target_rms value: {}".format(target_rms))
|
899 |
+
defaults = dict(
|
900 |
+
lr=lr,
|
901 |
+
betas=betas,
|
902 |
+
eps=eps,
|
903 |
+
weight_decay=weight_decay,
|
904 |
+
target_rms=target_rms,
|
905 |
+
)
|
906 |
+
super(Eve, self).__init__(params, defaults)
|
907 |
+
|
908 |
+
def __setstate__(self, state):
|
909 |
+
super(Eve, self).__setstate__(state)
|
910 |
+
|
911 |
+
@torch.no_grad()
|
912 |
+
def step(self, closure=None):
|
913 |
+
"""Performs a single optimization step.
|
914 |
+
|
915 |
+
Arguments:
|
916 |
+
closure (callable, optional): A closure that reevaluates the model
|
917 |
+
and returns the loss.
|
918 |
+
"""
|
919 |
+
loss = None
|
920 |
+
if closure is not None:
|
921 |
+
with torch.enable_grad():
|
922 |
+
loss = closure()
|
923 |
+
|
924 |
+
for group in self.param_groups:
|
925 |
+
for p in group["params"]:
|
926 |
+
if p.grad is None:
|
927 |
+
continue
|
928 |
+
|
929 |
+
# Perform optimization step
|
930 |
+
grad = p.grad
|
931 |
+
if grad.is_sparse:
|
932 |
+
raise RuntimeError(
|
933 |
+
"AdamW does not support sparse gradients"
|
934 |
+
)
|
935 |
+
|
936 |
+
state = self.state[p]
|
937 |
+
|
938 |
+
# State initialization
|
939 |
+
if len(state) == 0:
|
940 |
+
state["step"] = 0
|
941 |
+
# Exponential moving average of gradient values
|
942 |
+
state["exp_avg"] = torch.zeros_like(
|
943 |
+
p, memory_format=torch.preserve_format
|
944 |
+
)
|
945 |
+
# Exponential moving average of squared gradient values
|
946 |
+
state["exp_avg_sq"] = torch.zeros_like(
|
947 |
+
p, memory_format=torch.preserve_format
|
948 |
+
)
|
949 |
+
|
950 |
+
exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"]
|
951 |
+
|
952 |
+
beta1, beta2 = group["betas"]
|
953 |
+
|
954 |
+
state["step"] += 1
|
955 |
+
bias_correction1 = 1 - beta1 ** state["step"]
|
956 |
+
bias_correction2 = 1 - beta2 ** state["step"]
|
957 |
+
|
958 |
+
# Decay the first and second moment running average coefficient
|
959 |
+
exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
|
960 |
+
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
|
961 |
+
denom = (exp_avg_sq.sqrt() * (bias_correction2 ** -0.5)).add_(
|
962 |
+
group["eps"]
|
963 |
+
)
|
964 |
+
|
965 |
+
step_size = group["lr"] / bias_correction1
|
966 |
+
target_rms = group["target_rms"]
|
967 |
+
weight_decay = group["weight_decay"]
|
968 |
+
|
969 |
+
if p.numel() > 1:
|
970 |
+
# avoid applying this weight-decay on "scaling factors"
|
971 |
+
# (which are scalar).
|
972 |
+
is_above_target_rms = p.norm() > (
|
973 |
+
target_rms * (p.numel() ** 0.5)
|
974 |
+
)
|
975 |
+
p.mul_(1 - (weight_decay * is_above_target_rms))
|
976 |
+
|
977 |
+
p.addcdiv_(exp_avg, denom, value=-step_size)
|
978 |
+
|
979 |
+
# if random.random() < 0.0005:
|
980 |
+
# step = (exp_avg / denom) * step_size
|
981 |
+
# logging.info(
|
982 |
+
# f"Delta rms = {(step**2).mean().item()}, shape = {step.shape}"
|
983 |
+
# )
|
984 |
+
|
985 |
+
return loss
|
986 |
+
|
987 |
+
|
988 |
+
def _test_scaled_adam(hidden_dim: int):
|
989 |
+
import timeit
|
990 |
+
|
991 |
+
from scaling import ScaledLinear
|
992 |
+
|
993 |
+
E = 100
|
994 |
+
B = 4
|
995 |
+
T = 2
|
996 |
+
logging.info("in test_eve_cain")
|
997 |
+
# device = torch.device('cuda')
|
998 |
+
device = torch.device("cpu")
|
999 |
+
dtype = torch.float32
|
1000 |
+
|
1001 |
+
fix_random_seed(42)
|
1002 |
+
# these input_magnitudes and output_magnitudes are to test that
|
1003 |
+
# Abel is working as we expect and is able to adjust scales of
|
1004 |
+
# different dims differently.
|
1005 |
+
input_magnitudes = (1.0 * torch.randn(E, dtype=dtype, device=device)).exp()
|
1006 |
+
output_magnitudes = (1.0 * torch.randn(E, dtype=dtype, device=device)).exp()
|
1007 |
+
|
1008 |
+
for iter in [1, 0]:
|
1009 |
+
fix_random_seed(42)
|
1010 |
+
Linear = torch.nn.Linear if iter == 0 else ScaledLinear
|
1011 |
+
|
1012 |
+
m = torch.nn.Sequential(
|
1013 |
+
Linear(E, hidden_dim),
|
1014 |
+
torch.nn.PReLU(),
|
1015 |
+
Linear(hidden_dim, hidden_dim),
|
1016 |
+
torch.nn.PReLU(),
|
1017 |
+
Linear(hidden_dim, E),
|
1018 |
+
).to(device)
|
1019 |
+
|
1020 |
+
train_pairs = [
|
1021 |
+
(
|
1022 |
+
100.0
|
1023 |
+
* torch.randn(B, T, E, device=device, dtype=dtype)
|
1024 |
+
* input_magnitudes,
|
1025 |
+
torch.randn(B, T, E, device=device, dtype=dtype)
|
1026 |
+
* output_magnitudes,
|
1027 |
+
)
|
1028 |
+
for _ in range(20)
|
1029 |
+
]
|
1030 |
+
|
1031 |
+
if iter == 0:
|
1032 |
+
optim = Eve(m.parameters(), lr=0.003)
|
1033 |
+
elif iter == 1:
|
1034 |
+
optim = ScaledAdam(m.parameters(), lr=0.03, clipping_scale=2.0)
|
1035 |
+
scheduler = Eden(optim, lr_batches=200, lr_epochs=5, verbose=False)
|
1036 |
+
|
1037 |
+
start = timeit.default_timer()
|
1038 |
+
avg_loss = 0.0
|
1039 |
+
for epoch in range(180):
|
1040 |
+
scheduler.step_epoch()
|
1041 |
+
# if epoch == 100 and iter in [2,3]:
|
1042 |
+
# optim.reset_speedup() # check it doesn't crash.
|
1043 |
+
|
1044 |
+
# if epoch == 130:
|
1045 |
+
# opts = diagnostics.TensorDiagnosticOptions(
|
1046 |
+
# 2 ** 22
|
1047 |
+
# ) # allow 4 megabytes per sub-module
|
1048 |
+
# diagnostic = diagnostics.attach_diagnostics(m, opts)
|
1049 |
+
|
1050 |
+
for n, (x, y) in enumerate(train_pairs):
|
1051 |
+
y_out = m(x)
|
1052 |
+
loss = ((y_out - y) ** 2).mean() * 100.0
|
1053 |
+
if epoch == 0 and n == 0:
|
1054 |
+
avg_loss = loss.item()
|
1055 |
+
else:
|
1056 |
+
avg_loss = 0.98 * avg_loss + 0.02 * loss.item()
|
1057 |
+
if n == 0 and epoch % 5 == 0:
|
1058 |
+
# norm1 = '%.2e' % (m[0].weight**2).mean().sqrt().item()
|
1059 |
+
# norm1b = '%.2e' % (m[0].bias**2).mean().sqrt().item()
|
1060 |
+
# norm2 = '%.2e' % (m[2].weight**2).mean().sqrt().item()
|
1061 |
+
# norm2b = '%.2e' % (m[2].bias**2).mean().sqrt().item()
|
1062 |
+
# scale1 = '%.2e' % (m[0].weight_scale.exp().item())
|
1063 |
+
# scale1b = '%.2e' % (m[0].bias_scale.exp().item())
|
1064 |
+
# scale2 = '%.2e' % (m[2].weight_scale.exp().item())
|
1065 |
+
# scale2b = '%.2e' % (m[2].bias_scale.exp().item())
|
1066 |
+
lr = scheduler.get_last_lr()[0]
|
1067 |
+
logging.info(
|
1068 |
+
f"Iter {iter}, epoch {epoch}, batch {n}, avg_loss {avg_loss:.4g}, lr={lr:.4e}"
|
1069 |
+
) # , norms={norm1,norm1b,norm2,norm2b}") # scales={scale1,scale1b,scale2,scale2b}
|
1070 |
+
loss.log().backward()
|
1071 |
+
optim.step()
|
1072 |
+
optim.zero_grad()
|
1073 |
+
scheduler.step_batch()
|
1074 |
+
|
1075 |
+
# diagnostic.print_diagnostics()
|
1076 |
+
|
1077 |
+
stop = timeit.default_timer()
|
1078 |
+
logging.info(f"Iter={iter}, Time taken: {stop - start}")
|
1079 |
+
|
1080 |
+
logging.info(f"last lr = {scheduler.get_last_lr()}")
|
1081 |
+
# logging.info("state dict = ", scheduler.state_dict())
|
1082 |
+
# logging.info("optim state_dict = ", optim.state_dict())
|
1083 |
+
logging.info(f"input_magnitudes = {input_magnitudes}")
|
1084 |
+
logging.info(f"output_magnitudes = {output_magnitudes}")
|
1085 |
+
|
1086 |
+
|
1087 |
+
if __name__ == "__main__":
|
1088 |
+
torch.set_num_threads(1)
|
1089 |
+
torch.set_num_interop_threads(1)
|
1090 |
+
logging.getLogger().setLevel(logging.INFO)
|
1091 |
+
import subprocess
|
1092 |
+
|
1093 |
+
s = subprocess.check_output(
|
1094 |
+
"git status -uno .; git log -1; git diff HEAD .", shell=True
|
1095 |
+
)
|
1096 |
+
logging.info(s)
|
1097 |
+
import sys
|
1098 |
+
|
1099 |
+
if len(sys.argv) > 1:
|
1100 |
+
hidden_dim = int(sys.argv[1])
|
1101 |
+
else:
|
1102 |
+
hidden_dim = 200
|
1103 |
+
|
1104 |
+
_test_scaled_adam(hidden_dim)
|
1105 |
+
_test_eden()
|
modules/scaling.py
ADDED
@@ -0,0 +1,1401 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2022 Xiaomi Corp. (authors: Daniel Povey)
|
2 |
+
#
|
3 |
+
# See ../../../../LICENSE for clarification regarding multiple authors
|
4 |
+
#
|
5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
6 |
+
# you may not use this file except in compliance with the License.
|
7 |
+
# You may obtain a copy of the License at
|
8 |
+
#
|
9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
10 |
+
#
|
11 |
+
# Unless required by applicable law or agreed to in writing, software
|
12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
14 |
+
# See the License for the specific language governing permissions and
|
15 |
+
# limitations under the License.
|
16 |
+
|
17 |
+
|
18 |
+
import collections
|
19 |
+
import logging
|
20 |
+
import random
|
21 |
+
import math
|
22 |
+
from functools import reduce
|
23 |
+
from itertools import repeat
|
24 |
+
from typing import Optional, Tuple, Union
|
25 |
+
|
26 |
+
import torch
|
27 |
+
import torch.nn as nn
|
28 |
+
import torch.nn.functional as F
|
29 |
+
from torch import Tensor
|
30 |
+
from torch.nn import Embedding as ScaledEmbedding
|
31 |
+
|
32 |
+
from utils import Transpose
|
33 |
+
|
34 |
+
|
35 |
+
class ActivationBalancerFunction(torch.autograd.Function):
|
36 |
+
@staticmethod
|
37 |
+
def forward(
|
38 |
+
ctx,
|
39 |
+
x: Tensor,
|
40 |
+
scale_factor: Tensor,
|
41 |
+
sign_factor: Optional[Tensor],
|
42 |
+
channel_dim: int,
|
43 |
+
) -> Tensor:
|
44 |
+
if channel_dim < 0:
|
45 |
+
channel_dim += x.ndim
|
46 |
+
ctx.channel_dim = channel_dim
|
47 |
+
xgt0 = x > 0
|
48 |
+
if sign_factor is None:
|
49 |
+
ctx.save_for_backward(xgt0, scale_factor)
|
50 |
+
else:
|
51 |
+
ctx.save_for_backward(xgt0, scale_factor, sign_factor)
|
52 |
+
return x
|
53 |
+
|
54 |
+
@staticmethod
|
55 |
+
def backward(ctx, x_grad: Tensor) -> Tuple[Tensor, None, None, None]:
|
56 |
+
if len(ctx.saved_tensors) == 3:
|
57 |
+
xgt0, scale_factor, sign_factor = ctx.saved_tensors
|
58 |
+
for _ in range(ctx.channel_dim, x_grad.ndim - 1):
|
59 |
+
scale_factor = scale_factor.unsqueeze(-1)
|
60 |
+
sign_factor = sign_factor.unsqueeze(-1)
|
61 |
+
factor = sign_factor + scale_factor * (xgt0.to(x_grad.dtype) - 0.5)
|
62 |
+
else:
|
63 |
+
xgt0, scale_factor = ctx.saved_tensors
|
64 |
+
for _ in range(ctx.channel_dim, x_grad.ndim - 1):
|
65 |
+
scale_factor = scale_factor.unsqueeze(-1)
|
66 |
+
factor = scale_factor * (xgt0.to(x_grad.dtype) - 0.5)
|
67 |
+
neg_delta_grad = x_grad.abs() * factor
|
68 |
+
return (
|
69 |
+
x_grad - neg_delta_grad,
|
70 |
+
None,
|
71 |
+
None,
|
72 |
+
None,
|
73 |
+
)
|
74 |
+
|
75 |
+
|
76 |
+
def _compute_scale_factor(
|
77 |
+
x: Tensor,
|
78 |
+
channel_dim: int,
|
79 |
+
min_abs: float,
|
80 |
+
max_abs: float,
|
81 |
+
gain_factor: float,
|
82 |
+
max_factor: float,
|
83 |
+
) -> Tensor:
|
84 |
+
if channel_dim < 0:
|
85 |
+
channel_dim += x.ndim
|
86 |
+
sum_dims = [d for d in range(x.ndim) if d != channel_dim]
|
87 |
+
x_abs_mean = torch.mean(x.abs(), dim=sum_dims).to(torch.float32)
|
88 |
+
|
89 |
+
if min_abs == 0.0:
|
90 |
+
below_threshold = 0.0
|
91 |
+
else:
|
92 |
+
# below_threshold is 0 if x_abs_mean > min_abs, can be at most max_factor if
|
93 |
+
# x_abs)_mean , min_abs.
|
94 |
+
below_threshold = (
|
95 |
+
(min_abs - x_abs_mean) * (gain_factor / min_abs)
|
96 |
+
).clamp(min=0, max=max_factor)
|
97 |
+
|
98 |
+
above_threshold = ((x_abs_mean - max_abs) * (gain_factor / max_abs)).clamp(
|
99 |
+
min=0, max=max_factor
|
100 |
+
)
|
101 |
+
|
102 |
+
return below_threshold - above_threshold
|
103 |
+
|
104 |
+
|
105 |
+
def _compute_sign_factor(
|
106 |
+
x: Tensor,
|
107 |
+
channel_dim: int,
|
108 |
+
min_positive: float,
|
109 |
+
max_positive: float,
|
110 |
+
gain_factor: float,
|
111 |
+
max_factor: float,
|
112 |
+
) -> Tensor:
|
113 |
+
if channel_dim < 0:
|
114 |
+
channel_dim += x.ndim
|
115 |
+
sum_dims = [d for d in range(x.ndim) if d != channel_dim]
|
116 |
+
proportion_positive = torch.mean((x > 0).to(torch.float32), dim=sum_dims)
|
117 |
+
if min_positive == 0.0:
|
118 |
+
factor1 = 0.0
|
119 |
+
else:
|
120 |
+
# 0 if proportion_positive >= min_positive, else can be
|
121 |
+
# as large as max_factor.
|
122 |
+
factor1 = (
|
123 |
+
(min_positive - proportion_positive) * (gain_factor / min_positive)
|
124 |
+
).clamp_(min=0, max=max_factor)
|
125 |
+
|
126 |
+
if max_positive == 1.0:
|
127 |
+
factor2 = 0.0
|
128 |
+
else:
|
129 |
+
# 0 if self.proportion_positive <= max_positive, else can be
|
130 |
+
# as large as -max_factor.
|
131 |
+
factor2 = (
|
132 |
+
(proportion_positive - max_positive)
|
133 |
+
* (gain_factor / (1.0 - max_positive))
|
134 |
+
).clamp_(min=0, max=max_factor)
|
135 |
+
sign_factor = factor1 - factor2
|
136 |
+
# require min_positive != 0 or max_positive != 1:
|
137 |
+
assert not isinstance(sign_factor, float)
|
138 |
+
return sign_factor
|
139 |
+
|
140 |
+
|
141 |
+
class ActivationScaleBalancerFunction(torch.autograd.Function):
|
142 |
+
"""
|
143 |
+
This object is used in class ActivationBalancer when the user specified
|
144 |
+
min_positive=0, max_positive=1, so there are no constraints on the signs
|
145 |
+
of the activations and only the absolute value has a constraint.
|
146 |
+
"""
|
147 |
+
|
148 |
+
@staticmethod
|
149 |
+
def forward(
|
150 |
+
ctx,
|
151 |
+
x: Tensor,
|
152 |
+
sign_factor: Tensor,
|
153 |
+
scale_factor: Tensor,
|
154 |
+
channel_dim: int,
|
155 |
+
) -> Tensor:
|
156 |
+
if channel_dim < 0:
|
157 |
+
channel_dim += x.ndim
|
158 |
+
ctx.channel_dim = channel_dim
|
159 |
+
xgt0 = x > 0
|
160 |
+
ctx.save_for_backward(xgt0, sign_factor, scale_factor)
|
161 |
+
return x
|
162 |
+
|
163 |
+
@staticmethod
|
164 |
+
def backward(ctx, x_grad: Tensor) -> Tuple[Tensor, None, None, None]:
|
165 |
+
xgt0, sign_factor, scale_factor = ctx.saved_tensors
|
166 |
+
for _ in range(ctx.channel_dim, x_grad.ndim - 1):
|
167 |
+
sign_factor = sign_factor.unsqueeze(-1)
|
168 |
+
scale_factor = scale_factor.unsqueeze(-1)
|
169 |
+
|
170 |
+
factor = sign_factor + scale_factor * (xgt0.to(x_grad.dtype) - 0.5)
|
171 |
+
neg_delta_grad = x_grad.abs() * factor
|
172 |
+
return (
|
173 |
+
x_grad - neg_delta_grad,
|
174 |
+
None,
|
175 |
+
None,
|
176 |
+
None,
|
177 |
+
)
|
178 |
+
|
179 |
+
|
180 |
+
class RandomClampFunction(torch.autograd.Function):
|
181 |
+
@staticmethod
|
182 |
+
def forward(
|
183 |
+
ctx,
|
184 |
+
x: Tensor,
|
185 |
+
min: Optional[float],
|
186 |
+
max: Optional[float],
|
187 |
+
prob: float,
|
188 |
+
reflect: float,
|
189 |
+
) -> Tensor:
|
190 |
+
x_clamped = torch.clamp(x, min=min, max=max)
|
191 |
+
mask = torch.rand_like(x) < prob
|
192 |
+
ans = torch.where(mask, x_clamped, x)
|
193 |
+
if x.requires_grad:
|
194 |
+
ctx.save_for_backward(ans == x)
|
195 |
+
ctx.reflect = reflect
|
196 |
+
if reflect != 0.0:
|
197 |
+
ans = ans * (1.0 + reflect) - (x * reflect)
|
198 |
+
return ans
|
199 |
+
|
200 |
+
@staticmethod
|
201 |
+
def backward(
|
202 |
+
ctx, ans_grad: Tensor
|
203 |
+
) -> Tuple[Tensor, None, None, None, None]:
|
204 |
+
(is_same,) = ctx.saved_tensors
|
205 |
+
x_grad = ans_grad * is_same.to(ans_grad.dtype)
|
206 |
+
reflect = ctx.reflect
|
207 |
+
if reflect != 0.0:
|
208 |
+
x_grad = x_grad * (1.0 + reflect) - (ans_grad * reflect)
|
209 |
+
return x_grad, None, None, None, None
|
210 |
+
|
211 |
+
|
212 |
+
def random_clamp(
|
213 |
+
x: Tensor,
|
214 |
+
min: Optional[float] = None,
|
215 |
+
max: Optional[float] = None,
|
216 |
+
prob: float = 0.5,
|
217 |
+
reflect: float = 0.0,
|
218 |
+
):
|
219 |
+
return RandomClampFunction.apply(x, min, max, prob, reflect)
|
220 |
+
|
221 |
+
|
222 |
+
def random_cast_to_half(x: Tensor, min_abs: float = 5.0e-06) -> Tensor:
|
223 |
+
"""
|
224 |
+
A randomized way of casting a floating point value to half precision.
|
225 |
+
"""
|
226 |
+
if x.dtype == torch.float16:
|
227 |
+
return x
|
228 |
+
x_abs = x.abs()
|
229 |
+
is_too_small = x_abs < min_abs
|
230 |
+
# for elements where is_too_small is true, random_val will contain +-min_abs with
|
231 |
+
# probability (x.abs() / min_abs), and 0.0 otherwise. [so this preserves expectations,
|
232 |
+
# for those elements].
|
233 |
+
random_val = min_abs * x.sign() * (torch.rand_like(x) * min_abs < x_abs)
|
234 |
+
return torch.where(is_too_small, random_val, x).to(torch.float16)
|
235 |
+
|
236 |
+
|
237 |
+
class RandomGradFunction(torch.autograd.Function):
|
238 |
+
"""
|
239 |
+
Does nothing in forward pass; in backward pass, gets rid of very small grads using
|
240 |
+
randomized approach that preserves expectations (intended to reduce roundoff).
|
241 |
+
"""
|
242 |
+
|
243 |
+
@staticmethod
|
244 |
+
def forward(ctx, x: Tensor, min_abs: float) -> Tensor:
|
245 |
+
ctx.min_abs = min_abs
|
246 |
+
return x
|
247 |
+
|
248 |
+
@staticmethod
|
249 |
+
def backward(ctx, ans_grad: Tensor) -> Tuple[Tensor, None]:
|
250 |
+
if ans_grad.dtype == torch.float16:
|
251 |
+
return (
|
252 |
+
random_cast_to_half(
|
253 |
+
ans_grad.to(torch.float32), min_abs=ctx.min_abs
|
254 |
+
),
|
255 |
+
None,
|
256 |
+
)
|
257 |
+
else:
|
258 |
+
return ans_grad, None
|
259 |
+
|
260 |
+
|
261 |
+
class RandomGrad(torch.nn.Module):
|
262 |
+
"""
|
263 |
+
Gets rid of very small gradients using an expectation-preserving method, intended to increase
|
264 |
+
accuracy of training when using amp (automatic mixed precision)
|
265 |
+
"""
|
266 |
+
|
267 |
+
def __init__(self, min_abs: float = 5.0e-06):
|
268 |
+
super(RandomGrad, self).__init__()
|
269 |
+
self.min_abs = min_abs
|
270 |
+
|
271 |
+
def forward(self, x: Tensor):
|
272 |
+
if (
|
273 |
+
torch.jit.is_scripting()
|
274 |
+
or not self.training
|
275 |
+
or torch.jit.is_tracing()
|
276 |
+
):
|
277 |
+
return x
|
278 |
+
else:
|
279 |
+
return RandomGradFunction.apply(x, self.min_abs)
|
280 |
+
|
281 |
+
|
282 |
+
class SoftmaxFunction(torch.autograd.Function):
|
283 |
+
"""
|
284 |
+
Tries to handle half-precision derivatives in a randomized way that should
|
285 |
+
be more accurate for training than the default behavior.
|
286 |
+
"""
|
287 |
+
|
288 |
+
@staticmethod
|
289 |
+
def forward(ctx, x: Tensor, dim: int):
|
290 |
+
ans = x.softmax(dim=dim)
|
291 |
+
# if x dtype is float16, x.softmax() returns a float32 because
|
292 |
+
# (presumably) that op does not support float16, and autocast
|
293 |
+
# is enabled.
|
294 |
+
if torch.is_autocast_enabled():
|
295 |
+
ans = ans.to(torch.float16)
|
296 |
+
ctx.save_for_backward(ans)
|
297 |
+
ctx.x_dtype = x.dtype
|
298 |
+
ctx.dim = dim
|
299 |
+
return ans
|
300 |
+
|
301 |
+
@staticmethod
|
302 |
+
def backward(ctx, ans_grad: Tensor):
|
303 |
+
(ans,) = ctx.saved_tensors
|
304 |
+
with torch.cuda.amp.autocast(enabled=False):
|
305 |
+
ans_grad = ans_grad.to(torch.float32)
|
306 |
+
ans = ans.to(torch.float32)
|
307 |
+
x_grad = ans_grad * ans
|
308 |
+
x_grad = x_grad - ans * x_grad.sum(dim=ctx.dim, keepdim=True)
|
309 |
+
return x_grad, None
|
310 |
+
|
311 |
+
|
312 |
+
def softmax(x: Tensor, dim: int):
|
313 |
+
if torch.jit.is_scripting() or torch.jit.is_tracing():
|
314 |
+
return x.softmax(dim)
|
315 |
+
|
316 |
+
return SoftmaxFunction.apply(x, dim)
|
317 |
+
|
318 |
+
|
319 |
+
class MaxEigLimiterFunction(torch.autograd.Function):
|
320 |
+
@staticmethod
|
321 |
+
def forward(
|
322 |
+
ctx,
|
323 |
+
x: Tensor,
|
324 |
+
coeffs: Tensor,
|
325 |
+
direction: Tensor,
|
326 |
+
channel_dim: int,
|
327 |
+
grad_scale: float,
|
328 |
+
) -> Tensor:
|
329 |
+
ctx.channel_dim = channel_dim
|
330 |
+
ctx.grad_scale = grad_scale
|
331 |
+
ctx.save_for_backward(x.detach(), coeffs.detach(), direction.detach())
|
332 |
+
return x
|
333 |
+
|
334 |
+
@staticmethod
|
335 |
+
def backward(ctx, x_grad, *args):
|
336 |
+
with torch.enable_grad():
|
337 |
+
(x_orig, coeffs, new_direction) = ctx.saved_tensors
|
338 |
+
x_orig.requires_grad = True
|
339 |
+
num_channels = x_orig.shape[ctx.channel_dim]
|
340 |
+
x = x_orig.transpose(ctx.channel_dim, -1).reshape(-1, num_channels)
|
341 |
+
new_direction.requires_grad = False
|
342 |
+
x = x - x.mean(dim=0)
|
343 |
+
x_var = (x ** 2).mean()
|
344 |
+
x_residual = x - coeffs * new_direction
|
345 |
+
x_residual_var = (x_residual ** 2).mean()
|
346 |
+
# `variance_proportion` is the proportion of the variance accounted for
|
347 |
+
# by the top eigen-direction. This is to be minimized.
|
348 |
+
variance_proportion = (x_var - x_residual_var) / (x_var + 1.0e-20)
|
349 |
+
variance_proportion.backward()
|
350 |
+
x_orig_grad = x_orig.grad
|
351 |
+
x_extra_grad = (
|
352 |
+
x_orig.grad
|
353 |
+
* ctx.grad_scale
|
354 |
+
* x_grad.norm()
|
355 |
+
/ (x_orig_grad.norm() + 1.0e-20)
|
356 |
+
)
|
357 |
+
return x_grad + x_extra_grad.detach(), None, None, None, None
|
358 |
+
|
359 |
+
|
360 |
+
class BasicNorm(torch.nn.Module):
|
361 |
+
"""
|
362 |
+
This is intended to be a simpler, and hopefully cheaper, replacement for
|
363 |
+
LayerNorm. The observation this is based on, is that Transformer-type
|
364 |
+
networks, especially with pre-norm, sometimes seem to set one of the
|
365 |
+
feature dimensions to a large constant value (e.g. 50), which "defeats"
|
366 |
+
the LayerNorm because the output magnitude is then not strongly dependent
|
367 |
+
on the other (useful) features. Presumably the weight and bias of the
|
368 |
+
LayerNorm are required to allow it to do this.
|
369 |
+
|
370 |
+
So the idea is to introduce this large constant value as an explicit
|
371 |
+
parameter, that takes the role of the "eps" in LayerNorm, so the network
|
372 |
+
doesn't have to do this trick. We make the "eps" learnable.
|
373 |
+
|
374 |
+
Args:
|
375 |
+
num_channels: the number of channels, e.g. 512.
|
376 |
+
channel_dim: the axis/dimension corresponding to the channel,
|
377 |
+
interprted as an offset from the input's ndim if negative.
|
378 |
+
shis is NOT the num_channels; it should typically be one of
|
379 |
+
{-2, -1, 0, 1, 2, 3}.
|
380 |
+
eps: the initial "epsilon" that we add as ballast in:
|
381 |
+
scale = ((input_vec**2).mean() + epsilon)**-0.5
|
382 |
+
Note: our epsilon is actually large, but we keep the name
|
383 |
+
to indicate the connection with conventional LayerNorm.
|
384 |
+
learn_eps: if true, we learn epsilon; if false, we keep it
|
385 |
+
at the initial value.
|
386 |
+
eps_min: float
|
387 |
+
eps_max: float
|
388 |
+
"""
|
389 |
+
|
390 |
+
def __init__(
|
391 |
+
self,
|
392 |
+
num_channels: int,
|
393 |
+
channel_dim: int = -1, # CAUTION: see documentation.
|
394 |
+
eps: float = 0.25,
|
395 |
+
learn_eps: bool = True,
|
396 |
+
eps_min: float = -3.0,
|
397 |
+
eps_max: float = 3.0,
|
398 |
+
) -> None:
|
399 |
+
super(BasicNorm, self).__init__()
|
400 |
+
self.num_channels = num_channels
|
401 |
+
self.channel_dim = channel_dim
|
402 |
+
if learn_eps:
|
403 |
+
self.eps = nn.Parameter(torch.tensor(eps).log().detach())
|
404 |
+
else:
|
405 |
+
self.register_buffer("eps", torch.tensor(eps).log().detach())
|
406 |
+
self.eps_min = eps_min
|
407 |
+
self.eps_max = eps_max
|
408 |
+
|
409 |
+
def forward(self, x: Tensor) -> Tensor:
|
410 |
+
assert x.shape[self.channel_dim] == self.num_channels
|
411 |
+
eps = self.eps
|
412 |
+
if self.training and random.random() < 0.25:
|
413 |
+
# with probability 0.25, in training mode, clamp eps between the min
|
414 |
+
# and max; this will encourage it to learn parameters within the
|
415 |
+
# allowed range by making parameters that are outside the allowed
|
416 |
+
# range noisy.
|
417 |
+
|
418 |
+
# gradients to allow the parameter to get back into the allowed
|
419 |
+
# region if it happens to exit it.
|
420 |
+
eps = eps.clamp(min=self.eps_min, max=self.eps_max)
|
421 |
+
scales = (
|
422 |
+
torch.mean(x ** 2, dim=self.channel_dim, keepdim=True) + eps.exp()
|
423 |
+
) ** -0.5
|
424 |
+
return x * scales
|
425 |
+
|
426 |
+
|
427 |
+
def ScaledLinear(*args, initial_scale: float = 1.0, **kwargs) -> nn.Linear:
|
428 |
+
"""
|
429 |
+
Behaves like a constructor of a modified version of nn.Linear
|
430 |
+
that gives an easy way to set the default initial parameter scale.
|
431 |
+
|
432 |
+
Args:
|
433 |
+
Accepts the standard args and kwargs that nn.Linear accepts
|
434 |
+
e.g. in_features, out_features, bias=False.
|
435 |
+
|
436 |
+
initial_scale: you can override this if you want to increase
|
437 |
+
or decrease the initial magnitude of the module's output
|
438 |
+
(affects the initialization of weight_scale and bias_scale).
|
439 |
+
Another option, if you want to do something like this, is
|
440 |
+
to re-initialize the parameters.
|
441 |
+
"""
|
442 |
+
ans = nn.Linear(*args, **kwargs)
|
443 |
+
with torch.no_grad():
|
444 |
+
ans.weight[:] *= initial_scale
|
445 |
+
if ans.bias is not None:
|
446 |
+
torch.nn.init.uniform_(
|
447 |
+
ans.bias, -0.1 * initial_scale, 0.1 * initial_scale
|
448 |
+
)
|
449 |
+
return ans
|
450 |
+
|
451 |
+
|
452 |
+
def ScaledConv1d(
|
453 |
+
*args,
|
454 |
+
initial_scale: float = 1.0,
|
455 |
+
kernel_size: int = 3,
|
456 |
+
padding: str = "same",
|
457 |
+
**kwargs,
|
458 |
+
) -> nn.Conv1d:
|
459 |
+
"""
|
460 |
+
Behaves like a constructor of a modified version of nn.Conv1d
|
461 |
+
that gives an easy way to set the default initial parameter scale.
|
462 |
+
|
463 |
+
Args:
|
464 |
+
Accepts the standard args and kwargs that nn.Linear accepts
|
465 |
+
e.g. in_features, out_features, bias=False.
|
466 |
+
|
467 |
+
initial_scale: you can override this if you want to increase
|
468 |
+
or decrease the initial magnitude of the module's output
|
469 |
+
(affects the initialization of weight_scale and bias_scale).
|
470 |
+
Another option, if you want to do something like this, is
|
471 |
+
to re-initialize the parameters.
|
472 |
+
"""
|
473 |
+
ans = nn.Conv1d(*args, kernel_size=kernel_size, padding=padding, **kwargs)
|
474 |
+
with torch.no_grad():
|
475 |
+
ans.weight[:] *= initial_scale
|
476 |
+
if ans.bias is not None:
|
477 |
+
torch.nn.init.uniform_(
|
478 |
+
ans.bias, -0.1 * initial_scale, 0.1 * initial_scale
|
479 |
+
)
|
480 |
+
return ans
|
481 |
+
|
482 |
+
|
483 |
+
def TransposeScaledConv1d(
|
484 |
+
*args,
|
485 |
+
initial_scale: float = 1.0,
|
486 |
+
kernel_size: int = 3,
|
487 |
+
padding: str = "same",
|
488 |
+
**kwargs,
|
489 |
+
) -> nn.Sequential:
|
490 |
+
"""
|
491 |
+
Transpose -> ScaledConv1d
|
492 |
+
"""
|
493 |
+
return nn.Sequential(
|
494 |
+
Transpose(),
|
495 |
+
ScaledConv1d(
|
496 |
+
*args,
|
497 |
+
initial_scale=initial_scale,
|
498 |
+
kernel_size=kernel_size,
|
499 |
+
padding=padding,
|
500 |
+
**kwargs,
|
501 |
+
),
|
502 |
+
)
|
503 |
+
|
504 |
+
|
505 |
+
def ScaledConv1dTranspose(
|
506 |
+
*args,
|
507 |
+
initial_scale: float = 1.0,
|
508 |
+
kernel_size: int = 3,
|
509 |
+
padding: str = "same",
|
510 |
+
**kwargs,
|
511 |
+
) -> nn.Sequential:
|
512 |
+
"""
|
513 |
+
Transpose -> ScaledConv1d
|
514 |
+
"""
|
515 |
+
return nn.Sequential(
|
516 |
+
ScaledConv1d(
|
517 |
+
*args,
|
518 |
+
initial_scale=initial_scale,
|
519 |
+
kernel_size=kernel_size,
|
520 |
+
padding=padding,
|
521 |
+
**kwargs,
|
522 |
+
),
|
523 |
+
Transpose(),
|
524 |
+
)
|
525 |
+
|
526 |
+
|
527 |
+
def TransposeConv1d(
|
528 |
+
*args, kernel_size: int = 3, padding: str = "same", **kwargs
|
529 |
+
) -> nn.Sequential:
|
530 |
+
"""
|
531 |
+
Transpose -> Conv1d
|
532 |
+
"""
|
533 |
+
return nn.Sequential(
|
534 |
+
Transpose(),
|
535 |
+
nn.Conv1d(*args, kernel_size=kernel_size, padding=padding, **kwargs),
|
536 |
+
)
|
537 |
+
|
538 |
+
|
539 |
+
def Conv1dTranspose(
|
540 |
+
*args, kernel_size: int = 3, padding: str = "same", **kwargs
|
541 |
+
) -> nn.Sequential:
|
542 |
+
"""
|
543 |
+
ScaledConv1d -> Transpose
|
544 |
+
"""
|
545 |
+
return nn.Sequential(
|
546 |
+
nn.Conv1d(*args, kernel_size=kernel_size, padding=padding, **kwargs),
|
547 |
+
Transpose(),
|
548 |
+
)
|
549 |
+
|
550 |
+
|
551 |
+
class SRLinear(nn.Linear):
|
552 |
+
"""https://arxiv.org/abs/2303.06296
|
553 |
+
Stabilizing Transformer Training by Preventing Attention Entropy Collapse
|
554 |
+
"""
|
555 |
+
|
556 |
+
def __init__(self, in_features, out_features, bias=True, **kwargs):
|
557 |
+
super().__init__(in_features, out_features, bias=bias, **kwargs)
|
558 |
+
self.register_buffer(
|
559 |
+
"u", nn.functional.normalize(torch.randn(in_features), dim=0)
|
560 |
+
)
|
561 |
+
with torch.no_grad():
|
562 |
+
sigma = self.get_sigma()
|
563 |
+
self.register_buffer("spectral_norm", sigma)
|
564 |
+
self.sigma = nn.Parameter(torch.ones(1))
|
565 |
+
|
566 |
+
def get_sigma(self):
|
567 |
+
with torch.no_grad():
|
568 |
+
u = self.u
|
569 |
+
v = self.weight.mv(u)
|
570 |
+
v = nn.functional.normalize(v, dim=0)
|
571 |
+
u = self.weight.T.mv(v)
|
572 |
+
u = nn.functional.normalize(u, dim=0)
|
573 |
+
self.u.data.copy_(u)
|
574 |
+
return torch.einsum("c,cd,d->", v, self.weight, u)
|
575 |
+
|
576 |
+
def get_weight(self):
|
577 |
+
sigma = self.get_sigma()
|
578 |
+
if self.training:
|
579 |
+
self.spectral_norm.data.copy_(sigma)
|
580 |
+
weight = (self.sigma / sigma) * self.weight
|
581 |
+
return weight
|
582 |
+
|
583 |
+
def forward(self, x):
|
584 |
+
return nn.functional.linear(x, self.get_weight(), self.bias)
|
585 |
+
|
586 |
+
|
587 |
+
class SRConv1d(SRLinear):
|
588 |
+
def __init__(
|
589 |
+
self,
|
590 |
+
in_features,
|
591 |
+
out_features,
|
592 |
+
kernel_size,
|
593 |
+
stride: int = 1,
|
594 |
+
padding: str = "same",
|
595 |
+
bias: bool = True,
|
596 |
+
**kwargs,
|
597 |
+
):
|
598 |
+
in_features = in_features * kernel_size
|
599 |
+
super().__init__(in_features, out_features, bias=bias, **kwargs)
|
600 |
+
nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
|
601 |
+
self.kernel_size = kernel_size
|
602 |
+
self.stride = stride
|
603 |
+
self.padding = padding
|
604 |
+
|
605 |
+
def forward(self, x):
|
606 |
+
in_features = self.in_features // self.kernel_size
|
607 |
+
weight = self.get_weight().view(
|
608 |
+
self.out_features, in_features, self.kernel_size
|
609 |
+
)
|
610 |
+
return nn.functional.conv1d(
|
611 |
+
x, weight, bias=self.bias, stride=self.stride, padding=self.padding
|
612 |
+
)
|
613 |
+
|
614 |
+
|
615 |
+
def TransposeSRConv1d(
|
616 |
+
*args, kernel_size: int = 3, padding: str = "same", **kwargs
|
617 |
+
) -> nn.Sequential:
|
618 |
+
"""
|
619 |
+
Transpose -> SRConv1d
|
620 |
+
"""
|
621 |
+
return nn.Sequential(
|
622 |
+
Transpose(),
|
623 |
+
SRConv1d(*args, kernel_size=kernel_size, padding=padding, **kwargs),
|
624 |
+
)
|
625 |
+
|
626 |
+
|
627 |
+
def SRConv1dTranspose(
|
628 |
+
*args, kernel_size: int = 3, padding: str = "same", **kwargs
|
629 |
+
) -> nn.Sequential:
|
630 |
+
"""
|
631 |
+
SRConv1d -> Transpose
|
632 |
+
"""
|
633 |
+
return nn.Sequential(
|
634 |
+
SRConv1d(*args, kernel_size=kernel_size, padding=padding, **kwargs),
|
635 |
+
Transpose(),
|
636 |
+
)
|
637 |
+
|
638 |
+
|
639 |
+
class ActivationBalancer(torch.nn.Module):
|
640 |
+
"""
|
641 |
+
Modifies the backpropped derivatives of a function to try to encourage, for
|
642 |
+
each channel, that it is positive at least a proportion `threshold` of the
|
643 |
+
time. It does this by multiplying negative derivative values by up to
|
644 |
+
(1+max_factor), and positive derivative values by up to (1-max_factor),
|
645 |
+
interpolated from 1 at the threshold to those extremal values when none
|
646 |
+
of the inputs are positive.
|
647 |
+
|
648 |
+
Args:
|
649 |
+
num_channels: the number of channels
|
650 |
+
channel_dim: the dimension/axis corresponding to the channel, e.g.
|
651 |
+
-1, 0, 1, 2; will be interpreted as an offset from x.ndim if negative.
|
652 |
+
min_positive: the minimum, per channel, of the proportion of the time
|
653 |
+
that (x > 0), below which we start to modify the derivatives.
|
654 |
+
max_positive: the maximum, per channel, of the proportion of the time
|
655 |
+
that (x > 0), above which we start to modify the derivatives.
|
656 |
+
max_factor: the maximum factor by which we modify the derivatives for
|
657 |
+
either the sign constraint or the magnitude constraint;
|
658 |
+
e.g. with max_factor=0.02, the the derivatives would be multiplied by
|
659 |
+
values in the range [0.98..1.02].
|
660 |
+
sign_gain_factor: determines the 'gain' with which we increase the
|
661 |
+
change in gradient once the constraints on min_positive and max_positive
|
662 |
+
are violated.
|
663 |
+
scale_gain_factor: determines the 'gain' with which we increase the
|
664 |
+
change in gradient once the constraints on min_abs and max_abs
|
665 |
+
are violated.
|
666 |
+
min_abs: the minimum average-absolute-value difference from the mean
|
667 |
+
value per channel, which we allow, before we start to modify
|
668 |
+
the derivatives to prevent this.
|
669 |
+
max_abs: the maximum average-absolute-value difference from the mean
|
670 |
+
value per channel, which we allow, before we start to modify
|
671 |
+
the derivatives to prevent this.
|
672 |
+
min_prob: determines the minimum probability with which we modify the
|
673 |
+
gradients for the {min,max}_positive and {min,max}_abs constraints,
|
674 |
+
on each forward(). This is done randomly to prevent all layers
|
675 |
+
from doing it at the same time. Early in training we may use
|
676 |
+
higher probabilities than this; it will decay to this value.
|
677 |
+
"""
|
678 |
+
|
679 |
+
def __init__(
|
680 |
+
self,
|
681 |
+
num_channels: int,
|
682 |
+
channel_dim: int,
|
683 |
+
min_positive: float = 0.05,
|
684 |
+
max_positive: float = 0.95,
|
685 |
+
max_factor: float = 0.04,
|
686 |
+
sign_gain_factor: float = 0.01,
|
687 |
+
scale_gain_factor: float = 0.02,
|
688 |
+
min_abs: float = 0.2,
|
689 |
+
max_abs: float = 100.0,
|
690 |
+
min_prob: float = 0.1,
|
691 |
+
):
|
692 |
+
super(ActivationBalancer, self).__init__()
|
693 |
+
self.num_channels = num_channels
|
694 |
+
self.channel_dim = channel_dim
|
695 |
+
self.min_positive = min_positive
|
696 |
+
self.max_positive = max_positive
|
697 |
+
self.max_factor = max_factor
|
698 |
+
self.min_abs = min_abs
|
699 |
+
self.max_abs = max_abs
|
700 |
+
self.min_prob = min_prob
|
701 |
+
self.sign_gain_factor = sign_gain_factor
|
702 |
+
self.scale_gain_factor = scale_gain_factor
|
703 |
+
|
704 |
+
# count measures how many times the forward() function has been called.
|
705 |
+
# We occasionally sync this to a tensor called `count`, that exists to
|
706 |
+
# make sure it is synced to disk when we load and save the model.
|
707 |
+
self.cpu_count = 0
|
708 |
+
self.register_buffer("count", torch.tensor(0, dtype=torch.int64))
|
709 |
+
|
710 |
+
def forward(self, x: Tensor) -> Tensor:
|
711 |
+
if (
|
712 |
+
torch.jit.is_scripting()
|
713 |
+
or not x.requires_grad
|
714 |
+
or torch.jit.is_tracing()
|
715 |
+
):
|
716 |
+
return _no_op(x)
|
717 |
+
|
718 |
+
count = self.cpu_count
|
719 |
+
self.cpu_count += 1
|
720 |
+
|
721 |
+
if random.random() < 0.01:
|
722 |
+
# Occasionally sync self.cpu_count with self.count.
|
723 |
+
# count affects the decay of 'prob'. don't do this on every iter,
|
724 |
+
# because syncing with the GPU is slow.
|
725 |
+
self.cpu_count = max(self.cpu_count, self.count.item())
|
726 |
+
self.count.fill_(self.cpu_count)
|
727 |
+
|
728 |
+
# the prob of doing some work exponentially decreases from 0.5 till it hits
|
729 |
+
# a floor at min_prob (==0.1, by default)
|
730 |
+
prob = max(self.min_prob, 0.5 ** (1 + (count / 4000.0)))
|
731 |
+
|
732 |
+
if random.random() < prob:
|
733 |
+
sign_gain_factor = 0.5
|
734 |
+
if self.min_positive != 0.0 or self.max_positive != 1.0:
|
735 |
+
sign_factor = _compute_sign_factor(
|
736 |
+
x,
|
737 |
+
self.channel_dim,
|
738 |
+
self.min_positive,
|
739 |
+
self.max_positive,
|
740 |
+
gain_factor=self.sign_gain_factor / prob,
|
741 |
+
max_factor=self.max_factor,
|
742 |
+
)
|
743 |
+
else:
|
744 |
+
sign_factor = None
|
745 |
+
|
746 |
+
scale_factor = _compute_scale_factor(
|
747 |
+
x.detach(),
|
748 |
+
self.channel_dim,
|
749 |
+
min_abs=self.min_abs,
|
750 |
+
max_abs=self.max_abs,
|
751 |
+
gain_factor=self.scale_gain_factor / prob,
|
752 |
+
max_factor=self.max_factor,
|
753 |
+
)
|
754 |
+
return ActivationBalancerFunction.apply(
|
755 |
+
x,
|
756 |
+
scale_factor,
|
757 |
+
sign_factor,
|
758 |
+
self.channel_dim,
|
759 |
+
)
|
760 |
+
else:
|
761 |
+
return _no_op(x)
|
762 |
+
|
763 |
+
|
764 |
+
def penalize_abs_values_gt(x: Tensor, limit: float, penalty: float) -> Tensor:
|
765 |
+
"""
|
766 |
+
Returns x unmodified, but in backprop will put a penalty for the excess of
|
767 |
+
the absolute values of elements of x over the limit "limit". E.g. if
|
768 |
+
limit == 10.0, then if x has any values over 10 it will get a penalty.
|
769 |
+
|
770 |
+
Caution: the value of this penalty will be affected by grad scaling used
|
771 |
+
in automatic mixed precision training. For this reasons we use this,
|
772 |
+
it shouldn't really matter, or may even be helpful; we just use this
|
773 |
+
to disallow really implausible values of scores to be given to softmax.
|
774 |
+
"""
|
775 |
+
x_sign = x.sign()
|
776 |
+
over_limit = (x.abs() - limit) > 0
|
777 |
+
# The following is a memory efficient way to penalize the absolute values of
|
778 |
+
# x that's over the limit. (The memory efficiency comes when you think
|
779 |
+
# about which items torch needs to cache for the autograd, and which ones it
|
780 |
+
# can throw away). The numerical value of aux_loss as computed here will
|
781 |
+
# actually be larger than it should be, by limit * over_limit.sum(), but it
|
782 |
+
# has the same derivative as the real aux_loss which is penalty * (x.abs() -
|
783 |
+
# limit).relu().
|
784 |
+
aux_loss = penalty * ((x_sign * over_limit).to(torch.int8) * x)
|
785 |
+
# note: we don't do sum() here on aux)_loss, but it's as if we had done
|
786 |
+
# sum() due to how with_loss() works.
|
787 |
+
x = with_loss(x, aux_loss)
|
788 |
+
# you must use x for something, or this will be ineffective.
|
789 |
+
return x
|
790 |
+
|
791 |
+
|
792 |
+
def _diag(x: Tensor): # like .diag(), but works for tensors with 3 dims.
|
793 |
+
if x.ndim == 2:
|
794 |
+
return x.diag()
|
795 |
+
else:
|
796 |
+
(batch, dim, dim) = x.shape
|
797 |
+
x = x.reshape(batch, dim * dim)
|
798 |
+
x = x[:, :: dim + 1]
|
799 |
+
assert x.shape == (batch, dim)
|
800 |
+
return x
|
801 |
+
|
802 |
+
|
803 |
+
def _whitening_metric(x: Tensor, num_groups: int):
|
804 |
+
"""
|
805 |
+
Computes the "whitening metric", a value which will be 1.0 if all the eigenvalues of
|
806 |
+
of the centered feature covariance are the same within each group's covariance matrix
|
807 |
+
and also between groups.
|
808 |
+
Args:
|
809 |
+
x: a Tensor of shape (*, num_channels)
|
810 |
+
num_groups: the number of groups of channels, a number >=1 that divides num_channels
|
811 |
+
Returns:
|
812 |
+
Returns a scalar Tensor that will be 1.0 if the data is "perfectly white" and
|
813 |
+
greater than 1.0 otherwise.
|
814 |
+
"""
|
815 |
+
assert x.dtype != torch.float16
|
816 |
+
x = x.reshape(-1, x.shape[-1])
|
817 |
+
(num_frames, num_channels) = x.shape
|
818 |
+
assert num_channels % num_groups == 0
|
819 |
+
channels_per_group = num_channels // num_groups
|
820 |
+
x = x.reshape(num_frames, num_groups, channels_per_group).transpose(0, 1)
|
821 |
+
# x now has shape (num_groups, num_frames, channels_per_group)
|
822 |
+
# subtract the mean so we use the centered, not uncentered, covariance.
|
823 |
+
# My experience has been that when we "mess with the gradients" like this,
|
824 |
+
# it's better not do anything that tries to move the mean around, because
|
825 |
+
# that can easily cause instability.
|
826 |
+
x = x - x.mean(dim=1, keepdim=True)
|
827 |
+
# x_covar: (num_groups, channels_per_group, channels_per_group)
|
828 |
+
x_covar = torch.matmul(x.transpose(1, 2), x)
|
829 |
+
x_covar_mean_diag = _diag(x_covar).mean()
|
830 |
+
# the following expression is what we'd get if we took the matrix product
|
831 |
+
# of each covariance and measured the mean of its trace, i.e.
|
832 |
+
# the same as _diag(torch.matmul(x_covar, x_covar)).mean().
|
833 |
+
x_covarsq_mean_diag = (x_covar ** 2).sum() / (
|
834 |
+
num_groups * channels_per_group
|
835 |
+
)
|
836 |
+
# this metric will be >= 1.0; the larger it is, the less 'white' the data was.
|
837 |
+
metric = x_covarsq_mean_diag / (x_covar_mean_diag ** 2 + 1.0e-20)
|
838 |
+
return metric
|
839 |
+
|
840 |
+
|
841 |
+
class WhiteningPenaltyFunction(torch.autograd.Function):
|
842 |
+
@staticmethod
|
843 |
+
def forward(
|
844 |
+
ctx,
|
845 |
+
x: Tensor,
|
846 |
+
num_groups: int,
|
847 |
+
whitening_limit: float,
|
848 |
+
grad_scale: float,
|
849 |
+
) -> Tensor:
|
850 |
+
ctx.save_for_backward(x)
|
851 |
+
ctx.num_groups = num_groups
|
852 |
+
ctx.whitening_limit = whitening_limit
|
853 |
+
ctx.grad_scale = grad_scale
|
854 |
+
return x
|
855 |
+
|
856 |
+
@staticmethod
|
857 |
+
def backward(ctx, x_grad: Tensor):
|
858 |
+
(x_orig,) = ctx.saved_tensors
|
859 |
+
with torch.enable_grad():
|
860 |
+
with torch.cuda.amp.autocast(enabled=False):
|
861 |
+
x_detached = x_orig.to(torch.float32).detach()
|
862 |
+
x_detached.requires_grad = True
|
863 |
+
|
864 |
+
metric = _whitening_metric(x_detached, ctx.num_groups)
|
865 |
+
|
866 |
+
if random.random() < 0.005 or __name__ == "__main__":
|
867 |
+
logging.info(
|
868 |
+
f"Whitening: num_groups={ctx.num_groups}, num_channels={x_orig.shape[-1]}, "
|
869 |
+
f"metric={metric.item():.2f} vs. limit={ctx.whitening_limit}"
|
870 |
+
)
|
871 |
+
|
872 |
+
(metric - ctx.whitening_limit).relu().backward()
|
873 |
+
penalty_grad = x_detached.grad
|
874 |
+
scale = ctx.grad_scale * (
|
875 |
+
x_grad.to(torch.float32).norm()
|
876 |
+
/ (penalty_grad.norm() + 1.0e-20)
|
877 |
+
)
|
878 |
+
penalty_grad = penalty_grad * scale
|
879 |
+
return x_grad + penalty_grad.to(x_grad.dtype), None, None, None
|
880 |
+
|
881 |
+
|
882 |
+
class Whiten(nn.Module):
|
883 |
+
def __init__(
|
884 |
+
self,
|
885 |
+
num_groups: int,
|
886 |
+
whitening_limit: float,
|
887 |
+
prob: Union[float, Tuple[float, float]],
|
888 |
+
grad_scale: float,
|
889 |
+
):
|
890 |
+
"""
|
891 |
+
Args:
|
892 |
+
num_groups: the number of groups to divide the channel dim into before
|
893 |
+
whitening. We will attempt to make the feature covariance
|
894 |
+
within each group, after mean subtraction, as "white" as possible,
|
895 |
+
while having the same trace across all groups.
|
896 |
+
whitening_limit: a value greater than 1.0, that dictates how much
|
897 |
+
freedom we have to violate the constraints. 1.0 would mean perfectly
|
898 |
+
white, with exactly the same trace across groups; larger values
|
899 |
+
give more freedom. E.g. 2.0.
|
900 |
+
prob: the probability with which we apply the gradient modification
|
901 |
+
(also affects the grad scale). May be supplied as a float,
|
902 |
+
or as a pair (min_prob, max_prob)
|
903 |
+
|
904 |
+
grad_scale: determines the scale on the gradient term from this object,
|
905 |
+
relative to the rest of the gradient on the attention weights.
|
906 |
+
E.g. 0.02 (you may want to use smaller values than this if prob is large)
|
907 |
+
"""
|
908 |
+
super(Whiten, self).__init__()
|
909 |
+
assert num_groups >= 1
|
910 |
+
assert whitening_limit >= 1
|
911 |
+
assert grad_scale >= 0
|
912 |
+
self.num_groups = num_groups
|
913 |
+
self.whitening_limit = whitening_limit
|
914 |
+
if isinstance(prob, float):
|
915 |
+
assert 0 < prob <= 1
|
916 |
+
self.prob = prob
|
917 |
+
else:
|
918 |
+
(self.min_prob, self.max_prob) = prob
|
919 |
+
assert 0 < self.min_prob < self.max_prob <= 1
|
920 |
+
self.prob = self.max_prob
|
921 |
+
|
922 |
+
self.grad_scale = grad_scale
|
923 |
+
|
924 |
+
def forward(self, x: Tensor) -> Tensor:
|
925 |
+
"""
|
926 |
+
In the forward pass, this function just returns the input unmodified.
|
927 |
+
In the backward pass, it will modify the gradients to ensure that the
|
928 |
+
distribution in each group has close to (lambda times I) as the covariance
|
929 |
+
after mean subtraction, with the same lambda across groups.
|
930 |
+
For whitening_limit > 1, there will be more freedom to violate this
|
931 |
+
constraint.
|
932 |
+
|
933 |
+
Args:
|
934 |
+
x: the input of shape (*, num_channels)
|
935 |
+
|
936 |
+
Returns:
|
937 |
+
x, unmodified. You should make sure
|
938 |
+
you use the returned value, or the graph will be freed
|
939 |
+
and nothing will happen in backprop.
|
940 |
+
"""
|
941 |
+
if (
|
942 |
+
not x.requires_grad
|
943 |
+
or random.random() > self.prob
|
944 |
+
or self.grad_scale == 0
|
945 |
+
):
|
946 |
+
return _no_op(x)
|
947 |
+
else:
|
948 |
+
if hasattr(self, "min_prob") and random.random() < 0.25:
|
949 |
+
# occasionally switch between min_prob and max_prob, based on whether
|
950 |
+
# we are above or below the threshold.
|
951 |
+
if (
|
952 |
+
_whitening_metric(x.to(torch.float32), self.num_groups)
|
953 |
+
> self.whitening_limit
|
954 |
+
):
|
955 |
+
# there would be a change to the grad.
|
956 |
+
self.prob = self.max_prob
|
957 |
+
else:
|
958 |
+
self.prob = self.min_prob
|
959 |
+
|
960 |
+
return WhiteningPenaltyFunction.apply(
|
961 |
+
x, self.num_groups, self.whitening_limit, self.grad_scale
|
962 |
+
)
|
963 |
+
|
964 |
+
|
965 |
+
class WithLoss(torch.autograd.Function):
|
966 |
+
@staticmethod
|
967 |
+
def forward(ctx, x: Tensor, y: Tensor):
|
968 |
+
ctx.y_shape = y.shape
|
969 |
+
return x
|
970 |
+
|
971 |
+
@staticmethod
|
972 |
+
def backward(ctx, ans_grad: Tensor):
|
973 |
+
return ans_grad, torch.ones(
|
974 |
+
ctx.y_shape, dtype=ans_grad.dtype, device=ans_grad.device
|
975 |
+
)
|
976 |
+
|
977 |
+
|
978 |
+
def with_loss(x, y):
|
979 |
+
if torch.jit.is_scripting() or torch.jit.is_tracing():
|
980 |
+
return x
|
981 |
+
# returns x but adds y.sum() to the loss function.
|
982 |
+
return WithLoss.apply(x, y)
|
983 |
+
|
984 |
+
|
985 |
+
def _no_op(x: Tensor) -> Tensor:
|
986 |
+
if torch.jit.is_scripting() or torch.jit.is_tracing():
|
987 |
+
return x
|
988 |
+
else:
|
989 |
+
# a no-op function that will have a node in the autograd graph,
|
990 |
+
# to avoid certain bugs relating to backward hooks
|
991 |
+
return x.chunk(1, dim=-1)[0]
|
992 |
+
|
993 |
+
|
994 |
+
class Identity(torch.nn.Module):
|
995 |
+
def __init__(self):
|
996 |
+
super(Identity, self).__init__()
|
997 |
+
|
998 |
+
def forward(self, x):
|
999 |
+
return _no_op(x)
|
1000 |
+
|
1001 |
+
|
1002 |
+
class MaxEig(torch.nn.Module):
|
1003 |
+
"""
|
1004 |
+
Modifies the backpropped derivatives of a function to try to discourage
|
1005 |
+
that any given direction in activation space accounts for more than
|
1006 |
+
a specified proportion of the covariance (e.g. 0.2).
|
1007 |
+
|
1008 |
+
|
1009 |
+
Args:
|
1010 |
+
num_channels: the number of channels
|
1011 |
+
channel_dim: the dimension/axis corresponding to the channel, e.g.
|
1012 |
+
-1, 0, 1, 2; will be interpreted as an offset from x.ndim if negative.
|
1013 |
+
max_var_per_eig: the maximum proportion of the variance of the
|
1014 |
+
features/channels, after mean subtraction, that can come from
|
1015 |
+
any given eigenvalue.
|
1016 |
+
min_prob: the minimum probability with which we apply this during any invocation
|
1017 |
+
of forward(), assuming last time we applied the constraint it was
|
1018 |
+
not active; supplied for speed.
|
1019 |
+
scale: determines the scale with which we modify the gradients, relative
|
1020 |
+
to the existing / unmodified gradients
|
1021 |
+
"""
|
1022 |
+
|
1023 |
+
def __init__(
|
1024 |
+
self,
|
1025 |
+
num_channels: int,
|
1026 |
+
channel_dim: int,
|
1027 |
+
max_var_per_eig: float = 0.2,
|
1028 |
+
min_prob: float = 0.01,
|
1029 |
+
scale: float = 0.01,
|
1030 |
+
):
|
1031 |
+
super(MaxEig, self).__init__()
|
1032 |
+
self.num_channels = num_channels
|
1033 |
+
self.channel_dim = channel_dim
|
1034 |
+
self.scale = scale
|
1035 |
+
assert max_var_per_eig == 0.0 or max_var_per_eig > 1.0 / num_channels
|
1036 |
+
self.max_var_per_eig = max_var_per_eig
|
1037 |
+
|
1038 |
+
# we figure out the dominant direction using the power method: starting with
|
1039 |
+
# a random vector, keep multiplying by the covariance and renormalizing.
|
1040 |
+
with torch.no_grad():
|
1041 |
+
# arbitrary.. would use randn() but want to leave the rest of the model's
|
1042 |
+
# random parameters unchanged for comparison
|
1043 |
+
direction = torch.arange(num_channels).to(torch.float)
|
1044 |
+
direction = direction / direction.norm()
|
1045 |
+
self.register_buffer("max_eig_direction", direction)
|
1046 |
+
|
1047 |
+
self.min_prob = min_prob
|
1048 |
+
# cur_prob is the current probability we'll use to apply the ActivationBalancer.
|
1049 |
+
# We'll regress this towards prob, each time we try to apply it and it is not
|
1050 |
+
# active.
|
1051 |
+
self.cur_prob = 1.0
|
1052 |
+
|
1053 |
+
def forward(self, x: Tensor) -> Tensor:
|
1054 |
+
if (
|
1055 |
+
torch.jit.is_scripting()
|
1056 |
+
or self.max_var_per_eig <= 0
|
1057 |
+
or random.random() > self.cur_prob
|
1058 |
+
or torch.jit.is_tracing()
|
1059 |
+
):
|
1060 |
+
return _no_op(x)
|
1061 |
+
|
1062 |
+
with torch.cuda.amp.autocast(enabled=False):
|
1063 |
+
eps = 1.0e-20
|
1064 |
+
orig_x = x
|
1065 |
+
x = x.to(torch.float32)
|
1066 |
+
with torch.no_grad():
|
1067 |
+
x = x.transpose(self.channel_dim, -1).reshape(
|
1068 |
+
-1, self.num_channels
|
1069 |
+
)
|
1070 |
+
x = x - x.mean(dim=0)
|
1071 |
+
new_direction, coeffs = self._find_direction_coeffs(
|
1072 |
+
x, self.max_eig_direction
|
1073 |
+
)
|
1074 |
+
x_var = (x ** 2).mean()
|
1075 |
+
x_residual = x - coeffs * new_direction
|
1076 |
+
x_residual_var = (x_residual ** 2).mean()
|
1077 |
+
|
1078 |
+
# `variance_proportion` is the proportion of the variance accounted for
|
1079 |
+
# by the top eigen-direction.
|
1080 |
+
variance_proportion = (x_var - x_residual_var) / (
|
1081 |
+
x_var + 1.0e-20
|
1082 |
+
)
|
1083 |
+
|
1084 |
+
# ensure new direction is nonzero even if x == 0, by including `direction`.
|
1085 |
+
self._set_direction(
|
1086 |
+
0.1 * self.max_eig_direction + new_direction
|
1087 |
+
)
|
1088 |
+
|
1089 |
+
if random.random() < 0.01 or __name__ == "__main__":
|
1090 |
+
logging.info(
|
1091 |
+
f"variance_proportion = {variance_proportion.item()}, shape={tuple(orig_x.shape)}, cur_prob={self.cur_prob}"
|
1092 |
+
)
|
1093 |
+
|
1094 |
+
if variance_proportion >= self.max_var_per_eig:
|
1095 |
+
# The constraint is active. Note, we should quite rarely
|
1096 |
+
# reach here, only near the beginning of training if we are
|
1097 |
+
# starting to diverge, should this constraint be active.
|
1098 |
+
cur_prob = self.cur_prob
|
1099 |
+
self.cur_prob = (
|
1100 |
+
1.0 # next time, do the update with probability 1.0.
|
1101 |
+
)
|
1102 |
+
return MaxEigLimiterFunction.apply(
|
1103 |
+
orig_x, coeffs, new_direction, self.channel_dim, self.scale
|
1104 |
+
)
|
1105 |
+
else:
|
1106 |
+
# let self.cur_prob exponentially approach self.min_prob, as
|
1107 |
+
# long as the constraint is inactive.
|
1108 |
+
self.cur_prob = 0.75 * self.cur_prob + 0.25 * self.min_prob
|
1109 |
+
return orig_x
|
1110 |
+
|
1111 |
+
def _set_direction(self, direction: Tensor):
|
1112 |
+
"""
|
1113 |
+
Sets self.max_eig_direction to a normalized version of `direction`
|
1114 |
+
"""
|
1115 |
+
direction = direction.detach()
|
1116 |
+
direction = direction / direction.norm()
|
1117 |
+
direction_sum = direction.sum().item()
|
1118 |
+
if direction_sum - direction_sum == 0: # no inf/nan
|
1119 |
+
self.max_eig_direction[:] = direction
|
1120 |
+
else:
|
1121 |
+
logging.info(
|
1122 |
+
f"Warning: sum of direction in MaxEig is {direction_sum}, "
|
1123 |
+
"num_channels={self.num_channels}, channel_dim={self.channel_dim}"
|
1124 |
+
)
|
1125 |
+
|
1126 |
+
def _find_direction_coeffs(
|
1127 |
+
self, x: Tensor, prev_direction: Tensor
|
1128 |
+
) -> Tuple[Tensor, Tensor, Tensor]:
|
1129 |
+
"""
|
1130 |
+
Figure out (an approximation to) the proportion of the variance of a set of
|
1131 |
+
feature vectors that can be attributed to the top eigen-direction.
|
1132 |
+
Args:
|
1133 |
+
x: a Tensor of shape (num_frames, num_channels), with num_frames > 1.
|
1134 |
+
prev_direction: a Tensor of shape (num_channels,), that is our previous estimate
|
1135 |
+
of the top eigen-direction, or a random direction if this is the first
|
1136 |
+
iteration. Does not have to be normalized, but should be nonzero.
|
1137 |
+
|
1138 |
+
Returns: (cur_direction, coeffs), where:
|
1139 |
+
cur_direction: a Tensor of shape (num_channels,) that is the current
|
1140 |
+
estimate of the top eigen-direction.
|
1141 |
+
coeffs: a Tensor of shape (num_frames, 1) that minimizes, or
|
1142 |
+
approximately minimizes, (x - coeffs * cur_direction).norm()
|
1143 |
+
"""
|
1144 |
+
(num_frames, num_channels) = x.shape
|
1145 |
+
assert num_channels > 1 and num_frames > 1
|
1146 |
+
assert prev_direction.shape == (num_channels,)
|
1147 |
+
# `coeffs` are the coefficients of `prev_direction` in x.
|
1148 |
+
# actually represent the coeffs up to a constant positive factor.
|
1149 |
+
coeffs = (x * prev_direction).sum(dim=1, keepdim=True) + 1.0e-10
|
1150 |
+
cur_direction = (x * coeffs).sum(dim=0) / (
|
1151 |
+
(coeffs ** 2).sum() + 1.0e-20
|
1152 |
+
)
|
1153 |
+
return cur_direction, coeffs
|
1154 |
+
|
1155 |
+
|
1156 |
+
class DoubleSwishFunction(torch.autograd.Function):
|
1157 |
+
"""
|
1158 |
+
double_swish(x) = x * torch.sigmoid(x-1)
|
1159 |
+
This is a definition, originally motivated by its close numerical
|
1160 |
+
similarity to swish(swish(x)), where swish(x) = x * sigmoid(x).
|
1161 |
+
|
1162 |
+
Memory-efficient derivative computation:
|
1163 |
+
double_swish(x) = x * s, where s(x) = torch.sigmoid(x-1)
|
1164 |
+
double_swish'(x) = d/dx double_swish(x) = x * s'(x) + x' * s(x) = x * s'(x) + s(x).
|
1165 |
+
Now, s'(x) = s(x) * (1-s(x)).
|
1166 |
+
double_swish'(x) = x * s'(x) + s(x).
|
1167 |
+
= x * s(x) * (1-s(x)) + s(x).
|
1168 |
+
= double_swish(x) * (1-s(x)) + s(x)
|
1169 |
+
... so we just need to remember s(x) but not x itself.
|
1170 |
+
"""
|
1171 |
+
|
1172 |
+
@staticmethod
|
1173 |
+
def forward(ctx, x: Tensor) -> Tensor:
|
1174 |
+
requires_grad = x.requires_grad
|
1175 |
+
x_dtype = x.dtype
|
1176 |
+
if x.dtype == torch.float16:
|
1177 |
+
x = x.to(torch.float32)
|
1178 |
+
|
1179 |
+
s = torch.sigmoid(x - 1.0)
|
1180 |
+
y = x * s
|
1181 |
+
|
1182 |
+
if requires_grad:
|
1183 |
+
deriv = y * (1 - s) + s
|
1184 |
+
# notes on derivative of x * sigmoid(x - 1):
|
1185 |
+
# https://www.wolframalpha.com/input?i=d%2Fdx+%28x+*+sigmoid%28x-1%29%29
|
1186 |
+
# min \simeq -0.043638. Take floor as -0.043637 so it's a lower bund
|
1187 |
+
# max \simeq 1.1990. Take ceil to be 1.2 so it's an upper bound.
|
1188 |
+
# the combination of "+ torch.rand_like(deriv)" and casting to torch.uint8 (which
|
1189 |
+
# floors), should be expectation-preserving.
|
1190 |
+
floor = -0.043637
|
1191 |
+
ceil = 1.2
|
1192 |
+
d_scaled = (deriv - floor) * (
|
1193 |
+
255.0 / (ceil - floor)
|
1194 |
+
) + torch.rand_like(deriv)
|
1195 |
+
if __name__ == "__main__":
|
1196 |
+
# for self-testing only.
|
1197 |
+
assert d_scaled.min() >= 0.0
|
1198 |
+
assert d_scaled.max() < 256.0
|
1199 |
+
d_int = d_scaled.to(torch.uint8)
|
1200 |
+
ctx.save_for_backward(d_int)
|
1201 |
+
if x.dtype == torch.float16 or torch.is_autocast_enabled():
|
1202 |
+
y = y.to(torch.float16)
|
1203 |
+
return y
|
1204 |
+
|
1205 |
+
@staticmethod
|
1206 |
+
def backward(ctx, y_grad: Tensor) -> Tensor:
|
1207 |
+
(d,) = ctx.saved_tensors
|
1208 |
+
# the same constants as used in forward pass.
|
1209 |
+
floor = -0.043637
|
1210 |
+
ceil = 1.2
|
1211 |
+
d = d * ((ceil - floor) / 255.0) + floor
|
1212 |
+
return y_grad * d
|
1213 |
+
|
1214 |
+
|
1215 |
+
class DoubleSwish(torch.nn.Module):
|
1216 |
+
def forward(self, x: Tensor) -> Tensor:
|
1217 |
+
"""Return double-swish activation function which is an approximation to Swish(Swish(x)),
|
1218 |
+
that we approximate closely with x * sigmoid(x-1).
|
1219 |
+
"""
|
1220 |
+
if torch.jit.is_scripting() or torch.jit.is_tracing():
|
1221 |
+
return x * torch.sigmoid(x - 1.0)
|
1222 |
+
return DoubleSwishFunction.apply(x)
|
1223 |
+
|
1224 |
+
|
1225 |
+
def BalancedDoubleSwish(
|
1226 |
+
d_model, channel_dim=-1, max_abs=10.0, min_prob=0.25
|
1227 |
+
) -> nn.Sequential:
|
1228 |
+
"""
|
1229 |
+
ActivationBalancer -> DoubleSwish
|
1230 |
+
"""
|
1231 |
+
balancer = ActivationBalancer(
|
1232 |
+
d_model, channel_dim=channel_dim, max_abs=max_abs, min_prob=min_prob
|
1233 |
+
)
|
1234 |
+
return nn.Sequential(
|
1235 |
+
balancer,
|
1236 |
+
DoubleSwish(),
|
1237 |
+
)
|
1238 |
+
|
1239 |
+
|
1240 |
+
def _test_max_eig():
|
1241 |
+
for proportion in [0.1, 0.5, 10.0]:
|
1242 |
+
logging.info(f"proportion = {proportion}")
|
1243 |
+
x = torch.randn(100, 128)
|
1244 |
+
direction = torch.randn(128)
|
1245 |
+
coeffs = torch.randn(100, 1)
|
1246 |
+
x += proportion * direction * coeffs
|
1247 |
+
|
1248 |
+
x.requires_grad = True
|
1249 |
+
|
1250 |
+
num_channels = 128
|
1251 |
+
m = MaxEig(
|
1252 |
+
num_channels, 1, 0.5, scale=0.1 # channel_dim # max_var_per_eig
|
1253 |
+
) # grad_scale
|
1254 |
+
|
1255 |
+
for _ in range(4):
|
1256 |
+
y = m(x)
|
1257 |
+
|
1258 |
+
y_grad = torch.randn_like(x)
|
1259 |
+
y.backward(gradient=y_grad)
|
1260 |
+
|
1261 |
+
if proportion < 0.2:
|
1262 |
+
assert torch.allclose(x.grad, y_grad, atol=1.0e-02)
|
1263 |
+
elif proportion > 1.0:
|
1264 |
+
assert not torch.allclose(x.grad, y_grad)
|
1265 |
+
|
1266 |
+
|
1267 |
+
def _test_whiten():
|
1268 |
+
for proportion in [0.1, 0.5, 10.0]:
|
1269 |
+
logging.info(f"_test_whiten(): proportion = {proportion}")
|
1270 |
+
x = torch.randn(100, 128)
|
1271 |
+
direction = torch.randn(128)
|
1272 |
+
coeffs = torch.randn(100, 1)
|
1273 |
+
x += proportion * direction * coeffs
|
1274 |
+
|
1275 |
+
x.requires_grad = True
|
1276 |
+
|
1277 |
+
num_channels = 128
|
1278 |
+
m = Whiten(
|
1279 |
+
1, 5.0, prob=1.0, grad_scale=0.1 # num_groups # whitening_limit,
|
1280 |
+
) # grad_scale
|
1281 |
+
|
1282 |
+
for _ in range(4):
|
1283 |
+
y = m(x)
|
1284 |
+
|
1285 |
+
y_grad = torch.randn_like(x)
|
1286 |
+
y.backward(gradient=y_grad)
|
1287 |
+
|
1288 |
+
if proportion < 0.2:
|
1289 |
+
assert torch.allclose(x.grad, y_grad)
|
1290 |
+
elif proportion > 1.0:
|
1291 |
+
assert not torch.allclose(x.grad, y_grad)
|
1292 |
+
|
1293 |
+
|
1294 |
+
def _test_activation_balancer_sign():
|
1295 |
+
probs = torch.arange(0, 1, 0.01)
|
1296 |
+
N = 1000
|
1297 |
+
x = 1.0 * (
|
1298 |
+
(2.0 * (torch.rand(probs.numel(), N) < probs.unsqueeze(-1))) - 1.0
|
1299 |
+
)
|
1300 |
+
x = x.detach()
|
1301 |
+
x.requires_grad = True
|
1302 |
+
m = ActivationBalancer(
|
1303 |
+
probs.numel(),
|
1304 |
+
channel_dim=0,
|
1305 |
+
min_positive=0.05,
|
1306 |
+
max_positive=0.95,
|
1307 |
+
max_factor=0.2,
|
1308 |
+
min_abs=0.0,
|
1309 |
+
)
|
1310 |
+
|
1311 |
+
y_grad = torch.sign(torch.randn(probs.numel(), N))
|
1312 |
+
|
1313 |
+
y = m(x)
|
1314 |
+
y.backward(gradient=y_grad)
|
1315 |
+
print("_test_activation_balancer_sign: x = ", x)
|
1316 |
+
print("_test_activation_balancer_sign: y grad = ", y_grad)
|
1317 |
+
print("_test_activation_balancer_sign: x grad = ", x.grad)
|
1318 |
+
|
1319 |
+
|
1320 |
+
def _test_activation_balancer_magnitude():
|
1321 |
+
magnitudes = torch.arange(0, 1, 0.01)
|
1322 |
+
N = 1000
|
1323 |
+
x = torch.sign(torch.randn(magnitudes.numel(), N)) * magnitudes.unsqueeze(
|
1324 |
+
-1
|
1325 |
+
)
|
1326 |
+
x = x.detach()
|
1327 |
+
x.requires_grad = True
|
1328 |
+
m = ActivationBalancer(
|
1329 |
+
magnitudes.numel(),
|
1330 |
+
channel_dim=0,
|
1331 |
+
min_positive=0.0,
|
1332 |
+
max_positive=1.0,
|
1333 |
+
max_factor=0.2,
|
1334 |
+
min_abs=0.2,
|
1335 |
+
max_abs=0.8,
|
1336 |
+
min_prob=1.0,
|
1337 |
+
)
|
1338 |
+
|
1339 |
+
y_grad = torch.sign(torch.randn(magnitudes.numel(), N))
|
1340 |
+
|
1341 |
+
y = m(x)
|
1342 |
+
y.backward(gradient=y_grad)
|
1343 |
+
print("_test_activation_balancer_magnitude: x = ", x)
|
1344 |
+
print("_test_activation_balancer_magnitude: y grad = ", y_grad)
|
1345 |
+
print("_test_activation_balancer_magnitude: x grad = ", x.grad)
|
1346 |
+
|
1347 |
+
|
1348 |
+
def _test_basic_norm():
|
1349 |
+
num_channels = 128
|
1350 |
+
m = BasicNorm(num_channels=num_channels, channel_dim=1)
|
1351 |
+
|
1352 |
+
x = torch.randn(500, num_channels)
|
1353 |
+
|
1354 |
+
y = m(x)
|
1355 |
+
|
1356 |
+
assert y.shape == x.shape
|
1357 |
+
x_rms = (x ** 2).mean().sqrt()
|
1358 |
+
y_rms = (y ** 2).mean().sqrt()
|
1359 |
+
print("x rms = ", x_rms)
|
1360 |
+
print("y rms = ", y_rms)
|
1361 |
+
assert y_rms < x_rms
|
1362 |
+
assert y_rms > 0.5 * x_rms
|
1363 |
+
|
1364 |
+
|
1365 |
+
def _test_double_swish_deriv():
|
1366 |
+
x = torch.randn(10, 12, dtype=torch.double) * 3.0
|
1367 |
+
x.requires_grad = True
|
1368 |
+
m = DoubleSwish()
|
1369 |
+
|
1370 |
+
tol = (1.2 - (-0.043637)) / 255.0
|
1371 |
+
torch.autograd.gradcheck(m, x, atol=tol)
|
1372 |
+
|
1373 |
+
# for self-test.
|
1374 |
+
x = torch.randn(1000, 1000, dtype=torch.double) * 3.0
|
1375 |
+
x.requires_grad = True
|
1376 |
+
y = m(x)
|
1377 |
+
|
1378 |
+
|
1379 |
+
def _test_softmax():
|
1380 |
+
a = torch.randn(2, 10, dtype=torch.float64)
|
1381 |
+
b = a.clone()
|
1382 |
+
a.requires_grad = True
|
1383 |
+
b.requires_grad = True
|
1384 |
+
a.softmax(dim=1)[:, 0].sum().backward()
|
1385 |
+
print("a grad = ", a.grad)
|
1386 |
+
softmax(b, dim=1)[:, 0].sum().backward()
|
1387 |
+
print("b grad = ", b.grad)
|
1388 |
+
assert torch.allclose(a.grad, b.grad)
|
1389 |
+
|
1390 |
+
|
1391 |
+
if __name__ == "__main__":
|
1392 |
+
logging.getLogger().setLevel(logging.INFO)
|
1393 |
+
torch.set_num_threads(1)
|
1394 |
+
torch.set_num_interop_threads(1)
|
1395 |
+
_test_softmax()
|
1396 |
+
_test_whiten()
|
1397 |
+
_test_max_eig()
|
1398 |
+
_test_activation_balancer_sign()
|
1399 |
+
_test_activation_balancer_magnitude()
|
1400 |
+
_test_basic_norm()
|
1401 |
+
_test_double_swish_deriv()
|
modules/scheduler.py
ADDED
@@ -0,0 +1,78 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
# Copyright 2023 (authors: Feiteng Li)
|
3 |
+
#
|
4 |
+
# See ../../../../LICENSE for clarification regarding multiple authors
|
5 |
+
#
|
6 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
7 |
+
# you may not use this file except in compliance with the License.
|
8 |
+
# You may obtain a copy of the License at
|
9 |
+
#
|
10 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
11 |
+
#
|
12 |
+
# Unless required by applicable law or agreed to in writing, software
|
13 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
14 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
15 |
+
# See the License for the specific language governing permissions and
|
16 |
+
# limitations under the License.
|
17 |
+
|
18 |
+
|
19 |
+
import torch
|
20 |
+
|
21 |
+
from modules.optim import Eden
|
22 |
+
|
23 |
+
|
24 |
+
def calc_lr(step, dim_embed, warmup_steps):
|
25 |
+
return dim_embed ** (-0.5) * min(
|
26 |
+
step ** (-0.5), step * warmup_steps ** (-1.5)
|
27 |
+
)
|
28 |
+
|
29 |
+
|
30 |
+
class NoamScheduler(torch.optim.lr_scheduler._LRScheduler):
|
31 |
+
def __init__(
|
32 |
+
self,
|
33 |
+
base_lr: float,
|
34 |
+
optimizer: torch.optim.Optimizer,
|
35 |
+
dim_embed: int,
|
36 |
+
warmup_steps: int,
|
37 |
+
last_epoch: int = -1,
|
38 |
+
verbose: bool = False,
|
39 |
+
) -> None:
|
40 |
+
|
41 |
+
self.dim_embed = dim_embed
|
42 |
+
self.base_lr = base_lr
|
43 |
+
self.warmup_steps = warmup_steps
|
44 |
+
self.num_param_groups = len(optimizer.param_groups)
|
45 |
+
|
46 |
+
super().__init__(optimizer, last_epoch, verbose)
|
47 |
+
|
48 |
+
def get_lr(self) -> float:
|
49 |
+
lr = self.base_lr * calc_lr(
|
50 |
+
self._step_count, self.dim_embed, self.warmup_steps
|
51 |
+
)
|
52 |
+
return [lr] * self.num_param_groups
|
53 |
+
|
54 |
+
def set_step(self, step: int):
|
55 |
+
self._step_count = step
|
56 |
+
|
57 |
+
|
58 |
+
def get_scheduler(params, optimizer):
|
59 |
+
if params.scheduler_name.lower() == "eden":
|
60 |
+
scheduler = Eden(optimizer, 5000, 4, warmup_batches=params.warmup_steps)
|
61 |
+
elif params.scheduler_name.lower() == "noam":
|
62 |
+
scheduler = NoamScheduler(
|
63 |
+
params.base_lr,
|
64 |
+
optimizer,
|
65 |
+
params.decoder_dim,
|
66 |
+
warmup_steps=params.warmup_steps,
|
67 |
+
)
|
68 |
+
# scheduler.set_step(params.start_batch or params.batch_idx_train)
|
69 |
+
elif params.scheduler_name.lower() == "cosine":
|
70 |
+
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
|
71 |
+
params.warmup_steps,
|
72 |
+
optimizer,
|
73 |
+
eta_min=params.base_lr,
|
74 |
+
)
|
75 |
+
else:
|
76 |
+
raise NotImplementedError(f"{params.scheduler_name}")
|
77 |
+
|
78 |
+
return scheduler
|
modules/transformer.py
ADDED
@@ -0,0 +1,683 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import copy
|
2 |
+
import numbers
|
3 |
+
from functools import partial
|
4 |
+
from typing import Any, Callable, List, Optional, Tuple, Union
|
5 |
+
|
6 |
+
import torch
|
7 |
+
from torch import Tensor, nn
|
8 |
+
from torch.nn import functional as F
|
9 |
+
|
10 |
+
from .activation import MultiheadAttention
|
11 |
+
from .scaling import ActivationBalancer, BalancedDoubleSwish
|
12 |
+
from .scaling import BasicNorm as _BasicNorm
|
13 |
+
|
14 |
+
_shape_t = Union[int, List[int], torch.Size]
|
15 |
+
|
16 |
+
|
17 |
+
class LayerNorm(nn.Module):
|
18 |
+
__constants__ = ["normalized_shape", "eps", "elementwise_affine"]
|
19 |
+
normalized_shape: Tuple[int, ...]
|
20 |
+
eps: float
|
21 |
+
elementwise_affine: bool
|
22 |
+
|
23 |
+
def __init__(
|
24 |
+
self,
|
25 |
+
normalized_shape: _shape_t,
|
26 |
+
eps: float = 1e-5,
|
27 |
+
elementwise_affine: bool = True,
|
28 |
+
device=None,
|
29 |
+
dtype=None,
|
30 |
+
) -> None:
|
31 |
+
factory_kwargs = {"device": device, "dtype": dtype}
|
32 |
+
super(LayerNorm, self).__init__()
|
33 |
+
if isinstance(normalized_shape, numbers.Integral):
|
34 |
+
# mypy error: incompatible types in assignment
|
35 |
+
normalized_shape = (normalized_shape,) # type: ignore[assignment]
|
36 |
+
self.normalized_shape = tuple(normalized_shape) # type: ignore[arg-type]
|
37 |
+
self.eps = eps
|
38 |
+
self.elementwise_affine = elementwise_affine
|
39 |
+
if self.elementwise_affine:
|
40 |
+
self.weight = nn.Parameter(
|
41 |
+
torch.empty(self.normalized_shape, **factory_kwargs)
|
42 |
+
)
|
43 |
+
self.bias = nn.Parameter(
|
44 |
+
torch.empty(self.normalized_shape, **factory_kwargs)
|
45 |
+
)
|
46 |
+
else:
|
47 |
+
self.register_parameter("weight", None)
|
48 |
+
self.register_parameter("bias", None)
|
49 |
+
|
50 |
+
self.reset_parameters()
|
51 |
+
|
52 |
+
def reset_parameters(self) -> None:
|
53 |
+
if self.elementwise_affine:
|
54 |
+
nn.init.ones_(self.weight)
|
55 |
+
nn.init.zeros_(self.bias)
|
56 |
+
|
57 |
+
def forward(self, input: Tensor, embedding: Any = None) -> Tensor:
|
58 |
+
if isinstance(input, tuple):
|
59 |
+
input, embedding = input
|
60 |
+
return (
|
61 |
+
F.layer_norm(
|
62 |
+
input,
|
63 |
+
self.normalized_shape,
|
64 |
+
self.weight,
|
65 |
+
self.bias,
|
66 |
+
self.eps,
|
67 |
+
),
|
68 |
+
embedding,
|
69 |
+
)
|
70 |
+
|
71 |
+
assert embedding is None
|
72 |
+
return F.layer_norm(
|
73 |
+
input, self.normalized_shape, self.weight, self.bias, self.eps
|
74 |
+
)
|
75 |
+
|
76 |
+
def extra_repr(self) -> str:
|
77 |
+
return (
|
78 |
+
"{normalized_shape}, eps={eps}, "
|
79 |
+
"elementwise_affine={elementwise_affine}".format(**self.__dict__)
|
80 |
+
)
|
81 |
+
|
82 |
+
|
83 |
+
class AdaptiveLayerNorm(nn.Module):
|
84 |
+
r"""Adaptive Layer Normalization"""
|
85 |
+
|
86 |
+
def __init__(self, d_model, norm) -> None:
|
87 |
+
super(AdaptiveLayerNorm, self).__init__()
|
88 |
+
self.project_layer = nn.Linear(d_model, 2 * d_model)
|
89 |
+
self.norm = norm
|
90 |
+
self.d_model = d_model
|
91 |
+
self.eps = self.norm.eps
|
92 |
+
|
93 |
+
def forward(self, input: Tensor, embedding: Tensor = None) -> Tensor:
|
94 |
+
if isinstance(input, tuple):
|
95 |
+
input, embedding = input
|
96 |
+
weight, bias = torch.split(
|
97 |
+
self.project_layer(embedding),
|
98 |
+
split_size_or_sections=self.d_model,
|
99 |
+
dim=-1,
|
100 |
+
)
|
101 |
+
return (weight * self.norm(input) + bias, embedding)
|
102 |
+
|
103 |
+
weight, bias = torch.split(
|
104 |
+
self.project_layer(embedding),
|
105 |
+
split_size_or_sections=self.d_model,
|
106 |
+
dim=-1,
|
107 |
+
)
|
108 |
+
return weight * self.norm(input) + bias
|
109 |
+
|
110 |
+
|
111 |
+
class BasicNorm(_BasicNorm):
|
112 |
+
def __init__(
|
113 |
+
self,
|
114 |
+
d_model: int,
|
115 |
+
eps: float = 1e-5,
|
116 |
+
device=None,
|
117 |
+
dtype=None,
|
118 |
+
):
|
119 |
+
super(BasicNorm, self).__init__(d_model, eps=eps)
|
120 |
+
|
121 |
+
def forward(self, input: Tensor, embedding: Any = None) -> Tensor:
|
122 |
+
if isinstance(input, tuple):
|
123 |
+
input, embedding = input
|
124 |
+
return (
|
125 |
+
super(BasicNorm, self).forward(input),
|
126 |
+
embedding,
|
127 |
+
)
|
128 |
+
|
129 |
+
assert embedding is None
|
130 |
+
return super(BasicNorm, self).forward(input)
|
131 |
+
|
132 |
+
|
133 |
+
class BalancedBasicNorm(nn.Module):
|
134 |
+
def __init__(
|
135 |
+
self,
|
136 |
+
d_model: int,
|
137 |
+
eps: float = 1e-5,
|
138 |
+
device=None,
|
139 |
+
dtype=None,
|
140 |
+
):
|
141 |
+
super(BalancedBasicNorm, self).__init__()
|
142 |
+
self.balancer = ActivationBalancer(
|
143 |
+
d_model,
|
144 |
+
channel_dim=-1,
|
145 |
+
min_positive=0.45,
|
146 |
+
max_positive=0.55,
|
147 |
+
max_abs=6.0,
|
148 |
+
)
|
149 |
+
self.norm = BasicNorm(d_model, eps, device=device, dtype=dtype)
|
150 |
+
|
151 |
+
def forward(self, input: Tensor, embedding: Any = None) -> Tensor:
|
152 |
+
if isinstance(input, tuple):
|
153 |
+
input, embedding = input
|
154 |
+
return self.norm((self.balancer(input), embedding))
|
155 |
+
|
156 |
+
assert embedding is None
|
157 |
+
return self.norm(self.balancer(input))
|
158 |
+
|
159 |
+
|
160 |
+
class IdentityNorm(nn.Module):
|
161 |
+
def __init__(
|
162 |
+
self,
|
163 |
+
d_model: int,
|
164 |
+
eps: float = 1e-5,
|
165 |
+
device=None,
|
166 |
+
dtype=None,
|
167 |
+
) -> None:
|
168 |
+
super(IdentityNorm, self).__init__()
|
169 |
+
|
170 |
+
def forward(self, input: Tensor, embedding: Any = None) -> Tensor:
|
171 |
+
if isinstance(input, tuple):
|
172 |
+
return input
|
173 |
+
|
174 |
+
assert embedding is None
|
175 |
+
return input
|
176 |
+
|
177 |
+
|
178 |
+
class TransformerEncoderLayer(nn.Module):
|
179 |
+
__constants__ = ["batch_first", "norm_first"]
|
180 |
+
|
181 |
+
def __init__(
|
182 |
+
self,
|
183 |
+
d_model: int,
|
184 |
+
nhead: int,
|
185 |
+
dim_feedforward: int = 2048,
|
186 |
+
dropout: float = 0.1,
|
187 |
+
activation: Union[str, Callable[[Tensor], Tensor]] = F.relu,
|
188 |
+
batch_first: bool = False,
|
189 |
+
norm_first: bool = False,
|
190 |
+
device=None,
|
191 |
+
dtype=None,
|
192 |
+
linear1_self_attention_cls: nn.Module = nn.Linear,
|
193 |
+
linear2_self_attention_cls: nn.Module = nn.Linear,
|
194 |
+
linear1_feedforward_cls: nn.Module = nn.Linear,
|
195 |
+
linear2_feedforward_cls: nn.Module = nn.Linear,
|
196 |
+
layer_norm_cls: nn.Module = LayerNorm,
|
197 |
+
layer_norm_eps: float = 1e-5,
|
198 |
+
adaptive_layer_norm=False,
|
199 |
+
) -> None:
|
200 |
+
factory_kwargs = {"device": device, "dtype": dtype}
|
201 |
+
super(TransformerEncoderLayer, self).__init__()
|
202 |
+
self.self_attn = MultiheadAttention(
|
203 |
+
d_model,
|
204 |
+
nhead,
|
205 |
+
dropout=dropout,
|
206 |
+
batch_first=batch_first,
|
207 |
+
linear1_cls=linear1_self_attention_cls,
|
208 |
+
linear2_cls=linear2_self_attention_cls,
|
209 |
+
**factory_kwargs,
|
210 |
+
)
|
211 |
+
|
212 |
+
# Implementation of Feedforward model
|
213 |
+
self.linear1 = linear1_feedforward_cls(
|
214 |
+
d_model, dim_feedforward, **factory_kwargs
|
215 |
+
)
|
216 |
+
self.dropout = nn.Dropout(dropout)
|
217 |
+
self.linear2 = linear2_feedforward_cls(
|
218 |
+
dim_feedforward, d_model, **factory_kwargs
|
219 |
+
)
|
220 |
+
|
221 |
+
self.norm_first = norm_first
|
222 |
+
self.dropout1 = nn.Dropout(dropout)
|
223 |
+
self.dropout2 = nn.Dropout(dropout)
|
224 |
+
|
225 |
+
# Legacy string support for activation function.
|
226 |
+
if isinstance(activation, str):
|
227 |
+
activation = _get_activation_fn(activation)
|
228 |
+
elif isinstance(activation, partial):
|
229 |
+
activation = activation(d_model)
|
230 |
+
elif activation == BalancedDoubleSwish:
|
231 |
+
activation = BalancedDoubleSwish(d_model)
|
232 |
+
|
233 |
+
# # We can't test self.activation in forward() in TorchScript,
|
234 |
+
# # so stash some information about it instead.
|
235 |
+
# if activation is F.relu or isinstance(activation, torch.nn.ReLU):
|
236 |
+
# self.activation_relu_or_gelu = 1
|
237 |
+
# elif activation is F.gelu or isinstance(activation, torch.nn.GELU):
|
238 |
+
# self.activation_relu_or_gelu = 2
|
239 |
+
# else:
|
240 |
+
# self.activation_relu_or_gelu = 0
|
241 |
+
self.activation = activation
|
242 |
+
|
243 |
+
norm1 = layer_norm_cls(d_model, eps=layer_norm_eps, **factory_kwargs)
|
244 |
+
if layer_norm_cls == IdentityNorm:
|
245 |
+
norm2 = BalancedBasicNorm(
|
246 |
+
d_model, eps=layer_norm_eps, **factory_kwargs
|
247 |
+
)
|
248 |
+
else:
|
249 |
+
norm2 = layer_norm_cls(
|
250 |
+
d_model, eps=layer_norm_eps, **factory_kwargs
|
251 |
+
)
|
252 |
+
|
253 |
+
if adaptive_layer_norm:
|
254 |
+
self.norm1 = AdaptiveLayerNorm(d_model, norm1)
|
255 |
+
self.norm2 = AdaptiveLayerNorm(d_model, norm2)
|
256 |
+
else:
|
257 |
+
self.norm1 = norm1
|
258 |
+
self.norm2 = norm2
|
259 |
+
|
260 |
+
def __setstate__(self, state):
|
261 |
+
super(TransformerEncoderLayer, self).__setstate__(state)
|
262 |
+
if not hasattr(self, "activation"):
|
263 |
+
self.activation = F.relu
|
264 |
+
|
265 |
+
def forward(
|
266 |
+
self,
|
267 |
+
src: Tensor,
|
268 |
+
src_mask: Optional[Tensor] = None,
|
269 |
+
src_key_padding_mask: Optional[Tensor] = None,
|
270 |
+
) -> Tensor:
|
271 |
+
r"""Pass the input through the encoder layer.
|
272 |
+
|
273 |
+
Args:
|
274 |
+
src: the sequence to the encoder layer (required).
|
275 |
+
src_mask: the mask for the src sequence (optional).
|
276 |
+
src_key_padding_mask: the mask for the src keys per batch (optional).
|
277 |
+
|
278 |
+
Shape:
|
279 |
+
see the docs in Transformer class.
|
280 |
+
"""
|
281 |
+
x, stage_embedding = src, None
|
282 |
+
is_src_tuple = False
|
283 |
+
if isinstance(src, tuple):
|
284 |
+
x, stage_embedding = src
|
285 |
+
is_src_tuple = True
|
286 |
+
|
287 |
+
if src_key_padding_mask is not None:
|
288 |
+
_skpm_dtype = src_key_padding_mask.dtype
|
289 |
+
if _skpm_dtype != torch.bool and not torch.is_floating_point(
|
290 |
+
src_key_padding_mask
|
291 |
+
):
|
292 |
+
raise AssertionError(
|
293 |
+
"only bool and floating types of key_padding_mask are supported"
|
294 |
+
)
|
295 |
+
|
296 |
+
if self.norm_first:
|
297 |
+
x = x + self._sa_block(
|
298 |
+
self.norm1(x, stage_embedding),
|
299 |
+
src_mask,
|
300 |
+
src_key_padding_mask,
|
301 |
+
)
|
302 |
+
x = x + self._ff_block(self.norm2(x, stage_embedding))
|
303 |
+
else:
|
304 |
+
x = self.norm1(
|
305 |
+
x + self._sa_block(x, src_mask, src_key_padding_mask),
|
306 |
+
stage_embedding,
|
307 |
+
)
|
308 |
+
x = self.norm2(x + self._ff_block(x), stage_embedding)
|
309 |
+
|
310 |
+
if is_src_tuple:
|
311 |
+
return (x, stage_embedding)
|
312 |
+
return x
|
313 |
+
|
314 |
+
def infer(
|
315 |
+
self,
|
316 |
+
src: Tensor,
|
317 |
+
src_mask: Optional[Tensor] = None,
|
318 |
+
src_key_padding_mask: Optional[Tensor] = None,
|
319 |
+
past_kv: Optional[Tensor] = None,
|
320 |
+
use_cache: bool = False,
|
321 |
+
):
|
322 |
+
x, stage_embedding = src, None
|
323 |
+
is_src_tuple = False
|
324 |
+
if isinstance(src, tuple):
|
325 |
+
x, stage_embedding = src
|
326 |
+
is_src_tuple = True
|
327 |
+
|
328 |
+
if src_key_padding_mask is not None:
|
329 |
+
_skpm_dtype = src_key_padding_mask.dtype
|
330 |
+
if _skpm_dtype != torch.bool and not torch.is_floating_point(
|
331 |
+
src_key_padding_mask
|
332 |
+
):
|
333 |
+
raise AssertionError(
|
334 |
+
"only bool and floating types of key_padding_mask are supported"
|
335 |
+
)
|
336 |
+
|
337 |
+
if self.norm_first:
|
338 |
+
x_attn_out, kv = self.self_attn.infer(
|
339 |
+
self.norm1(x, stage_embedding),
|
340 |
+
attn_mask=src_mask,
|
341 |
+
key_padding_mask=src_key_padding_mask,
|
342 |
+
need_weights=False,
|
343 |
+
past_kv=past_kv,
|
344 |
+
use_cache=use_cache,
|
345 |
+
)
|
346 |
+
x = x + x_attn_out
|
347 |
+
x = x + self._ff_block(self.norm2(x, stage_embedding))
|
348 |
+
|
349 |
+
if is_src_tuple:
|
350 |
+
return (x, stage_embedding)
|
351 |
+
return (x, kv)
|
352 |
+
|
353 |
+
# self-attention block
|
354 |
+
def _sa_block(
|
355 |
+
self,
|
356 |
+
x: Tensor,
|
357 |
+
attn_mask: Optional[Tensor],
|
358 |
+
key_padding_mask: Optional[Tensor],
|
359 |
+
) -> Tensor:
|
360 |
+
x = self.self_attn(
|
361 |
+
x,
|
362 |
+
x,
|
363 |
+
x,
|
364 |
+
attn_mask=attn_mask,
|
365 |
+
key_padding_mask=key_padding_mask,
|
366 |
+
need_weights=False,
|
367 |
+
)[0]
|
368 |
+
return self.dropout1(x)
|
369 |
+
|
370 |
+
# feed forward block
|
371 |
+
def _ff_block(self, x: Tensor) -> Tensor:
|
372 |
+
x = self.linear2(self.dropout(self.activation(self.linear1(x))))
|
373 |
+
return self.dropout2(x)
|
374 |
+
|
375 |
+
|
376 |
+
class TransformerEncoder(nn.Module):
|
377 |
+
r"""TransformerEncoder is a stack of N encoder layers. Users can build the
|
378 |
+
BERT(https://arxiv.org/abs/1810.04805) model with corresponding parameters.
|
379 |
+
|
380 |
+
Args:
|
381 |
+
encoder_layer: an instance of the TransformerEncoderLayer() class (required).
|
382 |
+
num_layers: the number of sub-encoder-layers in the encoder (required).
|
383 |
+
norm: the layer normalization component (optional).
|
384 |
+
enable_nested_tensor: if True, input will automatically convert to nested tensor
|
385 |
+
(and convert back on output). This will improve the overall performance of
|
386 |
+
TransformerEncoder when padding rate is high. Default: ``True`` (enabled).
|
387 |
+
|
388 |
+
Examples::
|
389 |
+
>>> encoder_layer = TransformerEncoderLayer(d_model=512, nhead=8)
|
390 |
+
>>> transformer_encoder = TransformerEncoder(encoder_layer, num_layers=6)
|
391 |
+
>>> src = torch.rand(10, 32, 512)
|
392 |
+
>>> out = transformer_encoder(src)
|
393 |
+
"""
|
394 |
+
__constants__ = ["norm"]
|
395 |
+
|
396 |
+
def __init__(self, encoder_layer, num_layers, norm=None):
|
397 |
+
super(TransformerEncoder, self).__init__()
|
398 |
+
self.layers = _get_clones(encoder_layer, num_layers)
|
399 |
+
self.num_layers = num_layers
|
400 |
+
self.norm = norm
|
401 |
+
|
402 |
+
def forward(
|
403 |
+
self,
|
404 |
+
src: Tensor,
|
405 |
+
mask: Optional[Tensor] = None,
|
406 |
+
src_key_padding_mask: Optional[Tensor] = None,
|
407 |
+
return_layer_states: bool = False,
|
408 |
+
) -> Tensor:
|
409 |
+
r"""Pass the input through the encoder layers in turn.
|
410 |
+
|
411 |
+
Args:
|
412 |
+
src: the sequence to the encoder (required).
|
413 |
+
mask: the mask for the src sequence (optional).
|
414 |
+
src_key_padding_mask: the mask for the src keys per batch (optional).
|
415 |
+
return_layer_states: return layers' state (optional).
|
416 |
+
|
417 |
+
Shape:
|
418 |
+
see the docs in Transformer class.
|
419 |
+
"""
|
420 |
+
if return_layer_states:
|
421 |
+
layer_states = [] # layers' output
|
422 |
+
output = src
|
423 |
+
for mod in self.layers:
|
424 |
+
output = mod(
|
425 |
+
output,
|
426 |
+
src_mask=mask,
|
427 |
+
src_key_padding_mask=src_key_padding_mask,
|
428 |
+
)
|
429 |
+
layer_states.append(output[0])
|
430 |
+
|
431 |
+
if self.norm is not None:
|
432 |
+
output = self.norm(output)
|
433 |
+
|
434 |
+
return layer_states, output
|
435 |
+
|
436 |
+
output = src
|
437 |
+
for mod in self.layers:
|
438 |
+
output = mod(
|
439 |
+
output, src_mask=mask, src_key_padding_mask=src_key_padding_mask
|
440 |
+
)
|
441 |
+
|
442 |
+
if self.norm is not None:
|
443 |
+
output = self.norm(output)
|
444 |
+
|
445 |
+
return output
|
446 |
+
|
447 |
+
def infer(
|
448 |
+
self,
|
449 |
+
src: Tensor,
|
450 |
+
mask: Optional[Tensor] = None,
|
451 |
+
src_key_padding_mask: Optional[Tensor] = None,
|
452 |
+
return_layer_states: bool = False,
|
453 |
+
past_kv: Optional[Tensor] = None,
|
454 |
+
use_cache: bool = False,
|
455 |
+
):
|
456 |
+
if past_kv is None:
|
457 |
+
past_length = 0
|
458 |
+
past_kv = tuple([None] * self.num_layers)
|
459 |
+
else:
|
460 |
+
past_length = past_kv[0][0].size(-2)
|
461 |
+
new_kv = () if use_cache else None
|
462 |
+
output = src
|
463 |
+
for mod, past_layer_kv in zip(self.layers, past_kv):
|
464 |
+
output, kv = mod.infer(
|
465 |
+
output, src_mask=mask, src_key_padding_mask=src_key_padding_mask, past_kv=past_layer_kv, use_cache=use_cache
|
466 |
+
)
|
467 |
+
if use_cache:
|
468 |
+
new_kv = new_kv + (kv,)
|
469 |
+
|
470 |
+
if self.norm is not None:
|
471 |
+
output = self.norm(output)
|
472 |
+
|
473 |
+
return output, new_kv
|
474 |
+
|
475 |
+
|
476 |
+
class TransformerDecoderLayer(nn.Module):
|
477 |
+
__constants__ = ["batch_first", "norm_first"]
|
478 |
+
|
479 |
+
def __init__(
|
480 |
+
self,
|
481 |
+
d_model: int,
|
482 |
+
nhead: int,
|
483 |
+
dim_feedforward: int = 2048,
|
484 |
+
dropout: float = 0.1,
|
485 |
+
activation: Union[str, Callable[[Tensor], Tensor]] = F.relu,
|
486 |
+
linear1_self_attention_cls: nn.Module = nn.Linear,
|
487 |
+
linear2_self_attention_cls: nn.Module = nn.Linear,
|
488 |
+
linear1_feedforward_cls: nn.Module = nn.Linear,
|
489 |
+
linear2_feedforward_cls: nn.Module = nn.Linear,
|
490 |
+
batch_first: bool = False,
|
491 |
+
norm_first: bool = False,
|
492 |
+
device=None,
|
493 |
+
dtype=None,
|
494 |
+
layer_norm_cls: nn.Module = LayerNorm,
|
495 |
+
layer_norm_eps: float = 1e-5,
|
496 |
+
adaptive_layer_norm=False,
|
497 |
+
) -> None:
|
498 |
+
factory_kwargs = {"device": device, "dtype": dtype}
|
499 |
+
super(TransformerDecoderLayer, self).__init__()
|
500 |
+
self.self_attn = MultiheadAttention(
|
501 |
+
d_model,
|
502 |
+
nhead,
|
503 |
+
dropout=dropout,
|
504 |
+
batch_first=batch_first,
|
505 |
+
linear1_cls=linear1_self_attention_cls,
|
506 |
+
linear2_cls=linear2_self_attention_cls,
|
507 |
+
**factory_kwargs,
|
508 |
+
)
|
509 |
+
self.multihead_attn = MultiheadAttention(
|
510 |
+
d_model,
|
511 |
+
nhead,
|
512 |
+
dropout=dropout,
|
513 |
+
batch_first=batch_first,
|
514 |
+
linear1_cls=linear1_self_attention_cls,
|
515 |
+
linear2_cls=linear2_self_attention_cls,
|
516 |
+
**factory_kwargs,
|
517 |
+
)
|
518 |
+
# Implementation of Feedforward model
|
519 |
+
self.linear1 = linear1_feedforward_cls(
|
520 |
+
d_model, dim_feedforward, **factory_kwargs
|
521 |
+
)
|
522 |
+
self.dropout = nn.Dropout(dropout)
|
523 |
+
self.linear2 = linear2_feedforward_cls(
|
524 |
+
dim_feedforward, d_model, **factory_kwargs
|
525 |
+
)
|
526 |
+
|
527 |
+
self.norm_first = norm_first
|
528 |
+
self.dropout1 = nn.Dropout(dropout)
|
529 |
+
self.dropout2 = nn.Dropout(dropout)
|
530 |
+
self.dropout3 = nn.Dropout(dropout)
|
531 |
+
|
532 |
+
# Legacy string support for activation function.
|
533 |
+
if isinstance(activation, str):
|
534 |
+
self.activation = _get_activation_fn(activation)
|
535 |
+
elif isinstance(activation, partial):
|
536 |
+
self.activation = activation(d_model)
|
537 |
+
elif activation == BalancedDoubleSwish:
|
538 |
+
self.activation = BalancedDoubleSwish(d_model)
|
539 |
+
else:
|
540 |
+
self.activation = activation
|
541 |
+
|
542 |
+
if adaptive_layer_norm:
|
543 |
+
norm1 = layer_norm_cls(
|
544 |
+
d_model, eps=layer_norm_eps, **factory_kwargs
|
545 |
+
)
|
546 |
+
norm2 = layer_norm_cls(
|
547 |
+
d_model, eps=layer_norm_eps, **factory_kwargs
|
548 |
+
)
|
549 |
+
norm3 = layer_norm_cls(
|
550 |
+
d_model, eps=layer_norm_eps, **factory_kwargs
|
551 |
+
)
|
552 |
+
|
553 |
+
self.norm1 = AdaptiveLayerNorm(d_model, norm1)
|
554 |
+
self.norm2 = AdaptiveLayerNorm(d_model, norm2)
|
555 |
+
self.norm3 = AdaptiveLayerNorm(d_model, norm3)
|
556 |
+
else:
|
557 |
+
self.norm1 = layer_norm_cls(
|
558 |
+
d_model, eps=layer_norm_eps, **factory_kwargs
|
559 |
+
)
|
560 |
+
self.norm2 = layer_norm_cls(
|
561 |
+
d_model, eps=layer_norm_eps, **factory_kwargs
|
562 |
+
)
|
563 |
+
if layer_norm_cls == IdentityNorm:
|
564 |
+
self.norm3 = BalancedBasicNorm(
|
565 |
+
d_model, eps=layer_norm_eps, **factory_kwargs
|
566 |
+
)
|
567 |
+
else:
|
568 |
+
self.norm3 = layer_norm_cls(
|
569 |
+
d_model, eps=layer_norm_eps, **factory_kwargs
|
570 |
+
)
|
571 |
+
|
572 |
+
def forward(
|
573 |
+
self,
|
574 |
+
tgt: Tensor,
|
575 |
+
memory: Tensor,
|
576 |
+
tgt_mask: Optional[Tensor] = None,
|
577 |
+
memory_mask: Optional[Tensor] = None,
|
578 |
+
tgt_key_padding_mask: Optional[Tensor] = None,
|
579 |
+
memory_key_padding_mask: Optional[Tensor] = None,
|
580 |
+
) -> Tensor:
|
581 |
+
r"""Pass the inputs (and mask) through the decoder layer.
|
582 |
+
|
583 |
+
Args:
|
584 |
+
tgt: the sequence to the decoder layer (required).
|
585 |
+
memory: the sequence from the last layer of the encoder (required).
|
586 |
+
tgt_mask: the mask for the tgt sequence (optional).
|
587 |
+
memory_mask: the mask for the memory sequence (optional).
|
588 |
+
tgt_key_padding_mask: the mask for the tgt keys per batch (optional).
|
589 |
+
memory_key_padding_mask: the mask for the memory keys per batch (optional).
|
590 |
+
|
591 |
+
Shape:
|
592 |
+
see the docs in Transformer class.
|
593 |
+
"""
|
594 |
+
tgt_is_tuple = False
|
595 |
+
if isinstance(tgt, tuple):
|
596 |
+
x, stage_embedding = tgt
|
597 |
+
tgt_is_tuple = True
|
598 |
+
else:
|
599 |
+
x, stage_embedding = tgt, None
|
600 |
+
|
601 |
+
if self.norm_first:
|
602 |
+
x = x + self._sa_block(
|
603 |
+
self.norm1(x, stage_embedding), tgt_mask, tgt_key_padding_mask
|
604 |
+
)
|
605 |
+
x = x + self._mha_block(
|
606 |
+
self.norm2(x, stage_embedding),
|
607 |
+
memory,
|
608 |
+
memory_mask,
|
609 |
+
memory_key_padding_mask,
|
610 |
+
)
|
611 |
+
x = x + self._ff_block(self.norm3(x, stage_embedding))
|
612 |
+
else:
|
613 |
+
x = self.norm1(
|
614 |
+
x + self._sa_block(x, tgt_mask, tgt_key_padding_mask),
|
615 |
+
stage_embedding,
|
616 |
+
)
|
617 |
+
x = self.norm2(
|
618 |
+
x
|
619 |
+
+ self._mha_block(
|
620 |
+
x, memory, memory_mask, memory_key_padding_mask
|
621 |
+
),
|
622 |
+
stage_embedding,
|
623 |
+
)
|
624 |
+
x = self.norm3(x + self._ff_block(x), stage_embedding)
|
625 |
+
|
626 |
+
if tgt_is_tuple:
|
627 |
+
return (x, stage_embedding)
|
628 |
+
return x
|
629 |
+
|
630 |
+
# self-attention block
|
631 |
+
def _sa_block(
|
632 |
+
self,
|
633 |
+
x: Tensor,
|
634 |
+
attn_mask: Optional[Tensor],
|
635 |
+
key_padding_mask: Optional[Tensor],
|
636 |
+
) -> Tensor:
|
637 |
+
x = self.self_attn(
|
638 |
+
x,
|
639 |
+
x,
|
640 |
+
x,
|
641 |
+
attn_mask=attn_mask,
|
642 |
+
key_padding_mask=key_padding_mask,
|
643 |
+
need_weights=False,
|
644 |
+
)[0]
|
645 |
+
return self.dropout1(x)
|
646 |
+
|
647 |
+
# multihead attention block
|
648 |
+
def _mha_block(
|
649 |
+
self,
|
650 |
+
x: Tensor,
|
651 |
+
mem: Tensor,
|
652 |
+
attn_mask: Optional[Tensor],
|
653 |
+
key_padding_mask: Optional[Tensor],
|
654 |
+
) -> Tensor:
|
655 |
+
x = self.multihead_attn(
|
656 |
+
x,
|
657 |
+
mem,
|
658 |
+
mem,
|
659 |
+
attn_mask=attn_mask,
|
660 |
+
key_padding_mask=key_padding_mask,
|
661 |
+
need_weights=False,
|
662 |
+
)[0]
|
663 |
+
return self.dropout2(x)
|
664 |
+
|
665 |
+
# feed forward block
|
666 |
+
def _ff_block(self, x: Tensor) -> Tensor:
|
667 |
+
x = self.linear2(self.dropout(self.activation(self.linear1(x))))
|
668 |
+
return self.dropout3(x)
|
669 |
+
|
670 |
+
|
671 |
+
def _get_clones(module, N):
|
672 |
+
return nn.ModuleList([copy.deepcopy(module) for i in range(N)])
|
673 |
+
|
674 |
+
|
675 |
+
def _get_activation_fn(activation: str) -> Callable[[Tensor], Tensor]:
|
676 |
+
if activation == "relu":
|
677 |
+
return F.relu
|
678 |
+
elif activation == "gelu":
|
679 |
+
return F.gelu
|
680 |
+
|
681 |
+
raise RuntimeError(
|
682 |
+
"activation should be relu/gelu, not {}".format(activation)
|
683 |
+
)
|
nltk_data/tokenizers/punkt/.DS_Store
ADDED
Binary file (6.15 kB). View file
|
|
nltk_data/tokenizers/punkt/PY3/README
ADDED
@@ -0,0 +1,98 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Pretrained Punkt Models -- Jan Strunk (New version trained after issues 313 and 514 had been corrected)
|
2 |
+
|
3 |
+
Most models were prepared using the test corpora from Kiss and Strunk (2006). Additional models have
|
4 |
+
been contributed by various people using NLTK for sentence boundary detection.
|
5 |
+
|
6 |
+
For information about how to use these models, please confer the tokenization HOWTO:
|
7 |
+
http://nltk.googlecode.com/svn/trunk/doc/howto/tokenize.html
|
8 |
+
and chapter 3.8 of the NLTK book:
|
9 |
+
http://nltk.googlecode.com/svn/trunk/doc/book/ch03.html#sec-segmentation
|
10 |
+
|
11 |
+
There are pretrained tokenizers for the following languages:
|
12 |
+
|
13 |
+
File Language Source Contents Size of training corpus(in tokens) Model contributed by
|
14 |
+
=======================================================================================================================================================================
|
15 |
+
czech.pickle Czech Multilingual Corpus 1 (ECI) Lidove Noviny ~345,000 Jan Strunk / Tibor Kiss
|
16 |
+
Literarni Noviny
|
17 |
+
-----------------------------------------------------------------------------------------------------------------------------------------------------------------------
|
18 |
+
danish.pickle Danish Avisdata CD-Rom Ver. 1.1. 1995 Berlingske Tidende ~550,000 Jan Strunk / Tibor Kiss
|
19 |
+
(Berlingske Avisdata, Copenhagen) Weekend Avisen
|
20 |
+
-----------------------------------------------------------------------------------------------------------------------------------------------------------------------
|
21 |
+
dutch.pickle Dutch Multilingual Corpus 1 (ECI) De Limburger ~340,000 Jan Strunk / Tibor Kiss
|
22 |
+
-----------------------------------------------------------------------------------------------------------------------------------------------------------------------
|
23 |
+
english.pickle English Penn Treebank (LDC) Wall Street Journal ~469,000 Jan Strunk / Tibor Kiss
|
24 |
+
(American)
|
25 |
+
-----------------------------------------------------------------------------------------------------------------------------------------------------------------------
|
26 |
+
estonian.pickle Estonian University of Tartu, Estonia Eesti Ekspress ~359,000 Jan Strunk / Tibor Kiss
|
27 |
+
-----------------------------------------------------------------------------------------------------------------------------------------------------------------------
|
28 |
+
finnish.pickle Finnish Finnish Parole Corpus, Finnish Books and major national ~364,000 Jan Strunk / Tibor Kiss
|
29 |
+
Text Bank (Suomen Kielen newspapers
|
30 |
+
Tekstipankki)
|
31 |
+
Finnish Center for IT Science
|
32 |
+
(CSC)
|
33 |
+
-----------------------------------------------------------------------------------------------------------------------------------------------------------------------
|
34 |
+
french.pickle French Multilingual Corpus 1 (ECI) Le Monde ~370,000 Jan Strunk / Tibor Kiss
|
35 |
+
(European)
|
36 |
+
-----------------------------------------------------------------------------------------------------------------------------------------------------------------------
|
37 |
+
german.pickle German Neue Zürcher Zeitung AG Neue Zürcher Zeitung ~847,000 Jan Strunk / Tibor Kiss
|
38 |
+
(Switzerland) CD-ROM
|
39 |
+
(Uses "ss"
|
40 |
+
instead of "ß")
|
41 |
+
-----------------------------------------------------------------------------------------------------------------------------------------------------------------------
|
42 |
+
greek.pickle Greek Efstathios Stamatatos To Vima (TO BHMA) ~227,000 Jan Strunk / Tibor Kiss
|
43 |
+
-----------------------------------------------------------------------------------------------------------------------------------------------------------------------
|
44 |
+
italian.pickle Italian Multilingual Corpus 1 (ECI) La Stampa, Il Mattino ~312,000 Jan Strunk / Tibor Kiss
|
45 |
+
-----------------------------------------------------------------------------------------------------------------------------------------------------------------------
|
46 |
+
norwegian.pickle Norwegian Centre for Humanities Bergens Tidende ~479,000 Jan Strunk / Tibor Kiss
|
47 |
+
(Bokmål and Information Technologies,
|
48 |
+
Nynorsk) Bergen
|
49 |
+
-----------------------------------------------------------------------------------------------------------------------------------------------------------------------
|
50 |
+
polish.pickle Polish Polish National Corpus Literature, newspapers, etc. ~1,000,000 Krzysztof Langner
|
51 |
+
(http://www.nkjp.pl/)
|
52 |
+
-----------------------------------------------------------------------------------------------------------------------------------------------------------------------
|
53 |
+
portuguese.pickle Portuguese CETENFolha Corpus Folha de São Paulo ~321,000 Jan Strunk / Tibor Kiss
|
54 |
+
(Brazilian) (Linguateca)
|
55 |
+
-----------------------------------------------------------------------------------------------------------------------------------------------------------------------
|
56 |
+
slovene.pickle Slovene TRACTOR Delo ~354,000 Jan Strunk / Tibor Kiss
|
57 |
+
Slovene Academy for Arts
|
58 |
+
and Sciences
|
59 |
+
-----------------------------------------------------------------------------------------------------------------------------------------------------------------------
|
60 |
+
spanish.pickle Spanish Multilingual Corpus 1 (ECI) Sur ~353,000 Jan Strunk / Tibor Kiss
|
61 |
+
(European)
|
62 |
+
-----------------------------------------------------------------------------------------------------------------------------------------------------------------------
|
63 |
+
swedish.pickle Swedish Multilingual Corpus 1 (ECI) Dagens Nyheter ~339,000 Jan Strunk / Tibor Kiss
|
64 |
+
(and some other texts)
|
65 |
+
-----------------------------------------------------------------------------------------------------------------------------------------------------------------------
|
66 |
+
turkish.pickle Turkish METU Turkish Corpus Milliyet ~333,000 Jan Strunk / Tibor Kiss
|
67 |
+
(Türkçe Derlem Projesi)
|
68 |
+
University of Ankara
|
69 |
+
-----------------------------------------------------------------------------------------------------------------------------------------------------------------------
|
70 |
+
|
71 |
+
The corpora contained about 400,000 tokens on average and mostly consisted of newspaper text converted to
|
72 |
+
Unicode using the codecs module.
|
73 |
+
|
74 |
+
Kiss, Tibor and Strunk, Jan (2006): Unsupervised Multilingual Sentence Boundary Detection.
|
75 |
+
Computational Linguistics 32: 485-525.
|
76 |
+
|
77 |
+
---- Training Code ----
|
78 |
+
|
79 |
+
# import punkt
|
80 |
+
import nltk.tokenize.punkt
|
81 |
+
|
82 |
+
# Make a new Tokenizer
|
83 |
+
tokenizer = nltk.tokenize.punkt.PunktSentenceTokenizer()
|
84 |
+
|
85 |
+
# Read in training corpus (one example: Slovene)
|
86 |
+
import codecs
|
87 |
+
text = codecs.open("slovene.plain","Ur","iso-8859-2").read()
|
88 |
+
|
89 |
+
# Train tokenizer
|
90 |
+
tokenizer.train(text)
|
91 |
+
|
92 |
+
# Dump pickled tokenizer
|
93 |
+
import pickle
|
94 |
+
out = open("slovene.pickle","wb")
|
95 |
+
pickle.dump(tokenizer, out)
|
96 |
+
out.close()
|
97 |
+
|
98 |
+
---------
|
nltk_data/tokenizers/punkt/PY3/czech.pickle
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:64b0734b6fbe8e8d7cac79f48d1dd9f853824e57c4e3594dadd74ba2c1d97f50
|
3 |
+
size 1119050
|
nltk_data/tokenizers/punkt/PY3/danish.pickle
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:6189c7dd254e29e2bd406a7f6a4336297c8953214792466a790ea4444223ceb3
|
3 |
+
size 1191710
|
nltk_data/tokenizers/punkt/PY3/dutch.pickle
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:fda0d6a13f02e8898daec7fe923da88e25abe081bcfa755c0e015075c215fe4c
|
3 |
+
size 693759
|
nltk_data/tokenizers/punkt/PY3/english.pickle
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:5cad3758596392364e3be9803dbd7ebeda384b68937b488a01365f5551bb942c
|
3 |
+
size 406697
|
nltk_data/tokenizers/punkt/PY3/estonian.pickle
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:b364f72538d17b146a98009ad239a8096ce6c0a8b02958c0bc776ecd0c58a25f
|
3 |
+
size 1499502
|
nltk_data/tokenizers/punkt/PY3/finnish.pickle
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:6a4b5ff5500ee851c456f9dd40d5fc0d8c1859c88eb3178de1317d26b7d22833
|
3 |
+
size 1852226
|
nltk_data/tokenizers/punkt/PY3/french.pickle
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:28e3a4cd2971989b3cb9fd3433a6f15d17981e464db2be039364313b5de94f29
|
3 |
+
size 553575
|
nltk_data/tokenizers/punkt/PY3/german.pickle
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:ddcbbe85e2042a019b1a6e37fd8c153286c38ba201fae0f5bfd9a3f74abae25c
|
3 |
+
size 1463575
|