ylacombe commited on
Commit
c534b30
·
1 Parent(s): 245ae02

xtts-and-whisper-update (#3)

Browse files

- xtts and whisper jax (145f28e71d0018b8125e2a088a6e2fc7efff9ce9)
- remove unneded testfile (4d98613abf7d496dd3b12d2af377f8ba4ed587f4)
- add ffmpeg import (0e056f7dc33bdc6d3eab93d5ded7bb4ef630e102)

Files changed (2) hide show
  1. app.py +133 -31
  2. requirements.txt +5 -2
app.py CHANGED
@@ -11,8 +11,36 @@ import nltk # we'll use this to split into sentences
11
  nltk.download('punkt')
12
  import uuid
13
 
 
 
 
14
  from TTS.api import TTS
15
- tts = TTS("tts_models/multilingual/multi-dataset/xtts_v1", gpu=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
 
17
  title = "Voice chat with Mistral 7B Instruct"
18
 
@@ -44,11 +72,20 @@ from gradio_client import Client
44
  from huggingface_hub import InferenceClient
45
 
46
 
47
- whisper_client = Client("https://sanchit-gandhi-whisper-large-v2.hf.space/")
 
 
 
48
  text_client = InferenceClient(
49
  "mistralai/Mistral-7B-Instruct-v0.1"
50
  )
51
 
 
 
 
 
 
 
52
 
53
  def format_prompt(message, history):
54
  prompt = "<s>"
@@ -77,22 +114,35 @@ def generate(
77
 
78
  formatted_prompt = format_prompt(prompt, history)
79
 
80
- stream = text_client.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True, return_full_text=False)
81
- output = ""
82
-
83
- for response in stream:
84
- output += response.token.text
85
- yield output
 
 
 
 
 
 
 
 
 
 
 
86
  return output
87
 
88
 
89
  def transcribe(wav_path):
90
 
 
91
  return whisper_client.predict(
92
  wav_path, # str (filepath or URL to file) in 'inputs' Audio component
93
  "transcribe", # str in 'Task' Radio component
 
94
  api_name="/predict"
95
- )
96
 
97
 
98
  # Chatbot demo with multimodal input (text, markdown, LaTeX, code blocks, image, audio, & video). Plus shows support for streaming text.
@@ -106,9 +156,17 @@ def add_text(history, text):
106
 
107
  def add_file(history, file):
108
  history = [] if history is None else history
109
- text = transcribe(
110
- file
111
- )
 
 
 
 
 
 
 
 
112
 
113
  history = history + [(text, None)]
114
  return history
@@ -126,29 +184,65 @@ def bot(history, system_prompt=""):
126
  history[-1][1] = character
127
  yield history
128
 
129
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
130
  def generate_speech(history):
131
  text_to_generate = history[-1][1]
132
  text_to_generate = text_to_generate.replace("\n", " ").strip()
133
  text_to_generate = nltk.sent_tokenize(text_to_generate)
134
-
135
- filename = f"{uuid.uuid4()}.wav"
136
- sampling_rate = tts.synthesizer.tts_config.audio["sample_rate"]
137
- silence = [0] * int(0.25 * sampling_rate)
138
 
139
-
140
- for sentence in text_to_generate:
141
- try:
142
 
143
- # generate speech by cloning a voice using default settings
144
- wav = tts.tts(text=sentence,
145
- speaker_wav="examples/female.wav",
146
- decoder_iterations=25,
147
- decoder_sampler="dpm++2m",
148
- speed=1.2,
149
- language="en")
 
 
 
 
 
 
 
150
 
151
- yield (sampling_rate, np.array(wav)) #np.array(wav + silence))
 
 
 
 
 
 
 
 
152
 
153
  except RuntimeError as e :
154
  if "device-side assert" in str(e):
@@ -163,6 +257,14 @@ def generate_speech(history):
163
  else:
164
  print("RuntimeError: non device-side assert error:", str(e))
165
  raise e
 
 
 
 
 
 
 
 
166
 
167
  with gr.Blocks(title=title) as demo:
168
  gr.Markdown(DESCRIPTION)
@@ -186,7 +288,7 @@ with gr.Blocks(title=title) as demo:
186
  btn = gr.Audio(source="microphone", type="filepath", scale=4)
187
 
188
  with gr.Row():
189
- audio = gr.Audio(type="numpy", streaming=True, autoplay=True, label="Generated audio response", show_label=True)
190
 
191
  clear_btn = gr.ClearButton([chatbot, audio])
192
 
@@ -210,11 +312,11 @@ with gr.Blocks(title=title) as demo:
210
  gr.Markdown("""
211
  This Space demonstrates how to speak to a chatbot, based solely on open-source models.
212
  It relies on 3 models:
213
- 1. [Whisper-large-v2](https://huggingface.co/spaces/sanchit-gandhi/whisper-large-v2) as an ASR model, to transcribe recorded audio to text. It is called through a [gradio client](https://www.gradio.app/docs/client).
214
  2. [Mistral-7b-instruct](https://huggingface.co/spaces/osanseviero/mistral-super-fast) as the chat model, the actual chat model. It is called from [huggingface_hub](https://huggingface.co/docs/huggingface_hub/guides/inference).
215
  3. [Coqui's XTTS](https://huggingface.co/spaces/coqui/xtts) as a TTS model, to generate the chatbot answers. This time, the model is hosted locally.
216
 
217
  Note:
218
  - By using this demo you agree to the terms of the Coqui Public Model License at https://coqui.ai/cpml""")
219
  demo.queue()
220
- demo.launch(debug=True)
 
11
  nltk.download('punkt')
12
  import uuid
13
 
14
+ import ffmpeg
15
+ import librosa
16
+ import torchaudio
17
  from TTS.api import TTS
18
+ from TTS.tts.configs.xtts_config import XttsConfig
19
+ from TTS.tts.models.xtts import Xtts
20
+ from TTS.utils.generic_utils import get_user_data_dir
21
+
22
+ # This will trigger downloading model
23
+ print("Downloading if not downloaded Coqui XTTS V1")
24
+ tts = TTS("tts_models/multilingual/multi-dataset/xtts_v1")
25
+ del tts
26
+ print("XTTS downloaded")
27
+
28
+ print("Loading XTTS")
29
+ #Below will use model directly for inference
30
+ model_path = os.path.join(get_user_data_dir("tts"), "tts_models--multilingual--multi-dataset--xtts_v1")
31
+ config = XttsConfig()
32
+ config.load_json(os.path.join(model_path, "config.json"))
33
+ model = Xtts.init_from_config(config)
34
+ model.load_checkpoint(
35
+ config,
36
+ checkpoint_path=os.path.join(model_path, "model.pth"),
37
+ vocab_path=os.path.join(model_path, "vocab.json"),
38
+ eval=True,
39
+ use_deepspeed=True
40
+ )
41
+ model.cuda()
42
+ print("Done loading TTS")
43
+
44
 
45
  title = "Voice chat with Mistral 7B Instruct"
46
 
 
72
  from huggingface_hub import InferenceClient
73
 
74
 
75
+ # This client is down
76
+ #whisper_client = Client("https://sanchit-gandhi-whisper-large-v2.hf.space/")
77
+ # Replacement whisper client, it may be time limited
78
+ whisper_client = Client("https://sanchit-gandhi-whisper-jax.hf.space")
79
  text_client = InferenceClient(
80
  "mistralai/Mistral-7B-Instruct-v0.1"
81
  )
82
 
83
+ ###### COQUI TTS FUNCTIONS ######
84
+ def get_latents(speaker_wav):
85
+ # create as function as we can populate here with voice cleanup/filtering
86
+ gpt_cond_latent, diffusion_conditioning, speaker_embedding = model.get_conditioning_latents(audio_path=speaker_wav)
87
+ return gpt_cond_latent, diffusion_conditioning, speaker_embedding
88
+
89
 
90
  def format_prompt(message, history):
91
  prompt = "<s>"
 
114
 
115
  formatted_prompt = format_prompt(prompt, history)
116
 
117
+ try:
118
+ stream = text_client.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True, return_full_text=False)
119
+ output = ""
120
+ for response in stream:
121
+ output += response.token.text
122
+ yield output
123
+
124
+ except Exception as e:
125
+ if "Too Many Requests" in str(e):
126
+ print("ERROR: Too many requests on mistral client")
127
+ gr.Warning("Unfortunately Mistral is unable to process")
128
+ output = "Unfortuanately I am not able to process your request now !"
129
+ else:
130
+ print("Unhandled Exception: ", str(e))
131
+ gr.Warning("Unfortunately Mistral is unable to process")
132
+ output = "I do not know what happened but I could not understand you ."
133
+
134
  return output
135
 
136
 
137
  def transcribe(wav_path):
138
 
139
+ # get first element from whisper_jax and strip it to delete begin and end space
140
  return whisper_client.predict(
141
  wav_path, # str (filepath or URL to file) in 'inputs' Audio component
142
  "transcribe", # str in 'Task' Radio component
143
+ False, # return_timestamps=False for whisper-jax https://gist.github.com/sanchit-gandhi/781dd7003c5b201bfe16d28634c8d4cf#file-whisper_jax_endpoint-py
144
  api_name="/predict"
145
+ )[0].strip()
146
 
147
 
148
  # Chatbot demo with multimodal input (text, markdown, LaTeX, code blocks, image, audio, & video). Plus shows support for streaming text.
 
156
 
157
  def add_file(history, file):
158
  history = [] if history is None else history
159
+
160
+ try:
161
+ text = transcribe(
162
+ file
163
+ )
164
+ print("Transcribed text:",text)
165
+ except Exception as e:
166
+ print(str(e))
167
+ gr.Warning("There was an issue with transcription, please try writing for now")
168
+ # Apply a null text on error
169
+ text = "Transcription seems failed, please tell me a joke about chickens"
170
 
171
  history = history + [(text, None)]
172
  return history
 
184
  history[-1][1] = character
185
  yield history
186
 
187
+
188
+ def get_latents(speaker_wav):
189
+ # Generate speaker embedding and latents for TTS
190
+ gpt_cond_latent, diffusion_conditioning, speaker_embedding = model.get_conditioning_latents(audio_path=speaker_wav)
191
+ return gpt_cond_latent, diffusion_conditioning, speaker_embedding
192
+
193
+ latent_map={}
194
+ latent_map["Female_Voice"] = get_latents("examples/female.wav")
195
+
196
+ def get_voice(prompt,language, latent_tuple,suffix="0"):
197
+ gpt_cond_latent,diffusion_conditioning, speaker_embedding = latent_tuple
198
+ # Direct version
199
+ t0 = time.time()
200
+ out = model.inference(
201
+ prompt,
202
+ language,
203
+ gpt_cond_latent,
204
+ speaker_embedding,
205
+ diffusion_conditioning
206
+ )
207
+ inference_time = time.time() - t0
208
+ print(f"I: Time to generate audio: {round(inference_time*1000)} milliseconds")
209
+ real_time_factor= (time.time() - t0) / out['wav'].shape[-1] * 24000
210
+ print(f"Real-time factor (RTF): {real_time_factor}")
211
+ wav_filename=f"output_{suffix}.wav"
212
+ torchaudio.save(wav_filename, torch.tensor(out["wav"]).unsqueeze(0), 24000)
213
+ return wav_filename
214
+
215
  def generate_speech(history):
216
  text_to_generate = history[-1][1]
217
  text_to_generate = text_to_generate.replace("\n", " ").strip()
218
  text_to_generate = nltk.sent_tokenize(text_to_generate)
 
 
 
 
219
 
220
+ language = "en"
 
 
221
 
222
+ wav_list = []
223
+ for i,sentence in enumerate(text_to_generate):
224
+ # Sometimes prompt </s> coming on output remove it
225
+ sentence= sentence.replace("</s>","")
226
+ # A fast fix for last chacter, may produce weird sounds if it is with text
227
+ if sentence[-1] in ["!","?",".",","]:
228
+ #just add a space
229
+ sentence = sentence[:-1] + " " + sentence[-1]
230
+
231
+ print("Sentence:", sentence)
232
+
233
+ try:
234
+ # generate speech using precomputed latents
235
+ # This is not streaming but it will be fast
236
 
237
+ # giving sentence suffix so we can merge all to single audio at end
238
+ # On mobile there is no autoplay support due to mobile security!
239
+ wav = get_voice(sentence,language, latent_map["Female_Voice"], suffix=i)
240
+ wav_list.append(wav)
241
+
242
+ yield wav
243
+ wait_time= librosa.get_duration(path=wav)
244
+ print("Sleeping till audio end")
245
+ time.sleep(wait_time)
246
 
247
  except RuntimeError as e :
248
  if "device-side assert" in str(e):
 
257
  else:
258
  print("RuntimeError: non device-side assert error:", str(e))
259
  raise e
260
+ #Spoken on autoplay everysencen now produce a concataned one at the one
261
+ #requires pip install ffmpeg-python
262
+ files_to_concat= [ffmpeg.input(w) for w in wav_list]
263
+ combined_file_name="combined.wav"
264
+ ffmpeg.concat(*files_to_concat,v=0, a=1).output(combined_file_name).run(overwrite_output=True)
265
+
266
+ return gr.Audio.update(value=combined_file_name, autoplay=False)
267
+
268
 
269
  with gr.Blocks(title=title) as demo:
270
  gr.Markdown(DESCRIPTION)
 
288
  btn = gr.Audio(source="microphone", type="filepath", scale=4)
289
 
290
  with gr.Row():
291
+ audio = gr.Audio(type="numpy", streaming=False, autoplay=True, label="Generated audio response", show_label=True)
292
 
293
  clear_btn = gr.ClearButton([chatbot, audio])
294
 
 
312
  gr.Markdown("""
313
  This Space demonstrates how to speak to a chatbot, based solely on open-source models.
314
  It relies on 3 models:
315
+ 1. [Whisper-large-v2](https://huggingface.co/spaces/sanchit-gandhi/whisper-jax) as an ASR model, to transcribe recorded audio to text. It is called through a [gradio client](https://www.gradio.app/docs/client).
316
  2. [Mistral-7b-instruct](https://huggingface.co/spaces/osanseviero/mistral-super-fast) as the chat model, the actual chat model. It is called from [huggingface_hub](https://huggingface.co/docs/huggingface_hub/guides/inference).
317
  3. [Coqui's XTTS](https://huggingface.co/spaces/coqui/xtts) as a TTS model, to generate the chatbot answers. This time, the model is hosted locally.
318
 
319
  Note:
320
  - By using this demo you agree to the terms of the Coqui Public Model License at https://coqui.ai/cpml""")
321
  demo.queue()
322
+ demo.launch(debug=True)
requirements.txt CHANGED
@@ -53,8 +53,11 @@ encodec==0.1.*
53
  # deps for XTTS
54
  unidecode==1.3.*
55
  langid
56
- # Install tts
57
- git+https://github.com/coqui-ai/tts.git@43a7ca800b6508d95e084728a948846556f71a40
 
58
  deepspeed==0.8.3
59
  pydub
 
 
60
  gradio_client
 
53
  # deps for XTTS
54
  unidecode==1.3.*
55
  langid
56
+ # Install Coqui TTS
57
+ TTS==0.17.8
58
+ # Deepspeed for fast inference
59
  deepspeed==0.8.3
60
  pydub
61
+ librosa
62
+ ffmpeg-python
63
  gradio_client