AlexK-PL commited on
Commit
57c2272
verified
1 Parent(s): 6ac5bee

Create network.py

Browse files
Files changed (1) hide show
  1. network.py +389 -0
network.py ADDED
@@ -0,0 +1,389 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2022 NVIDIA CORPORATION.
2
+ # Licensed under the MIT license.
3
+
4
+ import numpy as np
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+
10
+ from util import weight_scaling_init
11
+
12
+ torch.manual_seed(0)
13
+ np.random.seed(0)
14
+
15
+
16
+ # Transformer (encoder) https://github.com/jadore801120/attention-is-all-you-need-pytorch
17
+ # Original Copyright 2017 Victor Huang
18
+ # MIT License (https://opensource.org/licenses/MIT)
19
+
20
+ class ScaledDotProductAttention(nn.Module):
21
+ ''' Scaled Dot-Product Attention '''
22
+
23
+ def __init__(self, temperature, attn_dropout=0.1):
24
+ super().__init__()
25
+ self.temperature = temperature
26
+ self.dropout = nn.Dropout(attn_dropout)
27
+
28
+ def forward(self, q, k, v, mask=None):
29
+
30
+ attn = torch.matmul(q / self.temperature, k.transpose(2, 3))
31
+
32
+ if mask is not None:
33
+ _MASKING_VALUE = -1e9 if attn.dtype == torch.float32 else -1e4
34
+ attn = attn.masked_fill(mask == 0, _MASKING_VALUE)
35
+
36
+ attn = self.dropout(F.softmax(attn, dim=-1))
37
+ output = torch.matmul(attn, v)
38
+
39
+ return output, attn
40
+
41
+
42
+ class MultiHeadAttention(nn.Module):
43
+ ''' Multi-Head Attention module '''
44
+
45
+ def __init__(self, n_head, d_model, d_k, d_v, dropout=0.1):
46
+ super().__init__()
47
+
48
+ self.n_head = n_head
49
+ self.d_k = d_k
50
+ self.d_v = d_v
51
+
52
+ self.w_qs = nn.Linear(d_model, n_head * d_k, bias=False)
53
+ self.w_ks = nn.Linear(d_model, n_head * d_k, bias=False)
54
+ self.w_vs = nn.Linear(d_model, n_head * d_v, bias=False)
55
+ self.fc = nn.Linear(n_head * d_v, d_model, bias=False)
56
+
57
+ self.attention = ScaledDotProductAttention(temperature=d_k ** 0.5)
58
+
59
+ self.dropout = nn.Dropout(dropout)
60
+ self.layer_norm = nn.LayerNorm(d_model, eps=1e-6)
61
+
62
+
63
+ def forward(self, q, k, v, mask=None):
64
+
65
+ d_k, d_v, n_head = self.d_k, self.d_v, self.n_head
66
+ sz_b, len_q, len_k, len_v = q.size(0), q.size(1), k.size(1), v.size(1)
67
+
68
+ residual = q
69
+
70
+ # Pass through the pre-attention projection: b x lq x (n*dv)
71
+ # Separate different heads: b x lq x n x dv
72
+ q = self.w_qs(q).view(sz_b, len_q, n_head, d_k)
73
+ k = self.w_ks(k).view(sz_b, len_k, n_head, d_k)
74
+ v = self.w_vs(v).view(sz_b, len_v, n_head, d_v)
75
+
76
+ # Transpose for attention dot product: b x n x lq x dv
77
+ q, k, v = q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2)
78
+
79
+ if mask is not None:
80
+ mask = mask.unsqueeze(1) # For head axis broadcasting.
81
+
82
+ q, attn = self.attention(q, k, v, mask=mask)
83
+
84
+ # Transpose to move the head dimension back: b x lq x n x dv
85
+ # Combine the last two dimensions to concatenate all the heads together: b x lq x (n*dv)
86
+ q = q.transpose(1, 2).contiguous().view(sz_b, len_q, -1)
87
+ q = self.dropout(self.fc(q))
88
+ q += residual
89
+
90
+ q = self.layer_norm(q)
91
+
92
+ return q, attn
93
+
94
+
95
+ class PositionwiseFeedForward(nn.Module):
96
+ ''' A two-feed-forward-layer module '''
97
+
98
+ def __init__(self, d_in, d_hid, dropout=0.1):
99
+ super().__init__()
100
+ self.w_1 = nn.Linear(d_in, d_hid) # position-wise
101
+ self.w_2 = nn.Linear(d_hid, d_in) # position-wise
102
+ self.layer_norm = nn.LayerNorm(d_in, eps=1e-6)
103
+ self.dropout = nn.Dropout(dropout)
104
+
105
+ def forward(self, x):
106
+
107
+ residual = x
108
+
109
+ x = self.w_2(F.relu(self.w_1(x)))
110
+ x = self.dropout(x)
111
+ x += residual
112
+
113
+ x = self.layer_norm(x)
114
+
115
+ return x
116
+
117
+
118
+ def get_subsequent_mask(seq):
119
+ ''' For masking out the subsequent info. '''
120
+ sz_b, len_s = seq.size()
121
+ subsequent_mask = (1 - torch.triu(
122
+ torch.ones((1, len_s, len_s), device=seq.device), diagonal=1)).bool()
123
+ return subsequent_mask
124
+
125
+
126
+ class PositionalEncoding(nn.Module):
127
+
128
+ def __init__(self, d_hid, n_position=200):
129
+ super(PositionalEncoding, self).__init__()
130
+
131
+ # Not a parameter
132
+ self.register_buffer('pos_table', self._get_sinusoid_encoding_table(n_position, d_hid))
133
+
134
+ def _get_sinusoid_encoding_table(self, n_position, d_hid):
135
+ ''' Sinusoid position encoding table '''
136
+ # TODO: make it with torch instead of numpy
137
+
138
+ def get_position_angle_vec(position):
139
+ return [position / np.power(10000, 2 * (hid_j // 2) / d_hid) for hid_j in range(d_hid)]
140
+
141
+ sinusoid_table = np.array([get_position_angle_vec(pos_i) for pos_i in range(n_position)])
142
+ sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i
143
+ sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1
144
+
145
+ return torch.FloatTensor(sinusoid_table).unsqueeze(0)
146
+
147
+ def forward(self, x):
148
+ return x + self.pos_table[:, :x.size(1)].clone().detach()
149
+
150
+
151
+ class EncoderLayer(nn.Module):
152
+ ''' Compose with two layers '''
153
+
154
+ def __init__(self, d_model, d_inner, n_head, d_k, d_v, dropout=0.0):
155
+ super(EncoderLayer, self).__init__()
156
+ self.slf_attn = MultiHeadAttention(n_head, d_model, d_k, d_v, dropout=dropout)
157
+ self.pos_ffn = PositionwiseFeedForward(d_model, d_inner, dropout=dropout)
158
+
159
+ def forward(self, enc_input, slf_attn_mask=None):
160
+ enc_output, enc_slf_attn = self.slf_attn(
161
+ enc_input, enc_input, enc_input, mask=slf_attn_mask)
162
+ enc_output = self.pos_ffn(enc_output)
163
+ return enc_output, enc_slf_attn
164
+
165
+
166
+ class TransformerEncoder(nn.Module):
167
+ ''' A encoder model with self attention mechanism. '''
168
+
169
+ def __init__(
170
+ self, d_word_vec=512, n_layers=2, n_head=8, d_k=64, d_v=64,
171
+ d_model=512, d_inner=2048, dropout=0.1, n_position=624, scale_emb=False):
172
+
173
+ super().__init__()
174
+
175
+ # self.src_word_emb = nn.Embedding(n_src_vocab, d_word_vec, padding_idx=pad_idx)
176
+ if n_position > 0:
177
+ self.position_enc = PositionalEncoding(d_word_vec, n_position=n_position)
178
+ else:
179
+ self.position_enc = lambda x: x
180
+ self.dropout = nn.Dropout(p=dropout)
181
+ self.layer_stack = nn.ModuleList([
182
+ EncoderLayer(d_model, d_inner, n_head, d_k, d_v, dropout=dropout)
183
+ for _ in range(n_layers)])
184
+ self.layer_norm = nn.LayerNorm(d_model, eps=1e-6)
185
+ self.scale_emb = scale_emb
186
+ self.d_model = d_model
187
+
188
+ def forward(self, src_seq, src_mask, return_attns=False):
189
+
190
+ enc_slf_attn_list = []
191
+
192
+ # -- Forward
193
+ # enc_output = self.src_word_emb(src_seq)
194
+ enc_output = src_seq
195
+ if self.scale_emb:
196
+ enc_output *= self.d_model ** 0.5
197
+ enc_output = self.dropout(self.position_enc(enc_output))
198
+ enc_output = self.layer_norm(enc_output)
199
+
200
+ for enc_layer in self.layer_stack:
201
+ enc_output, enc_slf_attn = enc_layer(enc_output, slf_attn_mask=src_mask)
202
+ enc_slf_attn_list += [enc_slf_attn] if return_attns else []
203
+
204
+ if return_attns:
205
+ return enc_output, enc_slf_attn_list
206
+ return enc_output
207
+
208
+
209
+ # CleanUNet architecture
210
+
211
+
212
+ def padding(x, D, K, S):
213
+ """padding zeroes to x so that denoised audio has the same length"""
214
+
215
+ L = x.shape[-1]
216
+ for _ in range(D):
217
+ if L < K:
218
+ L = 1
219
+ else:
220
+ L = 1 + np.ceil((L - K) / S)
221
+
222
+ for _ in range(D):
223
+ L = (L - 1) * S + K
224
+
225
+ L = int(L)
226
+ x = F.pad(x, (0, L - x.shape[-1]))
227
+ return x
228
+
229
+
230
+ class CleanUNet(nn.Module):
231
+ """ CleanUNet architecture. """
232
+
233
+ def __init__(self, channels_input=1, channels_output=1,
234
+ channels_H=64, max_H=768,
235
+ encoder_n_layers=8, kernel_size=4, stride=2,
236
+ tsfm_n_layers=3,
237
+ tsfm_n_head=8,
238
+ tsfm_d_model=512,
239
+ tsfm_d_inner=2048):
240
+
241
+ """
242
+ Parameters:
243
+ channels_input (int): input channels
244
+ channels_output (int): output channels
245
+ channels_H (int): middle channels H that controls capacity
246
+ max_H (int): maximum H
247
+ encoder_n_layers (int): number of encoder/decoder layers D
248
+ kernel_size (int): kernel size K
249
+ stride (int): stride S
250
+ tsfm_n_layers (int): number of self attention blocks N
251
+ tsfm_n_head (int): number of heads in each self attention block
252
+ tsfm_d_model (int): d_model of self attention
253
+ tsfm_d_inner (int): d_inner of self attention
254
+ """
255
+
256
+ super(CleanUNet, self).__init__()
257
+
258
+ self.channels_input = channels_input
259
+ self.channels_output = channels_output
260
+ self.channels_H = channels_H
261
+ self.max_H = max_H
262
+ self.encoder_n_layers = encoder_n_layers
263
+ self.kernel_size = kernel_size
264
+ self.stride = stride
265
+
266
+ self.tsfm_n_layers = tsfm_n_layers
267
+ self.tsfm_n_head = tsfm_n_head
268
+ self.tsfm_d_model = tsfm_d_model
269
+ self.tsfm_d_inner = tsfm_d_inner
270
+
271
+ # encoder and decoder
272
+ self.encoder = nn.ModuleList()
273
+ self.decoder = nn.ModuleList()
274
+
275
+ for i in range(encoder_n_layers):
276
+ self.encoder.append(nn.Sequential(
277
+ nn.Conv1d(channels_input, channels_H, kernel_size, stride),
278
+ nn.ReLU(),
279
+ nn.Conv1d(channels_H, channels_H * 2, 1),
280
+ nn.GLU(dim=1)
281
+ ))
282
+ channels_input = channels_H
283
+
284
+ if i == 0:
285
+ # no relu at end
286
+ self.decoder.append(nn.Sequential(
287
+ nn.Conv1d(channels_H, channels_H * 2, 1),
288
+ nn.GLU(dim=1),
289
+ nn.ConvTranspose1d(channels_H, channels_output, kernel_size, stride)
290
+ ))
291
+ else:
292
+ self.decoder.insert(0, nn.Sequential(
293
+ nn.Conv1d(channels_H, channels_H * 2, 1),
294
+ nn.GLU(dim=1),
295
+ nn.ConvTranspose1d(channels_H, channels_output, kernel_size, stride),
296
+ nn.ReLU()
297
+ ))
298
+ channels_output = channels_H
299
+
300
+ # double H but keep below max_H
301
+ channels_H *= 2
302
+ channels_H = min(channels_H, max_H)
303
+
304
+ # self attention block
305
+ self.tsfm_conv1 = nn.Conv1d(channels_output, tsfm_d_model, kernel_size=1)
306
+ self.tsfm_encoder = TransformerEncoder(d_word_vec=tsfm_d_model,
307
+ n_layers=tsfm_n_layers,
308
+ n_head=tsfm_n_head,
309
+ d_k=tsfm_d_model // tsfm_n_head,
310
+ d_v=tsfm_d_model // tsfm_n_head,
311
+ d_model=tsfm_d_model,
312
+ d_inner=tsfm_d_inner,
313
+ dropout=0.0,
314
+ n_position=0,
315
+ scale_emb=False)
316
+ self.tsfm_conv2 = nn.Conv1d(tsfm_d_model, channels_output, kernel_size=1)
317
+
318
+ # weight scaling initialization
319
+ for layer in self.modules():
320
+ if isinstance(layer, (nn.Conv1d, nn.ConvTranspose1d)):
321
+ weight_scaling_init(layer)
322
+
323
+ def forward(self, noisy_audio):
324
+ # (B, L) -> (B, C, L)
325
+ if len(noisy_audio.shape) == 2:
326
+ noisy_audio = noisy_audio.unsqueeze(1)
327
+ B, C, L = noisy_audio.shape
328
+ assert C == 1
329
+
330
+ # normalization and padding
331
+ std = noisy_audio.std(dim=2, keepdim=True) + 1e-3
332
+ noisy_audio /= std
333
+ x = padding(noisy_audio, self.encoder_n_layers, self.kernel_size, self.stride)
334
+
335
+ # encoder
336
+ skip_connections = []
337
+ for downsampling_block in self.encoder:
338
+ x = downsampling_block(x)
339
+ skip_connections.append(x)
340
+ skip_connections = skip_connections[::-1]
341
+
342
+ # attention mask for causal inference; for non-causal, set attn_mask to None
343
+ len_s = x.shape[-1] # length at bottleneck
344
+ attn_mask = (1 - torch.triu(torch.ones((1, len_s, len_s), device=x.device), diagonal=1)).bool()
345
+
346
+ x = self.tsfm_conv1(x) # C 1024 -> 512
347
+ x = x.permute(0, 2, 1)
348
+ x = self.tsfm_encoder(x, src_mask=attn_mask)
349
+ x = x.permute(0, 2, 1)
350
+ x = self.tsfm_conv2(x) # C 512 -> 1024
351
+
352
+ # decoder
353
+ for i, upsampling_block in enumerate(self.decoder):
354
+ skip_i = skip_connections[i]
355
+ x += skip_i[:, :, :x.shape[-1]]
356
+ x = upsampling_block(x)
357
+
358
+ x = x[:, :, :L] * std
359
+ return x
360
+
361
+
362
+ if __name__ == '__main__':
363
+ import json
364
+ import argparse
365
+ import os
366
+
367
+ parser = argparse.ArgumentParser()
368
+ parser.add_argument('-c', '--config', type=str, default='configs/DNS-large-full.json',
369
+ help='JSON file for configuration')
370
+ args = parser.parse_args()
371
+
372
+ with open(args.config) as f:
373
+ data = f.read()
374
+ config = json.loads(data)
375
+ network_config = config["network_config"]
376
+
377
+ model = CleanUNet(**network_config).cuda()
378
+ from util import print_size
379
+ print_size(model, keyword="tsfm")
380
+
381
+ input_data = torch.ones([4,1,int(4.5*16000)]).cuda()
382
+ output = model(input_data)
383
+ print(output.shape)
384
+
385
+ y = torch.rand([4,1,int(4.5*16000)]).cuda()
386
+ loss = torch.nn.MSELoss()(y, output)
387
+ loss.backward()
388
+ print(loss.item())
389
+