aiola commited on
Commit
f9ab349
1 Parent(s): f2dec6f

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +73 -3
README.md CHANGED
@@ -1,3 +1,73 @@
1
- ---
2
- license: mit
3
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: mit
3
+ datasets:
4
+ - facebook/voxpopuli
5
+ tags:
6
+ - ASR
7
+ - Automatic Speech Recognition
8
+ - Whisper
9
+ - Medusa
10
+ - Speech
11
+ - Speculative Decoding
12
+ language:
13
+ - en
14
+ - es
15
+ - de
16
+ - fr
17
+ ---
18
+
19
+ # Whisper Medusa
20
+
21
+ Whisper is an advanced encoder-decoder model for speech transcription and
22
+ translation, processing audio through encoding and decoding stages. Given
23
+ its large size and slow inference speed, various optimization strategies like
24
+ Faster-Whisper and Speculative Decoding have been proposed to enhance performance.
25
+ Our Medusa model builds on Whisper by predicting multiple tokens per iteration,
26
+ which significantly improves speed with small degradation in WER. We train and
27
+ evaluate our model on various datasets, demonstrating speed improvements.
28
+
29
+ ---------
30
+
31
+ ## Training Details
32
+ `aiola/whisper-medusa-multilingual` was trained on the Voxpopuli dataset to perform audio translation.
33
+ The Medusa heads were optimized for English, Spanish, German, and French so for optimal performance and speed improvements, please use these languages only.
34
+
35
+ ---------
36
+
37
+ ## Usage
38
+ To use `aiola/whisper-medusa-multilingual` install [`whisper-medusa`](https://github.com/aiola-lab/whisper-medusa) repo following the README instructions.
39
+
40
+ Inference can be done using the following code:
41
+ ```python
42
+ import torch
43
+ import torchaudio
44
+
45
+ from whisper_medusa import WhisperMedusaModel
46
+ from transformers import WhisperProcessor
47
+
48
+ model_name = "aiola/whisper-medusa-multilingual"
49
+ model = WhisperMedusaModel.from_pretrained(model_name)
50
+ processor = WhisperProcessor.from_pretrained(model_name)
51
+
52
+ path_to_audio = "path/to/audio.wav"
53
+ SAMPLING_RATE = 16000
54
+ language = "en"
55
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
56
+
57
+ input_speech, sr = torchaudio.load(path_to_audio)
58
+ if sr != SAMPLING_RATE:
59
+ input_speech = torchaudio.transforms.Resample(sr, SAMPLING_RATE)(input_speech)
60
+
61
+ input_features = processor(input_speech.squeeze(), return_tensors="pt", sampling_rate=SAMPLING_RATE).input_features
62
+ input_features = input_features.to(device)
63
+
64
+ model = model.to(device)
65
+ model_output = model.generate(
66
+ input_features,
67
+ language=language,
68
+ )
69
+ predict_ids = model_output[0]
70
+ pred = processor.decode(predict_ids, skip_special_tokens=True)
71
+ print(pred)
72
+
73
+ ```