tjysdsg commited on
Commit
37c7fb5
·
1 Parent(s): 0f1c861

Simplify UI and update gradio

Browse files
Files changed (2) hide show
  1. app.py +22 -65
  2. requirements.txt +2 -1
app.py CHANGED
@@ -6,6 +6,8 @@ import soundfile as sf
6
  from s2st_inference import s2st_inference
7
  from utils import download_model
8
 
 
 
9
  SAMPLE_RATE = 16000
10
  MAX_INPUT_LENGTH = 60 # seconds
11
 
@@ -29,20 +31,9 @@ class App:
29
 
30
  def s2st(
31
  self,
32
- audio_source: str,
33
- input_audio_mic: Optional[str],
34
- input_audio_file: Optional[str],
35
  ):
36
- if audio_source == 'file':
37
- input_path = input_audio_file
38
- else:
39
- input_path = input_audio_mic
40
-
41
- if input_path is None:
42
- gr.Error(f"Input audio is too long. Truncated to {MAX_INPUT_LENGTH} seconds.")
43
- return (None, None), None
44
-
45
- orig_wav, orig_sr = torchaudio.load(input_path)
46
  wav = torchaudio.functional.resample(orig_wav, orig_freq=orig_sr, new_freq=SAMPLE_RATE)
47
  max_length = int(MAX_INPUT_LENGTH * SAMPLE_RATE)
48
  if wav.shape[1] > max_length:
@@ -94,71 +85,37 @@ class App:
94
  "PCM_16",
95
  )
96
 
97
- return output_path, f'Source: {audio_source}'
98
-
99
-
100
- def update_audio_ui(audio_source: str) -> Tuple[dict, dict]:
101
- mic = audio_source == "microphone"
102
- return (
103
- gr.update(visible=mic, value=None), # input_audio_mic
104
- gr.update(visible=not mic, value=None), # input_audio_file
105
- )
106
 
107
 
108
  def main():
109
  app = App()
110
 
111
  with gr.Blocks() as demo:
 
112
  with gr.Group():
113
- with gr.Row() as audio_box:
114
- audio_source = gr.Radio(
115
- label="Audio source",
116
- choices=["file", "microphone"],
117
- value="file",
118
- )
119
- input_audio_mic = gr.Audio(
120
- label="Input speech",
121
- type="filepath",
122
- source="microphone",
123
- visible=False,
124
- )
125
- input_audio_file = gr.Audio(
126
- label="Input speech",
127
- type="filepath",
128
- source="upload",
129
- visible=True,
130
- )
131
 
132
  btn = gr.Button("Translate")
133
 
134
- with gr.Column():
135
- output_audio = gr.Audio(
136
- label="Translated speech",
137
- autoplay=False,
138
- streaming=False,
139
- type="numpy",
140
- )
141
- output_text = gr.Textbox(label="Translated text")
142
-
143
- audio_source.change(
144
- fn=update_audio_ui,
145
- inputs=audio_source,
146
- outputs=[
147
- input_audio_mic,
148
- input_audio_file,
149
- ],
150
- queue=False,
151
- api_name=False,
152
- )
153
 
154
  btn.click(
155
  fn=app.s2st,
156
- inputs=[
157
- audio_source,
158
- input_audio_mic,
159
- input_audio_file,
160
- ],
161
- outputs=[output_audio, output_text],
162
  api_name="run",
163
  )
164
 
 
6
  from s2st_inference import s2st_inference
7
  from utils import download_model
8
 
9
+ DESCRIPTION = r"**Speech-to-Speech Translation from Spanish to English**"
10
+
11
  SAMPLE_RATE = 16000
12
  MAX_INPUT_LENGTH = 60 # seconds
13
 
 
31
 
32
  def s2st(
33
  self,
34
+ input_audio: Optional[str],
 
 
35
  ):
36
+ orig_wav, orig_sr = torchaudio.load(input_audio)
 
 
 
 
 
 
 
 
 
37
  wav = torchaudio.functional.resample(orig_wav, orig_freq=orig_sr, new_freq=SAMPLE_RATE)
38
  max_length = int(MAX_INPUT_LENGTH * SAMPLE_RATE)
39
  if wav.shape[1] > max_length:
 
85
  "PCM_16",
86
  )
87
 
88
+ return output_path
 
 
 
 
 
 
 
 
89
 
90
 
91
  def main():
92
  app = App()
93
 
94
  with gr.Blocks() as demo:
95
+ gr.Markdown(DESCRIPTION)
96
  with gr.Group():
97
+ input_audio = gr.Audio(
98
+ label="Input speech",
99
+ type="filepath",
100
+ sources=["upload", "microphone"],
101
+ format='wav',
102
+ streaming=False,
103
+ visible=True,
104
+ )
 
 
 
 
 
 
 
 
 
 
105
 
106
  btn = gr.Button("Translate")
107
 
108
+ output_audio = gr.Audio(
109
+ label="Translated speech",
110
+ autoplay=False,
111
+ streaming=False,
112
+ type="numpy",
113
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
114
 
115
  btn.click(
116
  fn=app.s2st,
117
+ inputs=[input_audio],
118
+ outputs=[output_audio],
 
 
 
 
119
  api_name="run",
120
  )
121
 
requirements.txt CHANGED
@@ -1,4 +1,5 @@
1
- espnet
 
2
  torchaudio
3
  torch
4
  git+https://github.com/kan-bayashi/ParallelWaveGAN
 
1
+ gradio==4.1.1
2
+ espnet==202310
3
  torchaudio
4
  torch
5
  git+https://github.com/kan-bayashi/ParallelWaveGAN