|
import torch
|
|
import torchaudio
|
|
from transformers import AutoModel
|
|
|
|
|
|
def feature_loss(fmap_r, fmap_g):
|
|
loss = 0
|
|
for dr, dg in zip(fmap_r, fmap_g):
|
|
for rl, gl in zip(dr, dg):
|
|
rl = rl.float().detach()
|
|
gl = gl.float()
|
|
loss += torch.mean(torch.abs(rl - gl))
|
|
|
|
return loss * 2
|
|
|
|
|
|
def discriminator_loss(disc_real_outputs, disc_generated_outputs):
|
|
loss = 0
|
|
r_losses = []
|
|
g_losses = []
|
|
for dr, dg in zip(disc_real_outputs, disc_generated_outputs):
|
|
dr = dr.float()
|
|
dg = dg.float()
|
|
r_loss = torch.mean((1 - dr) ** 2)
|
|
g_loss = torch.mean(dg**2)
|
|
loss += r_loss + g_loss
|
|
r_losses.append(r_loss.item())
|
|
g_losses.append(g_loss.item())
|
|
|
|
return loss, r_losses, g_losses
|
|
|
|
|
|
def generator_loss(disc_outputs):
|
|
loss = 0
|
|
gen_losses = []
|
|
for dg in disc_outputs:
|
|
dg = dg.float()
|
|
l = torch.mean((1 - dg) ** 2)
|
|
gen_losses.append(l)
|
|
loss += l
|
|
|
|
return loss, gen_losses
|
|
|
|
|
|
def kl_loss(z_p, logs_q, m_p, logs_p, z_mask):
|
|
"""
|
|
z_p, logs_q: [b, h, t_t]
|
|
m_p, logs_p: [b, h, t_t]
|
|
"""
|
|
z_p = z_p.float()
|
|
logs_q = logs_q.float()
|
|
m_p = m_p.float()
|
|
logs_p = logs_p.float()
|
|
z_mask = z_mask.float()
|
|
|
|
kl = logs_p - logs_q - 0.5
|
|
kl += 0.5 * ((z_p - m_p) ** 2) * torch.exp(-2.0 * logs_p)
|
|
kl = torch.sum(kl * z_mask)
|
|
l = kl / torch.sum(z_mask)
|
|
return l
|
|
|
|
|
|
class WavLMLoss(torch.nn.Module):
|
|
def __init__(self, model, wd, model_sr, slm_sr=16000):
|
|
super(WavLMLoss, self).__init__()
|
|
self.wavlm = AutoModel.from_pretrained(model)
|
|
self.wd = wd
|
|
self.resample = torchaudio.transforms.Resample(model_sr, slm_sr)
|
|
self.wavlm.eval()
|
|
for param in self.wavlm.parameters():
|
|
param.requires_grad = False
|
|
|
|
def forward(self, wav, y_rec):
|
|
with torch.no_grad():
|
|
wav_16 = self.resample(wav)
|
|
wav_embeddings = self.wavlm(
|
|
input_values=wav_16, output_hidden_states=True
|
|
).hidden_states
|
|
y_rec_16 = self.resample(y_rec)
|
|
y_rec_embeddings = self.wavlm(
|
|
input_values=y_rec_16.squeeze(), output_hidden_states=True
|
|
).hidden_states
|
|
|
|
floss = 0
|
|
for er, eg in zip(wav_embeddings, y_rec_embeddings):
|
|
floss += torch.mean(torch.abs(er - eg))
|
|
|
|
return floss.mean()
|
|
|
|
def generator(self, y_rec):
|
|
y_rec_16 = self.resample(y_rec)
|
|
y_rec_embeddings = self.wavlm(
|
|
input_values=y_rec_16, output_hidden_states=True
|
|
).hidden_states
|
|
y_rec_embeddings = (
|
|
torch.stack(y_rec_embeddings, dim=1)
|
|
.transpose(-1, -2)
|
|
.flatten(start_dim=1, end_dim=2)
|
|
)
|
|
y_df_hat_g = self.wd(y_rec_embeddings)
|
|
loss_gen = torch.mean((1 - y_df_hat_g) ** 2)
|
|
|
|
return loss_gen
|
|
|
|
def discriminator(self, wav, y_rec):
|
|
with torch.no_grad():
|
|
wav_16 = self.resample(wav)
|
|
wav_embeddings = self.wavlm(
|
|
input_values=wav_16, output_hidden_states=True
|
|
).hidden_states
|
|
y_rec_16 = self.resample(y_rec)
|
|
y_rec_embeddings = self.wavlm(
|
|
input_values=y_rec_16, output_hidden_states=True
|
|
).hidden_states
|
|
|
|
y_embeddings = (
|
|
torch.stack(wav_embeddings, dim=1)
|
|
.transpose(-1, -2)
|
|
.flatten(start_dim=1, end_dim=2)
|
|
)
|
|
y_rec_embeddings = (
|
|
torch.stack(y_rec_embeddings, dim=1)
|
|
.transpose(-1, -2)
|
|
.flatten(start_dim=1, end_dim=2)
|
|
)
|
|
|
|
y_d_rs = self.wd(y_embeddings)
|
|
y_d_gs = self.wd(y_rec_embeddings)
|
|
|
|
y_df_hat_r, y_df_hat_g = y_d_rs, y_d_gs
|
|
|
|
r_loss = torch.mean((1 - y_df_hat_r) ** 2)
|
|
g_loss = torch.mean((y_df_hat_g) ** 2)
|
|
|
|
loss_disc_f = r_loss + g_loss
|
|
|
|
return loss_disc_f.mean()
|
|
|
|
def discriminator_forward(self, wav):
|
|
with torch.no_grad():
|
|
wav_16 = self.resample(wav)
|
|
wav_embeddings = self.wavlm(
|
|
input_values=wav_16, output_hidden_states=True
|
|
).hidden_states
|
|
y_embeddings = (
|
|
torch.stack(wav_embeddings, dim=1)
|
|
.transpose(-1, -2)
|
|
.flatten(start_dim=1, end_dim=2)
|
|
)
|
|
|
|
y_d_rs = self.wd(y_embeddings)
|
|
|
|
return y_d_rs
|
|
|