akera commited on
Commit
59bf002
·
verified ·
1 Parent(s): d4afb45

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +11 -37
app.py CHANGED
@@ -15,22 +15,24 @@ auth_token = os.environ.get("HF_TOKEN")
15
 
16
 
17
  target_lang_options = {"English": "eng", "Luganda": "lug", "Acholi": "ach", "Runyankole": "nyn", "Lugbara": "lgg"}
18
- target_lang_code = target_lang_options[target_lang]
19
 
20
  languages = list(target_lang_options.keys())
21
 
22
 
23
- if target_lang_code=="eng":
24
- model_id = "facebook/mms-1b-all"
25
- else:
26
- model_id = "Sunbird/sunbird-mms"
27
-
28
-
29
  # Transcribe audio using custom model
30
- def transcribe_audio(input_file, target_lang_code,
31
  device, model_id=model_id,
32
  chunk_length_s=10, stride_length_s=(4, 2), return_timestamps="word"):
33
 
 
 
 
 
 
 
 
 
 
34
  pipe = pipeline(model=model_id, device=device, token=hf_auth_token)
35
  pipe.tokenizer.set_target_lang(target_lang_code)
36
  pipe.model.load_adapter(target_lang_code)
@@ -41,41 +43,13 @@ def transcribe_audio(input_file, target_lang_code,
41
  return output
42
 
43
 
44
- # def transcribe(audio_file_mic=None, audio_file_upload=None, language="Luganda (lug)"):
45
- # if audio_file_mic:
46
- # audio_file = audio_file_mic
47
- # elif audio_file_upload:
48
- # audio_file = audio_file_upload
49
- # else:
50
- # return "Please upload an audio file or record one"
51
-
52
- # # Make sure audio is 16kHz
53
- # speech, sample_rate = librosa.load(audio_file)
54
- # if sample_rate != 16000:
55
- # speech = librosa.resample(speech, orig_sr=sample_rate, target_sr=16000)
56
-
57
- # # Keep the same model in memory and simply switch out the language adapters by calling load_adapter() for the model and set_target_lang() for the tokenizer
58
- # language_code = language
59
- # processor.tokenizer.set_target_lang(language_code)
60
- # model.load_adapter(language_code)
61
-
62
- # inputs = processor(speech, sampling_rate=16_000, return_tensors="pt")
63
-
64
- # with torch.no_grad():
65
- # outputs = model(**inputs).logits
66
-
67
- # ids = torch.argmax(outputs, dim=-1)[0]
68
- # transcription = processor.decode(ids)
69
- # return transcription
70
-
71
-
72
  description = '''ASR with salt-mms'''
73
 
74
  iface = gr.Interface(fn=transcribe_audio,
75
  inputs=[
76
  gr.Audio(source="microphone", type="filepath", label="Record Audio"),
77
  gr.Audio(source="upload", type="filepath", label="Upload Audio"),
78
- gr.Dropdown(choices=languages, label="Language", value="lug")
79
  ],
80
  outputs=gr.Textbox(label="Transcription"),
81
  description=description
 
15
 
16
 
17
  target_lang_options = {"English": "eng", "Luganda": "lug", "Acholi": "ach", "Runyankole": "nyn", "Lugbara": "lgg"}
 
18
 
19
  languages = list(target_lang_options.keys())
20
 
21
 
 
 
 
 
 
 
22
  # Transcribe audio using custom model
23
+ def transcribe_audio(input_file, language,
24
  device, model_id=model_id,
25
  chunk_length_s=10, stride_length_s=(4, 2), return_timestamps="word"):
26
 
27
+
28
+ target_lang_code = target_lang_options[target_lang_code]
29
+
30
+ # Determine the model_id based on the language
31
+ if target_lang_code == "eng":
32
+ model_id = "facebook/mms-1b-all"
33
+ else:
34
+ model_id = "Sunbird/sunbird-mms"
35
+
36
  pipe = pipeline(model=model_id, device=device, token=hf_auth_token)
37
  pipe.tokenizer.set_target_lang(target_lang_code)
38
  pipe.model.load_adapter(target_lang_code)
 
43
  return output
44
 
45
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46
  description = '''ASR with salt-mms'''
47
 
48
  iface = gr.Interface(fn=transcribe_audio,
49
  inputs=[
50
  gr.Audio(source="microphone", type="filepath", label="Record Audio"),
51
  gr.Audio(source="upload", type="filepath", label="Upload Audio"),
52
+ gr.Dropdown(choices=languages, label="Language", value="English"
53
  ],
54
  outputs=gr.Textbox(label="Transcription"),
55
  description=description