from librosa.filters import mel as librosa_mel_fn from torch import nn from torch.nn import functional as F import math import numpy as np import torch import torchaudio def dynamic_range_compression_torch(x, C=1, clip_val=1e-5): return torch.log(torch.clamp(x, min=clip_val) * C) def spectral_normalize_torch(magnitudes): output = dynamic_range_compression_torch(magnitudes) return output class TorchMelSpectrogram(nn.Module): def __init__( self, filter_length=1024, hop_length=160, win_length=640, n_mel_channels=80, mel_fmin=0, mel_fmax=8000, sampling_rate=16000, ): super().__init__() self.filter_length = filter_length self.hop_length = hop_length self.win_length = win_length self.n_mel_channels = n_mel_channels self.mel_fmin = mel_fmin self.mel_fmax = mel_fmax self.sampling_rate = sampling_rate self.mel_basis = {} self.hann_window = {} def forward(self, inp, length=None): if len(inp.shape) == 3: inp = inp.squeeze(1) if inp.shape[1] == 1 else inp.squeeze(2) assert len(inp.shape) == 2 y = inp if len(list(self.mel_basis.keys())) == 0: mel = librosa_mel_fn( sr=self.sampling_rate, n_fft=self.filter_length, n_mels=self.n_mel_channels, fmin=self.mel_fmin, fmax=self.mel_fmax, ) self.mel_basis[str(self.mel_fmax) + "_" + str(y.device)] = ( torch.from_numpy(mel).float().to(y.device) ) self.hann_window[str(y.device)] = torch.hann_window(self.win_length).to( y.device ) y = torch.nn.functional.pad( y.unsqueeze(1), ( int((self.filter_length - self.hop_length) / 2), int((self.filter_length - self.hop_length) / 2), ), mode="reflect", ) y = y.squeeze(1) # complex tensor as default, then use view_as_real for future pytorch compatibility spec = torch.stft( y, self.filter_length, hop_length=self.hop_length, win_length=self.win_length, window=self.hann_window[str(y.device)], center=False, pad_mode="reflect", normalized=False, onesided=True, return_complex=True, ) spec = torch.view_as_real(spec) spec = torch.sqrt(spec.pow(2).sum(-1) + (1e-9)) spec = torch.matmul( self.mel_basis[str(self.mel_fmax) + "_" + str(y.device)], spec ) spec = spectral_normalize_torch(spec) max_mel_length = math.ceil(y.shape[-1] / self.hop_length) spec = spec[..., :max_mel_length].transpose(1, 2) if length is None: return spec else: spec_len = torch.ceil(length / self.hop_length).clamp(max=spec.shape[1]) return spec, spec_len def length_to_mask(length, max_len=None, dtype=None, device=None): """Creates a binary mask for each sequence. Reference: https://discuss.pytorch.org/t/how-to-generate-variable-length-mask/23397/3 Arguments --------- length : torch.LongTensor Containing the length of each sequence in the batch. Must be 1D. max_len : int Max length for the mask, also the size of the second dimension. dtype : torch.dtype, default: None The dtype of the generated mask. device: torch.device, default: None The device to put the mask variable. Returns ------- mask : tensor The binary mask. Example ------- >>> length=torch.Tensor([1,2,3]) >>> mask=length_to_mask(length) >>> mask tensor([[1., 0., 0.], [1., 1., 0.], [1., 1., 1.]]) """ assert len(length.shape) == 1 if max_len is None: max_len = length.max().long().item() # using arange to generate mask mask = torch.arange(max_len, device=length.device, dtype=length.dtype).expand( len(length), max_len ) < length.unsqueeze(1) if dtype is None: dtype = length.dtype if device is None: device = length.device mask = torch.as_tensor(mask, dtype=dtype, device=device) return mask def get_padding_elem(L_in: int, stride: int, kernel_size: int, dilation: int): """This function computes the number of elements to add for zero-padding. Arguments --------- L_in : int stride: int kernel_size : int dilation : int """ if stride > 1: n_steps = math.ceil(((L_in - kernel_size * dilation) / stride) + 1) L_out = stride * (n_steps - 1) + kernel_size * dilation padding = [kernel_size // 2, kernel_size // 2] else: L_out = (L_in - dilation * (kernel_size - 1) - 1) // stride + 1 padding = [(L_in - L_out) // 2, (L_in - L_out) // 2] return padding class Conv1d(nn.Module): """This function implements 1d convolution. Arguments --------- out_channels : int It is the number of output channels. kernel_size : int Kernel size of the convolutional filters. input_shape : tuple The shape of the input. Alternatively use ``in_channels``. in_channels : int The number of input channels. Alternatively use ``input_shape``. stride : int Stride factor of the convolutional filters. When the stride factor > 1, a decimation in time is performed. dilation : int Dilation factor of the convolutional filters. padding : str (same, valid, causal). If "valid", no padding is performed. If "same" and stride is 1, output shape is the same as the input shape. "causal" results in causal (dilated) convolutions. padding_mode : str This flag specifies the type of padding. See torch.nn documentation for more information. skip_transpose : bool If False, uses batch x time x channel convention of speechbrain. If True, uses batch x channel x time convention. Example ------- >>> inp_tensor = torch.rand([10, 40, 16]) >>> cnn_1d = Conv1d( ... input_shape=inp_tensor.shape, out_channels=8, kernel_size=5 ... ) >>> out_tensor = cnn_1d(inp_tensor) >>> out_tensor.shape torch.Size([10, 40, 8]) """ def __init__( self, out_channels, kernel_size, input_shape=None, in_channels=None, stride=1, dilation=1, padding="same", groups=1, bias=True, padding_mode="reflect", skip_transpose=True, ): super().__init__() self.kernel_size = kernel_size self.stride = stride self.dilation = dilation self.padding = padding self.padding_mode = padding_mode self.unsqueeze = False self.skip_transpose = skip_transpose if input_shape is None and in_channels is None: raise ValueError("Must provide one of input_shape or in_channels") if in_channels is None: in_channels = self._check_input_shape(input_shape) self.conv = nn.Conv1d( in_channels, out_channels, self.kernel_size, stride=self.stride, dilation=self.dilation, padding=0, groups=groups, bias=bias, ) def forward(self, x): """Returns the output of the convolution. Arguments --------- x : torch.Tensor (batch, time, channel) input to convolve. 2d or 4d tensors are expected. """ if not self.skip_transpose: x = x.transpose(1, -1) if self.unsqueeze: x = x.unsqueeze(1) if self.padding == "same": x = self._manage_padding(x, self.kernel_size, self.dilation, self.stride) elif self.padding == "causal": num_pad = (self.kernel_size - 1) * self.dilation x = F.pad(x, (num_pad, 0)) elif self.padding == "valid": pass else: raise ValueError( "Padding must be 'same', 'valid' or 'causal'. Got " + self.padding ) wx = self.conv(x) if self.unsqueeze: wx = wx.squeeze(1) if not self.skip_transpose: wx = wx.transpose(1, -1) return wx def _manage_padding( self, x, kernel_size: int, dilation: int, stride: int, ): """This function performs zero-padding on the time axis such that their lengths is unchanged after the convolution. Arguments --------- x : torch.Tensor Input tensor. kernel_size : int Size of kernel. dilation : int Dilation used. stride : int Stride. """ # Detecting input shape L_in = x.shape[-1] # Time padding padding = get_padding_elem(L_in, stride, kernel_size, dilation) # Applying padding x = F.pad(x, padding, mode=self.padding_mode) return x def _check_input_shape(self, shape): """Checks the input shape and returns the number of input channels.""" if len(shape) == 2: self.unsqueeze = True in_channels = 1 elif self.skip_transpose: in_channels = shape[1] elif len(shape) == 3: in_channels = shape[2] else: raise ValueError("conv1d expects 2d, 3d inputs. Got " + str(len(shape))) # Kernel size must be odd if self.kernel_size % 2 == 0: raise ValueError( "The field kernel size must be an odd number. Got %s." % (self.kernel_size) ) return in_channels class Fp32BatchNorm(nn.Module): def __init__(self, sync=True, *args, **kwargs): super().__init__() if ( not torch.distributed.is_initialized() or torch.distributed.get_world_size() == 1 ): sync = False if sync: self.bn = nn.SyncBatchNorm(*args, **kwargs) else: self.bn = nn.BatchNorm1d(*args, **kwargs) self.sync = sync def forward(self, input): if self.bn.running_mean.dtype != torch.float: if self.sync: self.bn.running_mean = self.bn.running_mean.float() self.bn.running_var = self.bn.running_var.float() if self.bn.affine: try: self.bn.weight = self.bn.weight.float() self.bn.bias = self.bn.bias.float() except: self.bn.float() else: self.bn.float() output = self.bn(input.float()) return output.type_as(input) class BatchNorm1d(nn.Module): """Applies 1d batch normalization to the input tensor. Arguments --------- input_shape : tuple The expected shape of the input. Alternatively, use ``input_size``. input_size : int The expected size of the input. Alternatively, use ``input_shape``. eps : float This value is added to std deviation estimation to improve the numerical stability. momentum : float It is a value used for the running_mean and running_var computation. affine : bool When set to True, the affine parameters are learned. track_running_stats : bool When set to True, this module tracks the running mean and variance, and when set to False, this module does not track such statistics. combine_batch_time : bool When true, it combines batch an time axis. Example ------- >>> input = torch.randn(100, 10) >>> norm = BatchNorm1d(input_shape=input.shape) >>> output = norm(input) >>> output.shape torch.Size([100, 10]) """ def __init__( self, input_shape=None, input_size=None, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True, combine_batch_time=False, skip_transpose=True, enabled=True, ): super().__init__() self.combine_batch_time = combine_batch_time self.skip_transpose = skip_transpose if input_size is None and skip_transpose: input_size = input_shape[1] elif input_size is None: input_size = input_shape[-1] if enabled: self.norm = Fp32BatchNorm( num_features=input_size, eps=eps, momentum=momentum, affine=affine, track_running_stats=track_running_stats, ) else: self.norm = nn.Identity() def forward(self, x): """Returns the normalized input tensor. Arguments --------- x : torch.Tensor (batch, time, [channels]) input to normalize. 2d or 3d tensors are expected in input 4d tensors can be used when combine_dims=True. """ shape_or = x.shape if self.combine_batch_time: if x.ndim == 3: x = x.reshape(shape_or[0] * shape_or[1], shape_or[2]) else: x = x.reshape(shape_or[0] * shape_or[1], shape_or[3], shape_or[2]) elif not self.skip_transpose: x = x.transpose(-1, 1) x_n = self.norm(x) if self.combine_batch_time: x_n = x_n.reshape(shape_or) elif not self.skip_transpose: x_n = x_n.transpose(1, -1) return x_n class Linear(torch.nn.Module): """Computes a linear transformation y = wx + b. Arguments --------- n_neurons : int It is the number of output neurons (i.e, the dimensionality of the output). bias : bool If True, the additive bias b is adopted. combine_dims : bool If True and the input is 4D, combine 3rd and 4th dimensions of input. Example ------- >>> inputs = torch.rand(10, 50, 40) >>> lin_t = Linear(input_shape=(10, 50, 40), n_neurons=100) >>> output = lin_t(inputs) >>> output.shape torch.Size([10, 50, 100]) """ def __init__( self, n_neurons, input_shape=None, input_size=None, bias=True, combine_dims=False, ): super().__init__() self.combine_dims = combine_dims if input_shape is None and input_size is None: raise ValueError("Expected one of input_shape or input_size") if input_size is None: input_size = input_shape[-1] if len(input_shape) == 4 and self.combine_dims: input_size = input_shape[2] * input_shape[3] # Weights are initialized following pytorch approach self.w = nn.Linear(input_size, n_neurons, bias=bias) def forward(self, x): """Returns the linear transformation of input tensor. Arguments --------- x : torch.Tensor Input to transform linearly. """ if x.ndim == 4 and self.combine_dims: x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3]) wx = self.w(x) return wx class TDNNBlock(nn.Module): """An implementation of TDNN. Arguments ---------- in_channels : int Number of input channels. out_channels : int The number of output channels. kernel_size : int The kernel size of the TDNN blocks. dilation : int The dilation of the Res2Net block. activation : torch class A class for constructing the activation layers. Example ------- >>> inp_tensor = torch.rand([8, 120, 64]).transpose(1, 2) >>> layer = TDNNBlock(64, 64, kernel_size=3, dilation=1) >>> out_tensor = layer(inp_tensor).transpose(1, 2) >>> out_tensor.shape torch.Size([8, 120, 64]) """ def __init__( self, in_channels, out_channels, kernel_size, dilation, activation=nn.ReLU, batch_norm=True, ): super(TDNNBlock, self).__init__() self.conv = Conv1d( in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, dilation=dilation, ) self.activation = activation() self.norm = BatchNorm1d(input_size=out_channels, enabled=batch_norm) def forward(self, x): return self.norm(self.activation(self.conv(x))) class Res2NetBlock(torch.nn.Module): """An implementation of Res2NetBlock w/ dilation. Arguments --------- in_channels : int The number of channels expected in the input. out_channels : int The number of output channels. scale : int The scale of the Res2Net block. kernel_size: int The kernel size of the Res2Net block. dilation : int The dilation of the Res2Net block. Example ------- >>> inp_tensor = torch.rand([8, 120, 64]).transpose(1, 2) >>> layer = Res2NetBlock(64, 64, scale=4, dilation=3) >>> out_tensor = layer(inp_tensor).transpose(1, 2) >>> out_tensor.shape torch.Size([8, 120, 64]) """ def __init__( self, in_channels, out_channels, scale=8, kernel_size=3, dilation=1, batch_norm=True, ): super(Res2NetBlock, self).__init__() assert in_channels % scale == 0 assert out_channels % scale == 0 in_channel = in_channels // scale hidden_channel = out_channels // scale self.blocks = nn.ModuleList( [ TDNNBlock( in_channel, hidden_channel, kernel_size=kernel_size, dilation=dilation, batch_norm=batch_norm, ) for i in range(scale - 1) ] ) self.scale = scale def forward(self, x): y = [] for i, x_i in enumerate(torch.chunk(x, self.scale, dim=1)): if i == 0: y_i = x_i elif i == 1: y_i = self.blocks[i - 1](x_i) else: y_i = self.blocks[i - 1](x_i + y_i) y.append(y_i) y = torch.cat(y, dim=1) return y class SEBlock(nn.Module): """An implementation of squeeze-and-excitation block. Arguments --------- in_channels : int The number of input channels. se_channels : int The number of output channels after squeeze. out_channels : int The number of output channels. Example ------- >>> inp_tensor = torch.rand([8, 120, 64]).transpose(1, 2) >>> se_layer = SEBlock(64, 16, 64) >>> lengths = torch.rand((8,)) >>> out_tensor = se_layer(inp_tensor, lengths).transpose(1, 2) >>> out_tensor.shape torch.Size([8, 120, 64]) """ def __init__(self, in_channels, se_channels, out_channels): super(SEBlock, self).__init__() self.conv1 = Conv1d( in_channels=in_channels, out_channels=se_channels, kernel_size=1 ) self.relu = torch.nn.ReLU(inplace=True) self.conv2 = Conv1d( in_channels=se_channels, out_channels=out_channels, kernel_size=1 ) self.sigmoid = torch.nn.Sigmoid() def forward(self, x, lengths=None): L = x.shape[-1] if lengths is not None: mask = length_to_mask(lengths * L, max_len=L, device=x.device) mask = mask.unsqueeze(1) total = mask.sum(dim=2, keepdim=True) s = (x * mask).sum(dim=2, keepdim=True) / total else: s = x.mean(dim=2, keepdim=True) s = self.relu(self.conv1(s)) s = self.sigmoid(self.conv2(s)) return s * x class AttentiveStatisticsPooling(nn.Module): """This class implements an attentive statistic pooling layer for each channel. It returns the concatenated mean and std of the input tensor. Arguments --------- channels: int The number of input channels. attention_channels: int The number of attention channels. Example ------- >>> inp_tensor = torch.rand([8, 120, 64]).transpose(1, 2) >>> asp_layer = AttentiveStatisticsPooling(64) >>> lengths = torch.rand((8,)) >>> out_tensor = asp_layer(inp_tensor, lengths).transpose(1, 2) >>> out_tensor.shape torch.Size([8, 1, 128]) """ def __init__( self, channels, attention_channels=128, global_context=True, batch_norm=True ): super().__init__() self.eps = 1e-12 self.global_context = global_context if global_context: self.tdnn = TDNNBlock( channels * 3, attention_channels, 1, 1, batch_norm=batch_norm ) else: self.tdnn = TDNNBlock( channels, attention_channels, 1, 1, batch_norm, batch_norm ) self.tanh = nn.Tanh() self.conv = Conv1d( in_channels=attention_channels, out_channels=channels, kernel_size=1 ) def forward(self, x, lengths=None): """Calculates mean and std for a batch (input tensor). Arguments --------- x : torch.Tensor Tensor of shape [N, C, L]. """ L = x.shape[-1] def _compute_statistics(x, m, dim=2, eps=self.eps): mean = (m * x).sum(dim) std = torch.sqrt((m * (x - mean.unsqueeze(dim)).pow(2)).sum(dim).clamp(eps)) return mean, std if lengths is None: lengths = torch.ones(x.shape[0], device=x.device) # Make binary mask of shape [N, 1, L] mask = length_to_mask(lengths * L, max_len=L, device=x.device) mask = mask.unsqueeze(1) # Expand the temporal context of the pooling layer by allowing the # self-attention to look at global properties of the utterance. if self.global_context: # torch.std is unstable for backward computation # https://github.com/pytorch/pytorch/issues/4320 total = mask.sum(dim=2, keepdim=True).float() mean, std = _compute_statistics(x, mask / total) mean = mean.unsqueeze(2).repeat(1, 1, L) std = std.unsqueeze(2).repeat(1, 1, L) attn = torch.cat([x, mean, std], dim=1) else: attn = x # Apply layers attn = self.conv(self.tanh(self.tdnn(attn))) # Filter out zero-paddings attn = attn.masked_fill(mask == 0, float("-inf")) attn = F.softmax(attn, dim=2) mean, std = _compute_statistics(x, attn) # Append mean and std of the batch pooled_stats = torch.cat((mean, std), dim=1) pooled_stats = pooled_stats.unsqueeze(2) return pooled_stats class SERes2NetBlock(nn.Module): """An implementation of building block in ECAPA-TDNN, i.e., TDNN-Res2Net-TDNN-SEBlock. Arguments ---------- out_channels: int The number of output channels. res2net_scale: int The scale of the Res2Net block. kernel_size: int The kernel size of the TDNN blocks. dilation: int The dilation of the Res2Net block. activation : torch class A class for constructing the activation layers. Example ------- >>> x = torch.rand(8, 120, 64).transpose(1, 2) >>> conv = SERes2NetBlock(64, 64, res2net_scale=4) >>> out = conv(x).transpose(1, 2) >>> out.shape torch.Size([8, 120, 64]) """ def __init__( self, in_channels, out_channels, res2net_scale=8, se_channels=128, kernel_size=1, dilation=1, activation=torch.nn.ReLU, batch_norm=True, ): super().__init__() self.out_channels = out_channels self.tdnn1 = TDNNBlock( in_channels, out_channels, kernel_size=1, dilation=1, activation=activation, batch_norm=batch_norm, ) self.res2net_block = Res2NetBlock( out_channels, out_channels, res2net_scale, kernel_size, dilation, batch_norm=batch_norm, ) self.tdnn2 = TDNNBlock( out_channels, out_channels, kernel_size=1, dilation=1, activation=activation, batch_norm=batch_norm, ) self.se_block = SEBlock(out_channels, se_channels, out_channels) self.shortcut = None if in_channels != out_channels: self.shortcut = Conv1d( in_channels=in_channels, out_channels=out_channels, kernel_size=1, ) def forward(self, x, lengths=None): residual = x if self.shortcut: residual = self.shortcut(x) x = self.tdnn1(x) x = self.res2net_block(x) x = self.tdnn2(x) x = self.se_block(x, lengths) return x + residual class ECAPA_TDNN(torch.nn.Module): """An implementation of the speaker embedding model in a paper. "ECAPA-TDNN: Emphasized Channel Attention, Propagation and Aggregation in TDNN Based Speaker Verification" (https://arxiv.org/abs/2005.07143). Arguments --------- device : str Device used, e.g., "cpu" or "cuda". activation : torch class A class for constructing the activation layers. channels : list of ints Output channels for TDNN/SERes2Net layer. kernel_sizes : list of ints List of kernel sizes for each layer. dilations : list of ints List of dilations for kernels in each layer. lin_neurons : int Number of neurons in linear layers. Example ------- >>> input_feats = torch.rand([5, 120, 80]) >>> compute_embedding = ECAPA_TDNN(80, lin_neurons=192) >>> outputs = compute_embedding(input_feats) >>> outputs.shape torch.Size([5, 1, 192]) """ def __init__( self, input_size, lin_neurons=192, activation=torch.nn.ReLU, channels=[512, 512, 512, 512, 1536], kernel_sizes=[5, 3, 3, 3, 1], dilations=[1, 2, 3, 4, 1], attention_channels=128, res2net_scale=8, se_channels=128, global_context=True, batch_norm=True, ): super().__init__() assert len(channels) == len(kernel_sizes) assert len(channels) == len(dilations) self.channels = channels self.blocks = nn.ModuleList() # The initial TDNN layer self.blocks.append( TDNNBlock( input_size, channels[0], kernel_sizes[0], dilations[0], activation, batch_norm=batch_norm, ) ) # SE-Res2Net layers for i in range(1, len(channels) - 1): self.blocks.append( SERes2NetBlock( channels[i - 1], channels[i], res2net_scale=res2net_scale, se_channels=se_channels, kernel_size=kernel_sizes[i], dilation=dilations[i], activation=activation, batch_norm=batch_norm, ) ) # Multi-layer feature aggregation self.mfa = TDNNBlock( channels[-1], channels[-1], kernel_sizes[-1], dilations[-1], activation, batch_norm=batch_norm, ) # Attentive Statistical Pooling self.asp = AttentiveStatisticsPooling( channels[-1], attention_channels=attention_channels, global_context=global_context, batch_norm=batch_norm, ) self.asp_bn = BatchNorm1d(input_size=channels[-1] * 2, enabled=batch_norm) # Final linear transformation self.fc = Conv1d( in_channels=channels[-1] * 2, out_channels=lin_neurons, kernel_size=1, ) # @torch.cuda.amp.autocast(enabled=True, dtype=torch.float32) def forward(self, x, lengths=None): """Returns the embedding vector. Arguments --------- x : torch.Tensor Tensor of shape (batch, time, channel). """ # Minimize transpose for efficiency x = x.transpose(1, 2) xl = [] for layer in self.blocks: try: x = layer(x, lengths=lengths) except TypeError: x = layer(x) xl.append(x) # Multi-layer feature aggregation x = torch.cat(xl[1:], dim=1) x = self.mfa(x) # Attentive Statistical Pooling x = self.asp(x, lengths=lengths) x = self.asp_bn(x) # Final linear transformation x = self.fc(x) x = x.squeeze(-1) return x class SpeakerEmbedddingExtractor(object): def __init__(self, ckpt_path, device="cuda"): # NOTE: The sampling rate is 16000 self.mel_extractor = TorchMelSpectrogram() self.mel_extractor.to(device) model = ECAPA_TDNN( 80, 512, channels=[512, 512, 512, 512, 1536], kernel_sizes=[5, 3, 3, 3, 1], dilations=[1, 2, 3, 4, 1], attention_channels=128, res2net_scale=4, se_channels=128, global_context=True, batch_norm=True, ) model.load_state_dict(torch.load(ckpt_path), strict=True) model.eval() self.model = model self.model.to(device) def __call__(self, wav): # wav, sr = torchaudio.load(audio_path) # assert sr == 16000, f"The sampling rate is not 16000" # print(wav.shape) mel = self.mel_extractor(wav.unsqueeze(0)) spk = self.model(mel) spk = spk[0] return spk