Dionyssos commited on
Commit
9cbdf67
·
1 Parent(s): 318370a

pkl per sentence - No audinterface

Browse files
Files changed (1) hide show
  1. correct_figure.py +378 -0
correct_figure.py ADDED
@@ -0,0 +1,378 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # we have to evaluate emotion & cer per sentence -> not audinterface sliding window
2
+ import os
3
+ import audresample
4
+ import torch
5
+ import matplotlib.pyplot as plt
6
+ import soundfile
7
+ import json
8
+ import audb
9
+ from transformers import AutoModelForAudioClassification
10
+ from transformers.models.wav2vec2.modeling_wav2vec2 import Wav2Vec2PreTrainedModel
11
+ import types
12
+ from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline
13
+ import pandas as pd
14
+ import json
15
+ import numpy as np
16
+ from pathlib import Path
17
+ import transformers
18
+ import torch
19
+ import audmodel
20
+ import audiofile
21
+ import jiwer
22
+ # https://arxiv.org/pdf/2407.12229
23
+ # https://arxiv.org/pdf/2312.05187
24
+ # https://arxiv.org/abs/2407.05407
25
+ # https://arxiv.org/pdf/2408.06577
26
+ # https://arxiv.org/pdf/2309.07405
27
+ import msinference
28
+ import os
29
+ from random import shuffle
30
+
31
+ config = transformers.Wav2Vec2Config() #finetuning_task='spef2feat_reg')
32
+ config.dev = torch.device('cuda:0')
33
+ config.dev2 = torch.device('cuda:0')
34
+
35
+
36
+
37
+
38
+ LABELS = ['arousal', 'dominance', 'valence',
39
+ 'Angry',
40
+ 'Sad',
41
+ 'Happy',
42
+ 'Surprise',
43
+ 'Fear',
44
+ 'Disgust',
45
+ 'Contempt',
46
+ 'Neutral'
47
+ ]
48
+
49
+ config = transformers.Wav2Vec2Config() #finetuning_task='spef2feat_reg')
50
+ config.dev = torch.device('cuda:0')
51
+ config.dev2 = torch.device('cuda:0')
52
+
53
+
54
+
55
+
56
+ # https://arxiv.org/pdf/2407.12229
57
+ # https://arxiv.org/pdf/2312.05187
58
+ # https://arxiv.org/abs/2407.05407
59
+ # https://arxiv.org/pdf/2408.06577
60
+ # https://arxiv.org/pdf/2309.07405
61
+
62
+
63
+ def _infer(self, x):
64
+ '''x: (batch, audio-samples-16KHz)'''
65
+ x = (x + self.config.mean) / self.config.std # plus
66
+ x = self.ssl_model(x, attention_mask=None).last_hidden_state
67
+ # pool
68
+ h = self.pool_model.sap_linear(x).tanh()
69
+ w = torch.matmul(h, self.pool_model.attention)
70
+ w = w.softmax(1)
71
+ mu = (x * w).sum(1)
72
+ x = torch.cat(
73
+ [
74
+ mu,
75
+ ((x * x * w).sum(1) - mu * mu).clamp(min=1e-7).sqrt()
76
+ ], 1)
77
+ return self.ser_model(x)
78
+
79
+ teacher_cat = AutoModelForAudioClassification.from_pretrained(
80
+ '3loi/SER-Odyssey-Baseline-WavLM-Categorical-Attributes',
81
+ trust_remote_code=True # fun definitions see 3loi/SER-.. repo
82
+ ).to(config.dev2).eval()
83
+ teacher_cat.forward = types.MethodType(_infer, teacher_cat)
84
+
85
+
86
+ # ===================[:]===================== Dawn
87
+ def _prenorm(x, attention_mask=None):
88
+ '''mean/var'''
89
+ if attention_mask is not None:
90
+ N = attention_mask.sum(1, keepdim=True) # here attn msk is unprocessed just the original input
91
+ x -= x.sum(1, keepdim=True) / N
92
+ var = (x * x).sum(1, keepdim=True) / N
93
+
94
+ else:
95
+ x -= x.mean(1, keepdim=True) # mean is an onnx operator reducemean saves some ops compared to casting integer N to float and the div
96
+ var = (x * x).mean(1, keepdim=True)
97
+ return x / torch.sqrt(var + 1e-7)
98
+
99
+ from torch import nn
100
+ from transformers.models.wav2vec2.modeling_wav2vec2 import Wav2Vec2PreTrainedModel, Wav2Vec2Model
101
+ class RegressionHead(nn.Module):
102
+ r"""Classification head."""
103
+
104
+ def __init__(self, config):
105
+
106
+ super().__init__()
107
+
108
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
109
+ self.dropout = nn.Dropout(config.final_dropout)
110
+ self.out_proj = nn.Linear(config.hidden_size, config.num_labels)
111
+
112
+ def forward(self, features, **kwargs):
113
+
114
+ x = features
115
+ x = self.dropout(x)
116
+ x = self.dense(x)
117
+ x = torch.tanh(x)
118
+ x = self.dropout(x)
119
+ x = self.out_proj(x)
120
+
121
+ return x
122
+
123
+
124
+ class Dawn(Wav2Vec2PreTrainedModel):
125
+ r"""Speech emotion classifier."""
126
+
127
+ def __init__(self, config):
128
+
129
+ super().__init__(config)
130
+
131
+ self.config = config
132
+ self.wav2vec2 = Wav2Vec2Model(config)
133
+ self.classifier = RegressionHead(config)
134
+ self.init_weights()
135
+
136
+ def forward(
137
+ self,
138
+ input_values,
139
+ attention_mask=None,
140
+ ):
141
+ x = _prenorm(input_values, attention_mask=attention_mask)
142
+ outputs = self.wav2vec2(x, attention_mask=attention_mask)
143
+ hidden_states = outputs[0]
144
+ hidden_states = torch.mean(hidden_states, dim=1)
145
+ logits = self.classifier(hidden_states)
146
+ return logits
147
+ # return {'hidden_states': hidden_states,
148
+ # 'logits': logits}
149
+ dawn = Dawn.from_pretrained('audeering/wav2vec2-large-robust-12-ft-emotion-msp-dim').to(config.dev).eval()
150
+ # =======================================
151
+
152
+
153
+
154
+
155
+
156
+
157
+
158
+
159
+
160
+
161
+
162
+ torch_dtype = torch.float16 #if torch.cuda.is_available() else torch.float32
163
+ model_id = "openai/whisper-large-v3"
164
+ model = AutoModelForSpeechSeq2Seq.from_pretrained(
165
+ model_id, torch_dtype=torch_dtype, low_cpu_mem_usage=True, use_safetensors=True
166
+ ).to(config.dev)
167
+ processor = AutoProcessor.from_pretrained(model_id)
168
+ _pipe = pipeline(
169
+ "automatic-speech-recognition",
170
+ model=model,
171
+ tokenizer=processor.tokenizer,
172
+ feature_extractor=processor.feature_extractor,
173
+ max_new_tokens=128,
174
+ chunk_length_s=30,
175
+ batch_size=16,
176
+ return_timestamps=True,
177
+ torch_dtype=torch_dtype,
178
+ device=config.dev,
179
+ )
180
+
181
+
182
+
183
+
184
+
185
+
186
+
187
+
188
+
189
+
190
+ def process_function(x, sampling_rate, idx):
191
+ # x = x[None , :] ASaHSuFDCN
192
+ # {0: 'Angry', 1: 'Sad', 2: 'Happy', 3: 'Surprise',
193
+ # 4: 'Fear', 5: 'Disgust', 6: 'Contempt', 7: 'Neutral'}
194
+ #tensor([[0.0015, 0.3651, 0.0593, 0.0315, 0.0600, 0.0125, 0.0319, 0.4382]])
195
+ logits_cat = teacher_cat(torch.from_numpy(x).to(config.dev)).softmax(1)
196
+ logits_adv = dawn(torch.from_numpy(x).to(config.dev))
197
+
198
+ out = torch.cat([logits_adv,
199
+ logits_cat],
200
+ 1).cpu().detach().numpy()
201
+ # print(out.shape)
202
+ return out[0, :]
203
+
204
+
205
+
206
+ def load_speech(split=None):
207
+ DB = [
208
+ # [dataset, version, table, has_timdeltas_or_is_full_wavfile]
209
+ # ['crema-d', '1.1.1', 'emotion.voice.test', False],
210
+ #['librispeech', '3.1.0', 'test-clean', False],
211
+ ['emodb', '1.2.0', 'emotion.categories.train.gold_standard', False],
212
+ # ['entertain-playtestcloud', '1.1.0', 'emotion.categories.train.gold_standard', True],
213
+ # ['erik', '2.2.0', 'emotion.categories.train.gold_standard', True],
214
+ # ['meld', '1.3.1', 'emotion.categories.train.gold_standard', False],
215
+ # ['msppodcast', '5.0.0', 'emotion.categories.train.gold_standard', False], # tandalone bucket because it has gt labels?
216
+ # ['myai', '1.0.1', 'emotion.categories.train.gold_standard', False],
217
+ # ['casia', None, 'emotion.categories.gold_standard', False],
218
+ # ['switchboard-1', None, 'sentiment', True],
219
+ # ['swiss-parliament', None, 'segments', True],
220
+ # ['argentinian-parliament', None, 'segments', True],
221
+ # ['austrian-parliament', None, 'segments', True],
222
+ # #'german', --> bundestag
223
+ # ['brazilian-parliament', None, 'segments', True],
224
+ # ['mexican-parliament', None, 'segments', True],
225
+ # ['portuguese-parliament', None, 'segments', True],
226
+ # ['spanish-parliament', None, 'segments', True],
227
+ # ['chinese-vocal-emotions-liu-pell', None, 'emotion.categories.desired', False],
228
+ # peoples-speech slow
229
+ # ['peoples-speech', None, 'train-initial', False]
230
+ ]
231
+
232
+ output_list = []
233
+ for database_name, ver, table, has_timedeltas in DB:
234
+
235
+ a = audb.load(database_name,
236
+ sampling_rate=16000,
237
+ format='wav',
238
+ mixdown=True,
239
+ version=ver,
240
+ cache_root='/cache/audb/')
241
+ a = a[table].get()
242
+ if has_timedeltas:
243
+ print(f'{has_timedeltas=}')
244
+ # a = a.reset_index()[['file', 'start', 'end']]
245
+ # output_list += [[*t] for t
246
+ # in zip(a.file.values, a.start.dt.total_seconds().values, a.end.dt.total_seconds().values)]
247
+ else:
248
+ output_list += [f for f in a.index] # use file (no timedeltas)
249
+ return output_list
250
+
251
+
252
+
253
+
254
+
255
+
256
+
257
+
258
+
259
+
260
+
261
+ natural_wav_paths = load_speech()
262
+
263
+
264
+
265
+
266
+
267
+
268
+
269
+ with open('harvard.json', 'r') as f:
270
+ harvard_individual_sentences = json.load(f)['sentences']
271
+
272
+
273
+
274
+ synthetic_wav_paths = ['./enslow/' + i for i in
275
+ os.listdir('./enslow/')]
276
+ synthetic_wav_paths_4x = ['./style_vector_v2/' + i for i in
277
+ os.listdir('./style_vector_v2/')]
278
+ synthetic_wav_paths_foreign = ['./mimic3_foreign/' + i for i in os.listdir('./mimic3_foreign/') if 'en_U' not in i]
279
+ synthetic_wav_paths_foreign_4x = ['./mimic3_foreign_4x/' + i for i in os.listdir('./mimic3_foreign_4x/') if 'en_U' not in i] # very short segments
280
+
281
+ # filter very short styles
282
+ synthetic_wav_paths_foreign = [i for i in synthetic_wav_paths_foreign if audiofile.duration(i) > 2]
283
+ synthetic_wav_paths_foreign_4x = [i for i in synthetic_wav_paths_foreign_4x if audiofile.duration(i) > 2]
284
+ synthetic_wav_paths = [i for i in synthetic_wav_paths if audiofile.duration(i) > 2]
285
+ synthetic_wav_pathsn_4x = [i for i in synthetic_wav_paths_4x if audiofile.duration(i) > 2]
286
+
287
+ shuffle(synthetic_wav_paths_foreign_4x)
288
+ shuffle(synthetic_wav_paths_foreign)
289
+ shuffle(synthetic_wav_paths)
290
+ shuffle(synthetic_wav_paths_4x)
291
+ print(len(synthetic_wav_paths_foreign_4x), len(synthetic_wav_paths_foreign),
292
+ len(synthetic_wav_paths), len(synthetic_wav_paths_4x)) # 134 204 134 204
293
+
294
+
295
+
296
+ for audio_prompt in ['english',
297
+ 'english_4x',
298
+ 'human',
299
+ 'foreign',
300
+ 'foreign_4x']: # each of these creates a separate pkl - so outer for
301
+ #
302
+ data = np.zeros((767, len(LABELS)*2 + 2)) # 720 x LABELS-prompt & LABELS-stts2 & cer-prompt & cer-stts2
303
+
304
+
305
+
306
+ #
307
+
308
+ OUT_FILE = f'{audio_prompt}_analytic.pkl'
309
+ if not os.path.isfile(OUT_FILE):
310
+ ix = 0
311
+ for list_of_10 in harvard_individual_sentences[:10004]:
312
+ # long_sentence = ' '.join(list_of_10['sentences'])
313
+ # harvard.append(long_sentence.replace('.', ' '))
314
+ for text in list_of_10['sentences']:
315
+ if audio_prompt == 'english':
316
+ _p = synthetic_wav_paths[ix % len(synthetic_wav_paths)]
317
+ # 134
318
+ style_vec = msinference.compute_style(_p)
319
+ elif audio_prompt == 'english_4x':
320
+ _p = synthetic_wav_paths_4x[ix % len(synthetic_wav_paths_4x)]
321
+ # 134]
322
+ style_vec = msinference.compute_style(_p)
323
+ elif audio_prompt == 'human':
324
+ _p = natural_wav_paths[ix % len(natural_wav_paths)]
325
+ # ?
326
+ style_vec = msinference.compute_style(_p)
327
+ elif audio_prompt == 'foreign':
328
+ _p = synthetic_wav_paths_foreign[ix % len(synthetic_wav_paths_foreign)]
329
+ # 204 some short styles are discarded ~ 1180
330
+ style_vec = msinference.compute_style(_p)
331
+ elif audio_prompt == 'foreign_4x':
332
+ _p = synthetic_wav_paths_foreign_4x[ix % len(synthetic_wav_paths_foreign_4x)]
333
+ # 174
334
+ style_vec = msinference.compute_style(_p)
335
+ else:
336
+ print('unknonw list of style vector')
337
+
338
+ x = msinference.inference(text,
339
+ style_vec,
340
+ alpha=0.3,
341
+ beta=0.7,
342
+ diffusion_steps=7,
343
+ embedding_scale=1)
344
+ x = audresample.resample(x, 24000, 16000)
345
+
346
+
347
+ _st, fsr = audiofile.read(_p)
348
+ _st = audresample.resample(_st, fsr, 16000)
349
+ print(_st.shape, x.shape)
350
+
351
+ emotion_of_prompt = process_function(_st, 16000, None)
352
+ emotion_of_out = process_function(x, 16000, None)
353
+ data[ix, :11] = emotion_of_prompt
354
+ data[ix, 11:22] = emotion_of_out
355
+
356
+ # 2 last columns is cer-prompt cer-styletts2
357
+
358
+ transcription_prompt = _pipe(_st[0])
359
+ transcription_styletts2 = _pipe(x[0]) # allow singleton for EMO process func
360
+ # print(len(emotion_of_prompt + emotion_of_out), ix, text)
361
+ print(transcription_prompt, transcription_styletts2)
362
+
363
+ data[ix, 22] = jiwer.cer('Sweet dreams are made of this. I travel the world and the seven seas.',
364
+ transcription_prompt['text'])
365
+
366
+ data[ix, 23] = jiwer.cer(text,
367
+ transcription_styletts2['text'])
368
+ print(data[ix, :])
369
+
370
+ ix += 1
371
+
372
+ df = pd.DataFrame(data, columns=['prompt-' + i for i in LABELS] + ['styletts2-' + i for i in LABELS] + ['cer-prompt', 'cer-styletts2'])
373
+ df.to_pickle(OUT_FILE)
374
+ else:
375
+
376
+ df = pd.read_pickle(OUT_FILE)
377
+ print('\nALREADY EXISTS\n{df}')
378
+ # From the pickle we should also run cer and whisper on every prompt