FanavaranPars commited on
Commit
c7c097f
·
verified ·
1 Parent(s): cdf0255

Upload modeling_pyannote.py

Browse files
Files changed (1) hide show
  1. modeling_pyannote.py +162 -0
modeling_pyannote.py ADDED
@@ -0,0 +1,162 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ from models.pyannote.layers import SincNet
5
+
6
+ from asteroid_filterbanks.enc_dec import Filterbank, Encoder
7
+ from asteroid_filterbanks.param_sinc_fb import ParamSincFB
8
+
9
+
10
+
11
+ class SincNet(nn.Module):
12
+ """Filtering and convolutional part of Pyannote
13
+
14
+ Arguments
15
+ ---------
16
+ n_filters : list, int
17
+ List consist of number of each convolution kernel
18
+ stride_ : in
19
+ Stride of ParamSincFB fliltering.
20
+
21
+ Returns
22
+ -------
23
+ Sincnet model: class
24
+
25
+ """
26
+
27
+ def __init__(self,
28
+ n_filters = [80,60,60],
29
+ stride_ = 10,
30
+ ):
31
+ super(SincNet,self).__init__()
32
+
33
+
34
+ sincnet_list = nn.ModuleList(
35
+ [
36
+ nn.InstanceNorm1d(1, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False),
37
+ Encoder(ParamSincFB(n_filters=n_filters[0], kernel_size=251, stride=stride_)),
38
+ nn.MaxPool1d(kernel_size=3, stride=3, padding=0, dilation=1, ceil_mode=False),
39
+ nn.InstanceNorm1d(n_filters[0], eps=1e-05, momentum=0.1, affine=True, track_running_stats=False),
40
+ ]
41
+ )
42
+ for counter in range(len(n_filters) - 1):
43
+ sincnet_list.append(nn.Conv1d(n_filters[counter], n_filters[counter+1], kernel_size=(5,), stride=(1,)))
44
+ sincnet_list.append(nn.MaxPool1d(kernel_size=3, stride=3, padding=0, dilation=1, ceil_mode=False))
45
+ sincnet_list.append(nn.InstanceNorm1d(n_filters[counter+1], eps=1e-05, momentum=0.1, affine=True, track_running_stats=False))
46
+
47
+ self.sincnet_layer = nn.Sequential(*sincnet_list)
48
+
49
+ def forward(self, x):
50
+ """This method should implement forwarding operation in the SincNet model.
51
+
52
+ Arguments
53
+ ---------
54
+ x : float (Tensor)
55
+ The input of SincNet model.
56
+
57
+ Returns
58
+ -------
59
+ out : float (Tensor)
60
+ The output of SincNet model.
61
+ """
62
+ out = self.sincnet_layer(x)
63
+ return out
64
+
65
+
66
+
67
+ class PyanNet(nn.Module):
68
+ """Pyannote model
69
+
70
+ Arguments
71
+ ---------
72
+ model_config : dict, str
73
+ consist of model parameters
74
+
75
+ Returns
76
+ -------
77
+ Pyannote model: class
78
+
79
+ """
80
+ def __init__(self,
81
+ model_config,
82
+ ):
83
+ super(PyanNet,self).__init__()
84
+
85
+ self.model_config = model_config
86
+
87
+ sincnet_filters = model_config["sincnet_filters"]
88
+ sincnet_stride = model_config["sincnet_stride"]
89
+ linear_blocks = model_config["linear_blocks"]
90
+
91
+ self.sincnet = SincNet(n_filters=sincnet_filters, stride_ = sincnet_stride)
92
+
93
+ if model_config["sequence_type"] == "lstm":
94
+ self.sequence_blocks = nn.LSTM(sincnet_filters[-1],
95
+ model_config["sequence_neuron"],
96
+ num_layers=model_config["sequence_nlayers"],
97
+ batch_first=True,
98
+ dropout=model_config["sequence_drop_out"],
99
+ bidirectional=model_config["sequence_bidirectional"],
100
+ )
101
+ elif model_config["sequence_type"] == "gru":
102
+ self.sequence_blocks = nn.GRU(sincnet_filters[-1],
103
+ model_config["sequence_neuron"],
104
+ num_layers=model_config["sequence_nlayers"],
105
+ batch_first=True,
106
+ dropout=model_config["sequence_drop_out"],
107
+ bidirectional=model_config["sequence_bidirectional"],
108
+ )
109
+ elif model_config["sequence_type"] == "attention":
110
+ self.sequence_blocks = nn.TransformerEncoderLayer(d_model=sincnet_filters[-1],
111
+ dim_feedforward=model_config["sequence_neuron"],
112
+ nhead=model_config["sequence_nlayers"],
113
+ batch_first=True,
114
+ dropout=model_config["sequence_drop_out"])
115
+ else:
116
+ raise ValueError("Model type is not valid!!!")
117
+
118
+
119
+ if model_config["sequence_bidirectional"]:
120
+ last_sequence_block = model_config["sequence_neuron"] * 2
121
+ else:
122
+ last_sequence_block = model_config["sequence_neuron"]
123
+
124
+
125
+ linear_blocks = [last_sequence_block] + linear_blocks
126
+ linears_list = nn.ModuleList()
127
+ for counter in range(len(linear_blocks) - 1):
128
+ linears_list.append(
129
+ nn.Linear(
130
+ in_features=linear_blocks[counter],
131
+ out_features=linear_blocks[counter+1],
132
+ bias=True,
133
+ )
134
+ )
135
+ linears_list.append(nn.Sigmoid())
136
+ self.linears = nn.Sequential(*linears_list)
137
+
138
+
139
+ def forward(self, x):
140
+ """This method should implement forwarding operation in the Pyannote model.
141
+
142
+ Arguments
143
+ ---------
144
+ x : float (Tensor)
145
+ The input of Pyannote model.
146
+
147
+ Returns
148
+ -------
149
+ out : float (Tensor)
150
+ The output of Pyannote model.
151
+ """
152
+ x = torch.unsqueeze(x, 1)
153
+ x = self.sincnet(x)
154
+ x = x.permute(0,2,1)
155
+
156
+ if self.model_config["sequence_type"] == "attention":
157
+ x = self.sequence_blocks(x)
158
+ else:
159
+ x = self.sequence_blocks(x)[0]
160
+
161
+ out = self.linears(x)
162
+ return out