Spaces:
Runtime error
Runtime error
from __future__ import annotations | |
import os | |
# By using XTTS you agree to CPML license https://coqui.ai/cpml | |
os.environ["COQUI_TOS_AGREED"] = "1" | |
import gradio as gr | |
import numpy as np | |
import torch | |
import nltk # we'll use this to split into sentences | |
nltk.download('punkt') | |
import uuid | |
import ffmpeg | |
import librosa | |
import torchaudio | |
from TTS.api import TTS | |
from TTS.tts.configs.xtts_config import XttsConfig | |
from TTS.tts.models.xtts import Xtts | |
from TTS.utils.generic_utils import get_user_data_dir | |
# This will trigger downloading model | |
print("Downloading if not downloaded Coqui XTTS V1") | |
tts = TTS("tts_models/multilingual/multi-dataset/xtts_v1") | |
del tts | |
print("XTTS downloaded") | |
print("Loading XTTS") | |
#Below will use model directly for inference | |
model_path = os.path.join(get_user_data_dir("tts"), "tts_models--multilingual--multi-dataset--xtts_v1") | |
config = XttsConfig() | |
config.load_json(os.path.join(model_path, "config.json")) | |
model = Xtts.init_from_config(config) | |
model.load_checkpoint( | |
config, | |
checkpoint_path=os.path.join(model_path, "model.pth"), | |
vocab_path=os.path.join(model_path, "vocab.json"), | |
eval=True, | |
use_deepspeed=True | |
) | |
model.cuda() | |
print("Done loading TTS") | |
title = "Voice chat with Mistral 7B Instruct" | |
DESCRIPTION = """# Voice chat with Mistral 7B Instruct""" | |
css = """.toast-wrap { display: none !important } """ | |
from huggingface_hub import HfApi | |
HF_TOKEN = os.environ.get("HF_TOKEN") | |
# will use api to restart space on a unrecoverable error | |
api = HfApi(token=HF_TOKEN) | |
repo_id = "ylacombe/voice-chat-with-lama" | |
system_message = "\nYou are a helpful, respectful and honest assistant. Your answers are short, ideally a few words long, if it is possible. Always answer as helpfully as possible, while being safe.\n\nIf a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information." | |
temperature = 0.9 | |
top_p = 0.6 | |
repetition_penalty = 1.2 | |
import gradio as gr | |
import os | |
import time | |
import gradio as gr | |
from transformers import pipeline | |
import numpy as np | |
from gradio_client import Client | |
from huggingface_hub import InferenceClient | |
# This client is down | |
#whisper_client = Client("https://sanchit-gandhi-whisper-large-v2.hf.space/") | |
# Replacement whisper client, it may be time limited | |
whisper_client = Client("https://sanchit-gandhi-whisper-jax.hf.space") | |
text_client = InferenceClient( | |
"mistralai/Mistral-7B-Instruct-v0.1" | |
) | |
###### COQUI TTS FUNCTIONS ###### | |
def get_latents(speaker_wav): | |
# create as function as we can populate here with voice cleanup/filtering | |
gpt_cond_latent, diffusion_conditioning, speaker_embedding = model.get_conditioning_latents(audio_path=speaker_wav) | |
return gpt_cond_latent, diffusion_conditioning, speaker_embedding | |
def format_prompt(message, history): | |
prompt = "<s>" | |
for user_prompt, bot_response in history: | |
prompt += f"[INST] {user_prompt} [/INST]" | |
prompt += f" {bot_response}</s> " | |
prompt += f"[INST] {message} [/INST]" | |
return prompt | |
def generate( | |
prompt, history, temperature=0.9, max_new_tokens=256, top_p=0.95, repetition_penalty=1.0, | |
): | |
temperature = float(temperature) | |
if temperature < 1e-2: | |
temperature = 1e-2 | |
top_p = float(top_p) | |
generate_kwargs = dict( | |
temperature=temperature, | |
max_new_tokens=max_new_tokens, | |
top_p=top_p, | |
repetition_penalty=repetition_penalty, | |
do_sample=True, | |
seed=42, | |
) | |
formatted_prompt = format_prompt(prompt, history) | |
try: | |
stream = text_client.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True, return_full_text=False) | |
output = "" | |
for response in stream: | |
output += response.token.text | |
yield output | |
except Exception as e: | |
if "Too Many Requests" in str(e): | |
print("ERROR: Too many requests on mistral client") | |
gr.Warning("Unfortunately Mistral is unable to process") | |
output = "Unfortuanately I am not able to process your request now !" | |
else: | |
print("Unhandled Exception: ", str(e)) | |
gr.Warning("Unfortunately Mistral is unable to process") | |
output = "I do not know what happened but I could not understand you ." | |
return output | |
def transcribe(wav_path): | |
# get first element from whisper_jax and strip it to delete begin and end space | |
return whisper_client.predict( | |
wav_path, # str (filepath or URL to file) in 'inputs' Audio component | |
"transcribe", # str in 'Task' Radio component | |
False, # return_timestamps=False for whisper-jax https://gist.github.com/sanchit-gandhi/781dd7003c5b201bfe16d28634c8d4cf#file-whisper_jax_endpoint-py | |
api_name="/predict" | |
)[0].strip() | |
# Chatbot demo with multimodal input (text, markdown, LaTeX, code blocks, image, audio, & video). Plus shows support for streaming text. | |
def add_text(history, text): | |
history = [] if history is None else history | |
history = history + [(text, None)] | |
return history, gr.update(value="", interactive=False) | |
def add_file(history, file): | |
history = [] if history is None else history | |
try: | |
text = transcribe( | |
file | |
) | |
print("Transcribed text:",text) | |
except Exception as e: | |
print(str(e)) | |
gr.Warning("There was an issue with transcription, please try writing for now") | |
# Apply a null text on error | |
text = "Transcription seems failed, please tell me a joke about chickens" | |
history = history + [(text, None)] | |
return history | |
def bot(history, system_prompt=""): | |
history = [] if history is None else history | |
if system_prompt == "": | |
system_prompt = system_message | |
history[-1][1] = "" | |
for character in generate(history[-1][0], history[:-1]): | |
history[-1][1] = character | |
yield history | |
def get_latents(speaker_wav): | |
# Generate speaker embedding and latents for TTS | |
gpt_cond_latent, diffusion_conditioning, speaker_embedding = model.get_conditioning_latents(audio_path=speaker_wav) | |
return gpt_cond_latent, diffusion_conditioning, speaker_embedding | |
latent_map={} | |
latent_map["Female_Voice"] = get_latents("examples/female.wav") | |
def get_voice(prompt,language, latent_tuple,suffix="0"): | |
gpt_cond_latent,diffusion_conditioning, speaker_embedding = latent_tuple | |
# Direct version | |
t0 = time.time() | |
out = model.inference( | |
prompt, | |
language, | |
gpt_cond_latent, | |
speaker_embedding, | |
diffusion_conditioning | |
) | |
inference_time = time.time() - t0 | |
print(f"I: Time to generate audio: {round(inference_time*1000)} milliseconds") | |
real_time_factor= (time.time() - t0) / out['wav'].shape[-1] * 24000 | |
print(f"Real-time factor (RTF): {real_time_factor}") | |
wav_filename=f"output_{suffix}.wav" | |
torchaudio.save(wav_filename, torch.tensor(out["wav"]).unsqueeze(0), 24000) | |
return wav_filename | |
def generate_speech(history): | |
text_to_generate = history[-1][1] | |
text_to_generate = text_to_generate.replace("\n", " ").strip() | |
text_to_generate = nltk.sent_tokenize(text_to_generate) | |
language = "en" | |
wav_list = [] | |
for i,sentence in enumerate(text_to_generate): | |
# Sometimes prompt </s> coming on output remove it | |
sentence= sentence.replace("</s>","") | |
# A fast fix for last chacter, may produce weird sounds if it is with text | |
if sentence[-1] in ["!","?",".",","]: | |
#just add a space | |
sentence = sentence[:-1] + " " + sentence[-1] | |
print("Sentence:", sentence) | |
try: | |
# generate speech using precomputed latents | |
# This is not streaming but it will be fast | |
# giving sentence suffix so we can merge all to single audio at end | |
# On mobile there is no autoplay support due to mobile security! | |
wav = get_voice(sentence,language, latent_map["Female_Voice"], suffix=i) | |
wav_list.append(wav) | |
yield wav | |
wait_time= librosa.get_duration(path=wav) | |
print("Sleeping till audio end") | |
time.sleep(wait_time) | |
except RuntimeError as e : | |
if "device-side assert" in str(e): | |
# cannot do anything on cuda device side error, need tor estart | |
print(f"Exit due to: Unrecoverable exception caused by prompt:{sentence}", flush=True) | |
gr.Warning("Unhandled Exception encounter, please retry in a minute") | |
print("Cuda device-assert Runtime encountered need restart") | |
# HF Space specific.. This error is unrecoverable need to restart space | |
api.restart_space(repo_id=repo_id) | |
else: | |
print("RuntimeError: non device-side assert error:", str(e)) | |
raise e | |
#Spoken on autoplay everysencen now produce a concataned one at the one | |
#requires pip install ffmpeg-python | |
files_to_concat= [ffmpeg.input(w) for w in wav_list] | |
combined_file_name="combined.wav" | |
ffmpeg.concat(*files_to_concat,v=0, a=1).output(combined_file_name).run(overwrite_output=True) | |
return gr.Audio.update(value=combined_file_name, autoplay=False) | |
with gr.Blocks(title=title) as demo: | |
gr.Markdown(DESCRIPTION) | |
chatbot = gr.Chatbot( | |
[], | |
elem_id="chatbot", | |
avatar_images=('examples/lama.jpeg', 'examples/lama2.jpeg'), | |
bubble_full_width=False, | |
) | |
with gr.Row(): | |
txt = gr.Textbox( | |
scale=3, | |
show_label=False, | |
placeholder="Enter text and press enter, or speak to your microphone", | |
container=False, | |
) | |
txt_btn = gr.Button(value="Submit text",scale=1) | |
btn = gr.Audio(source="microphone", type="filepath", scale=4) | |
with gr.Row(): | |
audio = gr.Audio(type="numpy", streaming=False, autoplay=True, label="Generated audio response", show_label=True) | |
clear_btn = gr.ClearButton([chatbot, audio]) | |
txt_msg = txt_btn.click(add_text, [chatbot, txt], [chatbot, txt], queue=False).then( | |
bot, chatbot, chatbot | |
).then(generate_speech, chatbot, audio) | |
txt_msg.then(lambda: gr.update(interactive=True), None, [txt], queue=False) | |
txt_msg = txt.submit(add_text, [chatbot, txt], [chatbot, txt], queue=False).then( | |
bot, chatbot, chatbot | |
).then(generate_speech, chatbot, audio) | |
txt_msg.then(lambda: gr.update(interactive=True), None, [txt], queue=False) | |
file_msg = btn.stop_recording(add_file, [chatbot, btn], [chatbot], queue=False).then( | |
bot, chatbot, chatbot | |
).then(generate_speech, chatbot, audio) | |
gr.Markdown(""" | |
This Space demonstrates how to speak to a chatbot, based solely on open-source models. | |
It relies on 3 models: | |
1. [Whisper-large-v2](https://huggingface.co/spaces/sanchit-gandhi/whisper-jax) as an ASR model, to transcribe recorded audio to text. It is called through a [gradio client](https://www.gradio.app/docs/client). | |
2. [Mistral-7b-instruct](https://huggingface.co/spaces/osanseviero/mistral-super-fast) as the chat model, the actual chat model. It is called from [huggingface_hub](https://huggingface.co/docs/huggingface_hub/guides/inference). | |
3. [Coqui's XTTS](https://huggingface.co/spaces/coqui/xtts) as a TTS model, to generate the chatbot answers. This time, the model is hosted locally. | |
Note: | |
- By using this demo you agree to the terms of the Coqui Public Model License at https://coqui.ai/cpml""") | |
demo.queue() | |
demo.launch(debug=True) | |