LuisG07 commited on
Commit
577c3f8
·
1 Parent(s): 131a172

Add application file

Browse files
Files changed (3) hide show
  1. app.py +69 -0
  2. packages.txt +1 -0
  3. requirements.txt +6 -0
app.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import nltk
2
+ import librosa
3
+ import torch
4
+ import kenlm
5
+ import gradio as gr
6
+ from pyctcdecode import build_ctcdecoder
7
+ from transformers import Wav2Vec2Processor,Wav2Vec2ProcessorWithLM,Wav2Vec2ForCTC
8
+
9
+ nltk.download("punkt")
10
+
11
+ def return_processor_and_model(model_name):
12
+ return Wav2Vec2Processor.from_pretrained(model_name), Wav2Vec2ForCTC.from_pretrained(model_name)
13
+
14
+ def return_processor_and_modelWithLM(model_name):
15
+ return Wav2Vec2ProcessorWithLM.from_pretrained(model_name), Wav2Vec2ForCTC.from_pretrained(model_name)
16
+
17
+ def load_and_fix_data(input_file):
18
+ speech, sample_rate = librosa.load(input_file)
19
+ if len(speech.shape) > 1:
20
+ speech = speech[:,0] + speech[:,1]
21
+ if sample_rate !=16000:
22
+ speech = librosa.resample(speech, sample_rate,16000)
23
+ return speech
24
+
25
+ def fix_transcription_casing(input_sentence):
26
+ sentences = nltk.sent_tokenize(input_sentence)
27
+ return (' '.join([s.replace(s[0],s[0].capitalize(),1) for s in sentences]))
28
+
29
+
30
+ def predict_and_ctc_lm_decode(input_file, model_name):
31
+ processor, model = return_processor_and_modelWithLM(model_name)
32
+ speech = load_and_fix_data(input_file)
33
+
34
+ input_values = processor(speech, return_tensors="pt", sampling_rate=16000).input_values
35
+ logits = model(input_values).logits.cpu().detach().numpy()[0]
36
+
37
+ pred = processor.batch_decode(logits.numpy()).text
38
+
39
+ transcribed_text = fix_transcription_casing(pred[0].lower())
40
+
41
+ return transcribed_text
42
+
43
+ def predict_and_greedy_decode(input_file, model_name):
44
+ processor, model = return_processor_and_model(model_name)
45
+ speech = load_and_fix_data(input_file)
46
+
47
+ input_values = processor(speech, return_tensors="pt", sampling_rate=16000).input_values
48
+ logits = model(input_values).logits
49
+
50
+ predicted_ids = torch.argmax(logits, dim=-1)
51
+ pred = processor.batch_decode(predicted_ids)
52
+
53
+ transcribed_text = fix_transcription_casing(pred[0].lower())
54
+
55
+ return transcribed_text
56
+
57
+ def return_all_predictions(input_file, model_name):
58
+ return predict_and_ctc_lm_decode(input_file, model_name), predict_and_greedy_decode(input_file, model_name)
59
+
60
+
61
+ gr.Interface(return_all_predictions,
62
+ inputs = [gr.inputs.Audio(source="microphone", type="filepath", label="Record/ Drop audio"), gr.inputs.Dropdown(["jonatasgrosman/wav2vec2-large-xlsr-53-spanish", "jonatasgrosman/wav2vec2-large-xlsr-53-spanish"], label="Model Name")],
63
+ outputs = [gr.outputs.Textbox(label="Beam CTC decoding w/ LM"), gr.outputs.Textbox(label="Greedy decoding")],
64
+ title="ASR using Wav2Vec2 & pyctcdecode in spanish",
65
+ description = "Comparing greedy decoder with beam search CTC decoder, record/ drop your audio!",
66
+ layout = "horizontal",
67
+ examples = [["test1.wav", "jonatasgrosman/wav2vec2-large-xlsr-53-spanish"], ["test2.wav", "jonatasgrosman/wav2vec2-large-xlsr-53-spanish"]],
68
+ theme="huggingface",
69
+ enable_queue=True).launch()
packages.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ libsndfile1
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ nltk
2
+ transformers
3
+ torch
4
+ librosa
5
+ pyctcdecode
6
+ pypi-kenlm