AlexK-PL commited on
Commit
ee2c994
verified
1 Parent(s): d188a55

Create util.py

Browse files
Files changed (1) hide show
  1. util.py +225 -0
util.py ADDED
@@ -0,0 +1,225 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import time
3
+ import functools
4
+ import numpy as np
5
+ from math import cos, pi, floor, sin
6
+ from tqdm import tqdm
7
+
8
+ import torch
9
+ import torch.nn as nn
10
+ import torch.nn.functional as F
11
+
12
+ from stft_loss import MultiResolutionSTFTLoss
13
+
14
+ torch.manual_seed(0)
15
+ np.random.seed(0)
16
+
17
+
18
+ def flatten(v):
19
+ return [x for y in v for x in y]
20
+
21
+
22
+ def rescale(x):
23
+ return (x - x.min()) / (x.max() - x.min())
24
+
25
+
26
+ def find_max_epoch(path):
27
+ """
28
+ Find latest checkpoint
29
+
30
+ Returns:
31
+ maximum iteration, -1 if there is no (valid) checkpoint
32
+ """
33
+
34
+ files = os.listdir(path)
35
+ epoch = -1
36
+ for f in files:
37
+ if len(f) <= 4:
38
+ continue
39
+ if f[-4:] == '.pkl':
40
+ number = f[:-4]
41
+ try:
42
+ epoch = max(epoch, int(number))
43
+ except:
44
+ continue
45
+ return epoch
46
+
47
+
48
+ def print_size(net, keyword=None):
49
+ """
50
+ Print the number of parameters of a network
51
+ """
52
+
53
+ if net is not None and isinstance(net, torch.nn.Module):
54
+ module_parameters = filter(lambda p: p.requires_grad, net.parameters())
55
+ params = sum([np.prod(p.size()) for p in module_parameters])
56
+
57
+ print("{} Parameters: {:.6f}M".format(
58
+ net.__class__.__name__, params / 1e6), flush=True, end="; ")
59
+
60
+ if keyword is not None:
61
+ keyword_parameters = [p for name, p in net.named_parameters() if p.requires_grad and keyword in name]
62
+ params = sum([np.prod(p.size()) for p in keyword_parameters])
63
+ print("{} Parameters: {:.6f}M".format(
64
+ keyword, params / 1e6), flush=True, end="; ")
65
+
66
+ print(" ")
67
+
68
+
69
+ ####################### lr scheduler: Linear Warmup then Cosine Decay #############################
70
+
71
+ # Adapted from https://github.com/rosinality/vq-vae-2-pytorch
72
+
73
+ # Original Copyright 2019 Kim Seonghyeon
74
+ # MIT License (https://opensource.org/licenses/MIT)
75
+
76
+
77
+ def anneal_linear(start, end, proportion):
78
+ return start + proportion * (end - start)
79
+
80
+
81
+ def anneal_cosine(start, end, proportion):
82
+ cos_val = cos(pi * proportion) + 1
83
+ return end + (start - end) / 2 * cos_val
84
+
85
+
86
+ class Phase:
87
+ def __init__(self, start, end, n_iter, cur_iter, anneal_fn):
88
+ self.start, self.end = start, end
89
+ self.n_iter = n_iter
90
+ self.anneal_fn = anneal_fn
91
+ self.n = cur_iter
92
+
93
+ def step(self):
94
+ self.n += 1
95
+
96
+ return self.anneal_fn(self.start, self.end, self.n / self.n_iter)
97
+
98
+ def reset(self):
99
+ self.n = 0
100
+
101
+ @property
102
+ def is_done(self):
103
+ return self.n >= self.n_iter
104
+
105
+
106
+ class LinearWarmupCosineDecay:
107
+ def __init__(
108
+ self,
109
+ optimizer,
110
+ lr_max,
111
+ n_iter,
112
+ iteration=0,
113
+ divider=25,
114
+ warmup_proportion=0.3,
115
+ phase=('linear', 'cosine'),
116
+ ):
117
+ self.optimizer = optimizer
118
+
119
+ phase1 = int(n_iter * warmup_proportion)
120
+ phase2 = n_iter - phase1
121
+ lr_min = lr_max / divider
122
+
123
+ phase_map = {'linear': anneal_linear, 'cosine': anneal_cosine}
124
+
125
+ cur_iter_phase1 = iteration
126
+ cur_iter_phase2 = max(0, iteration - phase1)
127
+ self.lr_phase = [
128
+ Phase(lr_min, lr_max, phase1, cur_iter_phase1, phase_map[phase[0]]),
129
+ Phase(lr_max, lr_min / 1e4, phase2, cur_iter_phase2, phase_map[phase[1]]),
130
+ ]
131
+
132
+ if iteration < phase1:
133
+ self.phase = 0
134
+ else:
135
+ self.phase = 1
136
+
137
+ def step(self):
138
+ lr = self.lr_phase[self.phase].step()
139
+
140
+ for group in self.optimizer.param_groups:
141
+ group['lr'] = lr
142
+
143
+ if self.lr_phase[self.phase].is_done:
144
+ self.phase += 1
145
+
146
+ if self.phase >= len(self.lr_phase):
147
+ for phase in self.lr_phase:
148
+ phase.reset()
149
+
150
+ self.phase = 0
151
+
152
+ return lr
153
+
154
+
155
+ ####################### model util #############################
156
+
157
+ def std_normal(size):
158
+ """
159
+ Generate the standard Gaussian variable of a certain size
160
+ """
161
+
162
+ return torch.normal(0, 1, size=size).cuda()
163
+
164
+
165
+ def weight_scaling_init(layer):
166
+ """
167
+ weight rescaling initialization from https://arxiv.org/abs/1911.13254
168
+ """
169
+ w = layer.weight.detach()
170
+ alpha = 10.0 * w.std()
171
+ layer.weight.data /= torch.sqrt(alpha)
172
+ layer.bias.data /= torch.sqrt(alpha)
173
+
174
+
175
+ @torch.no_grad()
176
+ def sampling(net, noisy_audio):
177
+ """
178
+ Perform denoising (forward) step
179
+ """
180
+
181
+ return net(noisy_audio)
182
+
183
+
184
+ def loss_fn(net, X, ell_p, ell_p_lambda, stft_lambda, mrstftloss, **kwargs):
185
+ """
186
+ Loss function in CleanUNet
187
+ Parameters:
188
+ net: network
189
+ X: training data pair (clean audio, noisy_audio)
190
+ ell_p: \ell_p norm (1 or 2) of the AE loss
191
+ ell_p_lambda: factor of the AE loss
192
+ stft_lambda: factor of the STFT loss
193
+ mrstftloss: multi-resolution STFT loss function
194
+ Returns:
195
+ loss: value of objective function
196
+ output_dic: values of each component of loss
197
+ """
198
+
199
+ assert type(X) == tuple and len(X) == 2
200
+
201
+ clean_audio, noisy_audio = X
202
+ B, C, L = clean_audio.shape
203
+ output_dic = {}
204
+ loss = 0.0
205
+
206
+ # AE loss
207
+ denoised_audio = net(noisy_audio)
208
+
209
+ if ell_p == 2:
210
+ ae_loss = nn.MSELoss()(denoised_audio, clean_audio)
211
+ elif ell_p == 1:
212
+ ae_loss = F.l1_loss(denoised_audio, clean_audio)
213
+ else:
214
+ raise NotImplementedError
215
+ loss += ae_loss * ell_p_lambda
216
+ output_dic["reconstruct"] = ae_loss.data * ell_p_lambda
217
+
218
+ if stft_lambda > 0:
219
+ sc_loss, mag_loss = mrstftloss(denoised_audio.squeeze(1), clean_audio.squeeze(1))
220
+ loss += (sc_loss + mag_loss) * stft_lambda
221
+ output_dic["stft_sc"] = sc_loss.data * stft_lambda
222
+ output_dic["stft_mag"] = mag_loss.data * stft_lambda
223
+
224
+ return loss, output_dic
225
+