Jackellie commited on
Commit
503d5c9
·
1 Parent(s): 1188f9a

Upload 3 files

Browse files
train_fix/bert_genV0.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch.utils.data import DataLoader
3
+ from multiprocessing import Pool
4
+ import commons
5
+ import utils
6
+ from data_utils import TextAudioSpeakerLoader, TextAudioSpeakerCollate
7
+ from tqdm import tqdm
8
+ import warnings
9
+
10
+ from text import cleaned_text_to_sequence, get_bert
11
+
12
+ config_path = 'configs/base.json'
13
+ hps = utils.get_hparams_from_file(config_path)
14
+
15
+ def process_line(line):
16
+ _id, spk, language_str, text, phones, tone, word2ph = line.strip().split("|")
17
+ phone = phones.split(" ")
18
+ tone = [int(i) for i in tone.split(" ")]
19
+ word2ph = [int(i) for i in word2ph.split(" ")]
20
+ w2pho = [i for i in word2ph]
21
+ word2ph = [i for i in word2ph]
22
+ phone, tone, language = cleaned_text_to_sequence(phone, tone, language_str)
23
+
24
+ if hps.data.add_blank:
25
+ phone = commons.intersperse(phone, 0)
26
+ tone = commons.intersperse(tone, 0)
27
+ language = commons.intersperse(language, 0)
28
+ for i in range(len(word2ph)):
29
+ word2ph[i] = word2ph[i] * 2
30
+ word2ph[0] += 1
31
+ wav_path = f'{_id}'
32
+
33
+ bert_path = wav_path.replace(".wav", ".bert.pt")
34
+ try:
35
+ bert = torch.load(bert_path)
36
+ assert bert.shape[-1] == len(phone)
37
+ except:
38
+ bert = get_bert(text, word2ph, language_str)
39
+ assert bert.shape[-1] == len(phone)
40
+ torch.save(bert, bert_path)
41
+
42
+
43
+ if __name__ == '__main__':
44
+ lines = []
45
+ with open(hps.data.training_files, encoding='utf-8' ) as f:
46
+ lines.extend(f.readlines())
47
+
48
+ # with open(hps.data.validation_files, encoding='utf-8' ) as f:
49
+ # lines.extend(f.readlines())
50
+
51
+ with Pool(processes=2) as pool: #A100 40GB suitable config,if coom,please decrease the processess number.
52
+ for _ in tqdm(pool.imap_unordered(process_line, lines)):
53
+ pass
train_fix/text/__init__.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from text.symbols import *
2
+
3
+
4
+ _symbol_to_id = {s: i for i, s in enumerate(symbols)}
5
+
6
+
7
+ def cleaned_text_to_sequence(cleaned_text, tones, language):
8
+ """Converts a string of text to a sequence of IDs corresponding to the symbols in the text.
9
+ Args:
10
+ text: string to convert to a sequence
11
+ Returns:
12
+ List of integers corresponding to the symbols in the text
13
+ """
14
+ phones = [_symbol_to_id[symbol] for symbol in cleaned_text]
15
+ tone_start = language_tone_start_map[language]
16
+ tones = [i + tone_start for i in tones]
17
+ lang_id = language_id_map[language]
18
+ lang_ids = [lang_id for i in phones]
19
+ return phones, tones, lang_ids
20
+
21
+
22
+ def get_bert(norm_text, word2ph, language, device=None):
23
+ from .chinese_bert import get_bert_feature as zh_bert
24
+ from .english_bert_mock import get_bert_feature as en_bert
25
+ from .japanese_bert import get_bert_feature as jp_bert
26
+
27
+ lang_bert_func_map = {"ZH": zh_bert, "EN": en_bert, "JP": jp_bert}
28
+ bert = lang_bert_func_map[language](norm_text, word2ph, device)
29
+ return bert
train_fix/train_ms.py ADDED
@@ -0,0 +1,607 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # flake8: noqa: E402
2
+
3
+ import os
4
+ import torch
5
+ from torch.nn import functional as F
6
+ from torch.utils.data import DataLoader
7
+ from torch.utils.tensorboard import SummaryWriter
8
+ import torch.multiprocessing as mp
9
+ import torch.distributed as dist
10
+ from torch.nn.parallel import DistributedDataParallel as DDP
11
+ from torch.cuda.amp import autocast, GradScaler
12
+ from tqdm import tqdm
13
+ import logging
14
+
15
+ logging.getLogger("numba").setLevel(logging.WARNING)
16
+ import commons
17
+ import utils
18
+ from data_utils import (
19
+ TextAudioSpeakerLoader,
20
+ TextAudioSpeakerCollate,
21
+ DistributedBucketSampler,
22
+ )
23
+ from models import (
24
+ SynthesizerTrn,
25
+ MultiPeriodDiscriminator,
26
+ DurationDiscriminator,
27
+ )
28
+ from losses import generator_loss, discriminator_loss, feature_loss, kl_loss
29
+ from mel_processing import mel_spectrogram_torch, spec_to_mel_torch
30
+ from text.symbols import symbols
31
+
32
+ torch.backends.cuda.matmul.allow_tf32 = True
33
+ torch.backends.cudnn.allow_tf32 = (
34
+ True # If encontered training problem,please try to disable TF32.
35
+ )
36
+ torch.set_float32_matmul_precision("medium")
37
+ torch.backends.cudnn.benchmark = True
38
+ torch.backends.cuda.sdp_kernel("flash")
39
+ torch.backends.cuda.enable_flash_sdp(True)
40
+ torch.backends.cuda.enable_mem_efficient_sdp(
41
+ True
42
+ ) # Not available if torch version is lower than 2.0
43
+ torch.backends.cuda.enable_math_sdp(True)
44
+ global_step = 0
45
+
46
+
47
+ def main():
48
+ """Assume Single Node Multi GPUs Training Only"""
49
+ assert torch.cuda.is_available(), "CPU training is not allowed."
50
+
51
+ n_gpus = torch.cuda.device_count()
52
+ os.environ['MASTER_ADDR'] = 'localhost'
53
+ os.environ['MASTER_PORT'] = '65280'
54
+
55
+ hps = utils.get_hparams()
56
+ mp.spawn(run, nprocs=n_gpus, args=(n_gpus, hps,))
57
+
58
+ def run(rank, n_gpus, hps):
59
+ dist.init_process_group(
60
+ backend="gloo",
61
+ init_method="env://", # Due to some training problem,we proposed to use gloo instead of nccl.
62
+ world_size=n_gpus,
63
+ rank=rank
64
+ ) # Use torchrun instead of mp.spawn
65
+ # rank = dist.get_rank()
66
+ # n_gpus = dist.get_world_size()
67
+ #hps = utils.get_hparams()
68
+ torch.manual_seed(hps.train.seed)
69
+ torch.cuda.set_device(rank)
70
+ global global_step
71
+ if rank == 0:
72
+ logger = utils.get_logger(hps.model_dir)
73
+ logger.info(hps)
74
+ utils.check_git_hash(hps.model_dir)
75
+ writer = SummaryWriter(log_dir=hps.model_dir)
76
+ writer_eval = SummaryWriter(log_dir=os.path.join(hps.model_dir, "eval"))
77
+ train_dataset = TextAudioSpeakerLoader(hps.data.training_files, hps.data)
78
+ train_sampler = DistributedBucketSampler(
79
+ train_dataset,
80
+ hps.train.batch_size,
81
+ [32, 300, 400, 500, 600, 700, 800, 900, 1000],
82
+ num_replicas=n_gpus,
83
+ rank=rank,
84
+ shuffle=True,
85
+ )
86
+ collate_fn = TextAudioSpeakerCollate()
87
+ train_loader = DataLoader(
88
+ train_dataset,
89
+ num_workers=16,
90
+ shuffle=False,
91
+ pin_memory=True,
92
+ collate_fn=collate_fn,
93
+ batch_sampler=train_sampler,
94
+ persistent_workers=True,
95
+ prefetch_factor=4,
96
+ ) # DataLoader config could be adjusted.
97
+ if rank == 0:
98
+ eval_dataset = TextAudioSpeakerLoader(hps.data.validation_files, hps.data)
99
+ eval_loader = DataLoader(
100
+ eval_dataset,
101
+ num_workers=0,
102
+ shuffle=False,
103
+ batch_size=1,
104
+ pin_memory=True,
105
+ drop_last=False,
106
+ collate_fn=collate_fn,
107
+ )
108
+ if (
109
+ "use_noise_scaled_mas" in hps.model.keys()
110
+ and hps.model.use_noise_scaled_mas is True
111
+ ):
112
+ print("Using noise scaled MAS for VITS2")
113
+ mas_noise_scale_initial = 0.01
114
+ noise_scale_delta = 2e-6
115
+ else:
116
+ print("Using normal MAS for VITS1")
117
+ mas_noise_scale_initial = 0.0
118
+ noise_scale_delta = 0.0
119
+ if (
120
+ "use_duration_discriminator" in hps.model.keys()
121
+ and hps.model.use_duration_discriminator is True
122
+ ):
123
+ print("Using duration discriminator for VITS2")
124
+ net_dur_disc = DurationDiscriminator(
125
+ hps.model.hidden_channels,
126
+ hps.model.hidden_channels,
127
+ 3,
128
+ 0.1,
129
+ gin_channels=hps.model.gin_channels if hps.data.n_speakers != 0 else 0,
130
+ ).cuda(rank)
131
+ if (
132
+ "use_spk_conditioned_encoder" in hps.model.keys()
133
+ and hps.model.use_spk_conditioned_encoder is True
134
+ ):
135
+ if hps.data.n_speakers == 0:
136
+ raise ValueError(
137
+ "n_speakers must be > 0 when using spk conditioned encoder to train multi-speaker model"
138
+ )
139
+ else:
140
+ print("Using normal encoder for VITS1")
141
+
142
+ net_g = SynthesizerTrn(
143
+ len(symbols),
144
+ hps.data.filter_length // 2 + 1,
145
+ hps.train.segment_size // hps.data.hop_length,
146
+ n_speakers=hps.data.n_speakers,
147
+ mas_noise_scale_initial=mas_noise_scale_initial,
148
+ noise_scale_delta=noise_scale_delta,
149
+ **hps.model,
150
+ ).cuda(rank)
151
+
152
+ net_d = MultiPeriodDiscriminator(hps.model.use_spectral_norm).cuda(rank)
153
+ optim_g = torch.optim.AdamW(
154
+ filter(lambda p: p.requires_grad, net_g.parameters()),
155
+ hps.train.learning_rate,
156
+ betas=hps.train.betas,
157
+ eps=hps.train.eps,
158
+ )
159
+ optim_d = torch.optim.AdamW(
160
+ net_d.parameters(),
161
+ hps.train.learning_rate,
162
+ betas=hps.train.betas,
163
+ eps=hps.train.eps,
164
+ )
165
+ if net_dur_disc is not None:
166
+ optim_dur_disc = torch.optim.AdamW(
167
+ net_dur_disc.parameters(),
168
+ hps.train.learning_rate,
169
+ betas=hps.train.betas,
170
+ eps=hps.train.eps,
171
+ )
172
+ else:
173
+ optim_dur_disc = None
174
+ net_g = DDP(net_g, device_ids=[rank], find_unused_parameters=True)
175
+ net_d = DDP(net_d, device_ids=[rank], find_unused_parameters=True)
176
+ if net_dur_disc is not None:
177
+ net_dur_disc = DDP(net_dur_disc, device_ids=[rank], find_unused_parameters=True)
178
+ #dur_resume_lr=0.0003
179
+ try:
180
+ if net_dur_disc is not None:
181
+ _, _, dur_resume_lr, epoch_str = utils.load_checkpoint(
182
+ utils.latest_checkpoint_path(hps.model_dir, "DUR_*.pth"),
183
+ net_dur_disc,
184
+ optim_dur_disc,
185
+ skip_optimizer=hps.train.skip_optimizer
186
+ if "skip_optimizer" in hps.train
187
+ else True,
188
+ )
189
+ _, optim_g, g_resume_lr, epoch_str = utils.load_checkpoint(
190
+ utils.latest_checkpoint_path(hps.model_dir, "G_*.pth"),
191
+ net_g,
192
+ optim_g,
193
+ skip_optimizer=hps.train.skip_optimizer
194
+ if "skip_optimizer" in hps.train
195
+ else True,
196
+ )
197
+ _, optim_d, d_resume_lr, epoch_str = utils.load_checkpoint(
198
+ utils.latest_checkpoint_path(hps.model_dir, "D_*.pth"),
199
+ net_d,
200
+ optim_d,
201
+ skip_optimizer=hps.train.skip_optimizer
202
+ if "skip_optimizer" in hps.train
203
+ else True,
204
+ )
205
+ if not optim_g.param_groups[0].get("initial_lr"):
206
+ optim_g.param_groups[0]["initial_lr"] = g_resume_lr
207
+ if not optim_d.param_groups[0].get("initial_lr"):
208
+ optim_d.param_groups[0]["initial_lr"] = d_resume_lr
209
+
210
+ epoch_str = max(epoch_str, 1)
211
+ global_step = (epoch_str - 1) * len(train_loader)
212
+ except Exception as e:
213
+ print(e)
214
+ epoch_str = 1
215
+ global_step = 0
216
+
217
+ scheduler_g = torch.optim.lr_scheduler.ExponentialLR(
218
+ optim_g, gamma=hps.train.lr_decay, last_epoch=epoch_str - 2
219
+ )
220
+ scheduler_d = torch.optim.lr_scheduler.ExponentialLR(
221
+ optim_d, gamma=hps.train.lr_decay, last_epoch=epoch_str - 2
222
+ )
223
+ if net_dur_disc is not None:
224
+ if not optim_dur_disc.param_groups[0].get("initial_lr"):
225
+ optim_dur_disc.param_groups[0]["initial_lr"] = dur_resume_lr
226
+ scheduler_dur_disc = torch.optim.lr_scheduler.ExponentialLR(
227
+ optim_dur_disc, gamma=hps.train.lr_decay, last_epoch=epoch_str - 2
228
+ )
229
+ else:
230
+ scheduler_dur_disc = None
231
+ scaler = GradScaler(enabled=hps.train.fp16_run)
232
+
233
+ for epoch in range(epoch_str, hps.train.epochs + 1):
234
+ if rank == 0:
235
+ train_and_evaluate(
236
+ rank,
237
+ epoch,
238
+ hps,
239
+ [net_g, net_d, net_dur_disc],
240
+ [optim_g, optim_d, optim_dur_disc],
241
+ [scheduler_g, scheduler_d, scheduler_dur_disc],
242
+ scaler,
243
+ [train_loader, eval_loader],
244
+ logger,
245
+ [writer, writer_eval],
246
+ )
247
+ else:
248
+ train_and_evaluate(
249
+ rank,
250
+ epoch,
251
+ hps,
252
+ [net_g, net_d, net_dur_disc],
253
+ [optim_g, optim_d, optim_dur_disc],
254
+ [scheduler_g, scheduler_d, scheduler_dur_disc],
255
+ scaler,
256
+ [train_loader, None],
257
+ None,
258
+ None,
259
+ )
260
+ scheduler_g.step()
261
+ scheduler_d.step()
262
+ if net_dur_disc is not None:
263
+ scheduler_dur_disc.step()
264
+
265
+
266
+ def train_and_evaluate(
267
+ rank, epoch, hps, nets, optims, schedulers, scaler, loaders, logger, writers
268
+ ):
269
+ net_g, net_d, net_dur_disc = nets
270
+ optim_g, optim_d, optim_dur_disc = optims
271
+ scheduler_g, scheduler_d, scheduler_dur_disc = schedulers
272
+ train_loader, eval_loader = loaders
273
+ if writers is not None:
274
+ writer, writer_eval = writers
275
+
276
+ train_loader.batch_sampler.set_epoch(epoch)
277
+ global global_step
278
+
279
+ net_g.train()
280
+ net_d.train()
281
+ if net_dur_disc is not None:
282
+ net_dur_disc.train()
283
+ for batch_idx, (
284
+ x,
285
+ x_lengths,
286
+ spec,
287
+ spec_lengths,
288
+ y,
289
+ y_lengths,
290
+ speakers,
291
+ tone,
292
+ language,
293
+ bert,
294
+ ja_bert,
295
+ ) in tqdm(enumerate(train_loader)):
296
+ if net_g.module.use_noise_scaled_mas:
297
+ current_mas_noise_scale = (
298
+ net_g.module.mas_noise_scale_initial
299
+ - net_g.module.noise_scale_delta * global_step
300
+ )
301
+ net_g.module.current_mas_noise_scale = max(current_mas_noise_scale, 0.0)
302
+ x, x_lengths = x.cuda(rank, non_blocking=True), x_lengths.cuda(
303
+ rank, non_blocking=True
304
+ )
305
+ spec, spec_lengths = spec.cuda(rank, non_blocking=True), spec_lengths.cuda(
306
+ rank, non_blocking=True
307
+ )
308
+ y, y_lengths = y.cuda(rank, non_blocking=True), y_lengths.cuda(
309
+ rank, non_blocking=True
310
+ )
311
+ speakers = speakers.cuda(rank, non_blocking=True)
312
+ tone = tone.cuda(rank, non_blocking=True)
313
+ language = language.cuda(rank, non_blocking=True)
314
+ bert = bert.cuda(rank, non_blocking=True)
315
+ ja_bert = ja_bert.cuda(rank, non_blocking=True)
316
+
317
+ with autocast(enabled=hps.train.fp16_run):
318
+ (
319
+ y_hat,
320
+ l_length,
321
+ attn,
322
+ ids_slice,
323
+ x_mask,
324
+ z_mask,
325
+ (z, z_p, m_p, logs_p, m_q, logs_q),
326
+ (hidden_x, logw, logw_),
327
+ ) = net_g(
328
+ x,
329
+ x_lengths,
330
+ spec,
331
+ spec_lengths,
332
+ speakers,
333
+ tone,
334
+ language,
335
+ bert,
336
+ ja_bert,
337
+ )
338
+ mel = spec_to_mel_torch(
339
+ spec,
340
+ hps.data.filter_length,
341
+ hps.data.n_mel_channels,
342
+ hps.data.sampling_rate,
343
+ hps.data.mel_fmin,
344
+ hps.data.mel_fmax,
345
+ )
346
+ y_mel = commons.slice_segments(
347
+ mel, ids_slice, hps.train.segment_size // hps.data.hop_length
348
+ )
349
+ y_hat_mel = mel_spectrogram_torch(
350
+ y_hat.squeeze(1),
351
+ hps.data.filter_length,
352
+ hps.data.n_mel_channels,
353
+ hps.data.sampling_rate,
354
+ hps.data.hop_length,
355
+ hps.data.win_length,
356
+ hps.data.mel_fmin,
357
+ hps.data.mel_fmax,
358
+ )
359
+
360
+ y = commons.slice_segments(
361
+ y, ids_slice * hps.data.hop_length, hps.train.segment_size
362
+ ) # slice
363
+
364
+ # Discriminator
365
+ y_d_hat_r, y_d_hat_g, _, _ = net_d(y, y_hat.detach())
366
+ with autocast(enabled=False):
367
+ loss_disc, losses_disc_r, losses_disc_g = discriminator_loss(
368
+ y_d_hat_r, y_d_hat_g
369
+ )
370
+ loss_disc_all = loss_disc
371
+ if net_dur_disc is not None:
372
+ y_dur_hat_r, y_dur_hat_g = net_dur_disc(
373
+ hidden_x.detach(), x_mask.detach(), logw.detach(), logw_.detach()
374
+ )
375
+ with autocast(enabled=False):
376
+ # TODO: I think need to mean using the mask, but for now, just mean all
377
+ (
378
+ loss_dur_disc,
379
+ losses_dur_disc_r,
380
+ losses_dur_disc_g,
381
+ ) = discriminator_loss(y_dur_hat_r, y_dur_hat_g)
382
+ loss_dur_disc_all = loss_dur_disc
383
+ optim_dur_disc.zero_grad()
384
+ scaler.scale(loss_dur_disc_all).backward()
385
+ scaler.unscale_(optim_dur_disc)
386
+ commons.clip_grad_value_(net_dur_disc.parameters(), None)
387
+ scaler.step(optim_dur_disc)
388
+
389
+ optim_d.zero_grad()
390
+ scaler.scale(loss_disc_all).backward()
391
+ scaler.unscale_(optim_d)
392
+ grad_norm_d = commons.clip_grad_value_(net_d.parameters(), None)
393
+ scaler.step(optim_d)
394
+
395
+ with autocast(enabled=hps.train.fp16_run):
396
+ # Generator
397
+ y_d_hat_r, y_d_hat_g, fmap_r, fmap_g = net_d(y, y_hat)
398
+ if net_dur_disc is not None:
399
+ y_dur_hat_r, y_dur_hat_g = net_dur_disc(hidden_x, x_mask, logw, logw_)
400
+ with autocast(enabled=False):
401
+ loss_dur = torch.sum(l_length.float())
402
+ loss_mel = F.l1_loss(y_mel, y_hat_mel) * hps.train.c_mel
403
+ loss_kl = kl_loss(z_p, logs_q, m_p, logs_p, z_mask) * hps.train.c_kl
404
+
405
+ loss_fm = feature_loss(fmap_r, fmap_g)
406
+ loss_gen, losses_gen = generator_loss(y_d_hat_g)
407
+ loss_gen_all = loss_gen + loss_fm + loss_mel + loss_dur + loss_kl
408
+ if net_dur_disc is not None:
409
+ loss_dur_gen, losses_dur_gen = generator_loss(y_dur_hat_g)
410
+ loss_gen_all += loss_dur_gen
411
+ optim_g.zero_grad()
412
+ scaler.scale(loss_gen_all).backward()
413
+ scaler.unscale_(optim_g)
414
+ grad_norm_g = commons.clip_grad_value_(net_g.parameters(), None)
415
+ scaler.step(optim_g)
416
+ scaler.update()
417
+
418
+ if rank == 0:
419
+ if global_step % hps.train.log_interval == 0:
420
+ lr = optim_g.param_groups[0]["lr"]
421
+ losses = [loss_disc, loss_gen, loss_fm, loss_mel, loss_dur, loss_kl]
422
+ logger.info(
423
+ "Train Epoch: {} [{:.0f}%]".format(
424
+ epoch, 100.0 * batch_idx / len(train_loader)
425
+ )
426
+ )
427
+ logger.info([x.item() for x in losses] + [global_step, lr])
428
+
429
+ scalar_dict = {
430
+ "loss/g/total": loss_gen_all,
431
+ "loss/d/total": loss_disc_all,
432
+ "learning_rate": lr,
433
+ "grad_norm_d": grad_norm_d,
434
+ "grad_norm_g": grad_norm_g,
435
+ }
436
+ scalar_dict.update(
437
+ {
438
+ "loss/g/fm": loss_fm,
439
+ "loss/g/mel": loss_mel,
440
+ "loss/g/dur": loss_dur,
441
+ "loss/g/kl": loss_kl,
442
+ }
443
+ )
444
+ scalar_dict.update(
445
+ {"loss/g/{}".format(i): v for i, v in enumerate(losses_gen)}
446
+ )
447
+ scalar_dict.update(
448
+ {"loss/d_r/{}".format(i): v for i, v in enumerate(losses_disc_r)}
449
+ )
450
+ scalar_dict.update(
451
+ {"loss/d_g/{}".format(i): v for i, v in enumerate(losses_disc_g)}
452
+ )
453
+
454
+ image_dict = {
455
+ "slice/mel_org": utils.plot_spectrogram_to_numpy(
456
+ y_mel[0].data.cpu().numpy()
457
+ ),
458
+ "slice/mel_gen": utils.plot_spectrogram_to_numpy(
459
+ y_hat_mel[0].data.cpu().numpy()
460
+ ),
461
+ "all/mel": utils.plot_spectrogram_to_numpy(
462
+ mel[0].data.cpu().numpy()
463
+ ),
464
+ "all/attn": utils.plot_alignment_to_numpy(
465
+ attn[0, 0].data.cpu().numpy()
466
+ ),
467
+ }
468
+ utils.summarize(
469
+ writer=writer,
470
+ global_step=global_step,
471
+ images=image_dict,
472
+ scalars=scalar_dict,
473
+ )
474
+
475
+ if global_step % hps.train.eval_interval == 0:
476
+ evaluate(hps, net_g, eval_loader, writer_eval)
477
+ utils.save_checkpoint(
478
+ net_g,
479
+ optim_g,
480
+ hps.train.learning_rate,
481
+ epoch,
482
+ os.path.join(hps.model_dir, "G_{}.pth".format(global_step)),
483
+ )
484
+ utils.save_checkpoint(
485
+ net_d,
486
+ optim_d,
487
+ hps.train.learning_rate,
488
+ epoch,
489
+ os.path.join(hps.model_dir, "D_{}.pth".format(global_step)),
490
+ )
491
+ if net_dur_disc is not None:
492
+ utils.save_checkpoint(
493
+ net_dur_disc,
494
+ optim_dur_disc,
495
+ hps.train.learning_rate,
496
+ epoch,
497
+ os.path.join(hps.model_dir, "DUR_{}.pth".format(global_step)),
498
+ )
499
+ keep_ckpts = getattr(hps.train, "keep_ckpts", 5)
500
+ if keep_ckpts > 0:
501
+ utils.clean_checkpoints(
502
+ path_to_models=hps.model_dir,
503
+ n_ckpts_to_keep=keep_ckpts,
504
+ sort_by_time=True,
505
+ )
506
+
507
+ global_step += 1
508
+
509
+ if rank == 0:
510
+ logger.info("====> Epoch: {}".format(epoch))
511
+
512
+
513
+ def evaluate(hps, generator, eval_loader, writer_eval):
514
+ generator.eval()
515
+ image_dict = {}
516
+ audio_dict = {}
517
+ print("Evaluating ...")
518
+ with torch.no_grad():
519
+ for batch_idx, (
520
+ x,
521
+ x_lengths,
522
+ spec,
523
+ spec_lengths,
524
+ y,
525
+ y_lengths,
526
+ speakers,
527
+ tone,
528
+ language,
529
+ bert,
530
+ ja_bert,
531
+ ) in enumerate(eval_loader):
532
+ x, x_lengths = x.cuda(), x_lengths.cuda()
533
+ spec, spec_lengths = spec.cuda(), spec_lengths.cuda()
534
+ y, y_lengths = y.cuda(), y_lengths.cuda()
535
+ speakers = speakers.cuda()
536
+ bert = bert.cuda()
537
+ ja_bert = ja_bert.cuda()
538
+ tone = tone.cuda()
539
+ language = language.cuda()
540
+ for use_sdp in [True, False]:
541
+ y_hat, attn, mask, *_ = generator.module.infer(
542
+ x,
543
+ x_lengths,
544
+ speakers,
545
+ tone,
546
+ language,
547
+ bert,
548
+ ja_bert,
549
+ y=spec,
550
+ max_len=1000,
551
+ sdp_ratio=0.0 if not use_sdp else 1.0,
552
+ )
553
+ y_hat_lengths = mask.sum([1, 2]).long() * hps.data.hop_length
554
+
555
+ mel = spec_to_mel_torch(
556
+ spec,
557
+ hps.data.filter_length,
558
+ hps.data.n_mel_channels,
559
+ hps.data.sampling_rate,
560
+ hps.data.mel_fmin,
561
+ hps.data.mel_fmax,
562
+ )
563
+ y_hat_mel = mel_spectrogram_torch(
564
+ y_hat.squeeze(1).float(),
565
+ hps.data.filter_length,
566
+ hps.data.n_mel_channels,
567
+ hps.data.sampling_rate,
568
+ hps.data.hop_length,
569
+ hps.data.win_length,
570
+ hps.data.mel_fmin,
571
+ hps.data.mel_fmax,
572
+ )
573
+ image_dict.update(
574
+ {
575
+ f"gen/mel_{batch_idx}": utils.plot_spectrogram_to_numpy(
576
+ y_hat_mel[0].cpu().numpy()
577
+ )
578
+ }
579
+ )
580
+ audio_dict.update(
581
+ {
582
+ f"gen/audio_{batch_idx}_{use_sdp}": y_hat[
583
+ 0, :, : y_hat_lengths[0]
584
+ ]
585
+ }
586
+ )
587
+ image_dict.update(
588
+ {
589
+ f"gt/mel_{batch_idx}": utils.plot_spectrogram_to_numpy(
590
+ mel[0].cpu().numpy()
591
+ )
592
+ }
593
+ )
594
+ audio_dict.update({f"gt/audio_{batch_idx}": y[0, :, : y_lengths[0]]})
595
+
596
+ utils.summarize(
597
+ writer=writer_eval,
598
+ global_step=global_step,
599
+ images=image_dict,
600
+ audios=audio_dict,
601
+ audio_sampling_rate=hps.data.sampling_rate,
602
+ )
603
+ generator.train()
604
+
605
+
606
+ if __name__ == "__main__":
607
+ main()