sparanoid commited on
Commit
55836f8
1 Parent(s): c8318dc

feat: update deps

Browse files
Files changed (4) hide show
  1. requirements.txt +16 -0
  2. slicer.py +163 -0
  3. transforms.py +191 -0
  4. utils.py +263 -0
requirements.txt ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Cython==0.29.21
2
+ librosa==0.8.0
3
+ matplotlib==3.3.1
4
+ numpy==1.18.5
5
+ phonemizer==2.2.1
6
+ scipy==1.5.2
7
+ torch
8
+ torchvision
9
+ Unidecode==1.1.1
10
+ torchaudio
11
+ pyworld
12
+ scipy
13
+ keras
14
+ mir-eval
15
+ pretty-midi
16
+ pydub
slicer.py ADDED
@@ -0,0 +1,163 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os.path
2
+ import time
3
+ from argparse import ArgumentParser
4
+
5
+ import librosa
6
+ import numpy as np
7
+ import soundfile
8
+ from scipy.ndimage import maximum_filter1d, uniform_filter1d
9
+
10
+
11
+ def timeit(func):
12
+ def run(*args, **kwargs):
13
+ t = time.time()
14
+ res = func(*args, **kwargs)
15
+ print('executing \'%s\' costed %.3fs' % (func.__name__, time.time() - t))
16
+ return res
17
+
18
+ return run
19
+
20
+
21
+ # @timeit
22
+ def _window_maximum(arr, win_sz):
23
+ return maximum_filter1d(arr, size=win_sz)[win_sz // 2: win_sz // 2 + arr.shape[0] - win_sz + 1]
24
+
25
+
26
+ # @timeit
27
+ def _window_rms(arr, win_sz):
28
+ filtered = np.sqrt(uniform_filter1d(np.power(arr, 2), win_sz) - np.power(uniform_filter1d(arr, win_sz), 2))
29
+ return filtered[win_sz // 2: win_sz // 2 + arr.shape[0] - win_sz + 1]
30
+
31
+
32
+ def level2db(levels, eps=1e-12):
33
+ return 20 * np.log10(np.clip(levels, a_min=eps, a_max=1))
34
+
35
+
36
+ def _apply_slice(audio, begin, end):
37
+ if len(audio.shape) > 1:
38
+ return audio[:, begin: end]
39
+ else:
40
+ return audio[begin: end]
41
+
42
+
43
+ class Slicer:
44
+ def __init__(self,
45
+ sr: int,
46
+ db_threshold: float = -40,
47
+ min_length: int = 5000,
48
+ win_l: int = 300,
49
+ win_s: int = 20,
50
+ max_silence_kept: int = 500):
51
+ self.db_threshold = db_threshold
52
+ self.min_samples = round(sr * min_length / 1000)
53
+ self.win_ln = round(sr * win_l / 1000)
54
+ self.win_sn = round(sr * win_s / 1000)
55
+ self.max_silence = round(sr * max_silence_kept / 1000)
56
+ if not self.min_samples >= self.win_ln >= self.win_sn:
57
+ raise ValueError('The following condition must be satisfied: min_length >= win_l >= win_s')
58
+ if not self.max_silence >= self.win_sn:
59
+ raise ValueError('The following condition must be satisfied: max_silence_kept >= win_s')
60
+
61
+ @timeit
62
+ def slice(self, audio):
63
+ if len(audio.shape) > 1:
64
+ samples = librosa.to_mono(audio)
65
+ else:
66
+ samples = audio
67
+ if samples.shape[0] <= self.min_samples:
68
+ return [audio]
69
+ # get absolute amplitudes
70
+ abs_amp = np.abs(samples - np.mean(samples))
71
+ # calculate local maximum with large window
72
+ win_max_db = level2db(_window_maximum(abs_amp, win_sz=self.win_ln))
73
+ sil_tags = []
74
+ left = right = 0
75
+ while right < win_max_db.shape[0]:
76
+ if win_max_db[right] < self.db_threshold:
77
+ right += 1
78
+ elif left == right:
79
+ left += 1
80
+ right += 1
81
+ else:
82
+ if left == 0:
83
+ split_loc_l = left
84
+ else:
85
+ sil_left_n = min(self.max_silence, (right + self.win_ln - left) // 2)
86
+ rms_db_left = level2db(_window_rms(samples[left: left + sil_left_n], win_sz=self.win_sn))
87
+ split_win_l = left + np.argmin(rms_db_left)
88
+ split_loc_l = split_win_l + np.argmin(abs_amp[split_win_l: split_win_l + self.win_sn])
89
+ if len(sil_tags) != 0 and split_loc_l - sil_tags[-1][1] < self.min_samples and right < win_max_db.shape[
90
+ 0] - 1:
91
+ right += 1
92
+ left = right
93
+ continue
94
+ if right == win_max_db.shape[0] - 1:
95
+ split_loc_r = right + self.win_ln
96
+ else:
97
+ sil_right_n = min(self.max_silence, (right + self.win_ln - left) // 2)
98
+ rms_db_right = level2db(_window_rms(samples[right + self.win_ln - sil_right_n: right + self.win_ln],
99
+ win_sz=self.win_sn))
100
+ split_win_r = right + self.win_ln - sil_right_n + np.argmin(rms_db_right)
101
+ split_loc_r = split_win_r + np.argmin(abs_amp[split_win_r: split_win_r + self.win_sn])
102
+ sil_tags.append((split_loc_l, split_loc_r))
103
+ right += 1
104
+ left = right
105
+ if left != right:
106
+ sil_left_n = min(self.max_silence, (right + self.win_ln - left) // 2)
107
+ rms_db_left = level2db(_window_rms(samples[left: left + sil_left_n], win_sz=self.win_sn))
108
+ split_win_l = left + np.argmin(rms_db_left)
109
+ split_loc_l = split_win_l + np.argmin(abs_amp[split_win_l: split_win_l + self.win_sn])
110
+ sil_tags.append((split_loc_l, samples.shape[0]))
111
+ if len(sil_tags) == 0:
112
+ return [audio]
113
+ else:
114
+ chunks = []
115
+ for i in range(0, len(sil_tags)):
116
+ chunks.append(int((sil_tags[i][0] + sil_tags[i][1]) / 2))
117
+ return chunks
118
+
119
+
120
+ def main():
121
+ parser = ArgumentParser()
122
+ parser.add_argument('audio', type=str, help='The audio to be sliced')
123
+ parser.add_argument('--out_name', type=str, help='Output directory of the sliced audio clips')
124
+ parser.add_argument('--out', type=str, help='Output directory of the sliced audio clips')
125
+ parser.add_argument('--db_thresh', type=float, required=False, default=-40,
126
+ help='The dB threshold for silence detection')
127
+ parser.add_argument('--min_len', type=int, required=False, default=5000,
128
+ help='The minimum milliseconds required for each sliced audio clip')
129
+ parser.add_argument('--win_l', type=int, required=False, default=300,
130
+ help='Size of the large sliding window, presented in milliseconds')
131
+ parser.add_argument('--win_s', type=int, required=False, default=20,
132
+ help='Size of the small sliding window, presented in milliseconds')
133
+ parser.add_argument('--max_sil_kept', type=int, required=False, default=500,
134
+ help='The maximum silence length kept around the sliced audio, presented in milliseconds')
135
+ args = parser.parse_args()
136
+ out = args.out
137
+ if out is None:
138
+ out = os.path.dirname(os.path.abspath(args.audio))
139
+ audio, sr = librosa.load(args.audio, sr=None)
140
+ slicer = Slicer(
141
+ sr=sr,
142
+ db_threshold=args.db_thresh,
143
+ min_length=args.min_len,
144
+ win_l=args.win_l,
145
+ win_s=args.win_s,
146
+ max_silence_kept=args.max_sil_kept
147
+ )
148
+ chunks = slicer.slice(audio)
149
+ if not os.path.exists(args.out):
150
+ os.makedirs(args.out)
151
+ start = 0
152
+ end_id = 0
153
+ for i, chunk in enumerate(chunks):
154
+ end = chunk
155
+ soundfile.write(os.path.join(out, f'%s-%s.wav' % (args.out_name, str(i).zfill(2))), audio[start:end], sr)
156
+ start = end
157
+ end_id = i + 1
158
+ soundfile.write(os.path.join(out, f'%s-%s.wav' % (args.out_name, str(end_id).zfill(2))), audio[start:len(audio)],
159
+ sr)
160
+
161
+
162
+ if __name__ == '__main__':
163
+ main()
transforms.py ADDED
@@ -0,0 +1,191 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ from torch.nn import functional as t_func
4
+
5
+ DEFAULT_MIN_BIN_WIDTH = 1e-3
6
+ DEFAULT_MIN_BIN_HEIGHT = 1e-3
7
+ DEFAULT_MIN_DERIVATIVE = 1e-3
8
+
9
+
10
+ def piecewise_rational_quadratic_transform(inputs,
11
+ unnormalized_widths,
12
+ unnormalized_heights,
13
+ unnormalized_derivatives,
14
+ inverse=False,
15
+ tails=None,
16
+ tail_bound=1.,
17
+ min_bin_width=DEFAULT_MIN_BIN_WIDTH,
18
+ min_bin_height=DEFAULT_MIN_BIN_HEIGHT,
19
+ min_derivative=DEFAULT_MIN_DERIVATIVE):
20
+ if tails is None:
21
+ spline_fn = rational_quadratic_spline
22
+ spline_kwargs = {}
23
+ else:
24
+ spline_fn = unconstrained_rational_quadratic_spline
25
+ spline_kwargs = {
26
+ 'tails': tails,
27
+ 'tail_bound': tail_bound
28
+ }
29
+
30
+ outputs, logabsdet = spline_fn(
31
+ inputs=inputs,
32
+ unnormalized_widths=unnormalized_widths,
33
+ unnormalized_heights=unnormalized_heights,
34
+ unnormalized_derivatives=unnormalized_derivatives,
35
+ inverse=inverse,
36
+ min_bin_width=min_bin_width,
37
+ min_bin_height=min_bin_height,
38
+ min_derivative=min_derivative,
39
+ **spline_kwargs
40
+ )
41
+ return outputs, logabsdet
42
+
43
+
44
+ def searchsorted(bin_locations, inputs, eps=1e-6):
45
+ bin_locations[..., -1] += eps
46
+ return torch.sum(
47
+ inputs[..., None] >= bin_locations,
48
+ dim=-1
49
+ ) - 1
50
+
51
+
52
+ def unconstrained_rational_quadratic_spline(inputs,
53
+ unnormalized_widths,
54
+ unnormalized_heights,
55
+ unnormalized_derivatives,
56
+ inverse=False,
57
+ tails='linear',
58
+ tail_bound=1.,
59
+ min_bin_width=DEFAULT_MIN_BIN_WIDTH,
60
+ min_bin_height=DEFAULT_MIN_BIN_HEIGHT,
61
+ min_derivative=DEFAULT_MIN_DERIVATIVE):
62
+ inside_interval_mask = (inputs >= -tail_bound) & (inputs <= tail_bound)
63
+ outside_interval_mask = ~inside_interval_mask
64
+
65
+ outputs = torch.zeros_like(inputs)
66
+ logabsdet = torch.zeros_like(inputs)
67
+
68
+ if tails == 'linear':
69
+ unnormalized_derivatives = t_func.pad(unnormalized_derivatives, pad=(1, 1))
70
+ constant = np.log(np.exp(1 - min_derivative) - 1)
71
+ unnormalized_derivatives[..., 0] = constant
72
+ unnormalized_derivatives[..., -1] = constant
73
+
74
+ outputs[outside_interval_mask] = inputs[outside_interval_mask]
75
+ logabsdet[outside_interval_mask] = 0
76
+ else:
77
+ raise RuntimeError('{} tails are not implemented.'.format(tails))
78
+
79
+ outputs[inside_interval_mask], logabsdet[inside_interval_mask] = rational_quadratic_spline(
80
+ inputs=inputs[inside_interval_mask],
81
+ unnormalized_widths=unnormalized_widths[inside_interval_mask, :],
82
+ unnormalized_heights=unnormalized_heights[inside_interval_mask, :],
83
+ unnormalized_derivatives=unnormalized_derivatives[inside_interval_mask, :],
84
+ inverse=inverse,
85
+ left=-tail_bound, right=tail_bound, bottom=-tail_bound, top=tail_bound,
86
+ min_bin_width=min_bin_width,
87
+ min_bin_height=min_bin_height,
88
+ min_derivative=min_derivative
89
+ )
90
+
91
+ return outputs, logabsdet
92
+
93
+
94
+ def rational_quadratic_spline(inputs,
95
+ unnormalized_widths,
96
+ unnormalized_heights,
97
+ unnormalized_derivatives,
98
+ inverse=False,
99
+ left=0., right=1., bottom=0., top=1.,
100
+ min_bin_width=DEFAULT_MIN_BIN_WIDTH,
101
+ min_bin_height=DEFAULT_MIN_BIN_HEIGHT,
102
+ min_derivative=DEFAULT_MIN_DERIVATIVE):
103
+ if torch.min(inputs) < left or torch.max(inputs) > right:
104
+ raise ValueError('Input to a transform is not within its domain')
105
+
106
+ num_bins = unnormalized_widths.shape[-1]
107
+
108
+ if min_bin_width * num_bins > 1.0:
109
+ raise ValueError('Minimal bin width too large for the number of bins')
110
+ if min_bin_height * num_bins > 1.0:
111
+ raise ValueError('Minimal bin height too large for the number of bins')
112
+
113
+ widths = t_func.softmax(unnormalized_widths, dim=-1)
114
+ widths = min_bin_width + (1 - min_bin_width * num_bins) * widths
115
+ cumwidths = torch.cumsum(widths, dim=-1)
116
+ cumwidths = t_func.pad(cumwidths, pad=(1, 0), mode='constant', value=0.0)
117
+ cumwidths = (right - left) * cumwidths + left
118
+ cumwidths[..., 0] = left
119
+ cumwidths[..., -1] = right
120
+ widths = cumwidths[..., 1:] - cumwidths[..., :-1]
121
+
122
+ derivatives = min_derivative + t_func.softplus(unnormalized_derivatives)
123
+
124
+ heights = t_func.softmax(unnormalized_heights, dim=-1)
125
+ heights = min_bin_height + (1 - min_bin_height * num_bins) * heights
126
+ cumheights = torch.cumsum(heights, dim=-1)
127
+ cumheights = t_func.pad(cumheights, pad=(1, 0), mode='constant', value=0.0)
128
+ cumheights = (top - bottom) * cumheights + bottom
129
+ cumheights[..., 0] = bottom
130
+ cumheights[..., -1] = top
131
+ heights = cumheights[..., 1:] - cumheights[..., :-1]
132
+
133
+ if inverse:
134
+ bin_idx = searchsorted(cumheights, inputs)[..., None]
135
+ else:
136
+ bin_idx = searchsorted(cumwidths, inputs)[..., None]
137
+
138
+ input_cumwidths = cumwidths.gather(-1, bin_idx)[..., 0]
139
+ input_bin_widths = widths.gather(-1, bin_idx)[..., 0]
140
+
141
+ input_cumheights = cumheights.gather(-1, bin_idx)[..., 0]
142
+ delta = heights / widths
143
+ input_delta = delta.gather(-1, bin_idx)[..., 0]
144
+
145
+ input_derivatives = derivatives.gather(-1, bin_idx)[..., 0]
146
+ input_derivatives_plus_one = derivatives[..., 1:].gather(-1, bin_idx)[..., 0]
147
+
148
+ input_heights = heights.gather(-1, bin_idx)[..., 0]
149
+
150
+ if inverse:
151
+ a = (((inputs - input_cumheights) * (input_derivatives
152
+ + input_derivatives_plus_one
153
+ - 2 * input_delta)
154
+ + input_heights * (input_delta - input_derivatives)))
155
+ b = (input_heights * input_derivatives
156
+ - (inputs - input_cumheights) * (input_derivatives
157
+ + input_derivatives_plus_one
158
+ - 2 * input_delta))
159
+ c = - input_delta * (inputs - input_cumheights)
160
+
161
+ discriminant = b.pow(2) - 4 * a * c
162
+ assert (discriminant >= 0).all()
163
+
164
+ root = (2 * c) / (-b - torch.sqrt(discriminant))
165
+ outputs = root * input_bin_widths + input_cumwidths
166
+
167
+ theta_one_minus_theta = root * (1 - root)
168
+ denominator = input_delta + ((input_derivatives + input_derivatives_plus_one - 2 * input_delta)
169
+ * theta_one_minus_theta)
170
+ derivative_numerator = input_delta.pow(2) * (input_derivatives_plus_one * root.pow(2)
171
+ + 2 * input_delta * theta_one_minus_theta
172
+ + input_derivatives * (1 - root).pow(2))
173
+ logabsdet = torch.log(derivative_numerator) - 2 * torch.log(denominator)
174
+
175
+ return outputs, -logabsdet
176
+ else:
177
+ theta = (inputs - input_cumwidths) / input_bin_widths
178
+ theta_one_minus_theta = theta * (1 - theta)
179
+
180
+ numerator = input_heights * (input_delta * theta.pow(2)
181
+ + input_derivatives * theta_one_minus_theta)
182
+ denominator = input_delta + ((input_derivatives + input_derivatives_plus_one - 2 * input_delta)
183
+ * theta_one_minus_theta)
184
+ outputs = input_cumheights + numerator / denominator
185
+
186
+ derivative_numerator = input_delta.pow(2) * (input_derivatives_plus_one * theta.pow(2)
187
+ + 2 * input_delta * theta_one_minus_theta
188
+ + input_derivatives * (1 - theta).pow(2))
189
+ logabsdet = torch.log(derivative_numerator) - 2 * torch.log(denominator)
190
+
191
+ return outputs, logabsdet
utils.py ADDED
@@ -0,0 +1,263 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import glob
3
+ import json
4
+ import logging
5
+ import os
6
+ import subprocess
7
+ import sys
8
+
9
+ import numpy as np
10
+ import torch
11
+ from scipy.io.wavfile import read
12
+
13
+ MATPLOTLIB_FLAG = False
14
+
15
+ logging.basicConfig(stream=sys.stdout, level=logging.DEBUG)
16
+ logger = logging
17
+
18
+
19
+ def load_checkpoint(checkpoint_path, model, optimizer=None):
20
+ assert os.path.isfile(checkpoint_path)
21
+ checkpoint_dict = torch.load(checkpoint_path, map_location='cpu')
22
+ iteration = checkpoint_dict['iteration']
23
+ learning_rate = checkpoint_dict['learning_rate']
24
+ if optimizer is not None:
25
+ optimizer.load_state_dict(checkpoint_dict['optimizer'])
26
+ # print(1111)
27
+ saved_state_dict = checkpoint_dict['model']
28
+ # print(1111)
29
+
30
+ if hasattr(model, 'module'):
31
+ state_dict = model.module.state_dict()
32
+ else:
33
+ state_dict = model.state_dict()
34
+ new_state_dict = {}
35
+ for k, v in state_dict.items():
36
+ try:
37
+ new_state_dict[k] = saved_state_dict[k]
38
+ except Exception as e:
39
+ logger.info(e)
40
+ logger.info("%s is not in the checkpoint" % k)
41
+ new_state_dict[k] = v
42
+ if hasattr(model, 'module'):
43
+ model.module.load_state_dict(new_state_dict)
44
+ else:
45
+ model.load_state_dict(new_state_dict)
46
+ logger.info("Loaded checkpoint '{}' (iteration {})".format(
47
+ checkpoint_path, iteration))
48
+ return model, optimizer, learning_rate, iteration
49
+
50
+
51
+ def save_checkpoint(model, optimizer, learning_rate, iteration, checkpoint_path):
52
+ logger.info("Saving model and optimizer state at iteration {} to {}".format(
53
+ iteration, checkpoint_path))
54
+ if hasattr(model, 'module'):
55
+ state_dict = model.module.state_dict()
56
+ else:
57
+ state_dict = model.state_dict()
58
+ torch.save({'model': state_dict,
59
+ 'iteration': iteration,
60
+ 'optimizer': optimizer.state_dict(),
61
+ 'learning_rate': learning_rate}, checkpoint_path)
62
+
63
+
64
+ def summarize(writer, global_step, scalars={}, histograms={}, images={}, audios={}, audio_sampling_rate=22050):
65
+ for k, v in scalars.items():
66
+ writer.add_scalar(k, v, global_step)
67
+ for k, v in histograms.items():
68
+ writer.add_histogram(k, v, global_step)
69
+ for k, v in images.items():
70
+ writer.add_image(k, v, global_step, dataformats='HWC')
71
+ for k, v in audios.items():
72
+ writer.add_audio(k, v, global_step, audio_sampling_rate)
73
+
74
+
75
+ def latest_checkpoint_path(dir_path, regex="G_*.pth"):
76
+ f_list = glob.glob(os.path.join(dir_path, regex))
77
+ f_list.sort(key=lambda f: int("".join(filter(str.isdigit, f))))
78
+ x = f_list[-1]
79
+ print(x)
80
+ return x
81
+
82
+
83
+ def plot_spectrogram_to_numpy(spectrogram):
84
+ global MATPLOTLIB_FLAG
85
+ if not MATPLOTLIB_FLAG:
86
+ import matplotlib
87
+ matplotlib.use("Agg")
88
+ MATPLOTLIB_FLAG = True
89
+ mpl_logger = logging.getLogger('matplotlib')
90
+ mpl_logger.setLevel(logging.WARNING)
91
+ import matplotlib.pylab as plt
92
+ import numpy
93
+
94
+ fig, ax = plt.subplots(figsize=(10, 2))
95
+ im = ax.imshow(spectrogram, aspect="auto", origin="lower",
96
+ interpolation='none')
97
+ plt.colorbar(im, ax=ax)
98
+ plt.xlabel("Frames")
99
+ plt.ylabel("Channels")
100
+ plt.tight_layout()
101
+
102
+ fig.canvas.draw()
103
+ data = numpy.fromstring(fig.canvas.tostring_rgb(), dtype=numpy.uint8, sep='')
104
+ data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
105
+ plt.close()
106
+ return data
107
+
108
+
109
+ def plot_alignment_to_numpy(alignment, info=None):
110
+ global MATPLOTLIB_FLAG
111
+ if not MATPLOTLIB_FLAG:
112
+ import matplotlib
113
+ matplotlib.use("Agg")
114
+ MATPLOTLIB_FLAG = True
115
+ mpl_logger = logging.getLogger('matplotlib')
116
+ mpl_logger.setLevel(logging.WARNING)
117
+ import matplotlib.pylab as plt
118
+ import numpy
119
+
120
+ fig, ax = plt.subplots(figsize=(6, 4))
121
+ im = ax.imshow(alignment.transpose(), aspect='auto', origin='lower',
122
+ interpolation='none')
123
+ fig.colorbar(im, ax=ax)
124
+ xlabel = 'Decoder timestep'
125
+ if info is not None:
126
+ xlabel += '\n\n' + info
127
+ plt.xlabel(xlabel)
128
+ plt.ylabel('Encoder timestep')
129
+ plt.tight_layout()
130
+
131
+ fig.canvas.draw()
132
+ data = numpy.fromstring(fig.canvas.tostring_rgb(), dtype=numpy.uint8, sep='')
133
+ data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
134
+ plt.close()
135
+ return data
136
+
137
+
138
+ def load_wav_to_torch(full_path):
139
+ sampling_rate, data = read(full_path)
140
+ return torch.FloatTensor(data.astype(np.float32)), sampling_rate
141
+
142
+
143
+ def load_filepaths_and_text(filename, split="|"):
144
+ with open(filename, encoding='utf-8') as f:
145
+ filepaths_and_text = [line.strip().split(split) for line in f]
146
+ return filepaths_and_text
147
+
148
+
149
+ def get_hparams(init=True):
150
+ parser = argparse.ArgumentParser()
151
+ parser.add_argument('-c', '--config', type=str, default="./configs/base.json",
152
+ help='JSON file for configuration')
153
+ parser.add_argument('-m', '--model', type=str, required=True,
154
+ help='Model name')
155
+
156
+ args = parser.parse_args()
157
+ model_dir = os.path.join("./logs", args.model)
158
+
159
+ if not os.path.exists(model_dir):
160
+ os.makedirs(model_dir)
161
+
162
+ config_path = args.config
163
+ config_save_path = os.path.join(model_dir, "config.json")
164
+ if init:
165
+ with open(config_path, "r") as f:
166
+ data = f.read()
167
+ with open(config_save_path, "w") as f:
168
+ f.write(data)
169
+ else:
170
+ with open(config_save_path, "r") as f:
171
+ data = f.read()
172
+ config = json.loads(data)
173
+
174
+ hparams = HParams(**config)
175
+ hparams.model_dir = model_dir
176
+ return hparams
177
+
178
+
179
+ def get_hparams_from_dir(model_dir):
180
+ config_save_path = os.path.join(model_dir, "config.json")
181
+ with open(config_save_path, "r") as f:
182
+ data = f.read()
183
+ config = json.loads(data)
184
+
185
+ hparams = HParams(**config)
186
+ hparams.model_dir = model_dir
187
+ return hparams
188
+
189
+
190
+ def get_hparams_from_file(config_path):
191
+ with open(config_path, "r", encoding="utf-8") as f:
192
+ data = f.read()
193
+ config = json.loads(data)
194
+
195
+ hparams = HParams(**config)
196
+ return hparams
197
+
198
+
199
+ def check_git_hash(model_dir):
200
+ source_dir = os.path.dirname(os.path.realpath(__file__))
201
+ if not os.path.exists(os.path.join(source_dir, ".git")):
202
+ logger.warning("{} is not a git repository, therefore hash value comparison will be ignored.".format(
203
+ source_dir
204
+ ))
205
+ return
206
+
207
+ cur_hash = subprocess.getoutput("git rev-parse HEAD")
208
+
209
+ path = os.path.join(model_dir, "githash")
210
+ if os.path.exists(path):
211
+ saved_hash = open(path).read()
212
+ if saved_hash != cur_hash:
213
+ logger.warning("git hash values are different. {}(saved) != {}(current)".format(
214
+ saved_hash[:8], cur_hash[:8]))
215
+ else:
216
+ open(path, "w").write(cur_hash)
217
+
218
+
219
+ def get_logger(model_dir, filename="train.log"):
220
+ global logger
221
+ logger = logging.getLogger(os.path.basename(model_dir))
222
+ logger.setLevel(logging.DEBUG)
223
+
224
+ formatter = logging.Formatter("%(asctime)s\t%(name)s\t%(levelname)s\t%(message)s")
225
+ if not os.path.exists(model_dir):
226
+ os.makedirs(model_dir)
227
+ h = logging.FileHandler(os.path.join(model_dir, filename))
228
+ h.setLevel(logging.DEBUG)
229
+ h.setFormatter(formatter)
230
+ logger.addHandler(h)
231
+ return logger
232
+
233
+
234
+ class HParams:
235
+ def __init__(self, **kwargs):
236
+ for k, v in kwargs.items():
237
+ if type(v) == dict:
238
+ v = HParams(**v)
239
+ self[k] = v
240
+
241
+ def keys(self):
242
+ return self.__dict__.keys()
243
+
244
+ def items(self):
245
+ return self.__dict__.items()
246
+
247
+ def values(self):
248
+ return self.__dict__.values()
249
+
250
+ def __len__(self):
251
+ return len(self.__dict__)
252
+
253
+ def __getitem__(self, key):
254
+ return getattr(self, key)
255
+
256
+ def __setitem__(self, key, value):
257
+ return setattr(self, key, value)
258
+
259
+ def __contains__(self, key):
260
+ return key in self.__dict__
261
+
262
+ def __repr__(self):
263
+ return self.__dict__.__repr__()