import time import torch import torch.nn as nn from torch.nn import functional as F from fireredtts.modules.flow.utils import make_pad_mask class InterpolateRegulator(nn.Module): def __init__( self, channels: int = 512, num_blocks: int = 4, groups: int = 1, ): super().__init__() model = [] for _ in range(num_blocks): model.extend([ nn.Conv1d(channels, channels, 3, 1, 1), nn.GroupNorm(groups, channels), nn.Mish(), ]) model.append( nn.Conv1d(channels, channels, 1, 1) ) self.model = nn.Sequential(*model) def forward(self, x, ylens=None): # x in (B, T, D) mask = (~make_pad_mask(ylens)).to(x).unsqueeze(-1) x = F.interpolate(x.transpose(1, 2).contiguous(), size=ylens.max(), mode='nearest') out = self.model(x).transpose(1, 2).contiguous() olens = ylens return out * mask, olens class CrossAttnFlowMatching(nn.Module): def __init__(self, output_size: int, input_embedding: nn.Module, encoder: nn.Module, length_regulator: nn.Module, mel_encoder: nn.Module, decoder: nn.Module, ): super().__init__() self.input_embedding = input_embedding self.encoder = encoder self.length_regulator = length_regulator self.encoder_proj = torch.nn.Linear(self.encoder.output_size, output_size) self.prompt_prenet = mel_encoder self.decoder = decoder def inference(self, token: torch.Tensor, token_len: torch.Tensor, prompt_mel: torch.Tensor, prompt_mel_len: torch.Tensor, n_timesteps:int=10, ): # prompt projection prompt_feat = self.prompt_prenet(prompt_mel) prompt_feat_len = torch.ceil(prompt_mel_len/self.prompt_prenet.reduction_rate).long() # concat text and prompt_text mask = (~make_pad_mask(token_len)).float().unsqueeze(-1).to(token_len.device) token = self.input_embedding(torch.clamp(token, min=0)) * mask # 40ms shift to 10ms shift feat_len = (token_len *4).int() # first encoder h, _ = self.encoder(token, token_len, prompt_feat, prompt_feat_len) # length regulate h, _ = self.length_regulator(h, feat_len) # final projection h = self.encoder_proj(h) mask = (~make_pad_mask(feat_len)).to(h) feat = self.decoder.inference( mu=h.transpose(1, 2).contiguous(), mask=mask.unsqueeze(1), n_timesteps=n_timesteps, ) return feat