admin commited on
Commit
796529d
·
1 Parent(s): c87467b
.gitattributes CHANGED
@@ -1,35 +1,34 @@
1
  *.7z filter=lfs diff=lfs merge=lfs -text
2
  *.arrow filter=lfs diff=lfs merge=lfs -text
3
  *.bin filter=lfs diff=lfs merge=lfs -text
 
4
  *.bz2 filter=lfs diff=lfs merge=lfs -text
5
- *.ckpt filter=lfs diff=lfs merge=lfs -text
6
  *.ftz filter=lfs diff=lfs merge=lfs -text
7
  *.gz filter=lfs diff=lfs merge=lfs -text
8
  *.h5 filter=lfs diff=lfs merge=lfs -text
9
  *.joblib filter=lfs diff=lfs merge=lfs -text
10
  *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
- *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
  *.model filter=lfs diff=lfs merge=lfs -text
13
  *.msgpack filter=lfs diff=lfs merge=lfs -text
14
- *.npy filter=lfs diff=lfs merge=lfs -text
15
- *.npz filter=lfs diff=lfs merge=lfs -text
16
  *.onnx filter=lfs diff=lfs merge=lfs -text
17
  *.ot filter=lfs diff=lfs merge=lfs -text
18
  *.parquet filter=lfs diff=lfs merge=lfs -text
19
  *.pb filter=lfs diff=lfs merge=lfs -text
20
- *.pickle filter=lfs diff=lfs merge=lfs -text
21
- *.pkl filter=lfs diff=lfs merge=lfs -text
22
  *.pt filter=lfs diff=lfs merge=lfs -text
23
  *.pth filter=lfs diff=lfs merge=lfs -text
24
  *.rar filter=lfs diff=lfs merge=lfs -text
25
- *.safetensors filter=lfs diff=lfs merge=lfs -text
26
  saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
  *.tar.* filter=lfs diff=lfs merge=lfs -text
28
- *.tar filter=lfs diff=lfs merge=lfs -text
29
  *.tflite filter=lfs diff=lfs merge=lfs -text
30
  *.tgz filter=lfs diff=lfs merge=lfs -text
31
- *.wasm filter=lfs diff=lfs merge=lfs -text
32
  *.xz filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
- *.zst filter=lfs diff=lfs merge=lfs -text
35
- *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
1
  *.7z filter=lfs diff=lfs merge=lfs -text
2
  *.arrow filter=lfs diff=lfs merge=lfs -text
3
  *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bin.* filter=lfs diff=lfs merge=lfs -text
5
  *.bz2 filter=lfs diff=lfs merge=lfs -text
 
6
  *.ftz filter=lfs diff=lfs merge=lfs -text
7
  *.gz filter=lfs diff=lfs merge=lfs -text
8
  *.h5 filter=lfs diff=lfs merge=lfs -text
9
  *.joblib filter=lfs diff=lfs merge=lfs -text
10
  *.lfs.* filter=lfs diff=lfs merge=lfs -text
 
11
  *.model filter=lfs diff=lfs merge=lfs -text
12
  *.msgpack filter=lfs diff=lfs merge=lfs -text
 
 
13
  *.onnx filter=lfs diff=lfs merge=lfs -text
14
  *.ot filter=lfs diff=lfs merge=lfs -text
15
  *.parquet filter=lfs diff=lfs merge=lfs -text
16
  *.pb filter=lfs diff=lfs merge=lfs -text
 
 
17
  *.pt filter=lfs diff=lfs merge=lfs -text
18
  *.pth filter=lfs diff=lfs merge=lfs -text
19
  *.rar filter=lfs diff=lfs merge=lfs -text
 
20
  saved_model/**/* filter=lfs diff=lfs merge=lfs -text
21
  *.tar.* filter=lfs diff=lfs merge=lfs -text
 
22
  *.tflite filter=lfs diff=lfs merge=lfs -text
23
  *.tgz filter=lfs diff=lfs merge=lfs -text
 
24
  *.xz filter=lfs diff=lfs merge=lfs -text
25
  *.zip filter=lfs diff=lfs merge=lfs -text
26
+ *.zstandard filter=lfs diff=lfs merge=lfs -text
27
+ *.tfevents* filter=lfs diff=lfs merge=lfs -text
28
+ *.db* filter=lfs diff=lfs merge=lfs -text
29
+ *.ark* filter=lfs diff=lfs merge=lfs -text
30
+ **/*ckpt*data* filter=lfs diff=lfs merge=lfs -text
31
+ **/*ckpt*.meta filter=lfs diff=lfs merge=lfs -text
32
+ **/*ckpt*.index filter=lfs diff=lfs merge=lfs -text
33
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
34
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ */__pycache__/*
2
+ __pycache__/*
3
+ *.pth
4
+ *.json
app.py ADDED
@@ -0,0 +1,286 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from text.symbols import symbols
2
+ from text.cleaner import clean_text
3
+ from text import cleaned_text_to_sequence, get_bert
4
+ from modelscope import snapshot_download
5
+ from models import SynthesizerTrn
6
+ from tqdm import tqdm
7
+ import gradio as gr
8
+ import numpy as np
9
+ import commons
10
+ import random
11
+ import utils
12
+ import torch
13
+ import sys
14
+ import re
15
+ import os
16
+
17
+ if sys.platform == "darwin":
18
+ os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
19
+
20
+ import logging
21
+
22
+ logging.getLogger("numba").setLevel(logging.WARNING)
23
+ logging.getLogger("markdown_it").setLevel(logging.WARNING)
24
+ logging.getLogger("urllib3").setLevel(logging.WARNING)
25
+ logging.getLogger("matplotlib").setLevel(logging.WARNING)
26
+ logging.basicConfig(
27
+ level=logging.INFO, format="| %(name)s | %(levelname)s | %(message)s"
28
+ )
29
+
30
+ logger = logging.getLogger(__name__)
31
+ net_g = None
32
+ debug = False
33
+
34
+
35
+ def get_text(text, language_str, hps):
36
+ norm_text, phone, tone, word2ph = clean_text(text, language_str)
37
+ phone, tone, language = cleaned_text_to_sequence(phone, tone, language_str)
38
+ if hps.data.add_blank:
39
+ phone = commons.intersperse(phone, 0)
40
+ tone = commons.intersperse(tone, 0)
41
+ language = commons.intersperse(language, 0)
42
+ for i in range(len(word2ph)):
43
+ word2ph[i] = word2ph[i] * 2
44
+
45
+ word2ph[0] += 1
46
+
47
+ bert = get_bert(norm_text, word2ph, language_str)
48
+ del word2ph
49
+ assert bert.shape[-1] == len(phone)
50
+ phone = torch.LongTensor(phone)
51
+ tone = torch.LongTensor(tone)
52
+ language = torch.LongTensor(language)
53
+ return bert, phone, tone, language
54
+
55
+
56
+ def infer(text, sdp_ratio, noise_scale, noise_scale_w, length_scale, sid):
57
+ global net_g
58
+ bert, phones, tones, lang_ids = get_text(text, "ZH", hps)
59
+ with torch.no_grad():
60
+ x_tst = phones.to(device).unsqueeze(0)
61
+ tones = tones.to(device).unsqueeze(0)
62
+ lang_ids = lang_ids.to(device).unsqueeze(0)
63
+ bert = bert.to(device).unsqueeze(0)
64
+ x_tst_lengths = torch.LongTensor([phones.size(0)]).to(device)
65
+ del phones
66
+ speakers = torch.LongTensor([hps.data.spk2id[sid]]).to(device)
67
+ audio = (
68
+ net_g.infer(
69
+ x_tst,
70
+ x_tst_lengths,
71
+ speakers,
72
+ tones,
73
+ lang_ids,
74
+ bert,
75
+ sdp_ratio=sdp_ratio,
76
+ noise_scale=noise_scale,
77
+ noise_scale_w=noise_scale_w,
78
+ length_scale=length_scale,
79
+ )[0][0, 0]
80
+ .data.cpu()
81
+ .float()
82
+ .numpy()
83
+ )
84
+ del x_tst, tones, lang_ids, bert, x_tst_lengths, speakers
85
+ return audio
86
+
87
+
88
+ def tts_fn(text, speaker, sdp_ratio, noise_scale, noise_scale_w, length_scale):
89
+ with torch.no_grad():
90
+ audio = infer(
91
+ text,
92
+ sdp_ratio=sdp_ratio,
93
+ noise_scale=noise_scale,
94
+ noise_scale_w=noise_scale_w,
95
+ length_scale=length_scale,
96
+ sid=speaker,
97
+ )
98
+
99
+ return (hps.data.sampling_rate, audio)
100
+
101
+
102
+ def text_splitter(text: str):
103
+ punctuation = r"[。,;,!,?,〜,\n,\r,\t,.,!,;,?,~, ]"
104
+ # 使用正则表达式根据标点符号分割文本,并忽略重叠的分隔符
105
+ sentences = re.split(punctuation, text.strip())
106
+ # 过滤掉空字符串
107
+ return [sentence.strip() for sentence in sentences if sentence.strip()]
108
+
109
+
110
+ def concatenate_audios(audio_samples, sample_rate=44100):
111
+ half_second_silence = np.zeros(int(sample_rate / 2))
112
+ # 初始化最终的音频数组
113
+ final_audio = audio_samples[0]
114
+ # 遍历音频样本列表,并将它们连接起来,每个样本之间插入半秒钟的静音
115
+ for sample in audio_samples[1:]:
116
+ final_audio = np.concatenate((final_audio, half_second_silence, sample))
117
+
118
+ print("Audio pieces concatenated!")
119
+ return (sample_rate, final_audio)
120
+
121
+
122
+ def read_text(file_path: str):
123
+ try:
124
+ # 打开文件并读取内容
125
+ with open(file_path, "r", encoding="utf-8") as file:
126
+ content = file.read()
127
+ return content
128
+
129
+ except FileNotFoundError:
130
+ print(f"文件未找到: {file_path}")
131
+
132
+ except IOError:
133
+ print(f"读取文件时发生错误: {file_path}")
134
+
135
+ except Exception as e:
136
+ print(f"发生未知错误: {e}")
137
+
138
+
139
+ def infer_tab1(text, speaker, sdp_ratio, noise_scale, noise_scale_w, length_scale):
140
+ try:
141
+ content = read_text(text)
142
+ sentences = text_splitter(content)
143
+ audios = []
144
+ for sentence in tqdm(sentences, desc="TTS inferring..."):
145
+ with torch.no_grad():
146
+ audios.append(
147
+ infer(
148
+ sentence,
149
+ sdp_ratio=sdp_ratio,
150
+ noise_scale=noise_scale,
151
+ noise_scale_w=noise_scale_w,
152
+ length_scale=length_scale,
153
+ sid=speaker,
154
+ )
155
+ )
156
+
157
+ return concatenate_audios(audios, hps.data.sampling_rate), content
158
+
159
+ except Exception as e:
160
+ return None, f"{e}"
161
+
162
+
163
+ def infer_tab2(content, speaker, sdp_ratio, noise_scale, noise_scale_w, length_scale):
164
+ try:
165
+ sentences = text_splitter(content)
166
+ audios = []
167
+ for sentence in tqdm(sentences, desc="TTS inferring..."):
168
+ with torch.no_grad():
169
+ audios.append(
170
+ infer(
171
+ sentence,
172
+ sdp_ratio=sdp_ratio,
173
+ noise_scale=noise_scale,
174
+ noise_scale_w=noise_scale_w,
175
+ length_scale=length_scale,
176
+ sid=speaker,
177
+ )
178
+ )
179
+
180
+ return concatenate_audios(audios, hps.data.sampling_rate)
181
+
182
+ except Exception as e:
183
+ print(f"{e}")
184
+ return None
185
+
186
+
187
+ if __name__ == "__main__":
188
+ model_dir = snapshot_download("Genius-Society/hoyoTTS", cache_dir="./__pycache__")
189
+ if debug:
190
+ logger.info("Enable DEBUG-LEVEL log")
191
+ logging.basicConfig(level=logging.DEBUG)
192
+
193
+ hps = utils.get_hparams_from_dir(model_dir)
194
+ device = (
195
+ "cuda:0"
196
+ if torch.cuda.is_available()
197
+ else (
198
+ "mps"
199
+ if sys.platform == "darwin" and torch.backends.mps.is_available()
200
+ else "cpu"
201
+ )
202
+ )
203
+ net_g = SynthesizerTrn(
204
+ len(symbols),
205
+ hps.data.filter_length // 2 + 1,
206
+ hps.train.segment_size // hps.data.hop_length,
207
+ n_speakers=hps.data.n_speakers,
208
+ **hps.model,
209
+ ).to(device)
210
+ net_g.eval()
211
+ utils.load_checkpoint(f"{model_dir}/G_78000.pth", net_g, None, skip_optimizer=True)
212
+ speaker_ids = hps.data.spk2id
213
+ speakers = list(speaker_ids.keys())
214
+ random.shuffle(speakers)
215
+ with gr.Blocks() as app:
216
+ gr.Markdown(
217
+ """
218
+ <center>
219
+ 欢迎使用此创空间, 此创空间基于 <a href="https://github.com/fishaudio/Bert-VITS2">Bert-vits2</a> 开源项目制作,完全免费。使用此创空间必须遵守当地相关法律法规,禁止用其从事任何违法犯罪活动。首次推理需耗时下载模型,还请耐心等待。另外,移至最底端有原理浅讲。
220
+ </center>
221
+ """
222
+ )
223
+
224
+ with gr.Tab("输入模式"):
225
+ gr.Interface(
226
+ fn=infer_tab2, # 使用 text_to_speech 函数
227
+ inputs=[
228
+ gr.TextArea(label="请输入简体中文文案", show_copy_button=True),
229
+ gr.Dropdown(choices=speakers, value="莱依拉", label="角色"),
230
+ gr.Slider(
231
+ minimum=0, maximum=1, value=0.2, step=0.1, label="语调调节"
232
+ ), # SDP/DP混合比
233
+ gr.Slider(
234
+ minimum=0.1, maximum=2, value=0.6, step=0.1, label="感情调节"
235
+ ),
236
+ gr.Slider(
237
+ minimum=0.1, maximum=2, value=0.8, step=0.1, label="音素长度"
238
+ ),
239
+ gr.Slider(
240
+ minimum=0.1, maximum=2, value=1, step=0.1, label="生成时长"
241
+ ),
242
+ ],
243
+ outputs=gr.Audio(label="输出音频"),
244
+ flagging_mode="never",
245
+ concurrency_limit=4,
246
+ )
247
+
248
+ with gr.Tab("上传模式"):
249
+ gr.Interface(
250
+ fn=infer_tab1, # 使用 text_to_speech 函数
251
+ inputs=[
252
+ gr.components.File(
253
+ label="请上传简体中文TXT文案",
254
+ type="filepath",
255
+ file_types=[".txt"],
256
+ ),
257
+ gr.Dropdown(choices=speakers, value="莱依拉", label="角色"),
258
+ gr.Slider(
259
+ minimum=0, maximum=1, value=0.2, step=0.1, label="语调调节"
260
+ ), # SDP/DP混合比
261
+ gr.Slider(
262
+ minimum=0.1, maximum=2, value=0.6, step=0.1, label="感情调节"
263
+ ),
264
+ gr.Slider(
265
+ minimum=0.1, maximum=2, value=0.8, step=0.1, label="音素长度"
266
+ ),
267
+ gr.Slider(
268
+ minimum=0.1, maximum=2, value=1, step=0.1, label="生成时长"
269
+ ),
270
+ ],
271
+ outputs=[
272
+ gr.Audio(label="输出音频"),
273
+ gr.TextArea(label="文案提取结果", show_copy_button=True),
274
+ ],
275
+ flagging_mode="never",
276
+ concurrency_limit=4,
277
+ )
278
+
279
+ gr.HTML(
280
+ """
281
+ <iframe src="//player.bilibili.com/player.html?bvid=BV1gXDZYnECi&autoplay=0" scrolling="no" border="0" frameborder="no" framespacing="0" allowfullscreen="true" width="100%" style="aspect-ratio: 16 / 9;">
282
+ </iframe>
283
+ """
284
+ )
285
+
286
+ app.launch()
attentions.py ADDED
@@ -0,0 +1,464 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import commons
3
+ import logging
4
+ import torch
5
+ from torch import nn
6
+ from torch.nn import functional as F
7
+
8
+
9
+ logger = logging.getLogger(__name__)
10
+
11
+
12
+ class LayerNorm(nn.Module):
13
+ def __init__(self, channels, eps=1e-5):
14
+ super().__init__()
15
+ self.channels = channels
16
+ self.eps = eps
17
+
18
+ self.gamma = nn.Parameter(torch.ones(channels))
19
+ self.beta = nn.Parameter(torch.zeros(channels))
20
+
21
+ def forward(self, x):
22
+ x = x.transpose(1, -1)
23
+ x = F.layer_norm(x, (self.channels,), self.gamma, self.beta, self.eps)
24
+ return x.transpose(1, -1)
25
+
26
+
27
+ @torch.jit.script
28
+ def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels):
29
+ n_channels_int = n_channels[0]
30
+ in_act = input_a + input_b
31
+ t_act = torch.tanh(in_act[:, :n_channels_int, :])
32
+ s_act = torch.sigmoid(in_act[:, n_channels_int:, :])
33
+ acts = t_act * s_act
34
+ return acts
35
+
36
+
37
+ class Encoder(nn.Module):
38
+ def __init__(
39
+ self,
40
+ hidden_channels,
41
+ filter_channels,
42
+ n_heads,
43
+ n_layers,
44
+ kernel_size=1,
45
+ p_dropout=0.0,
46
+ window_size=4,
47
+ isflow=True,
48
+ **kwargs
49
+ ):
50
+ super().__init__()
51
+ self.hidden_channels = hidden_channels
52
+ self.filter_channels = filter_channels
53
+ self.n_heads = n_heads
54
+ self.n_layers = n_layers
55
+ self.kernel_size = kernel_size
56
+ self.p_dropout = p_dropout
57
+ self.window_size = window_size
58
+ # if isflow:
59
+ # cond_layer = torch.nn.Conv1d(256, 2*hidden_channels*n_layers, 1)
60
+ # self.cond_pre = torch.nn.Conv1d(hidden_channels, 2*hidden_channels, 1)
61
+ # self.cond_layer = weight_norm(cond_layer, name='weight')
62
+ # self.gin_channels = 256
63
+ self.cond_layer_idx = self.n_layers
64
+ if "gin_channels" in kwargs:
65
+ self.gin_channels = kwargs["gin_channels"]
66
+ if self.gin_channels != 0:
67
+ self.spk_emb_linear = nn.Linear(self.gin_channels, self.hidden_channels)
68
+ # vits2 says 3rd block, so idx is 2 by default
69
+ self.cond_layer_idx = (
70
+ kwargs["cond_layer_idx"] if "cond_layer_idx" in kwargs else 2
71
+ )
72
+ logging.debug(self.gin_channels, self.cond_layer_idx)
73
+ assert (
74
+ self.cond_layer_idx < self.n_layers
75
+ ), "cond_layer_idx should be less than n_layers"
76
+ self.drop = nn.Dropout(p_dropout)
77
+ self.attn_layers = nn.ModuleList()
78
+ self.norm_layers_1 = nn.ModuleList()
79
+ self.ffn_layers = nn.ModuleList()
80
+ self.norm_layers_2 = nn.ModuleList()
81
+ for i in range(self.n_layers):
82
+ self.attn_layers.append(
83
+ MultiHeadAttention(
84
+ hidden_channels,
85
+ hidden_channels,
86
+ n_heads,
87
+ p_dropout=p_dropout,
88
+ window_size=window_size,
89
+ )
90
+ )
91
+ self.norm_layers_1.append(LayerNorm(hidden_channels))
92
+ self.ffn_layers.append(
93
+ FFN(
94
+ hidden_channels,
95
+ hidden_channels,
96
+ filter_channels,
97
+ kernel_size,
98
+ p_dropout=p_dropout,
99
+ )
100
+ )
101
+ self.norm_layers_2.append(LayerNorm(hidden_channels))
102
+
103
+ def forward(self, x, x_mask, g=None):
104
+ attn_mask = x_mask.unsqueeze(2) * x_mask.unsqueeze(-1)
105
+ x = x * x_mask
106
+ for i in range(self.n_layers):
107
+ if i == self.cond_layer_idx and g is not None:
108
+ g = self.spk_emb_linear(g.transpose(1, 2))
109
+ g = g.transpose(1, 2)
110
+ x = x + g
111
+ x = x * x_mask
112
+ y = self.attn_layers[i](x, x, attn_mask)
113
+ y = self.drop(y)
114
+ x = self.norm_layers_1[i](x + y)
115
+
116
+ y = self.ffn_layers[i](x, x_mask)
117
+ y = self.drop(y)
118
+ x = self.norm_layers_2[i](x + y)
119
+ x = x * x_mask
120
+ return x
121
+
122
+
123
+ class Decoder(nn.Module):
124
+ def __init__(
125
+ self,
126
+ hidden_channels,
127
+ filter_channels,
128
+ n_heads,
129
+ n_layers,
130
+ kernel_size=1,
131
+ p_dropout=0.0,
132
+ proximal_bias=False,
133
+ proximal_init=True,
134
+ **kwargs
135
+ ):
136
+ super().__init__()
137
+ self.hidden_channels = hidden_channels
138
+ self.filter_channels = filter_channels
139
+ self.n_heads = n_heads
140
+ self.n_layers = n_layers
141
+ self.kernel_size = kernel_size
142
+ self.p_dropout = p_dropout
143
+ self.proximal_bias = proximal_bias
144
+ self.proximal_init = proximal_init
145
+
146
+ self.drop = nn.Dropout(p_dropout)
147
+ self.self_attn_layers = nn.ModuleList()
148
+ self.norm_layers_0 = nn.ModuleList()
149
+ self.encdec_attn_layers = nn.ModuleList()
150
+ self.norm_layers_1 = nn.ModuleList()
151
+ self.ffn_layers = nn.ModuleList()
152
+ self.norm_layers_2 = nn.ModuleList()
153
+ for i in range(self.n_layers):
154
+ self.self_attn_layers.append(
155
+ MultiHeadAttention(
156
+ hidden_channels,
157
+ hidden_channels,
158
+ n_heads,
159
+ p_dropout=p_dropout,
160
+ proximal_bias=proximal_bias,
161
+ proximal_init=proximal_init,
162
+ )
163
+ )
164
+ self.norm_layers_0.append(LayerNorm(hidden_channels))
165
+ self.encdec_attn_layers.append(
166
+ MultiHeadAttention(
167
+ hidden_channels, hidden_channels, n_heads, p_dropout=p_dropout
168
+ )
169
+ )
170
+ self.norm_layers_1.append(LayerNorm(hidden_channels))
171
+ self.ffn_layers.append(
172
+ FFN(
173
+ hidden_channels,
174
+ hidden_channels,
175
+ filter_channels,
176
+ kernel_size,
177
+ p_dropout=p_dropout,
178
+ causal=True,
179
+ )
180
+ )
181
+ self.norm_layers_2.append(LayerNorm(hidden_channels))
182
+
183
+ def forward(self, x, x_mask, h, h_mask):
184
+ """
185
+ x: decoder input
186
+ h: encoder output
187
+ """
188
+ self_attn_mask = commons.subsequent_mask(x_mask.size(2)).to(
189
+ device=x.device, dtype=x.dtype
190
+ )
191
+ encdec_attn_mask = h_mask.unsqueeze(2) * x_mask.unsqueeze(-1)
192
+ x = x * x_mask
193
+ for i in range(self.n_layers):
194
+ y = self.self_attn_layers[i](x, x, self_attn_mask)
195
+ y = self.drop(y)
196
+ x = self.norm_layers_0[i](x + y)
197
+
198
+ y = self.encdec_attn_layers[i](x, h, encdec_attn_mask)
199
+ y = self.drop(y)
200
+ x = self.norm_layers_1[i](x + y)
201
+
202
+ y = self.ffn_layers[i](x, x_mask)
203
+ y = self.drop(y)
204
+ x = self.norm_layers_2[i](x + y)
205
+ x = x * x_mask
206
+ return x
207
+
208
+
209
+ class MultiHeadAttention(nn.Module):
210
+ def __init__(
211
+ self,
212
+ channels,
213
+ out_channels,
214
+ n_heads,
215
+ p_dropout=0.0,
216
+ window_size=None,
217
+ heads_share=True,
218
+ block_length=None,
219
+ proximal_bias=False,
220
+ proximal_init=False,
221
+ ):
222
+ super().__init__()
223
+ assert channels % n_heads == 0
224
+
225
+ self.channels = channels
226
+ self.out_channels = out_channels
227
+ self.n_heads = n_heads
228
+ self.p_dropout = p_dropout
229
+ self.window_size = window_size
230
+ self.heads_share = heads_share
231
+ self.block_length = block_length
232
+ self.proximal_bias = proximal_bias
233
+ self.proximal_init = proximal_init
234
+ self.attn = None
235
+
236
+ self.k_channels = channels // n_heads
237
+ self.conv_q = nn.Conv1d(channels, channels, 1)
238
+ self.conv_k = nn.Conv1d(channels, channels, 1)
239
+ self.conv_v = nn.Conv1d(channels, channels, 1)
240
+ self.conv_o = nn.Conv1d(channels, out_channels, 1)
241
+ self.drop = nn.Dropout(p_dropout)
242
+
243
+ if window_size is not None:
244
+ n_heads_rel = 1 if heads_share else n_heads
245
+ rel_stddev = self.k_channels**-0.5
246
+ self.emb_rel_k = nn.Parameter(
247
+ torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels)
248
+ * rel_stddev
249
+ )
250
+ self.emb_rel_v = nn.Parameter(
251
+ torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels)
252
+ * rel_stddev
253
+ )
254
+
255
+ nn.init.xavier_uniform_(self.conv_q.weight)
256
+ nn.init.xavier_uniform_(self.conv_k.weight)
257
+ nn.init.xavier_uniform_(self.conv_v.weight)
258
+ if proximal_init:
259
+ with torch.no_grad():
260
+ self.conv_k.weight.copy_(self.conv_q.weight)
261
+ self.conv_k.bias.copy_(self.conv_q.bias)
262
+
263
+ def forward(self, x, c, attn_mask=None):
264
+ q = self.conv_q(x)
265
+ k = self.conv_k(c)
266
+ v = self.conv_v(c)
267
+
268
+ x, self.attn = self.attention(q, k, v, mask=attn_mask)
269
+
270
+ x = self.conv_o(x)
271
+ return x
272
+
273
+ def attention(self, query, key, value, mask=None):
274
+ # reshape [b, d, t] -> [b, n_h, t, d_k]
275
+ b, d, t_s, t_t = (*key.size(), query.size(2))
276
+ query = query.view(b, self.n_heads, self.k_channels, t_t).transpose(2, 3)
277
+ key = key.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3)
278
+ value = value.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3)
279
+
280
+ scores = torch.matmul(query / math.sqrt(self.k_channels), key.transpose(-2, -1))
281
+ if self.window_size is not None:
282
+ assert (
283
+ t_s == t_t
284
+ ), "Relative attention is only available for self-attention."
285
+ key_relative_embeddings = self._get_relative_embeddings(self.emb_rel_k, t_s)
286
+ rel_logits = self._matmul_with_relative_keys(
287
+ query / math.sqrt(self.k_channels), key_relative_embeddings
288
+ )
289
+ scores_local = self._relative_position_to_absolute_position(rel_logits)
290
+ scores = scores + scores_local
291
+ if self.proximal_bias:
292
+ assert t_s == t_t, "Proximal bias is only available for self-attention."
293
+ scores = scores + self._attention_bias_proximal(t_s).to(
294
+ device=scores.device, dtype=scores.dtype
295
+ )
296
+ if mask is not None:
297
+ scores = scores.masked_fill(mask == 0, -1e4)
298
+ if self.block_length is not None:
299
+ assert (
300
+ t_s == t_t
301
+ ), "Local attention is only available for self-attention."
302
+ block_mask = (
303
+ torch.ones_like(scores)
304
+ .triu(-self.block_length)
305
+ .tril(self.block_length)
306
+ )
307
+ scores = scores.masked_fill(block_mask == 0, -1e4)
308
+ p_attn = F.softmax(scores, dim=-1) # [b, n_h, t_t, t_s]
309
+ p_attn = self.drop(p_attn)
310
+ output = torch.matmul(p_attn, value)
311
+ if self.window_size is not None:
312
+ relative_weights = self._absolute_position_to_relative_position(p_attn)
313
+ value_relative_embeddings = self._get_relative_embeddings(
314
+ self.emb_rel_v, t_s
315
+ )
316
+ output = output + self._matmul_with_relative_values(
317
+ relative_weights, value_relative_embeddings
318
+ )
319
+ output = (
320
+ output.transpose(2, 3).contiguous().view(b, d, t_t)
321
+ ) # [b, n_h, t_t, d_k] -> [b, d, t_t]
322
+ return output, p_attn
323
+
324
+ def _matmul_with_relative_values(self, x, y):
325
+ """
326
+ x: [b, h, l, m]
327
+ y: [h or 1, m, d]
328
+ ret: [b, h, l, d]
329
+ """
330
+ ret = torch.matmul(x, y.unsqueeze(0))
331
+ return ret
332
+
333
+ def _matmul_with_relative_keys(self, x, y):
334
+ """
335
+ x: [b, h, l, d]
336
+ y: [h or 1, m, d]
337
+ ret: [b, h, l, m]
338
+ """
339
+ ret = torch.matmul(x, y.unsqueeze(0).transpose(-2, -1))
340
+ return ret
341
+
342
+ def _get_relative_embeddings(self, relative_embeddings, length):
343
+ max_relative_position = 2 * self.window_size + 1
344
+ # Pad first before slice to avoid using cond ops.
345
+ pad_length = max(length - (self.window_size + 1), 0)
346
+ slice_start_position = max((self.window_size + 1) - length, 0)
347
+ slice_end_position = slice_start_position + 2 * length - 1
348
+ if pad_length > 0:
349
+ padded_relative_embeddings = F.pad(
350
+ relative_embeddings,
351
+ commons.convert_pad_shape([[0, 0], [pad_length, pad_length], [0, 0]]),
352
+ )
353
+ else:
354
+ padded_relative_embeddings = relative_embeddings
355
+ used_relative_embeddings = padded_relative_embeddings[
356
+ :, slice_start_position:slice_end_position
357
+ ]
358
+ return used_relative_embeddings
359
+
360
+ def _relative_position_to_absolute_position(self, x):
361
+ """
362
+ x: [b, h, l, 2*l-1]
363
+ ret: [b, h, l, l]
364
+ """
365
+ batch, heads, length, _ = x.size()
366
+ # Concat columns of pad to shift from relative to absolute indexing.
367
+ x = F.pad(x, commons.convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, 1]]))
368
+
369
+ # Concat extra elements so to add up to shape (len+1, 2*len-1).
370
+ x_flat = x.view([batch, heads, length * 2 * length])
371
+ x_flat = F.pad(
372
+ x_flat, commons.convert_pad_shape([[0, 0], [0, 0], [0, length - 1]])
373
+ )
374
+
375
+ # Reshape and slice out the padded elements.
376
+ x_final = x_flat.view([batch, heads, length + 1, 2 * length - 1])[
377
+ :, :, :length, length - 1 :
378
+ ]
379
+ return x_final
380
+
381
+ def _absolute_position_to_relative_position(self, x):
382
+ """
383
+ x: [b, h, l, l]
384
+ ret: [b, h, l, 2*l-1]
385
+ """
386
+ batch, heads, length, _ = x.size()
387
+ # padd along column
388
+ x = F.pad(
389
+ x, commons.convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, length - 1]])
390
+ )
391
+ x_flat = x.view([batch, heads, length**2 + length * (length - 1)])
392
+ # add 0's in the beginning that will skew the elements after reshape
393
+ x_flat = F.pad(x_flat, commons.convert_pad_shape([[0, 0], [0, 0], [length, 0]]))
394
+ x_final = x_flat.view([batch, heads, length, 2 * length])[:, :, :, 1:]
395
+ return x_final
396
+
397
+ def _attention_bias_proximal(self, length):
398
+ """Bias for self-attention to encourage attention to close positions.
399
+ Args:
400
+ length: an integer scalar.
401
+ Returns:
402
+ a Tensor with shape [1, 1, length, length]
403
+ """
404
+ r = torch.arange(length, dtype=torch.float32)
405
+ diff = torch.unsqueeze(r, 0) - torch.unsqueeze(r, 1)
406
+ return torch.unsqueeze(torch.unsqueeze(-torch.log1p(torch.abs(diff)), 0), 0)
407
+
408
+
409
+ class FFN(nn.Module):
410
+ def __init__(
411
+ self,
412
+ in_channels,
413
+ out_channels,
414
+ filter_channels,
415
+ kernel_size,
416
+ p_dropout=0.0,
417
+ activation=None,
418
+ causal=False,
419
+ ):
420
+ super().__init__()
421
+ self.in_channels = in_channels
422
+ self.out_channels = out_channels
423
+ self.filter_channels = filter_channels
424
+ self.kernel_size = kernel_size
425
+ self.p_dropout = p_dropout
426
+ self.activation = activation
427
+ self.causal = causal
428
+
429
+ if causal:
430
+ self.padding = self._causal_padding
431
+ else:
432
+ self.padding = self._same_padding
433
+
434
+ self.conv_1 = nn.Conv1d(in_channels, filter_channels, kernel_size)
435
+ self.conv_2 = nn.Conv1d(filter_channels, out_channels, kernel_size)
436
+ self.drop = nn.Dropout(p_dropout)
437
+
438
+ def forward(self, x, x_mask):
439
+ x = self.conv_1(self.padding(x * x_mask))
440
+ if self.activation == "gelu":
441
+ x = x * torch.sigmoid(1.702 * x)
442
+ else:
443
+ x = torch.relu(x)
444
+ x = self.drop(x)
445
+ x = self.conv_2(self.padding(x * x_mask))
446
+ return x * x_mask
447
+
448
+ def _causal_padding(self, x):
449
+ if self.kernel_size == 1:
450
+ return x
451
+ pad_l = self.kernel_size - 1
452
+ pad_r = 0
453
+ padding = [[0, 0], [0, 0], [pad_l, pad_r]]
454
+ x = F.pad(x, commons.convert_pad_shape(padding))
455
+ return x
456
+
457
+ def _same_padding(self, x):
458
+ if self.kernel_size == 1:
459
+ return x
460
+ pad_l = (self.kernel_size - 1) // 2
461
+ pad_r = self.kernel_size // 2
462
+ padding = [[0, 0], [0, 0], [pad_l, pad_r]]
463
+ x = F.pad(x, commons.convert_pad_shape(padding))
464
+ return x
commons.py ADDED
@@ -0,0 +1,161 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ from torch.nn import functional as F
4
+
5
+
6
+ def init_weights(m, mean=0.0, std=0.01):
7
+ classname = m.__class__.__name__
8
+ if classname.find("Conv") != -1:
9
+ m.weight.data.normal_(mean, std)
10
+
11
+
12
+ def get_padding(kernel_size, dilation=1):
13
+ return int((kernel_size * dilation - dilation) / 2)
14
+
15
+
16
+ def convert_pad_shape(pad_shape):
17
+ l = pad_shape[::-1]
18
+ pad_shape = [item for sublist in l for item in sublist]
19
+ return pad_shape
20
+
21
+
22
+ def intersperse(lst, item):
23
+ result = [item] * (len(lst) * 2 + 1)
24
+ result[1::2] = lst
25
+ return result
26
+
27
+
28
+ def kl_divergence(m_p, logs_p, m_q, logs_q):
29
+ """KL(P||Q)"""
30
+ kl = (logs_q - logs_p) - 0.5
31
+ kl += (
32
+ 0.5 * (torch.exp(2.0 * logs_p) + ((m_p - m_q) ** 2)) * torch.exp(-2.0 * logs_q)
33
+ )
34
+ return kl
35
+
36
+
37
+ def rand_gumbel(shape):
38
+ """Sample from the Gumbel distribution, protect from overflows."""
39
+ uniform_samples = torch.rand(shape) * 0.99998 + 0.00001
40
+ return -torch.log(-torch.log(uniform_samples))
41
+
42
+
43
+ def rand_gumbel_like(x):
44
+ g = rand_gumbel(x.size()).to(dtype=x.dtype, device=x.device)
45
+ return g
46
+
47
+
48
+ def slice_segments(x, ids_str, segment_size=4):
49
+ ret = torch.zeros_like(x[:, :, :segment_size])
50
+ for i in range(x.size(0)):
51
+ idx_str = ids_str[i]
52
+ idx_end = idx_str + segment_size
53
+ ret[i] = x[i, :, idx_str:idx_end]
54
+ return ret
55
+
56
+
57
+ def rand_slice_segments(x, x_lengths=None, segment_size=4):
58
+ b, d, t = x.size()
59
+ if x_lengths is None:
60
+ x_lengths = t
61
+ ids_str_max = x_lengths - segment_size + 1
62
+ ids_str = (torch.rand([b]).to(device=x.device) * ids_str_max).to(dtype=torch.long)
63
+ ret = slice_segments(x, ids_str, segment_size)
64
+ return ret, ids_str
65
+
66
+
67
+ def get_timing_signal_1d(length, channels, min_timescale=1.0, max_timescale=1.0e4):
68
+ position = torch.arange(length, dtype=torch.float)
69
+ num_timescales = channels // 2
70
+ log_timescale_increment = math.log(float(max_timescale) / float(min_timescale)) / (
71
+ num_timescales - 1
72
+ )
73
+ inv_timescales = min_timescale * torch.exp(
74
+ torch.arange(num_timescales, dtype=torch.float) * -log_timescale_increment
75
+ )
76
+ scaled_time = position.unsqueeze(0) * inv_timescales.unsqueeze(1)
77
+ signal = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], 0)
78
+ signal = F.pad(signal, [0, 0, 0, channels % 2])
79
+ signal = signal.view(1, channels, length)
80
+ return signal
81
+
82
+
83
+ def add_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4):
84
+ b, channels, length = x.size()
85
+ signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale)
86
+ return x + signal.to(dtype=x.dtype, device=x.device)
87
+
88
+
89
+ def cat_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4, axis=1):
90
+ b, channels, length = x.size()
91
+ signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale)
92
+ return torch.cat([x, signal.to(dtype=x.dtype, device=x.device)], axis)
93
+
94
+
95
+ def subsequent_mask(length):
96
+ mask = torch.tril(torch.ones(length, length)).unsqueeze(0).unsqueeze(0)
97
+ return mask
98
+
99
+
100
+ @torch.jit.script
101
+ def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels):
102
+ n_channels_int = n_channels[0]
103
+ in_act = input_a + input_b
104
+ t_act = torch.tanh(in_act[:, :n_channels_int, :])
105
+ s_act = torch.sigmoid(in_act[:, n_channels_int:, :])
106
+ acts = t_act * s_act
107
+ return acts
108
+
109
+
110
+ def convert_pad_shape(pad_shape):
111
+ l = pad_shape[::-1]
112
+ pad_shape = [item for sublist in l for item in sublist]
113
+ return pad_shape
114
+
115
+
116
+ def shift_1d(x):
117
+ x = F.pad(x, convert_pad_shape([[0, 0], [0, 0], [1, 0]]))[:, :, :-1]
118
+ return x
119
+
120
+
121
+ def sequence_mask(length, max_length=None):
122
+ if max_length is None:
123
+ max_length = length.max()
124
+ x = torch.arange(max_length, dtype=length.dtype, device=length.device)
125
+ return x.unsqueeze(0) < length.unsqueeze(1)
126
+
127
+
128
+ def generate_path(duration, mask):
129
+ """
130
+ duration: [b, 1, t_x]
131
+ mask: [b, 1, t_y, t_x]
132
+ """
133
+ device = duration.device
134
+
135
+ b, _, t_y, t_x = mask.shape
136
+ cum_duration = torch.cumsum(duration, -1)
137
+
138
+ cum_duration_flat = cum_duration.view(b * t_x)
139
+ path = sequence_mask(cum_duration_flat, t_y).to(mask.dtype)
140
+ path = path.view(b, t_x, t_y)
141
+ path = path - F.pad(path, convert_pad_shape([[0, 0], [1, 0], [0, 0]]))[:, :-1]
142
+ path = path.unsqueeze(1).transpose(2, 3) * mask
143
+ return path
144
+
145
+
146
+ def clip_grad_value_(parameters, clip_value, norm_type=2):
147
+ if isinstance(parameters, torch.Tensor):
148
+ parameters = [parameters]
149
+ parameters = list(filter(lambda p: p.grad is not None, parameters))
150
+ norm_type = float(norm_type)
151
+ if clip_value is not None:
152
+ clip_value = float(clip_value)
153
+
154
+ total_norm = 0
155
+ for p in parameters:
156
+ param_norm = p.grad.data.norm(norm_type)
157
+ total_norm += param_norm.item() ** norm_type
158
+ if clip_value is not None:
159
+ p.grad.data.clamp_(min=-clip_value, max=clip_value)
160
+ total_norm = total_norm ** (1.0 / norm_type)
161
+ return total_norm
models.py ADDED
@@ -0,0 +1,977 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ import commons
4
+ import modules
5
+ import attentions
6
+ import monotonic_align
7
+ from torch import nn
8
+ from torch.nn import functional as F
9
+ from torch.nn import Conv1d, ConvTranspose1d, Conv2d
10
+ from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm
11
+ from commons import init_weights, get_padding
12
+ from text import symbols, num_tones, num_languages
13
+
14
+
15
+ class DurationDiscriminator(nn.Module): # vits2
16
+ def __init__(
17
+ self, in_channels, filter_channels, kernel_size, p_dropout, gin_channels=0
18
+ ):
19
+ super().__init__()
20
+
21
+ self.in_channels = in_channels
22
+ self.filter_channels = filter_channels
23
+ self.kernel_size = kernel_size
24
+ self.p_dropout = p_dropout
25
+ self.gin_channels = gin_channels
26
+
27
+ self.drop = nn.Dropout(p_dropout)
28
+ self.conv_1 = nn.Conv1d(
29
+ in_channels, filter_channels, kernel_size, padding=kernel_size // 2
30
+ )
31
+ self.norm_1 = modules.LayerNorm(filter_channels)
32
+ self.conv_2 = nn.Conv1d(
33
+ filter_channels, filter_channels, kernel_size, padding=kernel_size // 2
34
+ )
35
+ self.norm_2 = modules.LayerNorm(filter_channels)
36
+ self.dur_proj = nn.Conv1d(1, filter_channels, 1)
37
+
38
+ self.pre_out_conv_1 = nn.Conv1d(
39
+ 2 * filter_channels, filter_channels, kernel_size, padding=kernel_size // 2
40
+ )
41
+ self.pre_out_norm_1 = modules.LayerNorm(filter_channels)
42
+ self.pre_out_conv_2 = nn.Conv1d(
43
+ filter_channels, filter_channels, kernel_size, padding=kernel_size // 2
44
+ )
45
+ self.pre_out_norm_2 = modules.LayerNorm(filter_channels)
46
+
47
+ if gin_channels != 0:
48
+ self.cond = nn.Conv1d(gin_channels, in_channels, 1)
49
+
50
+ self.output_layer = nn.Sequential(nn.Linear(filter_channels, 1), nn.Sigmoid())
51
+
52
+ def forward_probability(self, x, x_mask, dur, g=None):
53
+ dur = self.dur_proj(dur)
54
+ x = torch.cat([x, dur], dim=1)
55
+ x = self.pre_out_conv_1(x * x_mask)
56
+ x = torch.relu(x)
57
+ x = self.pre_out_norm_1(x)
58
+ x = self.drop(x)
59
+ x = self.pre_out_conv_2(x * x_mask)
60
+ x = torch.relu(x)
61
+ x = self.pre_out_norm_2(x)
62
+ x = self.drop(x)
63
+ x = x * x_mask
64
+ x = x.transpose(1, 2)
65
+ output_prob = self.output_layer(x)
66
+ return output_prob
67
+
68
+ def forward(self, x, x_mask, dur_r, dur_hat, g=None):
69
+ x = torch.detach(x)
70
+ if g is not None:
71
+ g = torch.detach(g)
72
+ x = x + self.cond(g)
73
+ x = self.conv_1(x * x_mask)
74
+ x = torch.relu(x)
75
+ x = self.norm_1(x)
76
+ x = self.drop(x)
77
+ x = self.conv_2(x * x_mask)
78
+ x = torch.relu(x)
79
+ x = self.norm_2(x)
80
+ x = self.drop(x)
81
+
82
+ output_probs = []
83
+ for dur in [dur_r, dur_hat]:
84
+ output_prob = self.forward_probability(x, x_mask, dur, g)
85
+ output_probs.append(output_prob)
86
+
87
+ return output_probs
88
+
89
+
90
+ class TransformerCouplingBlock(nn.Module):
91
+ def __init__(
92
+ self,
93
+ channels,
94
+ hidden_channels,
95
+ filter_channels,
96
+ n_heads,
97
+ n_layers,
98
+ kernel_size,
99
+ p_dropout,
100
+ n_flows=4,
101
+ gin_channels=0,
102
+ share_parameter=False,
103
+ ):
104
+
105
+ super().__init__()
106
+ self.channels = channels
107
+ self.hidden_channels = hidden_channels
108
+ self.kernel_size = kernel_size
109
+ self.n_layers = n_layers
110
+ self.n_flows = n_flows
111
+ self.gin_channels = gin_channels
112
+
113
+ self.flows = nn.ModuleList()
114
+
115
+ self.wn = (
116
+ attentions.FFT(
117
+ hidden_channels,
118
+ filter_channels,
119
+ n_heads,
120
+ n_layers,
121
+ kernel_size,
122
+ p_dropout,
123
+ isflow=True,
124
+ gin_channels=self.gin_channels,
125
+ )
126
+ if share_parameter
127
+ else None
128
+ )
129
+
130
+ for i in range(n_flows):
131
+ self.flows.append(
132
+ modules.TransformerCouplingLayer(
133
+ channels,
134
+ hidden_channels,
135
+ kernel_size,
136
+ n_layers,
137
+ n_heads,
138
+ p_dropout,
139
+ filter_channels,
140
+ mean_only=True,
141
+ wn_sharing_parameter=self.wn,
142
+ gin_channels=self.gin_channels,
143
+ )
144
+ )
145
+ self.flows.append(modules.Flip())
146
+
147
+ def forward(self, x, x_mask, g=None, reverse=False):
148
+ if not reverse:
149
+ for flow in self.flows:
150
+ x, _ = flow(x, x_mask, g=g, reverse=reverse)
151
+ else:
152
+ for flow in reversed(self.flows):
153
+ x = flow(x, x_mask, g=g, reverse=reverse)
154
+ return x
155
+
156
+
157
+ class StochasticDurationPredictor(nn.Module):
158
+ def __init__(
159
+ self,
160
+ in_channels,
161
+ filter_channels,
162
+ kernel_size,
163
+ p_dropout,
164
+ n_flows=4,
165
+ gin_channels=0,
166
+ ):
167
+ super().__init__()
168
+ filter_channels = in_channels # it needs to be removed from future version.
169
+ self.in_channels = in_channels
170
+ self.filter_channels = filter_channels
171
+ self.kernel_size = kernel_size
172
+ self.p_dropout = p_dropout
173
+ self.n_flows = n_flows
174
+ self.gin_channels = gin_channels
175
+
176
+ self.log_flow = modules.Log()
177
+ self.flows = nn.ModuleList()
178
+ self.flows.append(modules.ElementwiseAffine(2))
179
+ for i in range(n_flows):
180
+ self.flows.append(
181
+ modules.ConvFlow(2, filter_channels, kernel_size, n_layers=3)
182
+ )
183
+ self.flows.append(modules.Flip())
184
+
185
+ self.post_pre = nn.Conv1d(1, filter_channels, 1)
186
+ self.post_proj = nn.Conv1d(filter_channels, filter_channels, 1)
187
+ self.post_convs = modules.DDSConv(
188
+ filter_channels, kernel_size, n_layers=3, p_dropout=p_dropout
189
+ )
190
+ self.post_flows = nn.ModuleList()
191
+ self.post_flows.append(modules.ElementwiseAffine(2))
192
+ for i in range(4):
193
+ self.post_flows.append(
194
+ modules.ConvFlow(2, filter_channels, kernel_size, n_layers=3)
195
+ )
196
+ self.post_flows.append(modules.Flip())
197
+
198
+ self.pre = nn.Conv1d(in_channels, filter_channels, 1)
199
+ self.proj = nn.Conv1d(filter_channels, filter_channels, 1)
200
+ self.convs = modules.DDSConv(
201
+ filter_channels, kernel_size, n_layers=3, p_dropout=p_dropout
202
+ )
203
+ if gin_channels != 0:
204
+ self.cond = nn.Conv1d(gin_channels, filter_channels, 1)
205
+
206
+ def forward(self, x, x_mask, w=None, g=None, reverse=False, noise_scale=1.0):
207
+ x = torch.detach(x)
208
+ x = self.pre(x)
209
+ if g is not None:
210
+ g = torch.detach(g)
211
+ x = x + self.cond(g)
212
+ x = self.convs(x, x_mask)
213
+ x = self.proj(x) * x_mask
214
+
215
+ if not reverse:
216
+ flows = self.flows
217
+ assert w is not None
218
+
219
+ logdet_tot_q = 0
220
+ h_w = self.post_pre(w)
221
+ h_w = self.post_convs(h_w, x_mask)
222
+ h_w = self.post_proj(h_w) * x_mask
223
+ e_q = (
224
+ torch.randn(w.size(0), 2, w.size(2)).to(device=x.device, dtype=x.dtype)
225
+ * x_mask
226
+ )
227
+ z_q = e_q
228
+ for flow in self.post_flows:
229
+ z_q, logdet_q = flow(z_q, x_mask, g=(x + h_w))
230
+ logdet_tot_q += logdet_q
231
+ z_u, z1 = torch.split(z_q, [1, 1], 1)
232
+ u = torch.sigmoid(z_u) * x_mask
233
+ z0 = (w - u) * x_mask
234
+ logdet_tot_q += torch.sum(
235
+ (F.logsigmoid(z_u) + F.logsigmoid(-z_u)) * x_mask, [1, 2]
236
+ )
237
+ logq = (
238
+ torch.sum(-0.5 * (math.log(2 * math.pi) + (e_q**2)) * x_mask, [1, 2])
239
+ - logdet_tot_q
240
+ )
241
+
242
+ logdet_tot = 0
243
+ z0, logdet = self.log_flow(z0, x_mask)
244
+ logdet_tot += logdet
245
+ z = torch.cat([z0, z1], 1)
246
+ for flow in flows:
247
+ z, logdet = flow(z, x_mask, g=x, reverse=reverse)
248
+ logdet_tot = logdet_tot + logdet
249
+ nll = (
250
+ torch.sum(0.5 * (math.log(2 * math.pi) + (z**2)) * x_mask, [1, 2])
251
+ - logdet_tot
252
+ )
253
+ return nll + logq # [b]
254
+ else:
255
+ flows = list(reversed(self.flows))
256
+ flows = flows[:-2] + [flows[-1]] # remove a useless vflow
257
+ z = (
258
+ torch.randn(x.size(0), 2, x.size(2)).to(device=x.device, dtype=x.dtype)
259
+ * noise_scale
260
+ )
261
+ for flow in flows:
262
+ z = flow(z, x_mask, g=x, reverse=reverse)
263
+ z0, z1 = torch.split(z, [1, 1], 1)
264
+ logw = z0
265
+ return logw
266
+
267
+
268
+ class DurationPredictor(nn.Module):
269
+ def __init__(
270
+ self, in_channels, filter_channels, kernel_size, p_dropout, gin_channels=0
271
+ ):
272
+ super().__init__()
273
+
274
+ self.in_channels = in_channels
275
+ self.filter_channels = filter_channels
276
+ self.kernel_size = kernel_size
277
+ self.p_dropout = p_dropout
278
+ self.gin_channels = gin_channels
279
+
280
+ self.drop = nn.Dropout(p_dropout)
281
+ self.conv_1 = nn.Conv1d(
282
+ in_channels, filter_channels, kernel_size, padding=kernel_size // 2
283
+ )
284
+ self.norm_1 = modules.LayerNorm(filter_channels)
285
+ self.conv_2 = nn.Conv1d(
286
+ filter_channels, filter_channels, kernel_size, padding=kernel_size // 2
287
+ )
288
+ self.norm_2 = modules.LayerNorm(filter_channels)
289
+ self.proj = nn.Conv1d(filter_channels, 1, 1)
290
+
291
+ if gin_channels != 0:
292
+ self.cond = nn.Conv1d(gin_channels, in_channels, 1)
293
+
294
+ def forward(self, x, x_mask, g=None):
295
+ x = torch.detach(x)
296
+ if g is not None:
297
+ g = torch.detach(g)
298
+ x = x + self.cond(g)
299
+ x = self.conv_1(x * x_mask)
300
+ x = torch.relu(x)
301
+ x = self.norm_1(x)
302
+ x = self.drop(x)
303
+ x = self.conv_2(x * x_mask)
304
+ x = torch.relu(x)
305
+ x = self.norm_2(x)
306
+ x = self.drop(x)
307
+ x = self.proj(x * x_mask)
308
+ return x * x_mask
309
+
310
+
311
+ class TextEncoder(nn.Module):
312
+ def __init__(
313
+ self,
314
+ n_vocab,
315
+ out_channels,
316
+ hidden_channels,
317
+ filter_channels,
318
+ n_heads,
319
+ n_layers,
320
+ kernel_size,
321
+ p_dropout,
322
+ gin_channels=0,
323
+ ):
324
+ super().__init__()
325
+ self.n_vocab = n_vocab
326
+ self.out_channels = out_channels
327
+ self.hidden_channels = hidden_channels
328
+ self.filter_channels = filter_channels
329
+ self.n_heads = n_heads
330
+ self.n_layers = n_layers
331
+ self.kernel_size = kernel_size
332
+ self.p_dropout = p_dropout
333
+ self.gin_channels = gin_channels
334
+ self.emb = nn.Embedding(len(symbols), hidden_channels)
335
+ nn.init.normal_(self.emb.weight, 0.0, hidden_channels**-0.5)
336
+ self.tone_emb = nn.Embedding(num_tones, hidden_channels)
337
+ nn.init.normal_(self.tone_emb.weight, 0.0, hidden_channels**-0.5)
338
+ self.language_emb = nn.Embedding(num_languages, hidden_channels)
339
+ nn.init.normal_(self.language_emb.weight, 0.0, hidden_channels**-0.5)
340
+ self.bert_proj = nn.Conv1d(1024, hidden_channels, 1)
341
+
342
+ self.encoder = attentions.Encoder(
343
+ hidden_channels,
344
+ filter_channels,
345
+ n_heads,
346
+ n_layers,
347
+ kernel_size,
348
+ p_dropout,
349
+ gin_channels=self.gin_channels,
350
+ )
351
+ self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
352
+
353
+ def forward(self, x, x_lengths, tone, language, bert, g=None):
354
+ x = (
355
+ self.emb(x)
356
+ + self.tone_emb(tone)
357
+ + self.language_emb(language)
358
+ + self.bert_proj(bert).transpose(1, 2)
359
+ ) * math.sqrt(
360
+ self.hidden_channels
361
+ ) # [b, t, h]
362
+ x = torch.transpose(x, 1, -1) # [b, h, t]
363
+ x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to(
364
+ x.dtype
365
+ )
366
+
367
+ x = self.encoder(x * x_mask, x_mask, g=g)
368
+ stats = self.proj(x) * x_mask
369
+
370
+ m, logs = torch.split(stats, self.out_channels, dim=1)
371
+ return x, m, logs, x_mask
372
+
373
+
374
+ class ResidualCouplingBlock(nn.Module):
375
+ def __init__(
376
+ self,
377
+ channels,
378
+ hidden_channels,
379
+ kernel_size,
380
+ dilation_rate,
381
+ n_layers,
382
+ n_flows=4,
383
+ gin_channels=0,
384
+ ):
385
+ super().__init__()
386
+ self.channels = channels
387
+ self.hidden_channels = hidden_channels
388
+ self.kernel_size = kernel_size
389
+ self.dilation_rate = dilation_rate
390
+ self.n_layers = n_layers
391
+ self.n_flows = n_flows
392
+ self.gin_channels = gin_channels
393
+
394
+ self.flows = nn.ModuleList()
395
+ for i in range(n_flows):
396
+ self.flows.append(
397
+ modules.ResidualCouplingLayer(
398
+ channels,
399
+ hidden_channels,
400
+ kernel_size,
401
+ dilation_rate,
402
+ n_layers,
403
+ gin_channels=gin_channels,
404
+ mean_only=True,
405
+ )
406
+ )
407
+ self.flows.append(modules.Flip())
408
+
409
+ def forward(self, x, x_mask, g=None, reverse=False):
410
+ if not reverse:
411
+ for flow in self.flows:
412
+ x, _ = flow(x, x_mask, g=g, reverse=reverse)
413
+ else:
414
+ for flow in reversed(self.flows):
415
+ x = flow(x, x_mask, g=g, reverse=reverse)
416
+ return x
417
+
418
+
419
+ class PosteriorEncoder(nn.Module):
420
+ def __init__(
421
+ self,
422
+ in_channels,
423
+ out_channels,
424
+ hidden_channels,
425
+ kernel_size,
426
+ dilation_rate,
427
+ n_layers,
428
+ gin_channels=0,
429
+ ):
430
+ super().__init__()
431
+ self.in_channels = in_channels
432
+ self.out_channels = out_channels
433
+ self.hidden_channels = hidden_channels
434
+ self.kernel_size = kernel_size
435
+ self.dilation_rate = dilation_rate
436
+ self.n_layers = n_layers
437
+ self.gin_channels = gin_channels
438
+
439
+ self.pre = nn.Conv1d(in_channels, hidden_channels, 1)
440
+ self.enc = modules.WN(
441
+ hidden_channels,
442
+ kernel_size,
443
+ dilation_rate,
444
+ n_layers,
445
+ gin_channels=gin_channels,
446
+ )
447
+ self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
448
+
449
+ def forward(self, x, x_lengths, g=None):
450
+ x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to(
451
+ x.dtype
452
+ )
453
+ x = self.pre(x) * x_mask
454
+ x = self.enc(x, x_mask, g=g)
455
+ stats = self.proj(x) * x_mask
456
+ m, logs = torch.split(stats, self.out_channels, dim=1)
457
+ z = (m + torch.randn_like(m) * torch.exp(logs)) * x_mask
458
+ return z, m, logs, x_mask
459
+
460
+
461
+ class Generator(torch.nn.Module):
462
+ def __init__(
463
+ self,
464
+ initial_channel,
465
+ resblock,
466
+ resblock_kernel_sizes,
467
+ resblock_dilation_sizes,
468
+ upsample_rates,
469
+ upsample_initial_channel,
470
+ upsample_kernel_sizes,
471
+ gin_channels=0,
472
+ ):
473
+ super(Generator, self).__init__()
474
+ self.num_kernels = len(resblock_kernel_sizes)
475
+ self.num_upsamples = len(upsample_rates)
476
+ self.conv_pre = Conv1d(
477
+ initial_channel, upsample_initial_channel, 7, 1, padding=3
478
+ )
479
+ resblock = modules.ResBlock1 if resblock == "1" else modules.ResBlock2
480
+
481
+ self.ups = nn.ModuleList()
482
+ for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
483
+ self.ups.append(
484
+ weight_norm(
485
+ ConvTranspose1d(
486
+ upsample_initial_channel // (2**i),
487
+ upsample_initial_channel // (2 ** (i + 1)),
488
+ k,
489
+ u,
490
+ padding=(k - u) // 2,
491
+ )
492
+ )
493
+ )
494
+
495
+ self.resblocks = nn.ModuleList()
496
+ for i in range(len(self.ups)):
497
+ ch = upsample_initial_channel // (2 ** (i + 1))
498
+ for j, (k, d) in enumerate(
499
+ zip(resblock_kernel_sizes, resblock_dilation_sizes)
500
+ ):
501
+ self.resblocks.append(resblock(ch, k, d))
502
+
503
+ self.conv_post = Conv1d(ch, 1, 7, 1, padding=3, bias=False)
504
+ self.ups.apply(init_weights)
505
+
506
+ if gin_channels != 0:
507
+ self.cond = nn.Conv1d(gin_channels, upsample_initial_channel, 1)
508
+
509
+ def forward(self, x, g=None):
510
+ x = self.conv_pre(x)
511
+ if g is not None:
512
+ x = x + self.cond(g)
513
+
514
+ for i in range(self.num_upsamples):
515
+ x = F.leaky_relu(x, modules.LRELU_SLOPE)
516
+ x = self.ups[i](x)
517
+ xs = None
518
+ for j in range(self.num_kernels):
519
+ if xs is None:
520
+ xs = self.resblocks[i * self.num_kernels + j](x)
521
+ else:
522
+ xs += self.resblocks[i * self.num_kernels + j](x)
523
+ x = xs / self.num_kernels
524
+ x = F.leaky_relu(x)
525
+ x = self.conv_post(x)
526
+ x = torch.tanh(x)
527
+
528
+ return x
529
+
530
+ def remove_weight_norm(self):
531
+ print("Removing weight norm...")
532
+ for l in self.ups:
533
+ remove_weight_norm(l)
534
+ for l in self.resblocks:
535
+ l.remove_weight_norm()
536
+
537
+
538
+ class DiscriminatorP(torch.nn.Module):
539
+ def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False):
540
+ super(DiscriminatorP, self).__init__()
541
+ self.period = period
542
+ self.use_spectral_norm = use_spectral_norm
543
+ norm_f = weight_norm if use_spectral_norm == False else spectral_norm
544
+ self.convs = nn.ModuleList(
545
+ [
546
+ norm_f(
547
+ Conv2d(
548
+ 1,
549
+ 32,
550
+ (kernel_size, 1),
551
+ (stride, 1),
552
+ padding=(get_padding(kernel_size, 1), 0),
553
+ )
554
+ ),
555
+ norm_f(
556
+ Conv2d(
557
+ 32,
558
+ 128,
559
+ (kernel_size, 1),
560
+ (stride, 1),
561
+ padding=(get_padding(kernel_size, 1), 0),
562
+ )
563
+ ),
564
+ norm_f(
565
+ Conv2d(
566
+ 128,
567
+ 512,
568
+ (kernel_size, 1),
569
+ (stride, 1),
570
+ padding=(get_padding(kernel_size, 1), 0),
571
+ )
572
+ ),
573
+ norm_f(
574
+ Conv2d(
575
+ 512,
576
+ 1024,
577
+ (kernel_size, 1),
578
+ (stride, 1),
579
+ padding=(get_padding(kernel_size, 1), 0),
580
+ )
581
+ ),
582
+ norm_f(
583
+ Conv2d(
584
+ 1024,
585
+ 1024,
586
+ (kernel_size, 1),
587
+ 1,
588
+ padding=(get_padding(kernel_size, 1), 0),
589
+ )
590
+ ),
591
+ ]
592
+ )
593
+ self.conv_post = norm_f(Conv2d(1024, 1, (3, 1), 1, padding=(1, 0)))
594
+
595
+ def forward(self, x):
596
+ fmap = []
597
+
598
+ # 1d to 2d
599
+ b, c, t = x.shape
600
+ if t % self.period != 0: # pad first
601
+ n_pad = self.period - (t % self.period)
602
+ x = F.pad(x, (0, n_pad), "reflect")
603
+ t = t + n_pad
604
+ x = x.view(b, c, t // self.period, self.period)
605
+
606
+ for l in self.convs:
607
+ x = l(x)
608
+ x = F.leaky_relu(x, modules.LRELU_SLOPE)
609
+ fmap.append(x)
610
+ x = self.conv_post(x)
611
+ fmap.append(x)
612
+ x = torch.flatten(x, 1, -1)
613
+
614
+ return x, fmap
615
+
616
+
617
+ class DiscriminatorS(torch.nn.Module):
618
+ def __init__(self, use_spectral_norm=False):
619
+ super(DiscriminatorS, self).__init__()
620
+ norm_f = weight_norm if use_spectral_norm == False else spectral_norm
621
+ self.convs = nn.ModuleList(
622
+ [
623
+ norm_f(Conv1d(1, 16, 15, 1, padding=7)),
624
+ norm_f(Conv1d(16, 64, 41, 4, groups=4, padding=20)),
625
+ norm_f(Conv1d(64, 256, 41, 4, groups=16, padding=20)),
626
+ norm_f(Conv1d(256, 1024, 41, 4, groups=64, padding=20)),
627
+ norm_f(Conv1d(1024, 1024, 41, 4, groups=256, padding=20)),
628
+ norm_f(Conv1d(1024, 1024, 5, 1, padding=2)),
629
+ ]
630
+ )
631
+ self.conv_post = norm_f(Conv1d(1024, 1, 3, 1, padding=1))
632
+
633
+ def forward(self, x):
634
+ fmap = []
635
+
636
+ for l in self.convs:
637
+ x = l(x)
638
+ x = F.leaky_relu(x, modules.LRELU_SLOPE)
639
+ fmap.append(x)
640
+ x = self.conv_post(x)
641
+ fmap.append(x)
642
+ x = torch.flatten(x, 1, -1)
643
+
644
+ return x, fmap
645
+
646
+
647
+ class MultiPeriodDiscriminator(torch.nn.Module):
648
+ def __init__(self, use_spectral_norm=False):
649
+ super(MultiPeriodDiscriminator, self).__init__()
650
+ periods = [2, 3, 5, 7, 11]
651
+
652
+ discs = [DiscriminatorS(use_spectral_norm=use_spectral_norm)]
653
+ discs = discs + [
654
+ DiscriminatorP(i, use_spectral_norm=use_spectral_norm) for i in periods
655
+ ]
656
+ self.discriminators = nn.ModuleList(discs)
657
+
658
+ def forward(self, y, y_hat):
659
+ y_d_rs = []
660
+ y_d_gs = []
661
+ fmap_rs = []
662
+ fmap_gs = []
663
+ for i, d in enumerate(self.discriminators):
664
+ y_d_r, fmap_r = d(y)
665
+ y_d_g, fmap_g = d(y_hat)
666
+ y_d_rs.append(y_d_r)
667
+ y_d_gs.append(y_d_g)
668
+ fmap_rs.append(fmap_r)
669
+ fmap_gs.append(fmap_g)
670
+
671
+ return y_d_rs, y_d_gs, fmap_rs, fmap_gs
672
+
673
+
674
+ class ReferenceEncoder(nn.Module):
675
+ """
676
+ inputs --- [N, Ty/r, n_mels*r] mels
677
+ outputs --- [N, ref_enc_gru_size]
678
+ """
679
+
680
+ def __init__(self, spec_channels, gin_channels=0):
681
+
682
+ super().__init__()
683
+ self.spec_channels = spec_channels
684
+ ref_enc_filters = [32, 32, 64, 64, 128, 128]
685
+ K = len(ref_enc_filters)
686
+ filters = [1] + ref_enc_filters
687
+ convs = [
688
+ weight_norm(
689
+ nn.Conv2d(
690
+ in_channels=filters[i],
691
+ out_channels=filters[i + 1],
692
+ kernel_size=(3, 3),
693
+ stride=(2, 2),
694
+ padding=(1, 1),
695
+ )
696
+ )
697
+ for i in range(K)
698
+ ]
699
+ self.convs = nn.ModuleList(convs)
700
+ # self.wns = nn.ModuleList([weight_norm(num_features=ref_enc_filters[i]) for i in range(K)])
701
+
702
+ out_channels = self.calculate_channels(spec_channels, 3, 2, 1, K)
703
+ self.gru = nn.GRU(
704
+ input_size=ref_enc_filters[-1] * out_channels,
705
+ hidden_size=256 // 2,
706
+ batch_first=True,
707
+ )
708
+ self.proj = nn.Linear(128, gin_channels)
709
+
710
+ def forward(self, inputs, mask=None):
711
+ N = inputs.size(0)
712
+ out = inputs.view(N, 1, -1, self.spec_channels) # [N, 1, Ty, n_freqs]
713
+ for conv in self.convs:
714
+ out = conv(out)
715
+ # out = wn(out)
716
+ out = F.relu(out) # [N, 128, Ty//2^K, n_mels//2^K]
717
+
718
+ out = out.transpose(1, 2) # [N, Ty//2^K, 128, n_mels//2^K]
719
+ T = out.size(1)
720
+ N = out.size(0)
721
+ out = out.contiguous().view(N, T, -1) # [N, Ty//2^K, 128*n_mels//2^K]
722
+
723
+ self.gru.flatten_parameters()
724
+ memory, out = self.gru(out) # out --- [1, N, 128]
725
+
726
+ return self.proj(out.squeeze(0))
727
+
728
+ def calculate_channels(self, L, kernel_size, stride, pad, n_convs):
729
+ for i in range(n_convs):
730
+ L = (L - kernel_size + 2 * pad) // stride + 1
731
+ return L
732
+
733
+
734
+ class SynthesizerTrn(nn.Module):
735
+ """
736
+ Synthesizer for Training
737
+ """
738
+
739
+ def __init__(
740
+ self,
741
+ n_vocab,
742
+ spec_channels,
743
+ segment_size,
744
+ inter_channels,
745
+ hidden_channels,
746
+ filter_channels,
747
+ n_heads,
748
+ n_layers,
749
+ kernel_size,
750
+ p_dropout,
751
+ resblock,
752
+ resblock_kernel_sizes,
753
+ resblock_dilation_sizes,
754
+ upsample_rates,
755
+ upsample_initial_channel,
756
+ upsample_kernel_sizes,
757
+ n_speakers=256,
758
+ gin_channels=256,
759
+ use_sdp=True,
760
+ n_flow_layer=4,
761
+ n_layers_trans_flow=3,
762
+ flow_share_parameter=False,
763
+ use_transformer_flow=True,
764
+ **kwargs
765
+ ):
766
+
767
+ super().__init__()
768
+ self.n_vocab = n_vocab
769
+ self.spec_channels = spec_channels
770
+ self.inter_channels = inter_channels
771
+ self.hidden_channels = hidden_channels
772
+ self.filter_channels = filter_channels
773
+ self.n_heads = n_heads
774
+ self.n_layers = n_layers
775
+ self.kernel_size = kernel_size
776
+ self.p_dropout = p_dropout
777
+ self.resblock = resblock
778
+ self.resblock_kernel_sizes = resblock_kernel_sizes
779
+ self.resblock_dilation_sizes = resblock_dilation_sizes
780
+ self.upsample_rates = upsample_rates
781
+ self.upsample_initial_channel = upsample_initial_channel
782
+ self.upsample_kernel_sizes = upsample_kernel_sizes
783
+ self.segment_size = segment_size
784
+ self.n_speakers = n_speakers
785
+ self.gin_channels = gin_channels
786
+ self.n_layers_trans_flow = n_layers_trans_flow
787
+ self.use_spk_conditioned_encoder = kwargs.get(
788
+ "use_spk_conditioned_encoder", True
789
+ )
790
+ self.use_sdp = use_sdp
791
+ self.use_noise_scaled_mas = kwargs.get("use_noise_scaled_mas", False)
792
+ self.mas_noise_scale_initial = kwargs.get("mas_noise_scale_initial", 0.01)
793
+ self.noise_scale_delta = kwargs.get("noise_scale_delta", 2e-6)
794
+ self.current_mas_noise_scale = self.mas_noise_scale_initial
795
+ if self.use_spk_conditioned_encoder and gin_channels > 0:
796
+ self.enc_gin_channels = gin_channels
797
+ self.enc_p = TextEncoder(
798
+ n_vocab,
799
+ inter_channels,
800
+ hidden_channels,
801
+ filter_channels,
802
+ n_heads,
803
+ n_layers,
804
+ kernel_size,
805
+ p_dropout,
806
+ gin_channels=self.enc_gin_channels,
807
+ )
808
+ self.dec = Generator(
809
+ inter_channels,
810
+ resblock,
811
+ resblock_kernel_sizes,
812
+ resblock_dilation_sizes,
813
+ upsample_rates,
814
+ upsample_initial_channel,
815
+ upsample_kernel_sizes,
816
+ gin_channels=gin_channels,
817
+ )
818
+ self.enc_q = PosteriorEncoder(
819
+ spec_channels,
820
+ inter_channels,
821
+ hidden_channels,
822
+ 5,
823
+ 1,
824
+ 16,
825
+ gin_channels=gin_channels,
826
+ )
827
+ if use_transformer_flow:
828
+ self.flow = TransformerCouplingBlock(
829
+ inter_channels,
830
+ hidden_channels,
831
+ filter_channels,
832
+ n_heads,
833
+ n_layers_trans_flow,
834
+ 5,
835
+ p_dropout,
836
+ n_flow_layer,
837
+ gin_channels=gin_channels,
838
+ share_parameter=flow_share_parameter,
839
+ )
840
+ else:
841
+ self.flow = ResidualCouplingBlock(
842
+ inter_channels,
843
+ hidden_channels,
844
+ 5,
845
+ 1,
846
+ n_flow_layer,
847
+ gin_channels=gin_channels,
848
+ )
849
+ self.sdp = StochasticDurationPredictor(
850
+ hidden_channels, 192, 3, 0.5, 4, gin_channels=gin_channels
851
+ )
852
+ self.dp = DurationPredictor(
853
+ hidden_channels, 256, 3, 0.5, gin_channels=gin_channels
854
+ )
855
+
856
+ if n_speakers > 1:
857
+ self.emb_g = nn.Embedding(n_speakers, gin_channels)
858
+ else:
859
+ self.ref_enc = ReferenceEncoder(spec_channels, gin_channels)
860
+
861
+ def forward(self, x, x_lengths, y, y_lengths, sid, tone, language, bert):
862
+ if self.n_speakers > 0:
863
+ g = self.emb_g(sid).unsqueeze(-1) # [b, h, 1]
864
+ else:
865
+ g = self.ref_enc(y.transpose(1, 2)).unsqueeze(-1)
866
+ x, m_p, logs_p, x_mask = self.enc_p(x, x_lengths, tone, language, bert, g=g)
867
+ z, m_q, logs_q, y_mask = self.enc_q(y, y_lengths, g=g)
868
+ z_p = self.flow(z, y_mask, g=g)
869
+
870
+ with torch.no_grad():
871
+ # negative cross-entropy
872
+ s_p_sq_r = torch.exp(-2 * logs_p) # [b, d, t]
873
+ neg_cent1 = torch.sum(
874
+ -0.5 * math.log(2 * math.pi) - logs_p, [1], keepdim=True
875
+ ) # [b, 1, t_s]
876
+ neg_cent2 = torch.matmul(
877
+ -0.5 * (z_p**2).transpose(1, 2), s_p_sq_r
878
+ ) # [b, t_t, d] x [b, d, t_s] = [b, t_t, t_s]
879
+ neg_cent3 = torch.matmul(
880
+ z_p.transpose(1, 2), (m_p * s_p_sq_r)
881
+ ) # [b, t_t, d] x [b, d, t_s] = [b, t_t, t_s]
882
+ neg_cent4 = torch.sum(
883
+ -0.5 * (m_p**2) * s_p_sq_r, [1], keepdim=True
884
+ ) # [b, 1, t_s]
885
+ neg_cent = neg_cent1 + neg_cent2 + neg_cent3 + neg_cent4
886
+ if self.use_noise_scaled_mas:
887
+ epsilon = (
888
+ torch.std(neg_cent)
889
+ * torch.randn_like(neg_cent)
890
+ * self.current_mas_noise_scale
891
+ )
892
+ neg_cent = neg_cent + epsilon
893
+
894
+ attn_mask = torch.unsqueeze(x_mask, 2) * torch.unsqueeze(y_mask, -1)
895
+ attn = (
896
+ monotonic_align.maximum_path(neg_cent, attn_mask.squeeze(1))
897
+ .unsqueeze(1)
898
+ .detach()
899
+ )
900
+
901
+ w = attn.sum(2)
902
+
903
+ l_length_sdp = self.sdp(x, x_mask, w, g=g)
904
+ l_length_sdp = l_length_sdp / torch.sum(x_mask)
905
+
906
+ logw_ = torch.log(w + 1e-6) * x_mask
907
+ logw = self.dp(x, x_mask, g=g)
908
+ l_length_dp = torch.sum((logw - logw_) ** 2, [1, 2]) / torch.sum(
909
+ x_mask
910
+ ) # for averaging
911
+
912
+ l_length = l_length_dp + l_length_sdp
913
+
914
+ # expand prior
915
+ m_p = torch.matmul(attn.squeeze(1), m_p.transpose(1, 2)).transpose(1, 2)
916
+ logs_p = torch.matmul(attn.squeeze(1), logs_p.transpose(1, 2)).transpose(1, 2)
917
+
918
+ z_slice, ids_slice = commons.rand_slice_segments(
919
+ z, y_lengths, self.segment_size
920
+ )
921
+ o = self.dec(z_slice, g=g)
922
+ return (
923
+ o,
924
+ l_length,
925
+ attn,
926
+ ids_slice,
927
+ x_mask,
928
+ y_mask,
929
+ (z, z_p, m_p, logs_p, m_q, logs_q),
930
+ (x, logw, logw_),
931
+ )
932
+
933
+ def infer(
934
+ self,
935
+ x,
936
+ x_lengths,
937
+ sid,
938
+ tone,
939
+ language,
940
+ bert,
941
+ noise_scale=0.667,
942
+ length_scale=1,
943
+ noise_scale_w=0.8,
944
+ max_len=None,
945
+ sdp_ratio=0,
946
+ y=None,
947
+ ):
948
+ # x, m_p, logs_p, x_mask = self.enc_p(x, x_lengths, tone, language, bert)
949
+ # g = self.gst(y)
950
+ if self.n_speakers > 0:
951
+ g = self.emb_g(sid).unsqueeze(-1) # [b, h, 1]
952
+ else:
953
+ g = self.ref_enc(y.transpose(1, 2)).unsqueeze(-1)
954
+ x, m_p, logs_p, x_mask = self.enc_p(x, x_lengths, tone, language, bert, g=g)
955
+ logw = self.sdp(x, x_mask, g=g, reverse=True, noise_scale=noise_scale_w) * (
956
+ sdp_ratio
957
+ ) + self.dp(x, x_mask, g=g) * (1 - sdp_ratio)
958
+ w = torch.exp(logw) * x_mask * length_scale
959
+ w_ceil = torch.ceil(w)
960
+ y_lengths = torch.clamp_min(torch.sum(w_ceil, [1, 2]), 1).long()
961
+ y_mask = torch.unsqueeze(commons.sequence_mask(y_lengths, None), 1).to(
962
+ x_mask.dtype
963
+ )
964
+ attn_mask = torch.unsqueeze(x_mask, 2) * torch.unsqueeze(y_mask, -1)
965
+ attn = commons.generate_path(w_ceil, attn_mask)
966
+
967
+ m_p = torch.matmul(attn.squeeze(1), m_p.transpose(1, 2)).transpose(
968
+ 1, 2
969
+ ) # [b, t', t], [b, t, d] -> [b, d, t']
970
+ logs_p = torch.matmul(attn.squeeze(1), logs_p.transpose(1, 2)).transpose(
971
+ 1, 2
972
+ ) # [b, t', t], [b, t, d] -> [b, d, t']
973
+
974
+ z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p) * noise_scale
975
+ z = self.flow(z_p, y_mask, g=g, reverse=True)
976
+ o = self.dec((z * y_mask)[:, :, :max_len], g=g)
977
+ return o, attn, y_mask, (z, z_p, m_p, logs_p)
modules.py ADDED
@@ -0,0 +1,596 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ import commons
4
+ from attentions import Encoder
5
+ from torch import nn
6
+ from torch.nn import Conv1d
7
+ from torch.nn import functional as F
8
+ from torch.nn.utils import weight_norm, remove_weight_norm
9
+ from transforms import piecewise_rational_quadratic_transform
10
+ from commons import init_weights, get_padding
11
+
12
+
13
+ LRELU_SLOPE = 0.1
14
+
15
+
16
+ class LayerNorm(nn.Module):
17
+ def __init__(self, channels, eps=1e-5):
18
+ super().__init__()
19
+ self.channels = channels
20
+ self.eps = eps
21
+
22
+ self.gamma = nn.Parameter(torch.ones(channels))
23
+ self.beta = nn.Parameter(torch.zeros(channels))
24
+
25
+ def forward(self, x):
26
+ x = x.transpose(1, -1)
27
+ x = F.layer_norm(x, (self.channels,), self.gamma, self.beta, self.eps)
28
+ return x.transpose(1, -1)
29
+
30
+
31
+ class ConvReluNorm(nn.Module):
32
+ def __init__(
33
+ self,
34
+ in_channels,
35
+ hidden_channels,
36
+ out_channels,
37
+ kernel_size,
38
+ n_layers,
39
+ p_dropout,
40
+ ):
41
+ super().__init__()
42
+ self.in_channels = in_channels
43
+ self.hidden_channels = hidden_channels
44
+ self.out_channels = out_channels
45
+ self.kernel_size = kernel_size
46
+ self.n_layers = n_layers
47
+ self.p_dropout = p_dropout
48
+ assert n_layers > 1, "Number of layers should be larger than 0."
49
+
50
+ self.conv_layers = nn.ModuleList()
51
+ self.norm_layers = nn.ModuleList()
52
+ self.conv_layers.append(
53
+ nn.Conv1d(
54
+ in_channels, hidden_channels, kernel_size, padding=kernel_size // 2
55
+ )
56
+ )
57
+ self.norm_layers.append(LayerNorm(hidden_channels))
58
+ self.relu_drop = nn.Sequential(nn.ReLU(), nn.Dropout(p_dropout))
59
+ for _ in range(n_layers - 1):
60
+ self.conv_layers.append(
61
+ nn.Conv1d(
62
+ hidden_channels,
63
+ hidden_channels,
64
+ kernel_size,
65
+ padding=kernel_size // 2,
66
+ )
67
+ )
68
+ self.norm_layers.append(LayerNorm(hidden_channels))
69
+ self.proj = nn.Conv1d(hidden_channels, out_channels, 1)
70
+ self.proj.weight.data.zero_()
71
+ self.proj.bias.data.zero_()
72
+
73
+ def forward(self, x, x_mask):
74
+ x_org = x
75
+ for i in range(self.n_layers):
76
+ x = self.conv_layers[i](x * x_mask)
77
+ x = self.norm_layers[i](x)
78
+ x = self.relu_drop(x)
79
+ x = x_org + self.proj(x)
80
+ return x * x_mask
81
+
82
+
83
+ class DDSConv(nn.Module):
84
+ """
85
+ Dialted and Depth-Separable Convolution
86
+ """
87
+
88
+ def __init__(self, channels, kernel_size, n_layers, p_dropout=0.0):
89
+ super().__init__()
90
+ self.channels = channels
91
+ self.kernel_size = kernel_size
92
+ self.n_layers = n_layers
93
+ self.p_dropout = p_dropout
94
+
95
+ self.drop = nn.Dropout(p_dropout)
96
+ self.convs_sep = nn.ModuleList()
97
+ self.convs_1x1 = nn.ModuleList()
98
+ self.norms_1 = nn.ModuleList()
99
+ self.norms_2 = nn.ModuleList()
100
+ for i in range(n_layers):
101
+ dilation = kernel_size**i
102
+ padding = (kernel_size * dilation - dilation) // 2
103
+ self.convs_sep.append(
104
+ nn.Conv1d(
105
+ channels,
106
+ channels,
107
+ kernel_size,
108
+ groups=channels,
109
+ dilation=dilation,
110
+ padding=padding,
111
+ )
112
+ )
113
+ self.convs_1x1.append(nn.Conv1d(channels, channels, 1))
114
+ self.norms_1.append(LayerNorm(channels))
115
+ self.norms_2.append(LayerNorm(channels))
116
+
117
+ def forward(self, x, x_mask, g=None):
118
+ if g is not None:
119
+ x = x + g
120
+ for i in range(self.n_layers):
121
+ y = self.convs_sep[i](x * x_mask)
122
+ y = self.norms_1[i](y)
123
+ y = F.gelu(y)
124
+ y = self.convs_1x1[i](y)
125
+ y = self.norms_2[i](y)
126
+ y = F.gelu(y)
127
+ y = self.drop(y)
128
+ x = x + y
129
+ return x * x_mask
130
+
131
+
132
+ class WN(torch.nn.Module):
133
+ def __init__(
134
+ self,
135
+ hidden_channels,
136
+ kernel_size,
137
+ dilation_rate,
138
+ n_layers,
139
+ gin_channels=0,
140
+ p_dropout=0,
141
+ ):
142
+ super(WN, self).__init__()
143
+ assert kernel_size % 2 == 1
144
+ self.hidden_channels = hidden_channels
145
+ self.kernel_size = (kernel_size,)
146
+ self.dilation_rate = dilation_rate
147
+ self.n_layers = n_layers
148
+ self.gin_channels = gin_channels
149
+ self.p_dropout = p_dropout
150
+
151
+ self.in_layers = torch.nn.ModuleList()
152
+ self.res_skip_layers = torch.nn.ModuleList()
153
+ self.drop = nn.Dropout(p_dropout)
154
+
155
+ if gin_channels != 0:
156
+ cond_layer = torch.nn.Conv1d(
157
+ gin_channels, 2 * hidden_channels * n_layers, 1
158
+ )
159
+ self.cond_layer = torch.nn.utils.weight_norm(cond_layer, name="weight")
160
+
161
+ for i in range(n_layers):
162
+ dilation = dilation_rate**i
163
+ padding = int((kernel_size * dilation - dilation) / 2)
164
+ in_layer = torch.nn.Conv1d(
165
+ hidden_channels,
166
+ 2 * hidden_channels,
167
+ kernel_size,
168
+ dilation=dilation,
169
+ padding=padding,
170
+ )
171
+ in_layer = torch.nn.utils.weight_norm(in_layer, name="weight")
172
+ self.in_layers.append(in_layer)
173
+
174
+ # last one is not necessary
175
+ if i < n_layers - 1:
176
+ res_skip_channels = 2 * hidden_channels
177
+ else:
178
+ res_skip_channels = hidden_channels
179
+
180
+ res_skip_layer = torch.nn.Conv1d(hidden_channels, res_skip_channels, 1)
181
+ res_skip_layer = torch.nn.utils.weight_norm(res_skip_layer, name="weight")
182
+ self.res_skip_layers.append(res_skip_layer)
183
+
184
+ def forward(self, x, x_mask, g=None, **kwargs):
185
+ output = torch.zeros_like(x)
186
+ n_channels_tensor = torch.IntTensor([self.hidden_channels])
187
+
188
+ if g is not None:
189
+ g = self.cond_layer(g)
190
+
191
+ for i in range(self.n_layers):
192
+ x_in = self.in_layers[i](x)
193
+ if g is not None:
194
+ cond_offset = i * 2 * self.hidden_channels
195
+ g_l = g[:, cond_offset : cond_offset + 2 * self.hidden_channels, :]
196
+ else:
197
+ g_l = torch.zeros_like(x_in)
198
+
199
+ acts = commons.fused_add_tanh_sigmoid_multiply(x_in, g_l, n_channels_tensor)
200
+ acts = self.drop(acts)
201
+
202
+ res_skip_acts = self.res_skip_layers[i](acts)
203
+ if i < self.n_layers - 1:
204
+ res_acts = res_skip_acts[:, : self.hidden_channels, :]
205
+ x = (x + res_acts) * x_mask
206
+ output = output + res_skip_acts[:, self.hidden_channels :, :]
207
+ else:
208
+ output = output + res_skip_acts
209
+ return output * x_mask
210
+
211
+ def remove_weight_norm(self):
212
+ if self.gin_channels != 0:
213
+ torch.nn.utils.remove_weight_norm(self.cond_layer)
214
+ for l in self.in_layers:
215
+ torch.nn.utils.remove_weight_norm(l)
216
+ for l in self.res_skip_layers:
217
+ torch.nn.utils.remove_weight_norm(l)
218
+
219
+
220
+ class ResBlock1(torch.nn.Module):
221
+ def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5)):
222
+ super(ResBlock1, self).__init__()
223
+ self.convs1 = nn.ModuleList(
224
+ [
225
+ weight_norm(
226
+ Conv1d(
227
+ channels,
228
+ channels,
229
+ kernel_size,
230
+ 1,
231
+ dilation=dilation[0],
232
+ padding=get_padding(kernel_size, dilation[0]),
233
+ )
234
+ ),
235
+ weight_norm(
236
+ Conv1d(
237
+ channels,
238
+ channels,
239
+ kernel_size,
240
+ 1,
241
+ dilation=dilation[1],
242
+ padding=get_padding(kernel_size, dilation[1]),
243
+ )
244
+ ),
245
+ weight_norm(
246
+ Conv1d(
247
+ channels,
248
+ channels,
249
+ kernel_size,
250
+ 1,
251
+ dilation=dilation[2],
252
+ padding=get_padding(kernel_size, dilation[2]),
253
+ )
254
+ ),
255
+ ]
256
+ )
257
+ self.convs1.apply(init_weights)
258
+
259
+ self.convs2 = nn.ModuleList(
260
+ [
261
+ weight_norm(
262
+ Conv1d(
263
+ channels,
264
+ channels,
265
+ kernel_size,
266
+ 1,
267
+ dilation=1,
268
+ padding=get_padding(kernel_size, 1),
269
+ )
270
+ ),
271
+ weight_norm(
272
+ Conv1d(
273
+ channels,
274
+ channels,
275
+ kernel_size,
276
+ 1,
277
+ dilation=1,
278
+ padding=get_padding(kernel_size, 1),
279
+ )
280
+ ),
281
+ weight_norm(
282
+ Conv1d(
283
+ channels,
284
+ channels,
285
+ kernel_size,
286
+ 1,
287
+ dilation=1,
288
+ padding=get_padding(kernel_size, 1),
289
+ )
290
+ ),
291
+ ]
292
+ )
293
+ self.convs2.apply(init_weights)
294
+
295
+ def forward(self, x, x_mask=None):
296
+ for c1, c2 in zip(self.convs1, self.convs2):
297
+ xt = F.leaky_relu(x, LRELU_SLOPE)
298
+ if x_mask is not None:
299
+ xt = xt * x_mask
300
+ xt = c1(xt)
301
+ xt = F.leaky_relu(xt, LRELU_SLOPE)
302
+ if x_mask is not None:
303
+ xt = xt * x_mask
304
+ xt = c2(xt)
305
+ x = xt + x
306
+ if x_mask is not None:
307
+ x = x * x_mask
308
+ return x
309
+
310
+ def remove_weight_norm(self):
311
+ for l in self.convs1:
312
+ remove_weight_norm(l)
313
+ for l in self.convs2:
314
+ remove_weight_norm(l)
315
+
316
+
317
+ class ResBlock2(torch.nn.Module):
318
+ def __init__(self, channels, kernel_size=3, dilation=(1, 3)):
319
+ super(ResBlock2, self).__init__()
320
+ self.convs = nn.ModuleList(
321
+ [
322
+ weight_norm(
323
+ Conv1d(
324
+ channels,
325
+ channels,
326
+ kernel_size,
327
+ 1,
328
+ dilation=dilation[0],
329
+ padding=get_padding(kernel_size, dilation[0]),
330
+ )
331
+ ),
332
+ weight_norm(
333
+ Conv1d(
334
+ channels,
335
+ channels,
336
+ kernel_size,
337
+ 1,
338
+ dilation=dilation[1],
339
+ padding=get_padding(kernel_size, dilation[1]),
340
+ )
341
+ ),
342
+ ]
343
+ )
344
+ self.convs.apply(init_weights)
345
+
346
+ def forward(self, x, x_mask=None):
347
+ for c in self.convs:
348
+ xt = F.leaky_relu(x, LRELU_SLOPE)
349
+ if x_mask is not None:
350
+ xt = xt * x_mask
351
+ xt = c(xt)
352
+ x = xt + x
353
+ if x_mask is not None:
354
+ x = x * x_mask
355
+ return x
356
+
357
+ def remove_weight_norm(self):
358
+ for l in self.convs:
359
+ remove_weight_norm(l)
360
+
361
+
362
+ class Log(nn.Module):
363
+ def forward(self, x, x_mask, reverse=False, **kwargs):
364
+ if not reverse:
365
+ y = torch.log(torch.clamp_min(x, 1e-5)) * x_mask
366
+ logdet = torch.sum(-y, [1, 2])
367
+ return y, logdet
368
+ else:
369
+ x = torch.exp(x) * x_mask
370
+ return x
371
+
372
+
373
+ class Flip(nn.Module):
374
+ def forward(self, x, *args, reverse=False, **kwargs):
375
+ x = torch.flip(x, [1])
376
+ if not reverse:
377
+ logdet = torch.zeros(x.size(0)).to(dtype=x.dtype, device=x.device)
378
+ return x, logdet
379
+ else:
380
+ return x
381
+
382
+
383
+ class ElementwiseAffine(nn.Module):
384
+ def __init__(self, channels):
385
+ super().__init__()
386
+ self.channels = channels
387
+ self.m = nn.Parameter(torch.zeros(channels, 1))
388
+ self.logs = nn.Parameter(torch.zeros(channels, 1))
389
+
390
+ def forward(self, x, x_mask, reverse=False, **kwargs):
391
+ if not reverse:
392
+ y = self.m + torch.exp(self.logs) * x
393
+ y = y * x_mask
394
+ logdet = torch.sum(self.logs * x_mask, [1, 2])
395
+ return y, logdet
396
+ else:
397
+ x = (x - self.m) * torch.exp(-self.logs) * x_mask
398
+ return x
399
+
400
+
401
+ class ResidualCouplingLayer(nn.Module):
402
+ def __init__(
403
+ self,
404
+ channels,
405
+ hidden_channels,
406
+ kernel_size,
407
+ dilation_rate,
408
+ n_layers,
409
+ p_dropout=0,
410
+ gin_channels=0,
411
+ mean_only=False,
412
+ ):
413
+ assert channels % 2 == 0, "channels should be divisible by 2"
414
+ super().__init__()
415
+ self.channels = channels
416
+ self.hidden_channels = hidden_channels
417
+ self.kernel_size = kernel_size
418
+ self.dilation_rate = dilation_rate
419
+ self.n_layers = n_layers
420
+ self.half_channels = channels // 2
421
+ self.mean_only = mean_only
422
+
423
+ self.pre = nn.Conv1d(self.half_channels, hidden_channels, 1)
424
+ self.enc = WN(
425
+ hidden_channels,
426
+ kernel_size,
427
+ dilation_rate,
428
+ n_layers,
429
+ p_dropout=p_dropout,
430
+ gin_channels=gin_channels,
431
+ )
432
+ self.post = nn.Conv1d(hidden_channels, self.half_channels * (2 - mean_only), 1)
433
+ self.post.weight.data.zero_()
434
+ self.post.bias.data.zero_()
435
+
436
+ def forward(self, x, x_mask, g=None, reverse=False):
437
+ x0, x1 = torch.split(x, [self.half_channels] * 2, 1)
438
+ h = self.pre(x0) * x_mask
439
+ h = self.enc(h, x_mask, g=g)
440
+ stats = self.post(h) * x_mask
441
+ if not self.mean_only:
442
+ m, logs = torch.split(stats, [self.half_channels] * 2, 1)
443
+ else:
444
+ m = stats
445
+ logs = torch.zeros_like(m)
446
+
447
+ if not reverse:
448
+ x1 = m + x1 * torch.exp(logs) * x_mask
449
+ x = torch.cat([x0, x1], 1)
450
+ logdet = torch.sum(logs, [1, 2])
451
+ return x, logdet
452
+ else:
453
+ x1 = (x1 - m) * torch.exp(-logs) * x_mask
454
+ x = torch.cat([x0, x1], 1)
455
+ return x
456
+
457
+
458
+ class ConvFlow(nn.Module):
459
+ def __init__(
460
+ self,
461
+ in_channels,
462
+ filter_channels,
463
+ kernel_size,
464
+ n_layers,
465
+ num_bins=10,
466
+ tail_bound=5.0,
467
+ ):
468
+ super().__init__()
469
+ self.in_channels = in_channels
470
+ self.filter_channels = filter_channels
471
+ self.kernel_size = kernel_size
472
+ self.n_layers = n_layers
473
+ self.num_bins = num_bins
474
+ self.tail_bound = tail_bound
475
+ self.half_channels = in_channels // 2
476
+
477
+ self.pre = nn.Conv1d(self.half_channels, filter_channels, 1)
478
+ self.convs = DDSConv(filter_channels, kernel_size, n_layers, p_dropout=0.0)
479
+ self.proj = nn.Conv1d(
480
+ filter_channels, self.half_channels * (num_bins * 3 - 1), 1
481
+ )
482
+ self.proj.weight.data.zero_()
483
+ self.proj.bias.data.zero_()
484
+
485
+ def forward(self, x, x_mask, g=None, reverse=False):
486
+ x0, x1 = torch.split(x, [self.half_channels] * 2, 1)
487
+ h = self.pre(x0)
488
+ h = self.convs(h, x_mask, g=g)
489
+ h = self.proj(h) * x_mask
490
+
491
+ b, c, t = x0.shape
492
+ h = h.reshape(b, c, -1, t).permute(0, 1, 3, 2) # [b, cx?, t] -> [b, c, t, ?]
493
+
494
+ unnormalized_widths = h[..., : self.num_bins] / math.sqrt(self.filter_channels)
495
+ unnormalized_heights = h[..., self.num_bins : 2 * self.num_bins] / math.sqrt(
496
+ self.filter_channels
497
+ )
498
+ unnormalized_derivatives = h[..., 2 * self.num_bins :]
499
+
500
+ x1, logabsdet = piecewise_rational_quadratic_transform(
501
+ x1,
502
+ unnormalized_widths,
503
+ unnormalized_heights,
504
+ unnormalized_derivatives,
505
+ inverse=reverse,
506
+ tails="linear",
507
+ tail_bound=self.tail_bound,
508
+ )
509
+
510
+ x = torch.cat([x0, x1], 1) * x_mask
511
+ logdet = torch.sum(logabsdet * x_mask, [1, 2])
512
+ if not reverse:
513
+ return x, logdet
514
+ else:
515
+ return x
516
+
517
+
518
+ class TransformerCouplingLayer(nn.Module):
519
+ def __init__(
520
+ self,
521
+ channels,
522
+ hidden_channels,
523
+ kernel_size,
524
+ n_layers,
525
+ n_heads,
526
+ p_dropout=0,
527
+ filter_channels=0,
528
+ mean_only=False,
529
+ wn_sharing_parameter=None,
530
+ gin_channels=0,
531
+ ):
532
+ assert channels % 2 == 0, "channels should be divisible by 2"
533
+ super().__init__()
534
+ self.channels = channels
535
+ self.hidden_channels = hidden_channels
536
+ self.kernel_size = kernel_size
537
+ self.n_layers = n_layers
538
+ self.half_channels = channels // 2
539
+ self.mean_only = mean_only
540
+
541
+ self.pre = nn.Conv1d(self.half_channels, hidden_channels, 1)
542
+ self.enc = (
543
+ Encoder(
544
+ hidden_channels,
545
+ filter_channels,
546
+ n_heads,
547
+ n_layers,
548
+ kernel_size,
549
+ p_dropout,
550
+ isflow=True,
551
+ gin_channels=gin_channels,
552
+ )
553
+ if wn_sharing_parameter is None
554
+ else wn_sharing_parameter
555
+ )
556
+ self.post = nn.Conv1d(hidden_channels, self.half_channels * (2 - mean_only), 1)
557
+ self.post.weight.data.zero_()
558
+ self.post.bias.data.zero_()
559
+
560
+ def forward(self, x, x_mask, g=None, reverse=False):
561
+ x0, x1 = torch.split(x, [self.half_channels] * 2, 1)
562
+ h = self.pre(x0) * x_mask
563
+ h = self.enc(h, x_mask, g=g)
564
+ stats = self.post(h) * x_mask
565
+ if not self.mean_only:
566
+ m, logs = torch.split(stats, [self.half_channels] * 2, 1)
567
+ else:
568
+ m = stats
569
+ logs = torch.zeros_like(m)
570
+
571
+ if not reverse:
572
+ x1 = m + x1 * torch.exp(logs) * x_mask
573
+ x = torch.cat([x0, x1], 1)
574
+ logdet = torch.sum(logs, [1, 2])
575
+ return x, logdet
576
+ else:
577
+ x1 = (x1 - m) * torch.exp(-logs) * x_mask
578
+ x = torch.cat([x0, x1], 1)
579
+ return x
580
+
581
+ x1, logabsdet = piecewise_rational_quadratic_transform(
582
+ x1,
583
+ unnormalized_widths,
584
+ unnormalized_heights,
585
+ unnormalized_derivatives,
586
+ inverse=reverse,
587
+ tails="linear",
588
+ tail_bound=self.tail_bound,
589
+ )
590
+
591
+ x = torch.cat([x0, x1], 1) * x_mask
592
+ logdet = torch.sum(logabsdet * x_mask, [1, 2])
593
+ if not reverse:
594
+ return x, logdet
595
+ else:
596
+ return x
monotonic_align/__init__.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from numpy import zeros, int32, float32
2
+ from torch import from_numpy
3
+
4
+ from .core import maximum_path_jit
5
+
6
+ def maximum_path(neg_cent, mask):
7
+ device = neg_cent.device
8
+ dtype = neg_cent.dtype
9
+ neg_cent = neg_cent.data.cpu().numpy().astype(float32)
10
+ path = zeros(neg_cent.shape, dtype=int32)
11
+
12
+ t_t_max = mask.sum(1)[:, 0].data.cpu().numpy().astype(int32)
13
+ t_s_max = mask.sum(2)[:, 0].data.cpu().numpy().astype(int32)
14
+ maximum_path_jit(path, neg_cent, t_t_max, t_s_max)
15
+ return from_numpy(path).to(device=device, dtype=dtype)
monotonic_align/core.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numba
2
+
3
+
4
+ @numba.jit(numba.void(numba.int32[:,:,::1], numba.float32[:,:,::1], numba.int32[::1], numba.int32[::1]), nopython=True, nogil=True)
5
+ def maximum_path_jit(paths, values, t_ys, t_xs):
6
+ b = paths.shape[0]
7
+ max_neg_val=-1e9
8
+ for i in range(int(b)):
9
+ path = paths[i]
10
+ value = values[i]
11
+ t_y = t_ys[i]
12
+ t_x = t_xs[i]
13
+
14
+ v_prev = v_cur = 0.0
15
+ index = t_x - 1
16
+
17
+ for y in range(t_y):
18
+ for x in range(max(0, t_x + y - t_y), min(t_x, y + 1)):
19
+ if x == y:
20
+ v_cur = max_neg_val
21
+ else:
22
+ v_cur = value[y-1, x]
23
+ if x == 0:
24
+ if y == 0:
25
+ v_prev = 0.
26
+ else:
27
+ v_prev = max_neg_val
28
+ else:
29
+ v_prev = value[y-1, x-1]
30
+ value[y, x] += max(v_prev, v_cur)
31
+
32
+ for y in range(t_y - 1, -1, -1):
33
+ path[y, index] = 1
34
+ if index != 0 and (index == y or value[y-1, index] < value[y-1, index-1]):
35
+ index = index - 1
requirements.txt ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ torch==2.3.1+cu118
2
+ -f https://mirrors.aliyun.com/pytorch-wheels/cu118
3
+ modelscope[framework]
4
+ amfm_decompy
5
+ tensorboard
6
+ matplotlib
7
+ phonemizer
8
+ Unidecode
9
+ pypinyin
10
+ gradio
11
+ cn2an
12
+ jieba
13
+ numba
14
+ scipy
15
+ av
16
+ librosa==0.9.1
17
+ numpy==1.26.4
text/__init__.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from text.symbols import *
2
+
3
+
4
+ _symbol_to_id = {s: i for i, s in enumerate(symbols)}
5
+
6
+ def cleaned_text_to_sequence(cleaned_text, tones, language):
7
+ '''Converts a string of text to a sequence of IDs corresponding to the symbols in the text.
8
+ Args:
9
+ text: string to convert to a sequence
10
+ Returns:
11
+ List of integers corresponding to the symbols in the text
12
+ '''
13
+ phones = [_symbol_to_id[symbol] for symbol in cleaned_text]
14
+ tone_start = language_tone_start_map[language]
15
+ tones = [i + tone_start for i in tones]
16
+ lang_id = language_id_map[language]
17
+ lang_ids = [lang_id for i in phones]
18
+ return phones, tones, lang_ids
19
+
20
+ def get_bert(norm_text, word2ph, language):
21
+ from .chinese_bert import get_bert_feature as zh_bert
22
+ from .english_bert_mock import get_bert_feature as en_bert
23
+ lang_bert_func_map = {
24
+ 'ZH': zh_bert,
25
+ 'EN': en_bert
26
+ }
27
+ bert = lang_bert_func_map[language](norm_text, word2ph)
28
+ return bert
text/chinese.py ADDED
@@ -0,0 +1,193 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import re
3
+
4
+ import cn2an
5
+ from pypinyin import lazy_pinyin, Style
6
+
7
+ from text import symbols
8
+ from text.symbols import punctuation
9
+ from text.tone_sandhi import ToneSandhi
10
+
11
+ current_file_path = os.path.dirname(__file__)
12
+ pinyin_to_symbol_map = {line.split("\t")[0]: line.strip().split("\t")[1] for line in
13
+ open(os.path.join(current_file_path, 'opencpop-strict.txt')).readlines()}
14
+
15
+ import jieba.posseg as psg
16
+
17
+
18
+ rep_map = {
19
+ ':': ',',
20
+ ';': ',',
21
+ ',': ',',
22
+ '。': '.',
23
+ '!': '!',
24
+ '?': '?',
25
+ '\n': '.',
26
+ "·": ",",
27
+ '、': ",",
28
+ '...': '…',
29
+ '$': '.',
30
+ '“': "'",
31
+ '”': "'",
32
+ '‘': "'",
33
+ '’': "'",
34
+ '(': "'",
35
+ ')': "'",
36
+ '(': "'",
37
+ ')': "'",
38
+ '《': "'",
39
+ '》': "'",
40
+ '【': "'",
41
+ '】': "'",
42
+ '[': "'",
43
+ ']': "'",
44
+ '—': "-",
45
+ '~': "-",
46
+ '~': "-",
47
+ '「': "'",
48
+ '」': "'",
49
+
50
+ }
51
+
52
+ tone_modifier = ToneSandhi()
53
+
54
+ def replace_punctuation(text):
55
+ text = text.replace("嗯", "恩").replace("呣","母")
56
+ pattern = re.compile('|'.join(re.escape(p) for p in rep_map.keys()))
57
+
58
+ replaced_text = pattern.sub(lambda x: rep_map[x.group()], text)
59
+
60
+ replaced_text = re.sub(r'[^\u4e00-\u9fa5'+"".join(punctuation)+r']+', '', replaced_text)
61
+
62
+ return replaced_text
63
+
64
+ def g2p(text):
65
+ pattern = r'(?<=[{0}])\s*'.format(''.join(punctuation))
66
+ sentences = [i for i in re.split(pattern, text) if i.strip()!='']
67
+ phones, tones, word2ph = _g2p(sentences)
68
+ assert sum(word2ph) == len(phones)
69
+ assert len(word2ph) == len(text) #Sometimes it will crash,you can add a try-catch.
70
+ phones = ['_'] + phones + ["_"]
71
+ tones = [0] + tones + [0]
72
+ word2ph = [1] + word2ph + [1]
73
+ return phones, tones, word2ph
74
+
75
+
76
+ def _get_initials_finals(word):
77
+ initials = []
78
+ finals = []
79
+ orig_initials = lazy_pinyin(
80
+ word, neutral_tone_with_five=True, style=Style.INITIALS)
81
+ orig_finals = lazy_pinyin(
82
+ word, neutral_tone_with_five=True, style=Style.FINALS_TONE3)
83
+ for c, v in zip(orig_initials, orig_finals):
84
+ initials.append(c)
85
+ finals.append(v)
86
+ return initials, finals
87
+
88
+
89
+ def _g2p(segments):
90
+ phones_list = []
91
+ tones_list = []
92
+ word2ph = []
93
+ for seg in segments:
94
+ pinyins = []
95
+ # Replace all English words in the sentence
96
+ seg = re.sub('[a-zA-Z]+', '', seg)
97
+ seg_cut = psg.lcut(seg)
98
+ initials = []
99
+ finals = []
100
+ seg_cut = tone_modifier.pre_merge_for_modify(seg_cut)
101
+ for word, pos in seg_cut:
102
+ if pos == 'eng':
103
+ continue
104
+ sub_initials, sub_finals = _get_initials_finals(word)
105
+ sub_finals = tone_modifier.modified_tone(word, pos,
106
+ sub_finals)
107
+ initials.append(sub_initials)
108
+ finals.append(sub_finals)
109
+
110
+ # assert len(sub_initials) == len(sub_finals) == len(word)
111
+ initials = sum(initials, [])
112
+ finals = sum(finals, [])
113
+ #
114
+ for c, v in zip(initials, finals):
115
+ raw_pinyin = c+v
116
+ # NOTE: post process for pypinyin outputs
117
+ # we discriminate i, ii and iii
118
+ if c == v:
119
+ assert c in punctuation
120
+ phone = [c]
121
+ tone = '0'
122
+ word2ph.append(1)
123
+ else:
124
+ v_without_tone = v[:-1]
125
+ tone = v[-1]
126
+
127
+ pinyin = c+v_without_tone
128
+ assert tone in '12345'
129
+
130
+ if c:
131
+ # 多音节
132
+ v_rep_map = {
133
+ "uei": 'ui',
134
+ 'iou': 'iu',
135
+ 'uen': 'un',
136
+ }
137
+ if v_without_tone in v_rep_map.keys():
138
+ pinyin = c+v_rep_map[v_without_tone]
139
+ else:
140
+ # 单音节
141
+ pinyin_rep_map = {
142
+ 'ing': 'ying',
143
+ 'i': 'yi',
144
+ 'in': 'yin',
145
+ 'u': 'wu',
146
+ }
147
+ if pinyin in pinyin_rep_map.keys():
148
+ pinyin = pinyin_rep_map[pinyin]
149
+ else:
150
+ single_rep_map = {
151
+ 'v': 'yu',
152
+ 'e': 'e',
153
+ 'i': 'y',
154
+ 'u': 'w',
155
+ }
156
+ if pinyin[0] in single_rep_map.keys():
157
+ pinyin = single_rep_map[pinyin[0]]+pinyin[1:]
158
+
159
+ assert pinyin in pinyin_to_symbol_map.keys(), (pinyin, seg, raw_pinyin)
160
+ phone = pinyin_to_symbol_map[pinyin].split(' ')
161
+ word2ph.append(len(phone))
162
+
163
+ phones_list += phone
164
+ tones_list += [int(tone)] * len(phone)
165
+ return phones_list, tones_list, word2ph
166
+
167
+
168
+
169
+ def text_normalize(text):
170
+ numbers = re.findall(r'\d+(?:\.?\d+)?', text)
171
+ for number in numbers:
172
+ text = text.replace(number, cn2an.an2cn(number), 1)
173
+ text = replace_punctuation(text)
174
+ return text
175
+
176
+ def get_bert_feature(text, word2ph):
177
+ from text import chinese_bert
178
+ return chinese_bert.get_bert_feature(text, word2ph)
179
+
180
+ if __name__ == '__main__':
181
+ from text.chinese_bert import get_bert_feature
182
+ text = "啊!但是《原神》是由,米哈\游自主, [研发]的一款全.新开放世界.冒险游戏"
183
+ text = text_normalize(text)
184
+ print(text)
185
+ phones, tones, word2ph = g2p(text)
186
+ bert = get_bert_feature(text, word2ph)
187
+
188
+ print(phones, tones, word2ph, bert.shape)
189
+
190
+
191
+ # # 示例用法
192
+ # text = "这是一个示例文本:,你好!这是一个测试...."
193
+ # print(g2p_paddle(text)) # 输出: 这是一个示例文本你好这是一个测试
text/chinese_bert.py ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import torch
3
+ from modelscope import snapshot_download
4
+ from transformers import AutoTokenizer, AutoModelForMaskedLM
5
+
6
+ device = torch.device(
7
+ "cuda"
8
+ if torch.cuda.is_available()
9
+ else (
10
+ "mps"
11
+ if sys.platform == "darwin" and torch.backends.mps.is_available()
12
+ else "cpu"
13
+ )
14
+ )
15
+
16
+ # 模型下载
17
+ model_dir = snapshot_download("dienstag/chinese-roberta-wwm-ext-large")
18
+ tokenizer = AutoTokenizer.from_pretrained(model_dir)
19
+ model = AutoModelForMaskedLM.from_pretrained(model_dir).to(device)
20
+
21
+
22
+ def get_bert_feature(text, word2ph):
23
+ with torch.no_grad():
24
+ inputs = tokenizer(text, return_tensors="pt")
25
+ for i in inputs:
26
+ inputs[i] = inputs[i].to(device)
27
+ res = model(**inputs, output_hidden_states=True)
28
+ res = torch.cat(res["hidden_states"][-3:-2], -1)[0].cpu()
29
+
30
+ assert len(word2ph) == len(text) + 2
31
+ word2phone = word2ph
32
+ phone_level_feature = []
33
+ for i in range(len(word2phone)):
34
+ repeat_feature = res[i].repeat(word2phone[i], 1)
35
+ phone_level_feature.append(repeat_feature)
36
+
37
+ phone_level_feature = torch.cat(phone_level_feature, dim=0)
38
+
39
+ return phone_level_feature.T
40
+
41
+
42
+ if __name__ == "__main__":
43
+ # feature = get_bert_feature('你好,我是说的道理。')
44
+ import torch
45
+
46
+ word_level_feature = torch.rand(38, 1024) # 12个词,每个词1024维特征
47
+ word2phone = [
48
+ 1,
49
+ 2,
50
+ 1,
51
+ 2,
52
+ 2,
53
+ 1,
54
+ 2,
55
+ 2,
56
+ 1,
57
+ 2,
58
+ 2,
59
+ 1,
60
+ 2,
61
+ 2,
62
+ 2,
63
+ 2,
64
+ 2,
65
+ 1,
66
+ 1,
67
+ 2,
68
+ 2,
69
+ 1,
70
+ 2,
71
+ 2,
72
+ 2,
73
+ 2,
74
+ 1,
75
+ 2,
76
+ 2,
77
+ 2,
78
+ 2,
79
+ 2,
80
+ 1,
81
+ 2,
82
+ 2,
83
+ 2,
84
+ 2,
85
+ 1,
86
+ ]
87
+
88
+ # 计算总帧数
89
+ total_frames = sum(word2phone)
90
+ print(word_level_feature.shape)
91
+ print(word2phone)
92
+ phone_level_feature = []
93
+ for i in range(len(word2phone)):
94
+ print(word_level_feature[i].shape)
95
+
96
+ # 对每个词重复word2phone[i]次
97
+ repeat_feature = word_level_feature[i].repeat(word2phone[i], 1)
98
+ phone_level_feature.append(repeat_feature)
99
+
100
+ phone_level_feature = torch.cat(phone_level_feature, dim=0)
101
+ print(phone_level_feature.shape) # torch.Size([36, 1024])
text/cleaner.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from text import chinese, cleaned_text_to_sequence
2
+
3
+
4
+ language_module_map = {
5
+ 'ZH': chinese
6
+ }
7
+
8
+
9
+ def clean_text(text, language):
10
+ language_module = language_module_map[language]
11
+ norm_text = language_module.text_normalize(text)
12
+ phones, tones, word2ph = language_module.g2p(norm_text)
13
+ return norm_text, phones, tones, word2ph
14
+
15
+ def clean_text_bert(text, language):
16
+ language_module = language_module_map[language]
17
+ norm_text = language_module.text_normalize(text)
18
+ phones, tones, word2ph = language_module.g2p(norm_text)
19
+ bert = language_module.get_bert_feature(norm_text, word2ph)
20
+ return phones, tones, bert
21
+
22
+ def text_to_sequence(text, language):
23
+ norm_text, phones, tones, word2ph = clean_text(text, language)
24
+ return cleaned_text_to_sequence(phones, tones, language)
25
+
26
+ if __name__ == '__main__':
27
+ pass
text/english_bert_mock.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ import torch
2
+
3
+
4
+ def get_bert_feature(norm_text, word2ph):
5
+ return torch.zeros(1024, sum(word2ph))
text/opencpop-strict.txt ADDED
@@ -0,0 +1,429 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ a AA a
2
+ ai AA ai
3
+ an AA an
4
+ ang AA ang
5
+ ao AA ao
6
+ ba b a
7
+ bai b ai
8
+ ban b an
9
+ bang b ang
10
+ bao b ao
11
+ bei b ei
12
+ ben b en
13
+ beng b eng
14
+ bi b i
15
+ bian b ian
16
+ biao b iao
17
+ bie b ie
18
+ bin b in
19
+ bing b ing
20
+ bo b o
21
+ bu b u
22
+ ca c a
23
+ cai c ai
24
+ can c an
25
+ cang c ang
26
+ cao c ao
27
+ ce c e
28
+ cei c ei
29
+ cen c en
30
+ ceng c eng
31
+ cha ch a
32
+ chai ch ai
33
+ chan ch an
34
+ chang ch ang
35
+ chao ch ao
36
+ che ch e
37
+ chen ch en
38
+ cheng ch eng
39
+ chi ch ir
40
+ chong ch ong
41
+ chou ch ou
42
+ chu ch u
43
+ chua ch ua
44
+ chuai ch uai
45
+ chuan ch uan
46
+ chuang ch uang
47
+ chui ch ui
48
+ chun ch un
49
+ chuo ch uo
50
+ ci c i0
51
+ cong c ong
52
+ cou c ou
53
+ cu c u
54
+ cuan c uan
55
+ cui c ui
56
+ cun c un
57
+ cuo c uo
58
+ da d a
59
+ dai d ai
60
+ dan d an
61
+ dang d ang
62
+ dao d ao
63
+ de d e
64
+ dei d ei
65
+ den d en
66
+ deng d eng
67
+ di d i
68
+ dia d ia
69
+ dian d ian
70
+ diao d iao
71
+ die d ie
72
+ ding d ing
73
+ diu d iu
74
+ dong d ong
75
+ dou d ou
76
+ du d u
77
+ duan d uan
78
+ dui d ui
79
+ dun d un
80
+ duo d uo
81
+ e EE e
82
+ ei EE ei
83
+ en EE en
84
+ eng EE eng
85
+ er EE er
86
+ fa f a
87
+ fan f an
88
+ fang f ang
89
+ fei f ei
90
+ fen f en
91
+ feng f eng
92
+ fo f o
93
+ fou f ou
94
+ fu f u
95
+ ga g a
96
+ gai g ai
97
+ gan g an
98
+ gang g ang
99
+ gao g ao
100
+ ge g e
101
+ gei g ei
102
+ gen g en
103
+ geng g eng
104
+ gong g ong
105
+ gou g ou
106
+ gu g u
107
+ gua g ua
108
+ guai g uai
109
+ guan g uan
110
+ guang g uang
111
+ gui g ui
112
+ gun g un
113
+ guo g uo
114
+ ha h a
115
+ hai h ai
116
+ han h an
117
+ hang h ang
118
+ hao h ao
119
+ he h e
120
+ hei h ei
121
+ hen h en
122
+ heng h eng
123
+ hong h ong
124
+ hou h ou
125
+ hu h u
126
+ hua h ua
127
+ huai h uai
128
+ huan h uan
129
+ huang h uang
130
+ hui h ui
131
+ hun h un
132
+ huo h uo
133
+ ji j i
134
+ jia j ia
135
+ jian j ian
136
+ jiang j iang
137
+ jiao j iao
138
+ jie j ie
139
+ jin j in
140
+ jing j ing
141
+ jiong j iong
142
+ jiu j iu
143
+ ju j v
144
+ jv j v
145
+ juan j van
146
+ jvan j van
147
+ jue j ve
148
+ jve j ve
149
+ jun j vn
150
+ jvn j vn
151
+ ka k a
152
+ kai k ai
153
+ kan k an
154
+ kang k ang
155
+ kao k ao
156
+ ke k e
157
+ kei k ei
158
+ ken k en
159
+ keng k eng
160
+ kong k ong
161
+ kou k ou
162
+ ku k u
163
+ kua k ua
164
+ kuai k uai
165
+ kuan k uan
166
+ kuang k uang
167
+ kui k ui
168
+ kun k un
169
+ kuo k uo
170
+ la l a
171
+ lai l ai
172
+ lan l an
173
+ lang l ang
174
+ lao l ao
175
+ le l e
176
+ lei l ei
177
+ leng l eng
178
+ li l i
179
+ lia l ia
180
+ lian l ian
181
+ liang l iang
182
+ liao l iao
183
+ lie l ie
184
+ lin l in
185
+ ling l ing
186
+ liu l iu
187
+ lo l o
188
+ long l ong
189
+ lou l ou
190
+ lu l u
191
+ luan l uan
192
+ lun l un
193
+ luo l uo
194
+ lv l v
195
+ lve l ve
196
+ ma m a
197
+ mai m ai
198
+ man m an
199
+ mang m ang
200
+ mao m ao
201
+ me m e
202
+ mei m ei
203
+ men m en
204
+ meng m eng
205
+ mi m i
206
+ mian m ian
207
+ miao m iao
208
+ mie m ie
209
+ min m in
210
+ ming m ing
211
+ miu m iu
212
+ mo m o
213
+ mou m ou
214
+ mu m u
215
+ na n a
216
+ nai n ai
217
+ nan n an
218
+ nang n ang
219
+ nao n ao
220
+ ne n e
221
+ nei n ei
222
+ nen n en
223
+ neng n eng
224
+ ni n i
225
+ nian n ian
226
+ niang n iang
227
+ niao n iao
228
+ nie n ie
229
+ nin n in
230
+ ning n ing
231
+ niu n iu
232
+ nong n ong
233
+ nou n ou
234
+ nu n u
235
+ nuan n uan
236
+ nun n un
237
+ nuo n uo
238
+ nv n v
239
+ nve n ve
240
+ o OO o
241
+ ou OO ou
242
+ pa p a
243
+ pai p ai
244
+ pan p an
245
+ pang p ang
246
+ pao p ao
247
+ pei p ei
248
+ pen p en
249
+ peng p eng
250
+ pi p i
251
+ pian p ian
252
+ piao p iao
253
+ pie p ie
254
+ pin p in
255
+ ping p ing
256
+ po p o
257
+ pou p ou
258
+ pu p u
259
+ qi q i
260
+ qia q ia
261
+ qian q ian
262
+ qiang q iang
263
+ qiao q iao
264
+ qie q ie
265
+ qin q in
266
+ qing q ing
267
+ qiong q iong
268
+ qiu q iu
269
+ qu q v
270
+ qv q v
271
+ quan q van
272
+ qvan q van
273
+ que q ve
274
+ qve q ve
275
+ qun q vn
276
+ qvn q vn
277
+ ran r an
278
+ rang r ang
279
+ rao r ao
280
+ re r e
281
+ ren r en
282
+ reng r eng
283
+ ri r ir
284
+ rong r ong
285
+ rou r ou
286
+ ru r u
287
+ rua r ua
288
+ ruan r uan
289
+ rui r ui
290
+ run r un
291
+ ruo r uo
292
+ sa s a
293
+ sai s ai
294
+ san s an
295
+ sang s ang
296
+ sao s ao
297
+ se s e
298
+ sen s en
299
+ seng s eng
300
+ sha sh a
301
+ shai sh ai
302
+ shan sh an
303
+ shang sh ang
304
+ shao sh ao
305
+ she sh e
306
+ shei sh ei
307
+ shen sh en
308
+ sheng sh eng
309
+ shi sh ir
310
+ shou sh ou
311
+ shu sh u
312
+ shua sh ua
313
+ shuai sh uai
314
+ shuan sh uan
315
+ shuang sh uang
316
+ shui sh ui
317
+ shun sh un
318
+ shuo sh uo
319
+ si s i0
320
+ song s ong
321
+ sou s ou
322
+ su s u
323
+ suan s uan
324
+ sui s ui
325
+ sun s un
326
+ suo s uo
327
+ ta t a
328
+ tai t ai
329
+ tan t an
330
+ tang t ang
331
+ tao t ao
332
+ te t e
333
+ tei t ei
334
+ teng t eng
335
+ ti t i
336
+ tian t ian
337
+ tiao t iao
338
+ tie t ie
339
+ ting t ing
340
+ tong t ong
341
+ tou t ou
342
+ tu t u
343
+ tuan t uan
344
+ tui t ui
345
+ tun t un
346
+ tuo t uo
347
+ wa w a
348
+ wai w ai
349
+ wan w an
350
+ wang w ang
351
+ wei w ei
352
+ wen w en
353
+ weng w eng
354
+ wo w o
355
+ wu w u
356
+ xi x i
357
+ xia x ia
358
+ xian x ian
359
+ xiang x iang
360
+ xiao x iao
361
+ xie x ie
362
+ xin x in
363
+ xing x ing
364
+ xiong x iong
365
+ xiu x iu
366
+ xu x v
367
+ xv x v
368
+ xuan x van
369
+ xvan x van
370
+ xue x ve
371
+ xve x ve
372
+ xun x vn
373
+ xvn x vn
374
+ ya y a
375
+ yan y En
376
+ yang y ang
377
+ yao y ao
378
+ ye y E
379
+ yi y i
380
+ yin y in
381
+ ying y ing
382
+ yo y o
383
+ yong y ong
384
+ you y ou
385
+ yu y v
386
+ yv y v
387
+ yuan y van
388
+ yvan y van
389
+ yue y ve
390
+ yve y ve
391
+ yun y vn
392
+ yvn y vn
393
+ za z a
394
+ zai z ai
395
+ zan z an
396
+ zang z ang
397
+ zao z ao
398
+ ze z e
399
+ zei z ei
400
+ zen z en
401
+ zeng z eng
402
+ zha zh a
403
+ zhai zh ai
404
+ zhan zh an
405
+ zhang zh ang
406
+ zhao zh ao
407
+ zhe zh e
408
+ zhei zh ei
409
+ zhen zh en
410
+ zheng zh eng
411
+ zhi zh ir
412
+ zhong zh ong
413
+ zhou zh ou
414
+ zhu zh u
415
+ zhua zh ua
416
+ zhuai zh uai
417
+ zhuan zh uan
418
+ zhuang zh uang
419
+ zhui zh ui
420
+ zhun zh un
421
+ zhuo zh uo
422
+ zi z i0
423
+ zong z ong
424
+ zou z ou
425
+ zu z u
426
+ zuan z uan
427
+ zui z ui
428
+ zun z un
429
+ zuo z uo
text/symbols.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ punctuation = ['!', '?', '…', ",", ".", "'", '-']
2
+ pu_symbols = punctuation + ["SP", "UNK"]
3
+ pad = '_'
4
+
5
+ # chinese
6
+ zh_symbols = ['E', 'En', 'a', 'ai', 'an', 'ang', 'ao', 'b', 'c', 'ch', 'd', 'e', 'ei', 'en', 'eng', 'er', 'f', 'g', 'h',
7
+ 'i', 'i0', 'ia', 'ian', 'iang', 'iao', 'ie', 'in', 'ing', 'iong', 'ir', 'iu', 'j', 'k', 'l', 'm', 'n', 'o',
8
+ 'ong',
9
+ 'ou', 'p', 'q', 'r', 's', 'sh', 't', 'u', 'ua', 'uai', 'uan', 'uang', 'ui', 'un', 'uo', 'v', 'van', 've', 'vn',
10
+ 'w', 'x', 'y', 'z', 'zh',
11
+ "AA", "EE", "OO"]
12
+ num_zh_tones = 6
13
+
14
+ # japanese
15
+ ja_symbols = ['I', 'N', 'U', 'a', 'b', 'by', 'ch', 'cl', 'd', 'dy', 'e', 'f', 'g', 'gy', 'h', 'hy', 'i', 'j', 'k', 'ky',
16
+ 'm', 'my', 'n', 'ny', 'o', 'p', 'py', 'r', 'ry', 's', 'sh', 't', 'ts', 'u', 'V', 'w', 'y', 'z']
17
+ num_ja_tones = 1
18
+
19
+ # English
20
+ en_symbols = ['aa', 'ae', 'ah', 'ao', 'aw', 'ay', 'b', 'ch', 'd', 'dh', 'eh', 'er', 'ey', 'f', 'g', 'hh', 'ih', 'iy',
21
+ 'jh', 'k', 'l', 'm', 'n', 'ng', 'ow', 'oy', 'p', 'r', 's',
22
+ 'sh', 't', 'th', 'uh', 'uw', 'V', 'w', 'y', 'z', 'zh']
23
+ num_en_tones = 4
24
+
25
+ # combine all symbols
26
+ normal_symbols = sorted(set(zh_symbols + ja_symbols + en_symbols))
27
+ symbols = [pad] + normal_symbols + pu_symbols
28
+ sil_phonemes_ids = [symbols.index(i) for i in pu_symbols]
29
+
30
+ # combine all tones
31
+ num_tones = num_zh_tones + num_ja_tones + num_en_tones
32
+
33
+ # language maps
34
+ language_id_map = {
35
+ 'ZH': 0,
36
+ "JA": 1,
37
+ "EN": 2
38
+ }
39
+ num_languages = len(language_id_map.keys())
40
+
41
+ language_tone_start_map = {
42
+ 'ZH': 0,
43
+ "JA": num_zh_tones,
44
+ "EN": num_zh_tones + num_ja_tones
45
+ }
46
+
47
+ if __name__ == '__main__':
48
+ a = set(zh_symbols)
49
+ b = set(en_symbols)
50
+ print(sorted(a&b))
51
+
text/tone_sandhi.py ADDED
@@ -0,0 +1,351 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from typing import List
15
+ from typing import Tuple
16
+
17
+ import jieba
18
+ from pypinyin import lazy_pinyin
19
+ from pypinyin import Style
20
+
21
+
22
+ class ToneSandhi():
23
+ def __init__(self):
24
+ self.must_neural_tone_words = {
25
+ '麻烦', '麻利', '鸳鸯', '高粱', '骨头', '骆驼', '马虎', '首饰', '馒头', '馄饨', '风筝',
26
+ '难为', '队伍', '阔气', '闺女', '门道', '锄头', '铺盖', '铃铛', '铁匠', '钥匙', '里脊',
27
+ '里头', '部分', '那么', '道士', '造化', '迷糊', '连累', '这么', '这个', '运气', '过去',
28
+ '软和', '转悠', '踏实', '跳蚤', '跟头', '趔趄', '财主', '豆腐', '讲究', '记性', '记号',
29
+ '认识', '规矩', '见识', '裁缝', '补丁', '衣裳', '衣服', '衙门', '街坊', '行李', '行当',
30
+ '蛤蟆', '蘑菇', '薄荷', '葫芦', '葡萄', '萝卜', '荸荠', '苗条', '苗头', '苍蝇', '芝麻',
31
+ '舒服', '舒坦', '舌头', '自在', '膏药', '脾气', '脑袋', '脊梁', '能耐', '胳膊', '胭脂',
32
+ '胡萝', '胡琴', '胡同', '聪明', '耽误', '耽搁', '耷拉', '耳朵', '老爷', '老实', '老婆',
33
+ '老头', '老太', '翻腾', '罗嗦', '罐头', '编辑', '结实', '红火', '累赘', '糨糊', '糊涂',
34
+ '精神', '粮食', '簸箕', '篱笆', '算计', '算盘', '答应', '笤帚', '笑语', '笑话', '窟窿',
35
+ '窝囊', '窗户', '稳当', '稀罕', '称呼', '秧歌', '秀气', '秀才', '福气', '祖宗', '砚台',
36
+ '码头', '石榴', '石头', '石匠', '知识', '眼睛', '眯缝', '眨巴', '眉毛', '相声', '盘算',
37
+ '白净', '痢疾', '痛快', '疟疾', '疙瘩', '疏忽', '畜生', '生意', '甘蔗', '琵琶', '琢磨',
38
+ '琉璃', '玻璃', '玫瑰', '玄乎', '狐狸', '状元', '特务', '牲口', '牙碜', '牌楼', '爽快',
39
+ '爱人', '热闹', '烧饼', '烟筒', '烂糊', '点心', '炊帚', '灯笼', '火候', '漂亮', '滑溜',
40
+ '溜达', '温和', '清楚', '消息', '浪头', '活泼', '比方', '正经', '欺负', '模糊', '槟榔',
41
+ '棺材', '棒槌', '棉花', '核桃', '栅栏', '柴火', '架势', '枕头', '枇杷', '机灵', '本事',
42
+ '木头', '木匠', '朋友', '月饼', '月亮', '暖和', '明白', '时候', '新鲜', '故事', '收拾',
43
+ '收成', '提防', '挖苦', '挑剔', '指甲', '指头', '拾掇', '拳头', '拨弄', '招牌', '招呼',
44
+ '抬举', '护士', '折腾', '扫帚', '打量', '打算', '打点', '打扮', '打听', '打发', '扎实',
45
+ '扁担', '戒指', '懒得', '意识', '意思', '情形', '悟性', '怪物', '思量', '怎么', '念头',
46
+ '念叨', '快活', '忙活', '志气', '心思', '得罪', '张罗', '弟兄', '开通', '应酬', '庄稼',
47
+ '干事', '帮手', '帐篷', '希罕', '师父', '师傅', '巴结', '巴掌', '差事', '工夫', '岁数',
48
+ '屁股', '尾巴', '少爷', '小气', '小伙', '将就', '对头', '对付', '寡妇', '家伙', '客气',
49
+ '实在', '官司', '学问', '学生', '字号', '嫁妆', '媳妇', '媒人', '婆家', '娘家', '委屈',
50
+ '姑娘', '姐夫', '妯娌', '妥当', '妖精', '奴才', '女婿', '头发', '太阳', '大爷', '大方',
51
+ '大意', '大夫', '多少', '多么', '外甥', '壮实', '地道', '地方', '在乎', '困难', '嘴巴',
52
+ '嘱咐', '嘟囔', '嘀咕', '喜欢', '喇嘛', '喇叭', '商量', '唾沫', '哑巴', '哈欠', '哆嗦',
53
+ '咳嗽', '和尚', '告诉', '告示', '含糊', '吓唬', '后头', '名字', '名堂', '合同', '吆喝',
54
+ '叫唤', '口袋', '厚道', '厉害', '千斤', '包袱', '包涵', '匀称', '勤快', '动静', '动弹',
55
+ '功夫', '力气', '前头', '刺猬', '刺激', '别扭', '利落', '利索', '利害', '分析', '出息',
56
+ '凑合', '凉快', '冷战', '冤枉', '冒失', '养活', '关系', '先生', '兄弟', '便宜', '使唤',
57
+ '佩服', '作坊', '体面', '位置', '似的', '伙计', '休息', '什么', '人家', '亲戚', '亲家',
58
+ '交情', '云彩', '事情', '买卖', '主意', '丫头', '丧气', '两口', '东西', '东家', '世故',
59
+ '不由', '不在', '下水', '下巴', '上头', '上司', '丈夫', '丈人', '一辈', '那个', '菩萨',
60
+ '父亲', '母亲', '咕噜', '邋遢', '费用', '冤家', '甜头', '介绍', '荒唐', '大人', '泥鳅',
61
+ '幸福', '熟悉', '计划', '扑腾', '蜡烛', '姥爷', '照顾', '喉咙', '吉他', '弄堂', '蚂蚱',
62
+ '凤凰', '拖沓', '寒碜', '糟蹋', '倒腾', '报复', '逻辑', '盘缠', '喽啰', '牢骚', '咖喱',
63
+ '扫把', '惦记'
64
+ }
65
+ self.must_not_neural_tone_words = {
66
+ "男子", "女子", "分子", "原子", "量子", "莲子", "石子", "瓜子", "电子", "人人", "虎虎"
67
+ }
68
+ self.punc = ":,;。?!“”‘’':,;.?!"
69
+
70
+ # the meaning of jieba pos tag: https://blog.csdn.net/weixin_44174352/article/details/113731041
71
+ # e.g.
72
+ # word: "家里"
73
+ # pos: "s"
74
+ # finals: ['ia1', 'i3']
75
+ def _neural_sandhi(self, word: str, pos: str,
76
+ finals: List[str]) -> List[str]:
77
+
78
+ # reduplication words for n. and v. e.g. 奶奶, 试试, 旺旺
79
+ for j, item in enumerate(word):
80
+ if j - 1 >= 0 and item == word[j - 1] and pos[0] in {
81
+ "n", "v", "a"
82
+ } and word not in self.must_not_neural_tone_words:
83
+ finals[j] = finals[j][:-1] + "5"
84
+ ge_idx = word.find("个")
85
+ if len(word) >= 1 and word[-1] in "吧呢啊呐噻嘛吖嗨呐哦哒额滴哩哟喽啰耶喔诶":
86
+ finals[-1] = finals[-1][:-1] + "5"
87
+ elif len(word) >= 1 and word[-1] in "的地得":
88
+ finals[-1] = finals[-1][:-1] + "5"
89
+ # e.g. 走了, 看着, 去过
90
+ # elif len(word) == 1 and word in "了着过" and pos in {"ul", "uz", "ug"}:
91
+ # finals[-1] = finals[-1][:-1] + "5"
92
+ elif len(word) > 1 and word[-1] in "们子" and pos in {
93
+ "r", "n"
94
+ } and word not in self.must_not_neural_tone_words:
95
+ finals[-1] = finals[-1][:-1] + "5"
96
+ # e.g. 桌上, 地下, 家里
97
+ elif len(word) > 1 and word[-1] in "上下里" and pos in {"s", "l", "f"}:
98
+ finals[-1] = finals[-1][:-1] + "5"
99
+ # e.g. 上来, 下去
100
+ elif len(word) > 1 and word[-1] in "来去" and word[-2] in "上下进出回过起开":
101
+ finals[-1] = finals[-1][:-1] + "5"
102
+ # 个做量词
103
+ elif (ge_idx >= 1 and
104
+ (word[ge_idx - 1].isnumeric() or
105
+ word[ge_idx - 1] in "几有两半多各整每做是")) or word == '个':
106
+ finals[ge_idx] = finals[ge_idx][:-1] + "5"
107
+ else:
108
+ if word in self.must_neural_tone_words or word[
109
+ -2:] in self.must_neural_tone_words:
110
+ finals[-1] = finals[-1][:-1] + "5"
111
+
112
+ word_list = self._split_word(word)
113
+ finals_list = [finals[:len(word_list[0])], finals[len(word_list[0]):]]
114
+ for i, word in enumerate(word_list):
115
+ # conventional neural in Chinese
116
+ if word in self.must_neural_tone_words or word[
117
+ -2:] in self.must_neural_tone_words:
118
+ finals_list[i][-1] = finals_list[i][-1][:-1] + "5"
119
+ finals = sum(finals_list, [])
120
+ return finals
121
+
122
+ def _bu_sandhi(self, word: str, finals: List[str]) -> List[str]:
123
+ # e.g. 看不懂
124
+ if len(word) == 3 and word[1] == "不":
125
+ finals[1] = finals[1][:-1] + "5"
126
+ else:
127
+ for i, char in enumerate(word):
128
+ # "不" before tone4 should be bu2, e.g. 不怕
129
+ if char == "不" and i + 1 < len(word) and finals[i +
130
+ 1][-1] == "4":
131
+ finals[i] = finals[i][:-1] + "2"
132
+ return finals
133
+
134
+ def _yi_sandhi(self, word: str, finals: List[str]) -> List[str]:
135
+ # "一" in number sequences, e.g. 一零零, 二一零
136
+ if word.find("一") != -1 and all(
137
+ [item.isnumeric() for item in word if item != "一"]):
138
+ return finals
139
+ # "一" between reduplication words shold be yi5, e.g. 看一看
140
+ elif len(word) == 3 and word[1] == "一" and word[0] == word[-1]:
141
+ finals[1] = finals[1][:-1] + "5"
142
+ # when "一" is ordinal word, it should be yi1
143
+ elif word.startswith("第一"):
144
+ finals[1] = finals[1][:-1] + "1"
145
+ else:
146
+ for i, char in enumerate(word):
147
+ if char == "一" and i + 1 < len(word):
148
+ # "一" before tone4 should be yi2, e.g. 一段
149
+ if finals[i + 1][-1] == "4":
150
+ finals[i] = finals[i][:-1] + "2"
151
+ # "一" before non-tone4 should be yi4, e.g. 一天
152
+ else:
153
+ # "一" 后面如果是标点,还读一声
154
+ if word[i + 1] not in self.punc:
155
+ finals[i] = finals[i][:-1] + "4"
156
+ return finals
157
+
158
+ def _split_word(self, word: str) -> List[str]:
159
+ word_list = jieba.cut_for_search(word)
160
+ word_list = sorted(word_list, key=lambda i: len(i), reverse=False)
161
+ first_subword = word_list[0]
162
+ first_begin_idx = word.find(first_subword)
163
+ if first_begin_idx == 0:
164
+ second_subword = word[len(first_subword):]
165
+ new_word_list = [first_subword, second_subword]
166
+ else:
167
+ second_subword = word[:-len(first_subword)]
168
+ new_word_list = [second_subword, first_subword]
169
+ return new_word_list
170
+
171
+ def _three_sandhi(self, word: str, finals: List[str]) -> List[str]:
172
+ if len(word) == 2 and self._all_tone_three(finals):
173
+ finals[0] = finals[0][:-1] + "2"
174
+ elif len(word) == 3:
175
+ word_list = self._split_word(word)
176
+ if self._all_tone_three(finals):
177
+ # disyllabic + monosyllabic, e.g. 蒙古/包
178
+ if len(word_list[0]) == 2:
179
+ finals[0] = finals[0][:-1] + "2"
180
+ finals[1] = finals[1][:-1] + "2"
181
+ # monosyllabic + disyllabic, e.g. 纸/老虎
182
+ elif len(word_list[0]) == 1:
183
+ finals[1] = finals[1][:-1] + "2"
184
+ else:
185
+ finals_list = [
186
+ finals[:len(word_list[0])], finals[len(word_list[0]):]
187
+ ]
188
+ if len(finals_list) == 2:
189
+ for i, sub in enumerate(finals_list):
190
+ # e.g. 所有/人
191
+ if self._all_tone_three(sub) and len(sub) == 2:
192
+ finals_list[i][0] = finals_list[i][0][:-1] + "2"
193
+ # e.g. 好/喜欢
194
+ elif i == 1 and not self._all_tone_three(sub) and finals_list[i][0][-1] == "3" and \
195
+ finals_list[0][-1][-1] == "3":
196
+
197
+ finals_list[0][-1] = finals_list[0][-1][:-1] + "2"
198
+ finals = sum(finals_list, [])
199
+ # split idiom into two words who's length is 2
200
+ elif len(word) == 4:
201
+ finals_list = [finals[:2], finals[2:]]
202
+ finals = []
203
+ for sub in finals_list:
204
+ if self._all_tone_three(sub):
205
+ sub[0] = sub[0][:-1] + "2"
206
+ finals += sub
207
+
208
+ return finals
209
+
210
+ def _all_tone_three(self, finals: List[str]) -> bool:
211
+ return all(x[-1] == "3" for x in finals)
212
+
213
+ # merge "不" and the word behind it
214
+ # if don't merge, "不" sometimes appears alone according to jieba, which may occur sandhi error
215
+ def _merge_bu(self, seg: List[Tuple[str, str]]) -> List[Tuple[str, str]]:
216
+ new_seg = []
217
+ last_word = ""
218
+ for word, pos in seg:
219
+ if last_word == "不":
220
+ word = last_word + word
221
+ if word != "不":
222
+ new_seg.append((word, pos))
223
+ last_word = word[:]
224
+ if last_word == "不":
225
+ new_seg.append((last_word, 'd'))
226
+ last_word = ""
227
+ return new_seg
228
+
229
+ # function 1: merge "一" and reduplication words in it's left and right, e.g. "听","一","听" ->"听一听"
230
+ # function 2: merge single "一" and the word behind it
231
+ # if don't merge, "一" sometimes appears alone according to jieba, which may occur sandhi error
232
+ # e.g.
233
+ # input seg: [('听', 'v'), ('一', 'm'), ('听', 'v')]
234
+ # output seg: [['听一听', 'v']]
235
+ def _merge_yi(self, seg: List[Tuple[str, str]]) -> List[Tuple[str, str]]:
236
+ new_seg = []
237
+ # function 1
238
+ for i, (word, pos) in enumerate(seg):
239
+ if i - 1 >= 0 and word == "一" and i + 1 < len(seg) and seg[i - 1][
240
+ 0] == seg[i + 1][0] and seg[i - 1][1] == "v":
241
+ new_seg[i - 1][0] = new_seg[i - 1][0] + "一" + new_seg[i - 1][0]
242
+ else:
243
+ if i - 2 >= 0 and seg[i - 1][0] == "一" and seg[i - 2][
244
+ 0] == word and pos == "v":
245
+ continue
246
+ else:
247
+ new_seg.append([word, pos])
248
+ seg = new_seg
249
+ new_seg = []
250
+ # function 2
251
+ for i, (word, pos) in enumerate(seg):
252
+ if new_seg and new_seg[-1][0] == "一":
253
+ new_seg[-1][0] = new_seg[-1][0] + word
254
+ else:
255
+ new_seg.append([word, pos])
256
+ return new_seg
257
+
258
+ # the first and the second words are all_tone_three
259
+ def _merge_continuous_three_tones(
260
+ self, seg: List[Tuple[str, str]]) -> List[Tuple[str, str]]:
261
+ new_seg = []
262
+ sub_finals_list = [
263
+ lazy_pinyin(
264
+ word, neutral_tone_with_five=True, style=Style.FINALS_TONE3)
265
+ for (word, pos) in seg
266
+ ]
267
+ assert len(sub_finals_list) == len(seg)
268
+ merge_last = [False] * len(seg)
269
+ for i, (word, pos) in enumerate(seg):
270
+ if i - 1 >= 0 and self._all_tone_three(
271
+ sub_finals_list[i - 1]) and self._all_tone_three(
272
+ sub_finals_list[i]) and not merge_last[i - 1]:
273
+ # if the last word is reduplication, not merge, because reduplication need to be _neural_sandhi
274
+ if not self._is_reduplication(seg[i - 1][0]) and len(
275
+ seg[i - 1][0]) + len(seg[i][0]) <= 3:
276
+ new_seg[-1][0] = new_seg[-1][0] + seg[i][0]
277
+ merge_last[i] = True
278
+ else:
279
+ new_seg.append([word, pos])
280
+ else:
281
+ new_seg.append([word, pos])
282
+
283
+ return new_seg
284
+
285
+ def _is_reduplication(self, word: str) -> bool:
286
+ return len(word) == 2 and word[0] == word[1]
287
+
288
+ # the last char of first word and the first char of second word is tone_three
289
+ def _merge_continuous_three_tones_2(
290
+ self, seg: List[Tuple[str, str]]) -> List[Tuple[str, str]]:
291
+ new_seg = []
292
+ sub_finals_list = [
293
+ lazy_pinyin(
294
+ word, neutral_tone_with_five=True, style=Style.FINALS_TONE3)
295
+ for (word, pos) in seg
296
+ ]
297
+ assert len(sub_finals_list) == len(seg)
298
+ merge_last = [False] * len(seg)
299
+ for i, (word, pos) in enumerate(seg):
300
+ if i - 1 >= 0 and sub_finals_list[i - 1][-1][-1] == "3" and sub_finals_list[i][0][-1] == "3" and not \
301
+ merge_last[i - 1]:
302
+ # if the last word is reduplication, not merge, because reduplication need to be _neural_sandhi
303
+ if not self._is_reduplication(seg[i - 1][0]) and len(
304
+ seg[i - 1][0]) + len(seg[i][0]) <= 3:
305
+ new_seg[-1][0] = new_seg[-1][0] + seg[i][0]
306
+ merge_last[i] = True
307
+ else:
308
+ new_seg.append([word, pos])
309
+ else:
310
+ new_seg.append([word, pos])
311
+ return new_seg
312
+
313
+ def _merge_er(self, seg: List[Tuple[str, str]]) -> List[Tuple[str, str]]:
314
+ new_seg = []
315
+ for i, (word, pos) in enumerate(seg):
316
+ if i - 1 >= 0 and word == "儿" and seg[i-1][0] != "#":
317
+ new_seg[-1][0] = new_seg[-1][0] + seg[i][0]
318
+ else:
319
+ new_seg.append([word, pos])
320
+ return new_seg
321
+
322
+ def _merge_reduplication(
323
+ self, seg: List[Tuple[str, str]]) -> List[Tuple[str, str]]:
324
+ new_seg = []
325
+ for i, (word, pos) in enumerate(seg):
326
+ if new_seg and word == new_seg[-1][0]:
327
+ new_seg[-1][0] = new_seg[-1][0] + seg[i][0]
328
+ else:
329
+ new_seg.append([word, pos])
330
+ return new_seg
331
+
332
+ def pre_merge_for_modify(
333
+ self, seg: List[Tuple[str, str]]) -> List[Tuple[str, str]]:
334
+ seg = self._merge_bu(seg)
335
+ try:
336
+ seg = self._merge_yi(seg)
337
+ except:
338
+ print("_merge_yi failed")
339
+ seg = self._merge_reduplication(seg)
340
+ seg = self._merge_continuous_three_tones(seg)
341
+ seg = self._merge_continuous_three_tones_2(seg)
342
+ seg = self._merge_er(seg)
343
+ return seg
344
+
345
+ def modified_tone(self, word: str, pos: str,
346
+ finals: List[str]) -> List[str]:
347
+ finals = self._bu_sandhi(word, finals)
348
+ finals = self._yi_sandhi(word, finals)
349
+ finals = self._neural_sandhi(word, pos, finals)
350
+ finals = self._three_sandhi(word, finals)
351
+ return finals
transforms.py ADDED
@@ -0,0 +1,207 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ from torch.nn import functional as F
4
+
5
+ DEFAULT_MIN_BIN_WIDTH = 1e-3
6
+ DEFAULT_MIN_BIN_HEIGHT = 1e-3
7
+ DEFAULT_MIN_DERIVATIVE = 1e-3
8
+
9
+
10
+ def piecewise_rational_quadratic_transform(
11
+ inputs,
12
+ unnormalized_widths,
13
+ unnormalized_heights,
14
+ unnormalized_derivatives,
15
+ inverse=False,
16
+ tails=None,
17
+ tail_bound=1.0,
18
+ min_bin_width=DEFAULT_MIN_BIN_WIDTH,
19
+ min_bin_height=DEFAULT_MIN_BIN_HEIGHT,
20
+ min_derivative=DEFAULT_MIN_DERIVATIVE,
21
+ ):
22
+
23
+ if tails is None:
24
+ spline_fn = rational_quadratic_spline
25
+ spline_kwargs = {}
26
+ else:
27
+ spline_fn = unconstrained_rational_quadratic_spline
28
+ spline_kwargs = {"tails": tails, "tail_bound": tail_bound}
29
+
30
+ outputs, logabsdet = spline_fn(
31
+ inputs=inputs,
32
+ unnormalized_widths=unnormalized_widths,
33
+ unnormalized_heights=unnormalized_heights,
34
+ unnormalized_derivatives=unnormalized_derivatives,
35
+ inverse=inverse,
36
+ min_bin_width=min_bin_width,
37
+ min_bin_height=min_bin_height,
38
+ min_derivative=min_derivative,
39
+ **spline_kwargs
40
+ )
41
+ return outputs, logabsdet
42
+
43
+
44
+ def searchsorted(bin_locations, inputs, eps=1e-6):
45
+ bin_locations[..., -1] += eps
46
+ return torch.sum(inputs[..., None] >= bin_locations, dim=-1) - 1
47
+
48
+
49
+ def unconstrained_rational_quadratic_spline(
50
+ inputs,
51
+ unnormalized_widths,
52
+ unnormalized_heights,
53
+ unnormalized_derivatives,
54
+ inverse=False,
55
+ tails="linear",
56
+ tail_bound=1.0,
57
+ min_bin_width=DEFAULT_MIN_BIN_WIDTH,
58
+ min_bin_height=DEFAULT_MIN_BIN_HEIGHT,
59
+ min_derivative=DEFAULT_MIN_DERIVATIVE,
60
+ ):
61
+ inside_interval_mask = (inputs >= -tail_bound) & (inputs <= tail_bound)
62
+ outside_interval_mask = ~inside_interval_mask
63
+
64
+ outputs = torch.zeros_like(inputs)
65
+ logabsdet = torch.zeros_like(inputs)
66
+
67
+ if tails == "linear":
68
+ unnormalized_derivatives = F.pad(unnormalized_derivatives, pad=(1, 1))
69
+ constant = np.log(np.exp(1 - min_derivative) - 1)
70
+ unnormalized_derivatives[..., 0] = constant
71
+ unnormalized_derivatives[..., -1] = constant
72
+
73
+ outputs[outside_interval_mask] = inputs[outside_interval_mask]
74
+ logabsdet[outside_interval_mask] = 0
75
+ else:
76
+ raise RuntimeError("{} tails are not implemented.".format(tails))
77
+
78
+ outputs[inside_interval_mask], logabsdet[inside_interval_mask] = (
79
+ rational_quadratic_spline(
80
+ inputs=inputs[inside_interval_mask],
81
+ unnormalized_widths=unnormalized_widths[inside_interval_mask, :],
82
+ unnormalized_heights=unnormalized_heights[inside_interval_mask, :],
83
+ unnormalized_derivatives=unnormalized_derivatives[inside_interval_mask, :],
84
+ inverse=inverse,
85
+ left=-tail_bound,
86
+ right=tail_bound,
87
+ bottom=-tail_bound,
88
+ top=tail_bound,
89
+ min_bin_width=min_bin_width,
90
+ min_bin_height=min_bin_height,
91
+ min_derivative=min_derivative,
92
+ )
93
+ )
94
+
95
+ return outputs, logabsdet
96
+
97
+
98
+ def rational_quadratic_spline(
99
+ inputs,
100
+ unnormalized_widths,
101
+ unnormalized_heights,
102
+ unnormalized_derivatives,
103
+ inverse=False,
104
+ left=0.0,
105
+ right=1.0,
106
+ bottom=0.0,
107
+ top=1.0,
108
+ min_bin_width=DEFAULT_MIN_BIN_WIDTH,
109
+ min_bin_height=DEFAULT_MIN_BIN_HEIGHT,
110
+ min_derivative=DEFAULT_MIN_DERIVATIVE,
111
+ ):
112
+ if torch.min(inputs) < left or torch.max(inputs) > right:
113
+ raise ValueError("Input to a transform is not within its domain")
114
+
115
+ num_bins = unnormalized_widths.shape[-1]
116
+
117
+ if min_bin_width * num_bins > 1.0:
118
+ raise ValueError("Minimal bin width too large for the number of bins")
119
+ if min_bin_height * num_bins > 1.0:
120
+ raise ValueError("Minimal bin height too large for the number of bins")
121
+
122
+ widths = F.softmax(unnormalized_widths, dim=-1)
123
+ widths = min_bin_width + (1 - min_bin_width * num_bins) * widths
124
+ cumwidths = torch.cumsum(widths, dim=-1)
125
+ cumwidths = F.pad(cumwidths, pad=(1, 0), mode="constant", value=0.0)
126
+ cumwidths = (right - left) * cumwidths + left
127
+ cumwidths[..., 0] = left
128
+ cumwidths[..., -1] = right
129
+ widths = cumwidths[..., 1:] - cumwidths[..., :-1]
130
+
131
+ derivatives = min_derivative + F.softplus(unnormalized_derivatives)
132
+
133
+ heights = F.softmax(unnormalized_heights, dim=-1)
134
+ heights = min_bin_height + (1 - min_bin_height * num_bins) * heights
135
+ cumheights = torch.cumsum(heights, dim=-1)
136
+ cumheights = F.pad(cumheights, pad=(1, 0), mode="constant", value=0.0)
137
+ cumheights = (top - bottom) * cumheights + bottom
138
+ cumheights[..., 0] = bottom
139
+ cumheights[..., -1] = top
140
+ heights = cumheights[..., 1:] - cumheights[..., :-1]
141
+
142
+ if inverse:
143
+ bin_idx = searchsorted(cumheights, inputs)[..., None]
144
+ else:
145
+ bin_idx = searchsorted(cumwidths, inputs)[..., None]
146
+
147
+ input_cumwidths = cumwidths.gather(-1, bin_idx)[..., 0]
148
+ input_bin_widths = widths.gather(-1, bin_idx)[..., 0]
149
+
150
+ input_cumheights = cumheights.gather(-1, bin_idx)[..., 0]
151
+ delta = heights / widths
152
+ input_delta = delta.gather(-1, bin_idx)[..., 0]
153
+
154
+ input_derivatives = derivatives.gather(-1, bin_idx)[..., 0]
155
+ input_derivatives_plus_one = derivatives[..., 1:].gather(-1, bin_idx)[..., 0]
156
+
157
+ input_heights = heights.gather(-1, bin_idx)[..., 0]
158
+
159
+ if inverse:
160
+ a = (inputs - input_cumheights) * (
161
+ input_derivatives + input_derivatives_plus_one - 2 * input_delta
162
+ ) + input_heights * (input_delta - input_derivatives)
163
+ b = input_heights * input_derivatives - (inputs - input_cumheights) * (
164
+ input_derivatives + input_derivatives_plus_one - 2 * input_delta
165
+ )
166
+ c = -input_delta * (inputs - input_cumheights)
167
+
168
+ discriminant = b.pow(2) - 4 * a * c
169
+ assert (discriminant >= 0).all()
170
+
171
+ root = (2 * c) / (-b - torch.sqrt(discriminant))
172
+ outputs = root * input_bin_widths + input_cumwidths
173
+
174
+ theta_one_minus_theta = root * (1 - root)
175
+ denominator = input_delta + (
176
+ (input_derivatives + input_derivatives_plus_one - 2 * input_delta)
177
+ * theta_one_minus_theta
178
+ )
179
+ derivative_numerator = input_delta.pow(2) * (
180
+ input_derivatives_plus_one * root.pow(2)
181
+ + 2 * input_delta * theta_one_minus_theta
182
+ + input_derivatives * (1 - root).pow(2)
183
+ )
184
+ logabsdet = torch.log(derivative_numerator) - 2 * torch.log(denominator)
185
+
186
+ return outputs, -logabsdet
187
+ else:
188
+ theta = (inputs - input_cumwidths) / input_bin_widths
189
+ theta_one_minus_theta = theta * (1 - theta)
190
+
191
+ numerator = input_heights * (
192
+ input_delta * theta.pow(2) + input_derivatives * theta_one_minus_theta
193
+ )
194
+ denominator = input_delta + (
195
+ (input_derivatives + input_derivatives_plus_one - 2 * input_delta)
196
+ * theta_one_minus_theta
197
+ )
198
+ outputs = input_cumheights + numerator / denominator
199
+
200
+ derivative_numerator = input_delta.pow(2) * (
201
+ input_derivatives_plus_one * theta.pow(2)
202
+ + 2 * input_delta * theta_one_minus_theta
203
+ + input_derivatives * (1 - theta).pow(2)
204
+ )
205
+ logabsdet = torch.log(derivative_numerator) - 2 * torch.log(denominator)
206
+
207
+ return outputs, logabsdet
utils.py ADDED
@@ -0,0 +1,380 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import glob
3
+ import json
4
+ import torch
5
+ import logging
6
+ import argparse
7
+ import requests
8
+ import subprocess
9
+ import numpy as np
10
+ from tqdm import tqdm
11
+ from scipy.io.wavfile import read
12
+
13
+
14
+ MATPLOTLIB_FLAG = False
15
+
16
+ logger = logging.getLogger(__name__)
17
+
18
+
19
+ def load_checkpoint(checkpoint_path, model, optimizer=None, skip_optimizer=False):
20
+ assert os.path.isfile(checkpoint_path)
21
+ checkpoint_dict = torch.load(checkpoint_path, map_location="cpu")
22
+ iteration = checkpoint_dict["iteration"]
23
+ learning_rate = checkpoint_dict["learning_rate"]
24
+ if (
25
+ optimizer is not None
26
+ and not skip_optimizer
27
+ and checkpoint_dict["optimizer"] is not None
28
+ ):
29
+ optimizer.load_state_dict(checkpoint_dict["optimizer"])
30
+
31
+ elif optimizer is None and not skip_optimizer:
32
+ # else: Disable this line if Infer and resume checkpoint,then enable the line upper
33
+ new_opt_dict = optimizer.state_dict()
34
+ new_opt_dict_params = new_opt_dict["param_groups"][0]["params"]
35
+ new_opt_dict["param_groups"] = checkpoint_dict["optimizer"]["param_groups"]
36
+ new_opt_dict["param_groups"][0]["params"] = new_opt_dict_params
37
+ optimizer.load_state_dict(new_opt_dict)
38
+
39
+ saved_state_dict = checkpoint_dict["model"]
40
+ if hasattr(model, "module"):
41
+ state_dict = model.module.state_dict()
42
+
43
+ else:
44
+ state_dict = model.state_dict()
45
+
46
+ new_state_dict = {}
47
+ for k, v in state_dict.items():
48
+ try:
49
+ # assert "emb_g" not in k
50
+ # print("load", k)
51
+ new_state_dict[k] = saved_state_dict[k]
52
+ assert saved_state_dict[k].shape == v.shape, (
53
+ saved_state_dict[k].shape,
54
+ v.shape,
55
+ )
56
+
57
+ except:
58
+ logger.error("%s is not in the checkpoint" % k)
59
+ new_state_dict[k] = v
60
+
61
+ if hasattr(model, "module"):
62
+ model.module.load_state_dict(new_state_dict, strict=False)
63
+
64
+ else:
65
+ model.load_state_dict(new_state_dict, strict=False)
66
+
67
+ logger.info(
68
+ "Loaded checkpoint '{}' (iteration {})".format(checkpoint_path, iteration)
69
+ )
70
+ return model, optimizer, learning_rate, iteration
71
+
72
+
73
+ def save_checkpoint(model, optimizer, learning_rate, iteration, checkpoint_path):
74
+ logger.info(
75
+ "Saving model and optimizer state at iteration {} to {}".format(
76
+ iteration, checkpoint_path
77
+ )
78
+ )
79
+ if hasattr(model, "module"):
80
+ state_dict = model.module.state_dict()
81
+
82
+ else:
83
+ state_dict = model.state_dict()
84
+
85
+ torch.save(
86
+ {
87
+ "model": state_dict,
88
+ "iteration": iteration,
89
+ "optimizer": optimizer.state_dict(),
90
+ "learning_rate": learning_rate,
91
+ },
92
+ checkpoint_path,
93
+ )
94
+
95
+
96
+ def summarize(
97
+ writer,
98
+ global_step,
99
+ scalars={},
100
+ histograms={},
101
+ images={},
102
+ audios={},
103
+ audio_sampling_rate=22050,
104
+ ):
105
+ for k, v in scalars.items():
106
+ writer.add_scalar(k, v, global_step)
107
+ for k, v in histograms.items():
108
+ writer.add_histogram(k, v, global_step)
109
+ for k, v in images.items():
110
+ writer.add_image(k, v, global_step, dataformats="HWC")
111
+ for k, v in audios.items():
112
+ writer.add_audio(k, v, global_step, audio_sampling_rate)
113
+
114
+
115
+ def latest_checkpoint_path(dir_path, regex="G_*.pth"):
116
+ f_list = glob.glob(os.path.join(dir_path, regex))
117
+ f_list.sort(key=lambda f: int("".join(filter(str.isdigit, f))))
118
+ x = f_list[-1]
119
+ print(x)
120
+ return x
121
+
122
+
123
+ def plot_spectrogram_to_numpy(spectrogram):
124
+ global MATPLOTLIB_FLAG
125
+ if not MATPLOTLIB_FLAG:
126
+ import matplotlib
127
+
128
+ matplotlib.use("Agg")
129
+ MATPLOTLIB_FLAG = True
130
+ mpl_logger = logging.getLogger("matplotlib")
131
+ mpl_logger.setLevel(logging.WARNING)
132
+
133
+ import matplotlib.pylab as plt
134
+ import numpy as np
135
+
136
+ fig, ax = plt.subplots(figsize=(10, 2))
137
+ im = ax.imshow(spectrogram, aspect="auto", origin="lower", interpolation="none")
138
+ plt.colorbar(im, ax=ax)
139
+ plt.xlabel("Frames")
140
+ plt.ylabel("Channels")
141
+ plt.tight_layout()
142
+ fig.canvas.draw()
143
+ data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep="")
144
+ data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
145
+ plt.close()
146
+ return data
147
+
148
+
149
+ def plot_alignment_to_numpy(alignment, info=None):
150
+ global MATPLOTLIB_FLAG
151
+ if not MATPLOTLIB_FLAG:
152
+ import matplotlib
153
+
154
+ matplotlib.use("Agg")
155
+ MATPLOTLIB_FLAG = True
156
+ mpl_logger = logging.getLogger("matplotlib")
157
+ mpl_logger.setLevel(logging.WARNING)
158
+
159
+ import matplotlib.pylab as plt
160
+ import numpy as np
161
+
162
+ fig, ax = plt.subplots(figsize=(6, 4))
163
+ im = ax.imshow(
164
+ alignment.transpose(), aspect="auto", origin="lower", interpolation="none"
165
+ )
166
+ fig.colorbar(im, ax=ax)
167
+ xlabel = "Decoder timestep"
168
+ if info is not None:
169
+ xlabel += "\n\n" + info
170
+
171
+ plt.xlabel(xlabel)
172
+ plt.ylabel("Encoder timestep")
173
+ plt.tight_layout()
174
+ fig.canvas.draw()
175
+ data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep="")
176
+ data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
177
+ plt.close()
178
+ return data
179
+
180
+
181
+ def load_wav_to_torch(full_path):
182
+ sampling_rate, data = read(full_path)
183
+ return torch.FloatTensor(data.astype(np.float32)), sampling_rate
184
+
185
+
186
+ def load_filepaths_and_text(filename, split="|"):
187
+ with open(filename, encoding="utf-8") as f:
188
+ filepaths_and_text = [line.strip().split(split) for line in f]
189
+
190
+ return filepaths_and_text
191
+
192
+
193
+ def get_hparams(init=True):
194
+ parser = argparse.ArgumentParser()
195
+ parser.add_argument(
196
+ "-c",
197
+ "--config",
198
+ type=str,
199
+ default="./configs/base.json",
200
+ help="JSON file for configuration",
201
+ )
202
+ parser.add_argument("-m", "--model", type=str, required=True, help="Model name")
203
+ args = parser.parse_args()
204
+ model_dir = os.path.join("./logs", args.model)
205
+ if not os.path.exists(model_dir):
206
+ os.makedirs(model_dir)
207
+
208
+ config_path = args.config
209
+ config_save_path = os.path.join(model_dir, "config.json")
210
+ if init:
211
+ with open(config_path, "r") as f:
212
+ data = f.read()
213
+
214
+ with open(config_save_path, "w") as f:
215
+ f.write(data)
216
+
217
+ else:
218
+ with open(config_save_path, "r") as f:
219
+ data = f.read()
220
+
221
+ config = json.loads(data)
222
+ hparams = HParams(**config)
223
+ hparams.model_dir = model_dir
224
+ return hparams
225
+
226
+
227
+ def clean_checkpoints(path_to_models="logs/44k/", n_ckpts_to_keep=2, sort_by_time=True):
228
+ """Freeing up space by deleting saved ckpts
229
+
230
+ Arguments:
231
+ path_to_models -- Path to the model directory
232
+ n_ckpts_to_keep -- Number of ckpts to keep, excluding G_0.pth and D_0.pth
233
+ sort_by_time -- True -> chronologically delete ckpts
234
+ False -> lexicographically delete ckpts
235
+ """
236
+ import re
237
+
238
+ ckpts_files = [
239
+ f
240
+ for f in os.listdir(path_to_models)
241
+ if os.path.isfile(os.path.join(path_to_models, f))
242
+ ]
243
+ name_key = lambda _f: int(re.compile("._(\d+)\.pth").match(_f).group(1))
244
+ time_key = lambda _f: os.path.getmtime(os.path.join(path_to_models, _f))
245
+ sort_key = time_key if sort_by_time else name_key
246
+ x_sorted = lambda _x: sorted(
247
+ [f for f in ckpts_files if f.startswith(_x) and not f.endswith("_0.pth")],
248
+ key=sort_key,
249
+ )
250
+ to_del = [
251
+ os.path.join(path_to_models, fn)
252
+ for fn in (x_sorted("G")[:-n_ckpts_to_keep] + x_sorted("D")[:-n_ckpts_to_keep])
253
+ ]
254
+ del_info = lambda fn: logger.info(f".. Free up space by deleting ckpt {fn}")
255
+ del_routine = lambda x: [os.remove(x), del_info(x)]
256
+ rs = [del_routine(fn) for fn in to_del]
257
+ print(rs)
258
+
259
+
260
+ def get_hparams_from_dir(model_dir):
261
+ config_save_path = os.path.join(model_dir, "config.json")
262
+ with open(config_save_path, "r", encoding="utf-8") as f:
263
+ data = f.read()
264
+
265
+ config = json.loads(data)
266
+ hparams = HParams(**config)
267
+ hparams.model_dir = model_dir
268
+ return hparams
269
+
270
+
271
+ def download_file(file_url: str):
272
+ filename = file_url.split("&FilePath=")[-1]
273
+ if os.path.exists(filename):
274
+ return filename
275
+
276
+ response = requests.get(file_url, stream=True)
277
+ # 检查请求是否成功
278
+ if response.status_code == 200:
279
+ # 获取文件总大小
280
+ file_size = int(response.headers.get("Content-Length", 0))
281
+ # 打开文件以写入二进制数据
282
+ with open(filename, "wb") as file:
283
+ # 创建进度条
284
+ progress_bar = tqdm(
285
+ total=file_size,
286
+ unit="B",
287
+ unit_scale=True,
288
+ desc=f"Downloading {filename}...",
289
+ )
290
+ # 以块的形式下载文件
291
+ for chunk in response.iter_content(chunk_size=8192):
292
+ if chunk: # 过滤掉保持连接的新块
293
+ file.write(chunk)
294
+ progress_bar.update(len(chunk)) # 更新进度条
295
+
296
+ progress_bar.close() # 关闭进度条
297
+
298
+ print(f"模型文件 '{file_url}' 下载成功。")
299
+
300
+ else:
301
+ print(f"下载失败,状态码:{response.status_code}")
302
+
303
+ return filename
304
+
305
+
306
+ def get_hparams_from_url(config_url):
307
+ response = requests.get(config_url)
308
+ config = response.json()
309
+ return HParams(**config)
310
+
311
+
312
+ def check_git_hash(model_dir):
313
+ source_dir = os.path.dirname(os.path.realpath(__file__))
314
+ if not os.path.exists(os.path.join(source_dir, ".git")):
315
+ logger.warn(
316
+ "{} is not a git repository, therefore hash value comparison will be ignored.".format(
317
+ source_dir
318
+ )
319
+ )
320
+ return
321
+
322
+ cur_hash = subprocess.getoutput("git rev-parse HEAD")
323
+ path = os.path.join(model_dir, "githash")
324
+ if os.path.exists(path):
325
+ saved_hash = open(path).read()
326
+ if saved_hash != cur_hash:
327
+ logger.warn(
328
+ "git hash values are different. {}(saved) != {}(current)".format(
329
+ saved_hash[:8], cur_hash[:8]
330
+ )
331
+ )
332
+ else:
333
+ open(path, "w").write(cur_hash)
334
+
335
+
336
+ def get_logger(model_dir, filename="train.log"):
337
+ global logger
338
+ logger = logging.getLogger(os.path.basename(model_dir))
339
+ logger.setLevel(logging.DEBUG)
340
+ formatter = logging.Formatter("%(asctime)s\t%(name)s\t%(levelname)s\t%(message)s")
341
+ if not os.path.exists(model_dir):
342
+ os.makedirs(model_dir)
343
+
344
+ h = logging.FileHandler(os.path.join(model_dir, filename))
345
+ h.setLevel(logging.DEBUG)
346
+ h.setFormatter(formatter)
347
+ logger.addHandler(h)
348
+ return logger
349
+
350
+
351
+ class HParams:
352
+ def __init__(self, **kwargs):
353
+ for k, v in kwargs.items():
354
+ if type(v) == dict:
355
+ v = HParams(**v)
356
+ self[k] = v
357
+
358
+ def keys(self):
359
+ return self.__dict__.keys()
360
+
361
+ def items(self):
362
+ return self.__dict__.items()
363
+
364
+ def values(self):
365
+ return self.__dict__.values()
366
+
367
+ def __len__(self):
368
+ return len(self.__dict__)
369
+
370
+ def __getitem__(self, key):
371
+ return getattr(self, key)
372
+
373
+ def __setitem__(self, key, value):
374
+ return setattr(self, key, value)
375
+
376
+ def __contains__(self, key):
377
+ return key in self.__dict__
378
+
379
+ def __repr__(self):
380
+ return self.__dict__.__repr__()