johntsi commited on
Commit
b19206a
·
1 Parent(s): 695d1bb

Initial Commit

Browse files
Files changed (4) hide show
  1. README.md +360 -3
  2. config.json +27 -0
  3. model.py +366 -0
  4. model.safetensors +3 -0
README.md CHANGED
@@ -1,3 +1,360 @@
1
- ---
2
- license: mit
3
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ language:
3
+ - ace
4
+ - acm
5
+ - acq
6
+ - aeb
7
+ - af
8
+ - ajp
9
+ - ak
10
+ - als
11
+ - am
12
+ - apc
13
+ - ar
14
+ - ars
15
+ - ary
16
+ - arz
17
+ - as
18
+ - ast
19
+ - awa
20
+ - ayr
21
+ - azb
22
+ - azj
23
+ - ba
24
+ - bm
25
+ - ban
26
+ - be
27
+ - bem
28
+ - bn
29
+ - bho
30
+ - bjn
31
+ - bo
32
+ - bs
33
+ - bug
34
+ - bg
35
+ - ca
36
+ - ceb
37
+ - cs
38
+ - cjk
39
+ - ckb
40
+ - crh
41
+ - cy
42
+ - da
43
+ - de
44
+ - dik
45
+ - dyu
46
+ - dz
47
+ - el
48
+ - en
49
+ - eo
50
+ - et
51
+ - eu
52
+ - ee
53
+ - fo
54
+ - fj
55
+ - fi
56
+ - fon
57
+ - fr
58
+ - fur
59
+ - fuv
60
+ - gaz
61
+ - gd
62
+ - ga
63
+ - gl
64
+ - gn
65
+ - gu
66
+ - ht
67
+ - ha
68
+ - he
69
+ - hi
70
+ - hne
71
+ - hr
72
+ - hu
73
+ - hy
74
+ - ig
75
+ - ilo
76
+ - id
77
+ - is
78
+ - it
79
+ - jv
80
+ - ja
81
+ - kab
82
+ - kac
83
+ - kam
84
+ - kn
85
+ - ks
86
+ - ka
87
+ - kk
88
+ - kbp
89
+ - kea
90
+ - khk
91
+ - km
92
+ - ki
93
+ - rw
94
+ - ky
95
+ - kmb
96
+ - kmr
97
+ - knc
98
+ - kg
99
+ - ko
100
+ - lo
101
+ - lij
102
+ - li
103
+ - ln
104
+ - lt
105
+ - lmo
106
+ - ltg
107
+ - lb
108
+ - lua
109
+ - lg
110
+ - luo
111
+ - lus
112
+ - lvs
113
+ - mag
114
+ - mai
115
+ - ml
116
+ - mar
117
+ - min
118
+ - mk
119
+ - mt
120
+ - mni
121
+ - mos
122
+ - mi
123
+ - my
124
+ - nl
125
+ - nn
126
+ - nb
127
+ - npi
128
+ - nso
129
+ - nus
130
+ - ny
131
+ - oc
132
+ - ory
133
+ - pag
134
+ - pa
135
+ - pap
136
+ - pbt
137
+ - pes
138
+ - plt
139
+ - pl
140
+ - pt
141
+ - prs
142
+ - quy
143
+ - ro
144
+ - rn
145
+ - ru
146
+ - sg
147
+ - sa
148
+ - sat
149
+ - scn
150
+ - shn
151
+ - si
152
+ - sk
153
+ - sl
154
+ - sm
155
+ - sn
156
+ - sd
157
+ - so
158
+ - st
159
+ - es
160
+ - sc
161
+ - sr
162
+ - ss
163
+ - su
164
+ - sv
165
+ - swh
166
+ - szl
167
+ - ta
168
+ - taq
169
+ - tt
170
+ - te
171
+ - tg
172
+ - tl
173
+ - th
174
+ - ti
175
+ - tpi
176
+ - tn
177
+ - ts
178
+ - tk
179
+ - tum
180
+ - tr
181
+ - tw
182
+ - tzm
183
+ - ug
184
+ - uk
185
+ - umb
186
+ - ur
187
+ - uzn
188
+ - vec
189
+ - vi
190
+ - war
191
+ - wo
192
+ - xh
193
+ - ydd
194
+ - yo
195
+ - yue
196
+ - zh
197
+ - zsm
198
+ - zu
199
+ language_details: >-
200
+ ace_Arab, ace_Latn, acm_Arab, acq_Arab, aeb_Arab, afr_Latn, ajp_Arab,
201
+ aka_Latn, amh_Ethi, apc_Arab, arb_Arab, ars_Arab, ary_Arab, arz_Arab,
202
+ asm_Beng, ast_Latn, awa_Deva, ayr_Latn, azb_Arab, azj_Latn, bak_Cyrl,
203
+ bam_Latn, ban_Latn,bel_Cyrl, bem_Latn, ben_Beng, bho_Deva, bjn_Arab, bjn_Latn,
204
+ bod_Tibt, bos_Latn, bug_Latn, bul_Cyrl, cat_Latn, ceb_Latn, ces_Latn,
205
+ cjk_Latn, ckb_Arab, crh_Latn, cym_Latn, dan_Latn, deu_Latn, dik_Latn,
206
+ dyu_Latn, dzo_Tibt, ell_Grek, eng_Latn, epo_Latn, est_Latn, eus_Latn,
207
+ ewe_Latn, fao_Latn, pes_Arab, fij_Latn, fin_Latn, fon_Latn, fra_Latn,
208
+ fur_Latn, fuv_Latn, gla_Latn, gle_Latn, glg_Latn, grn_Latn, guj_Gujr,
209
+ hat_Latn, hau_Latn, heb_Hebr, hin_Deva, hne_Deva, hrv_Latn, hun_Latn,
210
+ hye_Armn, ibo_Latn, ilo_Latn, ind_Latn, isl_Latn, ita_Latn, jav_Latn,
211
+ jpn_Jpan, kab_Latn, kac_Latn, kam_Latn, kan_Knda, kas_Arab, kas_Deva,
212
+ kat_Geor, knc_Arab, knc_Latn, kaz_Cyrl, kbp_Latn, kea_Latn, khm_Khmr,
213
+ kik_Latn, kin_Latn, kir_Cyrl, kmb_Latn, kon_Latn, kor_Hang, kmr_Latn,
214
+ lao_Laoo, lvs_Latn, lij_Latn, lim_Latn, lin_Latn, lit_Latn, lmo_Latn,
215
+ ltg_Latn, ltz_Latn, lua_Latn, lug_Latn, luo_Latn, lus_Latn, mag_Deva,
216
+ mai_Deva, mal_Mlym, mar_Deva, min_Latn, mkd_Cyrl, plt_Latn, mlt_Latn,
217
+ mni_Beng, khk_Cyrl, mos_Latn, mri_Latn, zsm_Latn, mya_Mymr, nld_Latn,
218
+ nno_Latn, nob_Latn, npi_Deva, nso_Latn, nus_Latn, nya_Latn, oci_Latn,
219
+ gaz_Latn, ory_Orya, pag_Latn, pan_Guru, pap_Latn, pol_Latn, por_Latn,
220
+ prs_Arab, pbt_Arab, quy_Latn, ron_Latn, run_Latn, rus_Cyrl, sag_Latn,
221
+ san_Deva, sat_Beng, scn_Latn, shn_Mymr, sin_Sinh, slk_Latn, slv_Latn,
222
+ smo_Latn, sna_Latn, snd_Arab, som_Latn, sot_Latn, spa_Latn, als_Latn,
223
+ srd_Latn, srp_Cyrl, ssw_Latn, sun_Latn, swe_Latn, swh_Latn, szl_Latn,
224
+ tam_Taml, tat_Cyrl, tel_Telu, tgk_Cyrl, tgl_Latn, tha_Thai, tir_Ethi,
225
+ taq_Latn, taq_Tfng, tpi_Latn, tsn_Latn, tso_Latn, tuk_Latn, tum_Latn,
226
+ tur_Latn, twi_Latn, tzm_Tfng, uig_Arab, ukr_Cyrl, umb_Latn, urd_Arab,
227
+ uzn_Latn, vec_Latn, vie_Latn, war_Latn, wol_Latn, xho_Latn, ydd_Hebr,
228
+ yor_Latn, yue_Hant, zho_Hans, zho_Hant, zul_Latn
229
+ license: mit
230
+ metrics:
231
+ - bleu
232
+ datasets:
233
+ - mozilla-foundation/common_voice_8_0
234
+ pipeline_tag: automatic-speech-recognition
235
+ tags:
236
+ - zeroswot
237
+ - speech translation
238
+ - zero-shot
239
+ - end-to-end
240
+ - nllb
241
+ - wav2vec2
242
+ ---
243
+
244
+ # ZeroSwot ✨🤖✨
245
+
246
+ <!-- <div style='display:flex; gap: 0.25rem; '>
247
+ <a href='https://arxiv.org/abs/2402.10422'><img src='https://img.shields.io/badge/paper-PDF-green'></a>
248
+ <a href='https://github.com/mt-upc/ZeroSwot/blob/main/LICENSE'><img src='https://img.shields.io/badge/License-MIT-blue.svg'></a>
249
+ <a href='https://github.com/mt-upc/ZeroSwot'><img src='https://img.shields.io/badge/github-%23121011.svg?style=for-the-badge&logo=github&logoColor=white'></a>
250
+ </div> -->
251
+
252
+ ZeroSwot is a state-of-the-art zero-shot end-to-end Speech Translation system.
253
+
254
+ <div align=center><img src="resources/intro.png" height="65%" width="65%"/></div>
255
+
256
+ The model is created by adapting a wav2vec2.0-based encoder to the embedding space of NLLB, using a novel subword compression module and Optimal Transport, while only utilizing ASR data. It thus enables **Zero-shot E2E Speech Translation to all the 200 languages supported by NLLB**.
257
+
258
+ For more details please refer to our [paper](https://arxiv.org/abs/2402.10422) and the [original repo](https://github.com/mt-upc/ZeroSwot) build on fairseq.
259
+
260
+ ## Architecture
261
+
262
+ The compression module is a light-weight transformer that takes as input the hidden state of wav2vec2.0 and the corresponding CTC predictions, and compresses them to subword-like embeddings similar to those expected from NLLB and aligns them using Optimal Transport. For inference we simply pass the output of the speech encoder to NLLB encoder.
263
+
264
+ <div align=center><img src="resources/methodology.png" height="120%" width="120%"/></div>
265
+
266
+ ## Version
267
+
268
+ This version of ZeroSwot is trained with ASR data from CommonVoice, and adapted [wav2vec2.0-large](https://huggingface.co/facebook/wav2vec2-large-960h-lv60-self) to the [nllb-200-distilled-600M](https://huggingface.co/facebook/nllb-200-distilled-600M) model.
269
+
270
+ We have more versions available:
271
+
272
+ | Models | ASR data | NLLB version |
273
+ |:------:|:--------:|:------------:|
274
+ | [ZeroSwot-Medium_asr-mustc](https://huggingface.co/johntsi/ZeroSwot-Medium_asr-mustc_en-to-200) | MuST-C v1.0 | [distilled-600M original](https://huggingface.co/facebook/nllb-200-distilled-600M)|
275
+ | [ZeroSwot-Medium_asr-mustc_mt-mustc](https://huggingface.co/johntsi/ZeroSwot-Medium_asr-mustc_mt-mustc_en-to-8) | MuST-C v1.0 | [distilled-600M finetuned w/ MuST-C](https://huggingface.co/johntsi/nllb-200-distilled-600M_mustc_en-to-8) |
276
+ | [ZeroSwot-Large_asr-mustc](https://huggingface.co/johntsi/ZeroSwot-Large_asr-mustc_en-to-200) | MuST-C v1.0 | [distilled-1.3B original](https://huggingface.co/facebook/nllb-200-distilled-1.3B) |
277
+ | [ZeroSwot-Large_asr-mustc_mt-mustc](https://huggingface.co/johntsi/ZeroSwot-Large_asr-mustc_mt-mustc_en-to-8) | MuST-C v1.0 | [distilled-1.3B finetuned w/ MuST-C](https://huggingface.co/johntsi/nllb-200-distilled-1.3B_mustc_en-to-8) |
278
+ | [ZeroSwot-Medium_asr-cv](https://huggingface.co/johntsi/ZeroSwot-Medium_asr-cv_en-to-200) | CommonVoice | [distilled-600M original](https://huggingface.co/facebook/nllb-200-distilled-600M)|
279
+ | [ZeroSwot-Medium_asr-cv_mt-covost2](https://huggingface.co/johntsi/ZeroSwot-Medium_asr-cv_mt-covost2_en-to-15) | CommonVoice | [distilled-600M finetuned w/ CoVoST2](https://huggingface.co/johntsi/nllb-200-distilled-600M_covost2_en-to-15) |
280
+ | [ZeroSwot-Large_asr-cv](https://huggingface.co/johntsi/ZeroSwot-Large_asr-cv_en-to-200) | CommonVoice | [distilled-1.3B original](https://huggingface.co/facebook/nllb-200-distilled-1.3B) |
281
+ | [ZeroSwot-Large_asr-cv_mt-covost2](https://huggingface.co/johntsi/ZeroSwot-Large_asr-cv_mt-covost2_en-to-15) | CommonVoice | [distilled-1.3B finetuned w/ CoVoST2](https://huggingface.co/johntsi/nllb-200-distilled-1.3B_covost2_en-to-15) |
282
+
283
+ ## Usage
284
+
285
+ The model is tested with python 3.9.16 and Transformer v4.41.2. Install also torchaudio and sentencepiece for processing.
286
+
287
+ ```bash
288
+ pip install transformers torchaudio sentencepiece
289
+ ```
290
+
291
+
292
+ ```python
293
+ from transformers import Wav2Vec2Processor, NllbTokenizer, AutoModel, AutoModelForSeq2SeqLM
294
+ import torchaudio
295
+
296
+ def load_and_resample_audio(audio_path, target_sr=16000):
297
+ audio, orig_freq = torchaudio.load(audio_path)
298
+ if orig_freq != target_sr:
299
+ audio = torchaudio.functional.resample(audio, orig_freq=orig_freq, new_freq=target_sr)
300
+ audio = audio.squeeze(0).numpy()
301
+ return audio
302
+
303
+ # Load processors and tokenizers
304
+ processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-large-960h-lv60-self")
305
+ tokenizer = NllbTokenizer.from_pretrained("facebook/nllb-200-distilled-600M")
306
+
307
+ # Load ZeroSwot Encoder
308
+ commit_hash = "eafabee295ea1c8b45483d1fd26bd747d9a7d937"
309
+ zeroswot_encoder = AutoModel.from_pretrained(
310
+ "johntsi/ZeroSwot-Medium_asr-cv_en-to-200", trust_remote_code=True, revision=commit_hash,
311
+ )
312
+ zeroswot_encoder.eval()
313
+ zeroswot_encoder.to("cuda")
314
+
315
+ # Load NLLB Model
316
+ nllb_model = AutoModelForSeq2SeqLM.from_pretrained("facebook/nllb-200-distilled-600M")
317
+ nllb_model.eval()
318
+ nllb_model.to("cuda")
319
+
320
+ # Load audio file
321
+ audio = load_and_resample_audio(path_to_audio_file) # you can use "resources/sample.wav" for testing
322
+ input_values = processor(audio, sampling_rate=16000, return_tensors="pt").to("cuda")
323
+
324
+ # translation to German
325
+ compressed_embeds, attention_mask = zeroswot_encoder(**input_values)
326
+ predicted_ids = nllb_model.generate(
327
+ inputs_embeds=compressed_embeds,
328
+ attention_mask=attention_mask,
329
+ forced_bos_token_id=tokenizer.lang_code_to_id["deu_Latn"],
330
+ num_beams=5,
331
+ )
332
+ translation = tokenizer.decode(predicted_ids[0], skip_special_tokens=True)
333
+ print(translation)
334
+ ```
335
+
336
+ ## Results
337
+
338
+ BLEU scores on CoVoST-2 test compared to supervised SOTA models [XLS-R-1B](https://huggingface.co/facebook/wav2vec2-xls-r-1b) and [SeamlessM4T-Medium](https://huggingface.co/facebook/seamless-m4t-medium). You can refer to Table 5 of the Results section in the paper for more details.
339
+
340
+ | Models | ZS | Size (B) | Ar | Ca | Cy | De | Et | Fa | Id | Ja | Lv | Mn | Sl | Sv | Ta | Tr | Zh | Average |
341
+ |:--------------:|:----:|:----------:|:----:|:----:|:----:|:----:|:----:|:----:|:----:|:----:|:----:|:----:|:----:|:----:|:----:|:----:|:----:|:-------:|
342
+ | [XLS-R-1B](https://huggingface.co/facebook/wav2vec2-xls-r-1b) | ✗ | 1.0 | 19.2 | 32.1 | **31.8** | 26.2 | 22.4 | 21.3 | 30.3 | 39.9 | 22.0 | 14.9 | 25.4 | 32.3 | 18.1 | 17.1 | 36.7 | 26.0 |
343
+ | [SeamlessM4T-Medium](https://huggingface.co/facebook/seamless-m4t-medium) | ✗ | 1.2 | 20.8 | 37.3 | 29.9 | **31.4** | 23.3 | 17.2 | 34.8 | 37.5 | 19.5 | 12.9 | 29.0 | 37.3 | 18.9 | **19.8** | 30.0 | 26.6 |
344
+ | [ZeroSwot-M_asr-cv](https://huggingface.co/johntsi/ZeroSwot-Medium_asr-cv_en-to-200) | ✓ | 0.35/0.95 | 17.6 | 32.5 | 18.0 | 29.9 | 20.4 | 16.3 | 32.4 | 32.0 | 13.3 | 10.0 | 25.2 | 34.4 | 17.8 | 15.6 | 30.5 | 23.1 |
345
+ | [ZeroSwot-M_asr-cv_mt-covost2](https://huggingface.co/johntsi/ZeroSwot-Medium_asr-cv_mt-covost2_en-to-200) | ✓ | 0.35/0.95 | **24.4** | **38.7** | 28.8 | 31.2 | **26.2** | **26.0** | **36.0** | **46.0** | **24.8** | **19.0** | **31.6** | **37.8** | **24.4** | 18.6 | **39.0** | **30.2** |
346
+
347
+ ## Citation
348
+
349
+ If you find ZeroSwot useful for your research, please cite our paper :)
350
+
351
+ ```
352
+ @misc{tsiamas2024pushing,
353
+ title={{Pushing the Limits of Zero-shot End-to-End Speech Translation}},
354
+ author={Ioannis Tsiamas and Gerard I. Gállego and José A. R. Fonollosa and Marta R. Costa-jussà},
355
+ year={2024},
356
+ eprint={2402.10422},
357
+ archivePrefix={arXiv},
358
+ primaryClass={cs.CL}
359
+ }
360
+ ```
config.json ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "johntsi/ZeroSwot-Medium_asr-mustc_en-to-200/model.safetensors",
3
+ "architectures": [
4
+ "ZeroSwotEncoderModel"
5
+ ],
6
+ "auto_map": {
7
+ "AutoConfig": "model.ZeroSwotEncoderConfig",
8
+ "AutoModel": "model.ZeroSwotEncoderModel"
9
+ },
10
+ "compression_adapter": {
11
+ "blank_idx": 0,
12
+ "dropout": 0.1,
13
+ "embed_dim": 1024,
14
+ "sep_idx": 4,
15
+ "transformer_layers": 3
16
+ },
17
+ "embed_dim": 1024,
18
+ "model_type": "zero_swot_encoder",
19
+ "nllb_model_name_or_path": "facebook/nllb-200-distilled-600M",
20
+ "speech_embedder": {
21
+ "nllb_eng_id": 256047,
22
+ "nllb_eos_id": 2
23
+ },
24
+ "torch_dtype": "float32",
25
+ "transformers_version": "4.41.2",
26
+ "wav2vec2_model_name_or_path": "facebook/wav2vec2-large-960h-lv60-self"
27
+ }
model.py ADDED
@@ -0,0 +1,366 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import PreTrainedModel, PretrainedConfig, Wav2Vec2ForCTC
2
+ import json
3
+ import torch
4
+ from torch import nn
5
+ from torch.nn.utils.rnn import pad_sequence
6
+ import math
7
+ from typing import Optional
8
+
9
+ # x: torch.FloatTensor [T, B, D]
10
+ # mask: torch.BoolTensor [B, T], where True indicates padding
11
+ # returns: torch.LongTensor [B]
12
+ def get_lengths(x, mask=None):
13
+ if mask is not None:
14
+ return (~mask).long().sum(dim=1)
15
+ else:
16
+ return torch.LongTensor([x.size(0)] * x.size(1)).to(x.device)
17
+
18
+ # lens: torch.LongTensor [B]
19
+ # returns: torch.BoolTensor [B, max_lens], where True indicates padding
20
+ def lengths_to_padding_mask(lens):
21
+ bsz, max_lens = lens.size(0), torch.max(lens).item()
22
+ mask = torch.arange(max_lens).to(lens.device).view(1, max_lens)
23
+ mask = mask.expand(bsz, -1) >= lens.view(bsz, 1).expand(-1, max_lens)
24
+ return mask
25
+
26
+ # input_lengths: torch.LongTensor [B]
27
+ def get_output_lengths(input_lengths):
28
+ conv_feature_layers = "[(512, 10, 5)] + [(512, 3, 2)] * 4 + [(512,2,2)] + [(512,2,2)]"
29
+ conv_cfg_list = eval(conv_feature_layers)
30
+
31
+ def _conv_out_length(input_length, kernel_size, stride):
32
+ return torch.floor((input_length - kernel_size) / stride + 1)
33
+
34
+ for i in range(len(conv_cfg_list)):
35
+ input_lengths = _conv_out_length(
36
+ input_lengths, conv_cfg_list[i][1], conv_cfg_list[i][2]
37
+ )
38
+
39
+ return input_lengths.to(torch.long)
40
+
41
+ class ZeroSwotEncoderConfig(PretrainedConfig):
42
+ model_type = "zero_swot_encoder"
43
+ def __init__(
44
+ self,
45
+ wav2vec2_model_name_or_path="",
46
+ compression_adapter=None,
47
+ embed_dim=1024,
48
+ **kwargs
49
+ ):
50
+ super().__init__(**kwargs)
51
+ self.wav2vec2_model_name_or_path = wav2vec2_model_name_or_path
52
+ self.compression_adapter = compression_adapter
53
+ self.embed_dim = embed_dim
54
+
55
+ @classmethod
56
+ def from_json_file(cls, json_file):
57
+ with open(json_file, "r") as reader:
58
+ text = reader.read()
59
+ config_dict = json.loads(text)
60
+ return cls(**config_dict)
61
+
62
+ class ZeroSwotEncoderModel(PreTrainedModel):
63
+ config_class = ZeroSwotEncoderConfig
64
+ model_type = "zero_swot_encoder"
65
+
66
+ def __init__(self, config):
67
+ super().__init__(config)
68
+
69
+ self.wav2vec2 = Wav2Vec2ForCTC.from_pretrained(config.wav2vec2_model_name_or_path)
70
+ self.compression_adapter = CompressionAdapter(config.compression_adapter)
71
+ self.speech_embedder = SpeechEmbedder(config.embed_dim)
72
+
73
+ def forward(self, input_values, attention_mask=None):
74
+ input_lens = get_lengths(input_values, ~attention_mask)
75
+
76
+ # Forward pass through wav2vec2 encoder
77
+ x = self.wav2vec2.wav2vec2(input_values, attention_mask)[0] # [B, T, D]
78
+ # CTC predictions
79
+ preds = self.wav2vec2.lm_head(x).argmax(-1) # [B, T]
80
+ # Get output lengths for x
81
+ output_lens = get_output_lengths(input_lens)
82
+
83
+ # Compression
84
+ x, mask, _ = self.compression_adapter(x, preds, output_lens) # [B, N, D] with N << T
85
+
86
+ # BOS and EOS embeddings
87
+ x, mask = self.speech_embedder(x, mask) # [B, N+2, D]
88
+
89
+ return x, ~mask
90
+
91
+
92
+ class SpeechEmbedder(nn.Module):
93
+ def __init__(self, embed_dim):
94
+ super().__init__()
95
+
96
+ self.embed_dim = embed_dim
97
+ self.bos_emb = nn.Parameter(torch.empty(embed_dim))
98
+ self.eos_emb = nn.Parameter(torch.empty(embed_dim))
99
+
100
+ self.scale = self.embed_dim ** 0.5
101
+
102
+ def forward(self, x, padding_mask=None):
103
+ """Add special embedding and positional embedding.
104
+ Args:
105
+ x (FloatTensor): (B, T, C)
106
+ padding_mask (ByteTensor): (B, T)
107
+ Outputs:
108
+ x (FloatTensor): (B, T+2, C)
109
+ padding_mask (ByteTensor): (B, T+2)
110
+ """
111
+ B = x.size(0)
112
+ lengths = get_lengths(x.transpose(0, 1), padding_mask)
113
+ assert B == len(lengths)
114
+
115
+ if padding_mask is not None:
116
+ x = x * (1 - padding_mask.unsqueeze(-1).type_as(x))
117
+
118
+ # prepend bos
119
+ x = torch.cat([self.bos_emb.view(1, 1, -1).expand(B, 1, -1), x], dim=1)
120
+ lengths += 1
121
+
122
+ # append padding (zeros) and then convert first padding to eos
123
+ x = torch.cat([x, torch.zeros(B, 1, x.size(-1), device=x.device, dtype=x.dtype)], dim=1)
124
+ for i in range(B):
125
+ x[i, lengths[i], :] = self.eos_emb
126
+ lengths += 1
127
+
128
+ padding_mask = lengths_to_padding_mask(lengths)
129
+
130
+ x = x * self.scale
131
+
132
+ return x, padding_mask
133
+
134
+
135
+ class PositionalEmbedding(nn.Module):
136
+ def __init__(self, num_embeddings, embedding_dim, padding_idx):
137
+ super().__init__()
138
+ self.embedding_dim = embedding_dim
139
+ self.padding_idx = padding_idx if padding_idx is not None else 0
140
+ num_embeddings += padding_idx + 1
141
+ self.weights = PositionalEmbedding.get_embedding(
142
+ num_embeddings, embedding_dim, padding_idx
143
+ )
144
+ self.register_buffer("_float_tensor", torch.FloatTensor(1))
145
+ self.max_positions = int(1e5)
146
+
147
+ @staticmethod
148
+ def get_embedding(
149
+ num_embeddings: int, embedding_dim: int, padding_idx: Optional[int] = None
150
+ ):
151
+ half_dim = embedding_dim // 2
152
+ emb = math.log(10000) / (half_dim - 1)
153
+ emb = torch.exp(torch.arange(half_dim, dtype=torch.float) * -emb)
154
+ emb = torch.arange(num_embeddings, dtype=torch.float).unsqueeze(1) * emb.unsqueeze(0)
155
+ emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1).view(num_embeddings, -1)
156
+ if embedding_dim % 2 == 1:
157
+ # zero pad
158
+ emb = torch.cat([emb, torch.zeros(num_embeddings, 1)], dim=1)
159
+ if padding_idx is not None:
160
+ emb[padding_idx, :] = 0
161
+ return emb
162
+
163
+ def make_positions(self, x, padding_idx: int):
164
+ mask = x.ne(padding_idx).int()
165
+ return (torch.cumsum(mask, dim=1).type_as(mask) * mask).long() + padding_idx
166
+
167
+ def forward(self, input):
168
+ """Input is expected to be of size [bsz x seqlen]."""
169
+ bsz, seq_len = input.size()
170
+ max_pos = self.padding_idx + 1 + seq_len
171
+ if self.weights is None or max_pos > self.weights.size(0):
172
+ # recompute/expand embeddings if needed
173
+ self.weights = PositionalEmbedding.get_embedding(
174
+ max_pos, self.embedding_dim, self.padding_idx
175
+ )
176
+ self.weights = self.weights.to(self._float_tensor)
177
+ positions = self.make_positions(input, self.padding_idx)
178
+ return (
179
+ self.weights.index_select(0, positions.view(-1))
180
+ .view(bsz, seq_len, -1)
181
+ .detach()
182
+ )
183
+
184
+
185
+ class CLSPooling(nn.Module):
186
+ def __init__(self, embed_dim, num_transformer_layers, dropout_rate):
187
+ super().__init__()
188
+
189
+ self.cls_token = nn.Parameter(torch.empty(1, 1, embed_dim))
190
+ nn.init.normal_(self.cls_token, mean=0.0, std=0.25)
191
+
192
+ self.transformer = nn.TransformerEncoder(
193
+ nn.TransformerEncoderLayer(
194
+ embed_dim,
195
+ nhead=16 if embed_dim == 1024 else 8,
196
+ dim_feedforward=4*embed_dim,
197
+ dropout=dropout_rate,
198
+ activation="relu",
199
+ batch_first=True,
200
+ norm_first=True
201
+ ),
202
+ num_layers=num_transformer_layers,
203
+ )
204
+
205
+ self.pos_emb = PositionalEmbedding(512, embed_dim, 1)
206
+ self.scale = math.sqrt(embed_dim)
207
+
208
+ def forward(self, x, lens):
209
+ # x: [B, N, D]
210
+ # lens: [B]
211
+
212
+ # prepend cls token
213
+ x = torch.cat(
214
+ [
215
+ self.cls_token.to(dtype=x.dtype, device=x.device).repeat(x.size(0), 1, 1), # B x 1 x D
216
+ x
217
+ ],
218
+ dim=1) # [B, N+1, D]
219
+
220
+ mask = lengths_to_padding_mask(lens+1)
221
+
222
+ x = x + self.pos_emb(mask.long()) / self.scale
223
+
224
+ x = self.transformer(x, src_key_padding_mask=mask) # [B, N+1, D]
225
+ x = x[:, 0] # [B, D]
226
+ return x
227
+
228
+
229
+ class CompressionAdapter(nn.Module):
230
+ def __init__(self, cfg):
231
+ super().__init__()
232
+ self.embed_dim = cfg["embed_dim"]
233
+ self.transformer_layers = cfg["transformer_layers"]
234
+ self.dropout = cfg["dropout"]
235
+ self.blank_idx = cfg["blank_idx"]
236
+ self.sep_idx = cfg["sep_idx"]
237
+
238
+ self.token_pooling_module = CLSPooling(
239
+ self.embed_dim, self.transformer_layers, self.dropout
240
+ )
241
+
242
+ def char_compression(self, x, preds, lens):
243
+ # x: B x T x D
244
+ # preds: B x T
245
+ # lens: B
246
+
247
+ B, T, D = x.size()
248
+ device = x.device
249
+ dtype = x.dtype
250
+
251
+ # zero-out the padding
252
+ mask = lengths_to_padding_mask(lens) # B x T
253
+ x = x.masked_fill(mask.unsqueeze(-1), 0)
254
+ preds = preds.masked_fill(mask, self.blank_idx)
255
+
256
+ # add a vector of -1 to know where each example ends after flattening the batch
257
+ preds = torch.cat([-torch.ones(B, 1, device=device, dtype=torch.long), preds], dim=1).view(-1)
258
+ x = torch.cat([torch.zeros(B, 1, D, device=device, dtype=dtype), x], dim=1).view(-1, D)
259
+
260
+ # get points of consecutive preds
261
+ preds, counts = preds.unique_consecutive(return_counts=True)
262
+
263
+ # split in representations of same chars
264
+ x = torch.split(x, counts.tolist())
265
+
266
+ # remove blanks
267
+ valid_mask = preds != self.blank_idx
268
+ preds = preds[valid_mask]
269
+ counts = counts[valid_mask] # [N]
270
+ x = [x_i for x_i, v_i in zip(x, valid_mask) if v_i]
271
+
272
+ # pack into tensor
273
+ x = pad_sequence(x, batch_first=True, padding_value=0)
274
+
275
+ # char pooling
276
+ x = torch.sum(x, dim=1) / counts.to(dtype=x.dtype).unsqueeze(1) # [B, N, D] -> [B, D]
277
+
278
+ # find split points for retrieving the examples
279
+ split_points = (preds == -1).nonzero(as_tuple=True)[0]
280
+ split_points = torch.cat([split_points, torch.tensor([len(preds)], device=device)])
281
+ split_points = (split_points[1:] - split_points[:-1]).tolist()
282
+
283
+ # split into examples
284
+ x = torch.split(x, split_points)
285
+ preds = torch.split(preds, split_points)
286
+ lens = torch.tensor([len(x_i) for x_i in x], device=device)
287
+
288
+ # pack into tensors
289
+ x = pad_sequence(x, batch_first=True, padding_value=0)
290
+ preds = pad_sequence(preds, batch_first=True, padding_value=self.blank_idx)
291
+
292
+ # remove the parts we add to identify the bounds for each example
293
+ x = x[:, 1:]
294
+ preds = preds[:, 1:]
295
+ lens -= 1
296
+
297
+ mask = lengths_to_padding_mask(lens)
298
+
299
+ # account for empty examples (just a sep token)
300
+ empty_examples = lens == 0
301
+ num_empty_examples = empty_examples.sum()
302
+ if num_empty_examples > 0:
303
+ mask[empty_examples, 0] = True
304
+ lens[empty_examples] = 1
305
+ preds[empty_examples, 0] = self.sep_idx
306
+
307
+ return x, mask, lens, preds, num_empty_examples
308
+
309
+ def token_compression(self, x, preds, lens):
310
+ # x: B x T x D
311
+ # preds: B x T
312
+ # lens: B
313
+
314
+ B, T, D = x.size()
315
+ device = x.device
316
+ dtype = x.dtype
317
+
318
+ # new lengths after compression
319
+ new_lens = preds.eq(self.sep_idx).sum(dim=1)
320
+
321
+ # unpad and unpack to list of tensors
322
+ preds = [preds[i, :lens[i]] for i in range(B)]
323
+ x = [x[i, :lens[i]] for i in range(B)]
324
+
325
+ # make sure every example ends with a separator
326
+ num_examples_without_ending_sep = torch.tensor(0, device=device, dtype=torch.long)
327
+ for i in range(B):
328
+ if preds[i][-1] != self.sep_idx:
329
+ preds[i] = torch.cat([preds[i], torch.tensor([self.sep_idx], device=device, dtype=torch.long)])
330
+ x[i] = torch.cat([x[i], torch.zeros(1, D, device=device, dtype=dtype)])
331
+ new_lens[i] += 1
332
+ num_examples_without_ending_sep += 1
333
+
334
+ # flatten
335
+ preds = torch.cat(preds)
336
+ x = torch.cat(x)
337
+
338
+ # split points according to separators
339
+ split_points = preds.eq(self.sep_idx).nonzero(as_tuple=True)[0] + 1
340
+ split_points = torch.cat([torch.tensor([0], device=device, dtype=torch.long), split_points])
341
+ split_points = (split_points[1:] - split_points[:-1]).tolist()
342
+
343
+ # re-arrange in 3d [total_num_tokens x max(count) x D]
344
+ x = torch.split(x, split_points) # Tuple[2d tensor]
345
+
346
+ counts = torch.tensor([len(x_i) for x_i in x], device=device, dtype=torch.long)
347
+ x = pad_sequence(x, batch_first=True, padding_value=0)
348
+
349
+ # reduce dim 1
350
+ x = self.token_pooling_module(x, counts)
351
+
352
+ # reconstruct the batch
353
+ split_points = new_lens.cumsum(dim=0)
354
+ split_points = torch.cat([torch.tensor([0], device=device, dtype=torch.long), split_points])
355
+ split_points = (split_points[1:] - split_points[:-1]).tolist()
356
+ x = torch.split(x, split_points)
357
+ x = pad_sequence(x, batch_first=True, padding_value=0) # B x ? x D
358
+
359
+ mask = lengths_to_padding_mask(new_lens)
360
+
361
+ return x, mask, new_lens, num_examples_without_ending_sep
362
+
363
+ def forward(self, x, preds, lens):
364
+ x, mask, lens, preds, _ = self.char_compression(x, preds, lens)
365
+ x, mask, lens, _ = self.token_compression(x, preds, lens)
366
+ return x, mask, lens
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:926efaaae71884e9747c9e3c874debcfbe040d6b0347148752cd32aa6ad8f4f0
3
+ size 1413115412