mtauro commited on
Commit
8f58aff
·
1 Parent(s): a08d6e0

Upload custom_interface.py

Browse files
Files changed (1) hide show
  1. custom_interface.py +161 -0
custom_interface.py ADDED
@@ -0,0 +1,161 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from speechbrain.pretrained import Pretrained
3
+
4
+
5
+ class CustomEncoderWav2vec2Classifier(Pretrained):
6
+ """A ready-to-use class for utterance-level classification (e.g, speaker-id,
7
+ language-id, emotion recognition, keyword spotting, etc).
8
+
9
+ The class assumes that an self-supervised encoder like wav2vec2/hubert and a classifier model
10
+ are defined in the yaml file. If you want to
11
+ convert the predicted index into a corresponding text label, please
12
+ provide the path of the label_encoder in a variable called 'lab_encoder_file'
13
+ within the yaml.
14
+
15
+ The class can be used either to run only the encoder (encode_batch()) to
16
+ extract embeddings or to run a classification step (classify_batch()).
17
+ ```
18
+
19
+ Example
20
+ -------
21
+ >>> import torchaudio
22
+ >>> from speechbrain.pretrained import EncoderClassifier
23
+ >>> # Model is downloaded from the speechbrain HuggingFace repo
24
+ >>> tmpdir = getfixture("tmpdir")
25
+ >>> classifier = EncoderClassifier.from_hparams(
26
+ ... source="speechbrain/spkrec-ecapa-voxceleb",
27
+ ... savedir=tmpdir,
28
+ ... )
29
+
30
+ >>> # Compute embeddings
31
+ >>> signal, fs = torchaudio.load("samples/audio_samples/example1.wav")
32
+ >>> embeddings = classifier.encode_batch(signal)
33
+
34
+ >>> # Classification
35
+ >>> prediction = classifier .classify_batch(signal)
36
+ """
37
+
38
+ def __init__(self, *args, **kwargs):
39
+ super().__init__(*args, **kwargs)
40
+
41
+ def encode_batch(self, wavs, wav_lens=None, normalize=False):
42
+ """Encodes the input audio into a single vector embedding.
43
+
44
+ The waveforms should already be in the model's desired format.
45
+ You can call:
46
+ ``normalized = <this>.normalizer(signal, sample_rate)``
47
+ to get a correctly converted signal in most cases.
48
+
49
+ Arguments
50
+ ---------
51
+ wavs : torch.tensor
52
+ Batch of waveforms [batch, time, channels] or [batch, time]
53
+ depending on the model. Make sure the sample rate is fs=16000 Hz.
54
+ wav_lens : torch.tensor
55
+ Lengths of the waveforms relative to the longest one in the
56
+ batch, tensor of shape [batch]. The longest one should have
57
+ relative length 1.0 and others len(waveform) / max_length.
58
+ Used for ignoring padding.
59
+ normalize : bool
60
+ If True, it normalizes the embeddings with the statistics
61
+ contained in mean_var_norm_emb.
62
+
63
+ Returns
64
+ -------
65
+ torch.tensor
66
+ The encoded batch
67
+ """
68
+ # Manage single waveforms in input
69
+ if len(wavs.shape) == 1:
70
+ wavs = wavs.unsqueeze(0)
71
+
72
+ # Assign full length if wav_lens is not assigned
73
+ if wav_lens is None:
74
+ wav_lens = torch.ones(wavs.shape[0], device=self.device)
75
+
76
+ # Storing waveform in the specified device
77
+ wavs, wav_lens = wavs.to(self.device), wav_lens.to(self.device)
78
+ wavs = wavs.float()
79
+
80
+ # Computing features and embeddings
81
+ outputs = self.mods.transf(wavs)
82
+
83
+ # last dim will be used for AdaptativeAVG pool
84
+ outputs = self.mods.avg_pool(outputs, wav_lens)
85
+ outputs = self.mods.enc(outputs)
86
+ outputs = outputs.view(outputs.shape[0], -1)
87
+ return outputs
88
+
89
+ def classify_batch(self, wavs, wav_lens=None):
90
+ """Performs classification on the top of the encoded features.
91
+
92
+ It returns the posterior probabilities, the index and, if the label
93
+ encoder is specified it also the text label.
94
+
95
+ Arguments
96
+ ---------
97
+ wavs : torch.tensor
98
+ Batch of waveforms [batch, time, channels] or [batch, time]
99
+ depending on the model. Make sure the sample rate is fs=16000 Hz.
100
+ wav_lens : torch.tensor
101
+ Lengths of the waveforms relative to the longest one in the
102
+ batch, tensor of shape [batch]. The longest one should have
103
+ relative length 1.0 and others len(waveform) / max_length.
104
+ Used for ignoring padding.
105
+
106
+ Returns
107
+ -------
108
+ out_prob
109
+ The log posterior probabilities of each class ([batch, N_class])
110
+ score:
111
+ It is the value of the log-posterior for the best class ([batch,])
112
+ index
113
+ The indexes of the best class ([batch,])
114
+ text_lab:
115
+ List with the text labels corresponding to the indexes.
116
+ (label encoder should be provided).
117
+ """
118
+ outputs = self.encode_batch(wavs, wav_lens)
119
+ outputs = self.mods.classifier(outputs)
120
+ #out_prob = self.hparams.softmax(outputs)
121
+ out_prob = outputs # added
122
+ score, index = torch.max(out_prob, dim=-1)
123
+ text_lab = self.hparams.label_encoder.decode_torch(index)
124
+ return out_prob, score, index, text_lab
125
+
126
+ def classify_file(self, path):
127
+ """Classifies the given audiofile into the given set of labels.
128
+
129
+ Arguments
130
+ ---------
131
+ path : str
132
+ Path to audio file to classify.
133
+
134
+ Returns
135
+ -------
136
+ out_prob
137
+ The log posterior probabilities of each class ([batch, N_class])
138
+ score:
139
+ It is the value of the log-posterior for the best class ([batch,])
140
+ index
141
+ The indexes of the best class ([batch,])
142
+ text_lab:
143
+ List with the text labels corresponding to the indexes.
144
+ (label encoder should be provided).
145
+ """
146
+ waveform = self.load_audio(path)
147
+ # Fake a batch:
148
+ batch = waveform.unsqueeze(0)
149
+ rel_length = torch.tensor([1.0])
150
+ outputs = self.encode_batch(batch, rel_length)
151
+ outputs = self.mods.classifier(outputs).squeeze(1)
152
+ #out_prob = self.hparams.softmax(outputs)
153
+ out_prob = outputs # added
154
+ score, index = torch.max(out_prob, dim=-1)
155
+ text_lab = self.hparams.label_encoder.decode_torch(index)
156
+ return out_prob, score, index, text_lab
157
+
158
+ def forward(self, wavs, wav_lens=None, normalize=False):
159
+ return self.encode_batch(
160
+ wavs=wavs, wav_lens=wav_lens, normalize=normalize
161
+ )