--- license: apache-2.0 language: de library_name: transformers thumbnail: null tags: - automatic-speech-recognition - whisper-event datasets: - mozilla-foundation/common_voice_11_0 metrics: - wer model-index: - name: Fine-tuned whisper-medium model for ASR in German results: - task: name: Automatic Speech Recognition type: automatic-speech-recognition dataset: name: Common Voice 11.0 type: mozilla-foundation/common_voice_11_0 config: de split: test args: de metrics: - name: WER (Greedy) type: wer value: 7.05 --- ![Model architecture](https://img.shields.io/badge/Model_Architecture-seq2seq-lightgrey) ![Model size](https://img.shields.io/badge/Params-769M-lightgrey) ![Language](https://img.shields.io/badge/Language-German-lightgrey) This model is a converted version of [bofenghuang/whisper-medium-cv11-german](https://huggingface.co/bofenghuang/whisper-medium-cv11-german/) converted to ctranslate2. # Fine-tuned whisper-medium model for ASR in German This model is a fine-tuned version of [openai/whisper-medium](https://huggingface.co/openai/whisper-medium), trained on the mozilla-foundation/common_voice_11_0 de dataset. When using the model make sure that your speech input is also sampled at 16Khz. **This model also predicts casing and punctuation.** ## Performance *Below are the WERs of the pre-trained models on the [Common Voice 9.0](https://huggingface.co/datasets/mozilla-foundation/common_voice_9_0). These results are reported in the original [paper](https://cdn.openai.com/papers/whisper.pdf).* | Model | Common Voice 9.0 | | --- | :---: | | [openai/whisper-small](https://huggingface.co/openai/whisper-small) | 13.0 | | [openai/whisper-medium](https://huggingface.co/openai/whisper-medium) | 8.5 | | [openai/whisper-large-v2](https://huggingface.co/openai/whisper-large-v2) | 6.4 | *Below are the WERs of the fine-tuned models on the [Common Voice 11.0](https://huggingface.co/datasets/mozilla-foundation/common_voice_11_0).* | Model | Common Voice 11.0 | | --- | :---: | | [bofenghuang/whisper-small-cv11-german](https://huggingface.co/bofenghuang/whisper-small-cv11-german) | 11.35 | | [bofenghuang/whisper-medium-cv11-german](https://huggingface.co/bofenghuang/whisper-medium-cv11-german) | 7.05 | | [bofenghuang/whisper-large-v2-cv11-german](https://huggingface.co/bofenghuang/whisper-large-v2-cv11-german) | **5.76** | ## Usage Inference with 🤗 Pipeline ```python import torch from datasets import load_dataset from transformers import pipeline device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") # Load pipeline pipe = pipeline("automatic-speech-recognition", model="bofenghuang/whisper-medium-cv11-german", device=device) # NB: set forced_decoder_ids for generation utils pipe.model.config.forced_decoder_ids = pipe.tokenizer.get_decoder_prompt_ids(language="de", task="transcribe") # Load data ds_mcv_test = load_dataset("mozilla-foundation/common_voice_11_0", "de", split="test", streaming=True) test_segment = next(iter(ds_mcv_test)) waveform = test_segment["audio"] # NB: decoding option # limit the maximum number of generated tokens to 225 pipe.model.config.max_length = 225 + 1 # sampling # pipe.model.config.do_sample = True # beam search # pipe.model.config.num_beams = 5 # return # pipe.model.config.return_dict_in_generate = True # pipe.model.config.output_scores = True # pipe.model.config.num_return_sequences = 5 # Run generated_sentences = pipe(waveform)["text"] ``` Inference with 🤗 low-level APIs ```python import torch import torchaudio from datasets import load_dataset from transformers import AutoProcessor, AutoModelForSpeechSeq2Seq device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") # Load model model = AutoModelForSpeechSeq2Seq.from_pretrained("bofenghuang/whisper-medium-cv11-german").to(device) processor = AutoProcessor.from_pretrained("bofenghuang/whisper-medium-cv11-german", language="german", task="transcribe") # NB: set forced_decoder_ids for generation utils model.config.forced_decoder_ids = processor.get_decoder_prompt_ids(language="de", task="transcribe") # 16_000 model_sample_rate = processor.feature_extractor.sampling_rate # Load data ds_mcv_test = load_dataset("mozilla-foundation/common_voice_11_0", "de", split="test", streaming=True) test_segment = next(iter(ds_mcv_test)) waveform = torch.from_numpy(test_segment["audio"]["array"]) sample_rate = test_segment["audio"]["sampling_rate"] # Resample if sample_rate != model_sample_rate: resampler = torchaudio.transforms.Resample(sample_rate, model_sample_rate) waveform = resampler(waveform) # Get feat inputs = processor(waveform, sampling_rate=model_sample_rate, return_tensors="pt") input_features = inputs.input_features input_features = input_features.to(device) # Generate generated_ids = model.generate(inputs=input_features, max_new_tokens=225) # greedy # generated_ids = model.generate(inputs=input_features, max_new_tokens=225, num_beams=5) # beam search # Detokenize generated_sentences = processor.batch_decode(generated_ids, skip_special_tokens=True)[0] # Normalise predicted sentences if necessary ```