Text-to-Audio / modules /encoder /condition_encoder.py
yuancwang
init
5548515
# Copyright (c) 2023 Amphion.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import numpy as np
import torch
import torch.nn as nn
from torchaudio.models import Conformer
from models.svc.transformer.transformer import PositionalEncoding
from utils.f0 import f0_to_coarse
class ContentEncoder(nn.Module):
def __init__(self, cfg, input_dim, output_dim):
super().__init__()
self.cfg = cfg
assert input_dim != 0
self.nn = nn.Linear(input_dim, output_dim)
# Introduce conformer or not
if (
"use_conformer_for_content_features" in cfg
and cfg.use_conformer_for_content_features
):
self.pos_encoder = PositionalEncoding(input_dim)
self.conformer = Conformer(
input_dim=input_dim,
num_heads=2,
ffn_dim=256,
num_layers=6,
depthwise_conv_kernel_size=3,
)
else:
self.conformer = None
def forward(self, x, length=None):
# x: (N, seq_len, input_dim) -> (N, seq_len, output_dim)
if self.conformer:
x = self.pos_encoder(x)
x, _ = self.conformer(x, length)
return self.nn(x)
class MelodyEncoder(nn.Module):
def __init__(self, cfg):
super().__init__()
self.cfg = cfg
self.input_dim = self.cfg.input_melody_dim
self.output_dim = self.cfg.output_melody_dim
self.n_bins = self.cfg.n_bins_melody
self.pitch_min = self.cfg.pitch_min
self.pitch_max = self.cfg.pitch_max
if self.input_dim != 0:
if self.n_bins == 0:
# Not use quantization
self.nn = nn.Linear(self.input_dim, self.output_dim)
else:
self.f0_min = cfg.f0_min
self.f0_max = cfg.f0_max
self.nn = nn.Embedding(
num_embeddings=self.n_bins,
embedding_dim=self.output_dim,
padding_idx=None,
)
self.uv_embedding = nn.Embedding(2, self.output_dim)
# self.conformer = Conformer(
# input_dim=self.output_dim,
# num_heads=4,
# ffn_dim=128,
# num_layers=4,
# depthwise_conv_kernel_size=3,
# )
def forward(self, x, uv=None, length=None):
# x: (N, frame_len)
# print(x.shape)
if self.n_bins == 0:
x = x.unsqueeze(-1)
else:
x = f0_to_coarse(x, self.n_bins, self.f0_min, self.f0_max)
x = self.nn(x)
if uv is not None:
uv = self.uv_embedding(uv)
x = x + uv
# x, _ = self.conformer(x, length)
return x
class LoudnessEncoder(nn.Module):
def __init__(self, cfg):
super().__init__()
self.cfg = cfg
self.input_dim = self.cfg.input_loudness_dim
self.output_dim = self.cfg.output_loudness_dim
self.n_bins = self.cfg.n_bins_loudness
if self.input_dim != 0:
if self.n_bins == 0:
# Not use quantization
self.nn = nn.Linear(self.input_dim, self.output_dim)
else:
# TODO: set trivially now
self.loudness_min = 1e-30
self.loudness_max = 1.5
if cfg.use_log_loudness:
self.energy_bins = nn.Parameter(
torch.exp(
torch.linspace(
np.log(self.loudness_min),
np.log(self.loudness_max),
self.n_bins - 1,
)
),
requires_grad=False,
)
self.nn = nn.Embedding(
num_embeddings=self.n_bins,
embedding_dim=self.output_dim,
padding_idx=None,
)
def forward(self, x):
# x: (N, frame_len)
if self.n_bins == 0:
x = x.unsqueeze(-1)
else:
x = torch.bucketize(x, self.energy_bins)
return self.nn(x)
class SingerEncoder(nn.Module):
def __init__(self, cfg):
super().__init__()
self.cfg = cfg
self.input_dim = 1
self.output_dim = self.cfg.output_singer_dim
self.nn = nn.Embedding(
num_embeddings=cfg.singer_table_size,
embedding_dim=self.output_dim,
padding_idx=None,
)
def forward(self, x):
# x: (N, 1) -> (N, 1, output_dim)
return self.nn(x)
class ConditionEncoder(nn.Module):
def __init__(self, cfg):
super().__init__()
self.cfg = cfg
self.merge_mode = cfg.merge_mode
if cfg.use_whisper:
self.whisper_encoder = ContentEncoder(
self.cfg, self.cfg.whisper_dim, self.cfg.content_encoder_dim
)
if cfg.use_contentvec:
self.contentvec_encoder = ContentEncoder(
self.cfg, self.cfg.contentvec_dim, self.cfg.content_encoder_dim
)
if cfg.use_mert:
self.mert_encoder = ContentEncoder(
self.cfg, self.cfg.mert_dim, self.cfg.content_encoder_dim
)
if cfg.use_wenet:
self.wenet_encoder = ContentEncoder(
self.cfg, self.cfg.wenet_dim, self.cfg.content_encoder_dim
)
self.melody_encoder = MelodyEncoder(self.cfg)
self.loudness_encoder = LoudnessEncoder(self.cfg)
if cfg.use_spkid:
self.singer_encoder = SingerEncoder(self.cfg)
def forward(self, x):
outputs = []
if "frame_pitch" in x.keys():
if "frame_uv" not in x.keys():
x["frame_uv"] = None
pitch_enc_out = self.melody_encoder(
x["frame_pitch"], uv=x["frame_uv"], length=x["target_len"]
)
outputs.append(pitch_enc_out)
if "frame_energy" in x.keys():
loudness_enc_out = self.loudness_encoder(x["frame_energy"])
outputs.append(loudness_enc_out)
if "whisper_feat" in x.keys():
# whisper_feat: [b, T, 1024]
whiser_enc_out = self.whisper_encoder(
x["whisper_feat"], length=x["target_len"]
)
outputs.append(whiser_enc_out)
seq_len = whiser_enc_out.shape[1]
if "contentvec_feat" in x.keys():
contentvec_enc_out = self.contentvec_encoder(
x["contentvec_feat"], length=x["target_len"]
)
outputs.append(contentvec_enc_out)
seq_len = contentvec_enc_out.shape[1]
if "mert_feat" in x.keys():
mert_enc_out = self.mert_encoder(x["mert_feat"], length=x["target_len"])
outputs.append(mert_enc_out)
seq_len = mert_enc_out.shape[1]
if "wenet_feat" in x.keys():
wenet_enc_out = self.wenet_encoder(x["wenet_feat"], length=x["target_len"])
outputs.append(wenet_enc_out)
seq_len = wenet_enc_out.shape[1]
if "spk_id" in x.keys():
speaker_enc_out = self.singer_encoder(x["spk_id"]) # [b, 1, 384]
assert (
"whisper_feat" in x.keys()
or "contentvec_feat" in x.keys()
or "mert_feat" in x.keys()
or "wenet_feat" in x.keys()
)
singer_info = speaker_enc_out.expand(-1, seq_len, -1)
outputs.append(singer_info)
encoder_output = None
if self.merge_mode == "concat":
encoder_output = torch.cat(outputs, dim=-1)
if self.merge_mode == "add":
# (#modules, N, seq_len, output_dim)
outputs = torch.cat([out[None, :, :, :] for out in outputs], dim=0)
# (N, seq_len, output_dim)
encoder_output = torch.sum(outputs, dim=0)
return encoder_output