ylacombe commited on
Commit
ac1e9e7
·
1 Parent(s): 3029bff

Improve UX a bit and switch back to Whisper large v2

Browse files
Files changed (1) hide show
  1. app.py +19 -36
app.py CHANGED
@@ -19,7 +19,6 @@ import datetime
19
 
20
  from scipy.io.wavfile import write
21
  from pydub import AudioSegment
22
- import ffmpeg
23
 
24
  import re
25
  import io, wave
@@ -57,7 +56,7 @@ model.load_checkpoint(
57
  checkpoint_path=os.path.join(model_path, "model.pth"),
58
  vocab_path=os.path.join(model_path, "vocab.json"),
59
  eval=True,
60
- use_deepspeed=True,
61
  )
62
  model.cuda()
63
  print("Done loading TTS")
@@ -113,10 +112,7 @@ from gradio_client import Client
113
  from huggingface_hub import InferenceClient
114
 
115
  WHISPER_TIMEOUT = int(os.environ.get("WHISPER_TIMEOUT", 30))
116
- # This client is down
117
- # whisper_client = Client("https://sanchit-gandhi-whisper-large-v2.hf.space/")
118
- # Replacement whisper client, it may be time limited
119
- whisper_client = Client("https://sanchit-gandhi-whisper-jax.hf.space")
120
  text_client = InferenceClient(
121
  "mistralai/Mistral-7B-Instruct-v0.1",
122
  timeout=WHISPER_TIMEOUT,
@@ -203,13 +199,12 @@ def generate(
203
 
204
  def transcribe(wav_path):
205
  try:
206
- # get first element from whisper_jax and strip it to delete begin and end space
207
  return whisper_client.predict(
208
- wav_path, # str (filepath or URL to file) in 'inputs' Audio component
209
- "transcribe", # str in 'Task' Radio component
210
- False, # return_timestamps=False for whisper-jax https://gist.github.com/sanchit-gandhi/781dd7003c5b201bfe16d28634c8d4cf#file-whisper_jax_endpoint-py
211
- api_name="/predict",
212
- )[0].strip()
213
  except:
214
  gr.Warning("There was a problem with Whisper endpoint, telling a joke for you.")
215
  return "There was a problem with my voice, tell me joke"
@@ -242,8 +237,8 @@ def add_file(history, file):
242
 
243
  ##NOTE: not using this as it yields a chacter each time while we need to feed history to TTS
244
  def bot(history, system_prompt=""):
245
- history = [] if history is None else history
246
-
247
  if system_prompt == "":
248
  system_prompt = system_message
249
 
@@ -267,21 +262,6 @@ latent_map = {}
267
  latent_map["Female_Voice"] = get_latents("examples/female.wav")
268
 
269
 
270
- def get_voice(prompt, language, latent_tuple, suffix="0"):
271
- gpt_cond_latent, diffusion_conditioning, speaker_embedding = latent_tuple
272
- # Direct version
273
- t0 = time.time()
274
- out = model.inference(
275
- prompt, language, gpt_cond_latent, speaker_embedding, diffusion_conditioning
276
- )
277
- inference_time = time.time() - t0
278
- print(f"I: Time to generate audio: {round(inference_time*1000)} milliseconds")
279
- real_time_factor = (time.time() - t0) / out["wav"].shape[-1] * 24000
280
- print(f"Real-time factor (RTF): {real_time_factor}")
281
- wav_filename = f"output_{suffix}.wav"
282
- torchaudio.save(wav_filename, torch.tensor(out["wav"]).unsqueeze(0), 24000)
283
- return wav_filename
284
-
285
 
286
  def wave_header_chunk(frame_input=b"", channels=1, sample_width=2, sample_rate=24000):
287
  # This will create a wave header then append the frame input
@@ -333,7 +313,7 @@ def get_voice_streaming(prompt, language, latent_tuple, suffix="0"):
333
  if "device-side assert" in str(e):
334
  # cannot do anything on cuda device side error, need tor estart
335
  print(
336
- f"Exit due to: Unrecoverable exception caused by prompt:{sentence}",
337
  flush=True,
338
  )
339
  gr.Warning("Unhandled Exception encounter, please retry in a minute")
@@ -353,10 +333,12 @@ def get_voice_streaming(prompt, language, latent_tuple, suffix="0"):
353
 
354
  def get_sentence(history, system_prompt=""):
355
  history = [["", None]] if history is None else history
356
- print(history)
357
  if system_prompt == "":
358
  system_prompt = system_message
359
 
 
 
360
  mistral_start = time.time()
361
  print("Mistral start")
362
  sentence_list = []
@@ -422,8 +404,8 @@ def generate_speech(history):
422
  try:
423
  # generate speech using precomputed latents
424
  # This is not streaming but it will be fast
425
- # wav = get_voice(sentence,language, latent_map["Female_Voice"], suffix=len(wav_list))
426
  if len(sentence) > 250:
 
427
  # should not generate voice it will hit token limit
428
  # It should not generate audio for it
429
  audio_stream = None
@@ -520,6 +502,7 @@ with gr.Blocks(title=title) as demo:
520
  show_label=False,
521
  placeholder="Enter text and press enter, or speak to your microphone",
522
  container=False,
 
523
  )
524
  txt_btn = gr.Button(value="Submit text", scale=1)
525
  btn = gr.Audio(source="microphone", type="filepath", scale=4)
@@ -536,7 +519,7 @@ with gr.Blocks(title=title) as demo:
536
  # final_audio = gr.Audio(label="Final audio response", streaming=False, autoplay=False, interactive=False,show_label=True, visible=False)
537
 
538
  clear_btn = gr.ClearButton([chatbot, audio])
539
-
540
  txt_msg = txt_btn.click(add_text, [chatbot, txt], [chatbot, txt], queue=False).then(
541
  generate_speech, chatbot, [audio, chatbot]
542
  )
@@ -553,13 +536,13 @@ with gr.Blocks(title=title) as demo:
553
  add_file, [chatbot, btn], [chatbot, txt], queue=False
554
  ).then(generate_speech, chatbot, [audio, chatbot])
555
 
556
- file_msg.then(lambda: gr.update(interactive=True), None, [txt], queue=False)
557
 
558
  gr.Markdown(
559
  """
560
  This Space demonstrates how to speak to a chatbot, based solely on open-source models.
561
  It relies on 3 models:
562
- 1. [Whisper-large-v2](https://huggingface.co/spaces/sanchit-gandhi/whisper-jax) as an ASR model, to transcribe recorded audio to text. It is called through a [gradio client](https://www.gradio.app/docs/client).
563
  2. [Mistral-7b-instruct](https://huggingface.co/spaces/osanseviero/mistral-super-fast) as the chat model, the actual chat model. It is called from [huggingface_hub](https://huggingface.co/docs/huggingface_hub/guides/inference).
564
  3. [Coqui's XTTS](https://huggingface.co/spaces/coqui/xtts) as a TTS model, to generate the chatbot answers. This time, the model is hosted locally.
565
 
@@ -567,4 +550,4 @@ Note:
567
  - By using this demo you agree to the terms of the Coqui Public Model License at https://coqui.ai/cpml"""
568
  )
569
  demo.queue()
570
- demo.launch(debug=True, share=True)
 
19
 
20
  from scipy.io.wavfile import write
21
  from pydub import AudioSegment
 
22
 
23
  import re
24
  import io, wave
 
56
  checkpoint_path=os.path.join(model_path, "model.pth"),
57
  vocab_path=os.path.join(model_path, "vocab.json"),
58
  eval=True,
59
+ use_deepspeed=False, # TODO: replace by True
60
  )
61
  model.cuda()
62
  print("Done loading TTS")
 
112
  from huggingface_hub import InferenceClient
113
 
114
  WHISPER_TIMEOUT = int(os.environ.get("WHISPER_TIMEOUT", 30))
115
+ whisper_client = Client("https://sanchit-gandhi-whisper-large-v2.hf.space/")
 
 
 
116
  text_client = InferenceClient(
117
  "mistralai/Mistral-7B-Instruct-v0.1",
118
  timeout=WHISPER_TIMEOUT,
 
199
 
200
  def transcribe(wav_path):
201
  try:
202
+ # get result from whisper and strip it to delete begin and end space
203
  return whisper_client.predict(
204
+ wav_path, # str (filepath or URL to file) in 'inputs' Audio component
205
+ "transcribe", # str in 'Task' Radio component
206
+ api_name="/predict"
207
+ ).strip()
 
208
  except:
209
  gr.Warning("There was a problem with Whisper endpoint, telling a joke for you.")
210
  return "There was a problem with my voice, tell me joke"
 
237
 
238
  ##NOTE: not using this as it yields a chacter each time while we need to feed history to TTS
239
  def bot(history, system_prompt=""):
240
+ history = [["", None]] if history is None else history
241
+
242
  if system_prompt == "":
243
  system_prompt = system_message
244
 
 
262
  latent_map["Female_Voice"] = get_latents("examples/female.wav")
263
 
264
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
265
 
266
  def wave_header_chunk(frame_input=b"", channels=1, sample_width=2, sample_rate=24000):
267
  # This will create a wave header then append the frame input
 
313
  if "device-side assert" in str(e):
314
  # cannot do anything on cuda device side error, need tor estart
315
  print(
316
+ f"Exit due to: Unrecoverable exception caused by prompt:{prompt}",
317
  flush=True,
318
  )
319
  gr.Warning("Unhandled Exception encounter, please retry in a minute")
 
333
 
334
  def get_sentence(history, system_prompt=""):
335
  history = [["", None]] if history is None else history
336
+
337
  if system_prompt == "":
338
  system_prompt = system_message
339
 
340
+ history[-1][1] = ""
341
+
342
  mistral_start = time.time()
343
  print("Mistral start")
344
  sentence_list = []
 
404
  try:
405
  # generate speech using precomputed latents
406
  # This is not streaming but it will be fast
 
407
  if len(sentence) > 250:
408
+ gr.Warning("There was a problem with the last sentence, which was too long, so it won't be spoken.")
409
  # should not generate voice it will hit token limit
410
  # It should not generate audio for it
411
  audio_stream = None
 
502
  show_label=False,
503
  placeholder="Enter text and press enter, or speak to your microphone",
504
  container=False,
505
+ interactive=True,
506
  )
507
  txt_btn = gr.Button(value="Submit text", scale=1)
508
  btn = gr.Audio(source="microphone", type="filepath", scale=4)
 
519
  # final_audio = gr.Audio(label="Final audio response", streaming=False, autoplay=False, interactive=False,show_label=True, visible=False)
520
 
521
  clear_btn = gr.ClearButton([chatbot, audio])
522
+
523
  txt_msg = txt_btn.click(add_text, [chatbot, txt], [chatbot, txt], queue=False).then(
524
  generate_speech, chatbot, [audio, chatbot]
525
  )
 
536
  add_file, [chatbot, btn], [chatbot, txt], queue=False
537
  ).then(generate_speech, chatbot, [audio, chatbot])
538
 
539
+ file_msg.then(lambda: (gr.update(interactive=True),gr.update(interactive=True,value=None)), None, [txt, btn], queue=False)
540
 
541
  gr.Markdown(
542
  """
543
  This Space demonstrates how to speak to a chatbot, based solely on open-source models.
544
  It relies on 3 models:
545
+ 1. [Whisper-large-v2](https://sanchit-gandhi-whisper-large-v2.hf.space/) as an ASR model, to transcribe recorded audio to text. It is called through a [gradio client](https://www.gradio.app/docs/client).
546
  2. [Mistral-7b-instruct](https://huggingface.co/spaces/osanseviero/mistral-super-fast) as the chat model, the actual chat model. It is called from [huggingface_hub](https://huggingface.co/docs/huggingface_hub/guides/inference).
547
  3. [Coqui's XTTS](https://huggingface.co/spaces/coqui/xtts) as a TTS model, to generate the chatbot answers. This time, the model is hosted locally.
548
 
 
550
  - By using this demo you agree to the terms of the Coqui Public Model License at https://coqui.ai/cpml"""
551
  )
552
  demo.queue()
553
+ demo.launch(debug=True)