Spaces:
Running
on
L4
Running
on
L4
File size: 3,231 Bytes
5b4c852 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 |
from typing import Dict, Optional
import torch
import torch.nn as nn
import torch.nn.functional as F
from matcha.hifigan.models import feature_loss, generator_loss, discriminator_loss
from cosyvoice.utils.losses import tpr_loss, mel_loss
class HiFiGan(nn.Module):
def __init__(self, generator, discriminator, mel_spec_transform,
multi_mel_spectral_recon_loss_weight=45, feat_match_loss_weight=2.0,
tpr_loss_weight=1.0, tpr_loss_tau=0.04):
super(HiFiGan, self).__init__()
self.generator = generator
self.discriminator = discriminator
self.mel_spec_transform = mel_spec_transform
self.multi_mel_spectral_recon_loss_weight = multi_mel_spectral_recon_loss_weight
self.feat_match_loss_weight = feat_match_loss_weight
self.tpr_loss_weight = tpr_loss_weight
self.tpr_loss_tau = tpr_loss_tau
def forward(
self,
batch: dict,
device: torch.device,
) -> Dict[str, Optional[torch.Tensor]]:
if batch['turn'] == 'generator':
return self.forward_generator(batch, device)
else:
return self.forward_discriminator(batch, device)
def forward_generator(self, batch, device):
real_speech = batch['speech'].to(device)
pitch_feat = batch['pitch_feat'].to(device)
# 1. calculate generator outputs
generated_speech, generated_f0 = self.generator(batch, device)
# 2. calculate discriminator outputs
y_d_rs, y_d_gs, fmap_rs, fmap_gs = self.discriminator(real_speech, generated_speech)
# 3. calculate generator losses, feature loss, mel loss, tpr losses [Optional]
loss_gen, _ = generator_loss(y_d_gs)
loss_fm = feature_loss(fmap_rs, fmap_gs)
loss_mel = mel_loss(real_speech, generated_speech, self.mel_spec_transform)
if self.tpr_loss_weight != 0:
loss_tpr = tpr_loss(y_d_rs, y_d_gs, self.tpr_loss_tau)
else:
loss_tpr = torch.zeros(1).to(device)
loss_f0 = F.l1_loss(generated_f0, pitch_feat)
loss = loss_gen + self.feat_match_loss_weight * loss_fm + \
self.multi_mel_spectral_recon_loss_weight * loss_mel + \
self.tpr_loss_weight * loss_tpr + loss_f0
return {'loss': loss, 'loss_gen': loss_gen, 'loss_fm': loss_fm, 'loss_mel': loss_mel, 'loss_tpr': loss_tpr, 'loss_f0': loss_f0}
def forward_discriminator(self, batch, device):
real_speech = batch['speech'].to(device)
# 1. calculate generator outputs
with torch.no_grad():
generated_speech, generated_f0 = self.generator(batch, device)
# 2. calculate discriminator outputs
y_d_rs, y_d_gs, fmap_rs, fmap_gs = self.discriminator(real_speech, generated_speech)
# 3. calculate discriminator losses, tpr losses [Optional]
loss_disc, _, _ = discriminator_loss(y_d_rs, y_d_gs)
if self.tpr_loss_weight != 0:
loss_tpr = tpr_loss(y_d_rs, y_d_gs, self.tpr_loss_tau)
else:
loss_tpr = torch.zeros(1).to(device)
loss = loss_disc + self.tpr_loss_weight * loss_tpr
return {'loss': loss, 'loss_disc': loss_disc, 'loss_tpr': loss_tpr}
|