import torch import torch.nn as nn from models.pyannote.layers import SincNet from asteroid_filterbanks.enc_dec import Filterbank, Encoder from asteroid_filterbanks.param_sinc_fb import ParamSincFB class SincNet(nn.Module): """Filtering and convolutional part of Pyannote Arguments --------- n_filters : list, int List consist of number of each convolution kernel stride_ : in Stride of ParamSincFB fliltering. Returns ------- Sincnet model: class """ def __init__(self, n_filters = [80,60,60], stride_ = 10, ): super(SincNet,self).__init__() sincnet_list = nn.ModuleList( [ nn.InstanceNorm1d(1, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False), Encoder(ParamSincFB(n_filters=n_filters[0], kernel_size=251, stride=stride_)), nn.MaxPool1d(kernel_size=3, stride=3, padding=0, dilation=1, ceil_mode=False), nn.InstanceNorm1d(n_filters[0], eps=1e-05, momentum=0.1, affine=True, track_running_stats=False), ] ) for counter in range(len(n_filters) - 1): sincnet_list.append(nn.Conv1d(n_filters[counter], n_filters[counter+1], kernel_size=(5,), stride=(1,))) sincnet_list.append(nn.MaxPool1d(kernel_size=3, stride=3, padding=0, dilation=1, ceil_mode=False)) sincnet_list.append(nn.InstanceNorm1d(n_filters[counter+1], eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)) self.sincnet_layer = nn.Sequential(*sincnet_list) def forward(self, x): """This method should implement forwarding operation in the SincNet model. Arguments --------- x : float (Tensor) The input of SincNet model. Returns ------- out : float (Tensor) The output of SincNet model. """ out = self.sincnet_layer(x) return out class PyanNet(nn.Module): """Pyannote model Arguments --------- model_config : dict, str consist of model parameters Returns ------- Pyannote model: class """ def __init__(self, model_config, ): super(PyanNet,self).__init__() self.model_config = model_config sincnet_filters = model_config["sincnet_filters"] sincnet_stride = model_config["sincnet_stride"] linear_blocks = model_config["linear_blocks"] self.sincnet = SincNet(n_filters=sincnet_filters, stride_ = sincnet_stride) if model_config["sequence_type"] == "lstm": self.sequence_blocks = nn.LSTM(sincnet_filters[-1], model_config["sequence_neuron"], num_layers=model_config["sequence_nlayers"], batch_first=True, dropout=model_config["sequence_drop_out"], bidirectional=model_config["sequence_bidirectional"], ) elif model_config["sequence_type"] == "gru": self.sequence_blocks = nn.GRU(sincnet_filters[-1], model_config["sequence_neuron"], num_layers=model_config["sequence_nlayers"], batch_first=True, dropout=model_config["sequence_drop_out"], bidirectional=model_config["sequence_bidirectional"], ) elif model_config["sequence_type"] == "attention": self.sequence_blocks = nn.TransformerEncoderLayer(d_model=sincnet_filters[-1], dim_feedforward=model_config["sequence_neuron"], nhead=model_config["sequence_nlayers"], batch_first=True, dropout=model_config["sequence_drop_out"]) else: raise ValueError("Model type is not valid!!!") if model_config["sequence_bidirectional"]: last_sequence_block = model_config["sequence_neuron"] * 2 else: last_sequence_block = model_config["sequence_neuron"] linear_blocks = [last_sequence_block] + linear_blocks linears_list = nn.ModuleList() for counter in range(len(linear_blocks) - 1): linears_list.append( nn.Linear( in_features=linear_blocks[counter], out_features=linear_blocks[counter+1], bias=True, ) ) linears_list.append(nn.Sigmoid()) self.linears = nn.Sequential(*linears_list) def forward(self, x): """This method should implement forwarding operation in the Pyannote model. Arguments --------- x : float (Tensor) The input of Pyannote model. Returns ------- out : float (Tensor) The output of Pyannote model. """ x = torch.unsqueeze(x, 1) x = self.sincnet(x) x = x.permute(0,2,1) if self.model_config["sequence_type"] == "attention": x = self.sequence_blocks(x) else: x = self.sequence_blocks(x)[0] out = self.linears(x) return out