Spaces:
Running
on
T4
Running
on
T4
from __future__ import annotations | |
import os | |
import io | |
import re | |
import time | |
import uuid | |
import torch | |
import cohere | |
import random | |
import secrets | |
import requests | |
import fasttext | |
import replicate | |
import numpy as np | |
import gradio as gr | |
from PIL import Image | |
from groq import Groq | |
from TTS.api import TTS | |
from elevenlabs import save | |
from gradio.themes.base import Base | |
from elevenlabs.client import ElevenLabs | |
from huggingface_hub import hf_hub_download | |
from gradio.themes.utils import colors, fonts, sizes | |
from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM | |
from prompt_examples import TEXT_CHAT_EXAMPLES, IMG_GEN_PROMPT_EXAMPLES, AUDIO_EXAMPLES, TEXT_CHAT_EXAMPLES_LABELS, IMG_GEN_PROMPT_EXAMPLES_LABELS, AUDIO_EXAMPLES_LABELS | |
from preambles import CHAT_PREAMBLE, AUDIO_RESPONSE_PREAMBLE, IMG_DESCRIPTION_PREAMBLE | |
from constants import LID_LANGUAGES, NEETS_AI_LANGID_MAP, AYA_MODEL_NAME, BATCH_SIZE, USE_ELVENLABS, USE_REPLICATE | |
HF_API_TOKEN = os.getenv("HF_API_KEY") | |
ELEVEN_LABS_KEY = os.getenv("ELEVEN_LABS_KEY") | |
NEETS_AI_API_KEY = os.getenv("NEETS_AI_API_KEY") | |
GROQ_API_KEY = os.getenv("GROQ_API_KEY") | |
IMG_COHERE_API_KEY = os.getenv("IMG_COHERE_API_KEY") | |
AUDIO_COHERE_API_KEY = os.getenv("AUDIO_COHERE_API_KEY") | |
CHAT_COHERE_API_KEY = os.getenv("CHAT_COHERE_API_KEY") | |
DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
# Initialize cohere clients | |
img_prompt_client = cohere.Client( | |
api_key=IMG_COHERE_API_KEY, | |
client_name="c4ai-aya-expanse-img" | |
) | |
chat_client = cohere.Client( | |
api_key=CHAT_COHERE_API_KEY, | |
client_name="c4ai-aya-expanse-chat" | |
) | |
audio_response_client = cohere.Client( | |
api_key=AUDIO_COHERE_API_KEY, | |
client_name="c4ai-aya-expanse-audio" | |
) | |
# Initialize the Groq client | |
groq_client = Groq(api_key=GROQ_API_KEY) | |
# Initialize the ElevenLabs client | |
eleven_labs_client = ElevenLabs( | |
api_key=ELEVEN_LABS_KEY, | |
) | |
# Language identification | |
lid_model_path = hf_hub_download(repo_id="facebook/fasttext-language-identification", filename="model.bin") | |
LID_model = fasttext.load_model(lid_model_path) | |
def predict_language(text): | |
text = re.sub("\n", " ", text) | |
label, logit = LID_model.predict(text) | |
label = label[0][len("__label__") :] | |
print("predicted language:", label) | |
return label | |
# Image Generation util functions | |
def choose_img_prompt_examples(language): | |
example_choice = random.choice(IMG_GEN_PROMPT_EXAMPLES[language]) | |
return example_choice | |
def get_hf_inference_api_response(payload, model_id): | |
headers = {"Authorization": f"Bearer {HF_API_TOKEN}"} | |
MODEL_API_URL = f"https://api-inference.huggingface.co/models/{model_id}" | |
response = requests.post(MODEL_API_URL, headers=headers, json=payload) | |
return response.content | |
def replicate_api_inference(input_prompt): | |
input_params={ | |
"prompt": input_prompt, | |
"go_fast": True, | |
"megapixels": "1", | |
"num_outputs": 1, | |
"aspect_ratio": "1:1", | |
"output_format": "jpg", | |
"output_quality": 80, | |
"enable_safety_checker": True, | |
"safety_tolerance": 1, | |
"num_inference_steps": 4 | |
} | |
image = replicate.run("black-forest-labs/flux-schnell",input=input_params) | |
image = Image.open(image[0]) | |
return image | |
def generate_image(input_prompt, model_id="black-forest-labs/FLUX.1-schnell"): | |
if input_prompt: | |
if USE_REPLICATE: | |
print("using replicate for image generation") | |
image = replicate_api_inference(input_prompt) | |
else: | |
try: | |
print("using HF inference API for image generation") | |
image_bytes = get_hf_inference_api_response({ "inputs": input_prompt}, model_id) | |
image = np.array(Image.open(io.BytesIO(image_bytes))) | |
except Exception as e: | |
print("HF API error:", e) | |
# generate image with help replicate in case of error | |
image = replicate_api_inference(input_prompt) | |
return image | |
else: | |
return None | |
def generate_img_prompt(input_prompt): | |
if input_prompt: | |
# clean prompt before doing language detection | |
cleaned_prompt = clean_text(input_prompt, remove_bullets=True, remove_newline=True) | |
text_lang_code = predict_language(cleaned_prompt) | |
gr.Info("Generating Image", duration=2) | |
if text_lang_code!="eng_Latn": | |
text = f""" | |
Translate the given input prompt to English. | |
Input Prompt: {input_prompt} | |
Then based on the English translation of the prompt, generate a detailed image description which can be used to generate an image using a text-to-image model. | |
Do not use more than 3-4 lines for the image description. Respond with only the image description. | |
""" | |
else: | |
text = f"""Generate a detailed image description which can be used to generate an image using a text-to-image model based on the given input prompt: | |
Input Prompt: {input_prompt} | |
Do not use more than 3-4 lines for the description. | |
""" | |
response = img_prompt_client.chat(message=text, preamble=IMG_DESCRIPTION_PREAMBLE, model=AYA_MODEL_NAME) | |
output = response.text | |
return output | |
else: | |
return None | |
# Chat with Aya util functions | |
def choose_chat_examples(language): | |
example_choice = random.choice(TEXT_CHAT_EXAMPLES[language]) | |
return example_choice | |
def trigger_example(example): | |
chat, updated_history = generate_aya_chat_response(example) | |
return chat, updated_history | |
def generate_aya_chat_response(user_message, cid, token, history=None): | |
if not token: | |
print("no token") | |
#raise gr.Error("Error loading.") | |
if history is None: | |
history = [] | |
if cid == "" or None: | |
cid = str(uuid.uuid4()) | |
print(f"cid: {cid} prompt:{user_message}") | |
history.append(user_message) | |
stream = chat_client.chat_stream(message=user_message, preamble=CHAT_PREAMBLE, conversation_id=cid, model=AYA_MODEL_NAME, connectors=[], temperature=0.3) | |
output = "" | |
for idx, response in enumerate(stream): | |
if response.event_type == "text-generation": | |
output += response.text | |
if idx == 0: | |
history.append(" " + output) | |
else: | |
history[-1] = output | |
chat = [ | |
(history[i].strip(), history[i + 1].strip()) | |
for i in range(0, len(history) - 1, 2) | |
] | |
yield chat, history, cid | |
return chat, history, cid | |
def clear_chat(): | |
return [], [], str(uuid.uuid4()) | |
# Audio Pipeline util functions | |
def transcribe_and_stream(inputs, model_name="groq_whisper", show_info="show_info", language="english"): | |
if inputs: | |
if show_info=="show_info": | |
gr.Info("Processing Audio", duration=1) | |
if model_name != "groq_whisper": | |
print("DEVICE:", DEVICE) | |
pipe = pipeline( | |
task="automatic-speech-recognition", | |
model=model_name, | |
chunk_length_s=30, | |
DEVICE=DEVICE) | |
text = pipe(inputs, batch_size=BATCH_SIZE, return_timestamps=True)["text"] | |
else: | |
text = groq_whisper_tts(inputs) | |
# stream text output | |
for i in range(len(text)): | |
time.sleep(0.01) | |
yield text[: i + 10] | |
else: | |
return "" | |
def aya_speech_text_response(text): | |
if text: | |
stream = audio_response_client.chat_stream(message=text,preamble=AUDIO_RESPONSE_PREAMBLE, model=AYA_MODEL_NAME) | |
output = "" | |
for event in stream: | |
if event: | |
if event.event_type == "text-generation": | |
output+=event.text | |
cleaned_output = clean_text(output) | |
yield cleaned_output | |
else: | |
return "" | |
def clean_text(text, remove_bullets=False, remove_newline=False): | |
# Remove bold formatting | |
cleaned_text = re.sub(r"\*\*", "", text) | |
if remove_bullets: | |
cleaned_text = re.sub(r"^- ", "", cleaned_text, flags=re.MULTILINE) | |
if remove_newline: | |
cleaned_text = re.sub(r"\n", " ", cleaned_text) | |
return cleaned_text | |
def convert_text_to_speech(text, language="english"): | |
# do language detection to determine voice of speech response | |
if text: | |
# clean text before doing language detection | |
cleaned_text = clean_text(text, remove_bullets=True, remove_newline=True) | |
text_lang_code = predict_language(cleaned_text) | |
if not USE_ELVENLABS: | |
if text_lang_code!= "jpn_Jpan": | |
audio_path = neetsai_tts(text, text_lang_code) | |
else: | |
print("DEVICE:", DEVICE) | |
# if language is japanese then use XTTS for TTS since neets_ai doesn't support japanese voice | |
tts = TTS("tts_models/multilingual/multi-dataset/xtts_v2").to(DEVICE) | |
speaker_wav="samples/ja-sample.wav" | |
lang_code="ja" | |
audio_path = "./output.wav" | |
tts.tts_to_file(text=text, speaker_wav=speaker_wav, language=lang_code, file_path=audio_path) | |
else: | |
# use elevenlabs for TTS | |
audio_path = elevenlabs_generate_audio(text) | |
return audio_path | |
else: | |
return None | |
def elevenlabs_generate_audio(text): | |
audio = eleven_labs_client.generate( | |
text=text, | |
voice="River", | |
model="eleven_turbo_v2_5", #"eleven_multilingual_v2" | |
) | |
# save audio | |
audio_path = "./audio.mp3" | |
save(audio, audio_path) | |
return audio_path | |
def neetsai_tts(input_text, text_lang_code): | |
if text_lang_code in LID_LANGUAGES.keys(): | |
language = LID_LANGUAGES[text_lang_code] | |
else: | |
# use english voice as default for languages outside 23 languages of Aya Expanse | |
language = "english" | |
neets_lang_id = NEETS_AI_LANGID_MAP[language] | |
neets_vits_voice_id = f"vits-{neets_lang_id}" | |
response = requests.request( | |
method="POST", | |
url="https://api.neets.ai/v1/tts", | |
headers={ | |
"Content-Type": "application/json", | |
"X-API-Key": NEETS_AI_API_KEY | |
}, | |
json={ | |
"text": input_text, | |
"voice_id": neets_vits_voice_id, | |
"params": { | |
"model": "vits" | |
} | |
} | |
) | |
# save audio file | |
audio_path = "neets_demo.mp3" | |
with open(audio_path, "wb") as f: | |
f.write(response.content) | |
return audio_path | |
def groq_whisper_tts(filename): | |
with open(filename, "rb") as file: | |
transcriptions = groq_client.audio.transcriptions.create( | |
file=(filename, file.read()), | |
model="whisper-large-v3-turbo", | |
response_format="json", | |
temperature=0.0 | |
) | |
print("transcribed text:", transcriptions.text) | |
print("********************************") | |
return transcriptions.text | |
# setup gradio app theme | |
theme = gr.themes.Base( | |
primary_hue=gr.themes.colors.teal, | |
secondary_hue=gr.themes.colors.blue, | |
neutral_hue=gr.themes.colors.gray, | |
text_size=gr.themes.sizes.text_lg, | |
).set( | |
# Primary Button Color | |
button_primary_background_fill="#114A56", | |
button_primary_background_fill_hover="#114A56", | |
# Block Labels | |
block_title_text_weight="600", | |
block_label_text_weight="600", | |
block_label_text_size="*text_md", | |
) | |
demo = gr.Blocks(theme=theme, analytics_enabled=False) | |
with demo: | |
with gr.Row(variant="panel"): | |
with gr.Column(scale=1): | |
gr.Image("AyaExpanse.png", elem_id="logo-img", show_label=False, show_share_button=False, show_download_button=False, show_fullscreen_button=False) | |
with gr.Column(scale=30): | |
gr.Markdown("""C4AI Aya Expanse is a state-of-art model with highly advanced capabilities to connect the world across languages. | |
<br/> | |
You can use this space to chat, speak and visualize with Aya Expanse in 23 languages. | |
<br/> | |
**Model**: [aya-expanse-32B](https://huggingface.co/CohereForAI/aya-expanse-32b) | |
<br/> | |
**Developed by**: [Cohere for AI](https://cohere.com/research) and [Cohere](https://cohere.com/) | |
<br/> | |
**License**: [CC-BY-NC](https://cohere.com/c4ai-cc-by-nc-license), requires also adhering to [C4AI's Acceptable Use Policy](https://docs.cohere.com/docs/c4ai-acceptable-use-policy) | |
""" | |
) | |
with gr.TabItem("Chat with Aya") as chat_with_aya: | |
cid = gr.State("") | |
token = gr.State(value=None) | |
with gr.Column(): | |
with gr.Row(): | |
chatbot = gr.Chatbot(show_label=False, show_share_button=False, show_copy_button=True, height=300) | |
with gr.Row(): | |
user_message = gr.Textbox(lines=1, placeholder="Ask anything in our 23 languages ...", label="Input", show_label=False) | |
msg_temp = gr.Textbox(visible=False) | |
with gr.Row(): | |
submit_button = gr.Button("Submit",variant="primary") | |
clear_button = gr.Button("Clear") | |
history = gr.State([]) | |
user_message.submit(fn=generate_aya_chat_response, inputs=[user_message, cid, token, history], outputs=[chatbot, history, cid], concurrency_limit=32) | |
submit_button.click(fn=generate_aya_chat_response, inputs=[user_message, cid, token, history], outputs=[chatbot, history, cid], concurrency_limit=32) | |
clear_button.click(fn=clear_chat, inputs=None, outputs=[chatbot, history, cid], concurrency_limit=32) | |
user_message.submit(lambda x: gr.update(value=""), None, [user_message], queue=False) | |
submit_button.click(lambda x: gr.update(value=""), None, [user_message], queue=False) | |
clear_button.click(lambda x: gr.update(value=""), None, [user_message], queue=False) | |
with gr.Row(): | |
gr.Examples( | |
examples=[[lang] for lang in TEXT_CHAT_EXAMPLES.keys()], | |
inputs=msg_temp, | |
outputs=user_message, | |
fn=choose_chat_examples, | |
label="Load example prompt for:", | |
examples_per_page=25, | |
run_on_click=True | |
) | |
# End to End Testing Pipeline for speak with Aya | |
with gr.TabItem("Speak with Aya") as speak_with_aya: | |
with gr.Row(): | |
with gr.Column(): | |
e2e_audio_file = gr.Audio(sources="microphone", type="filepath", min_length=None) | |
e2_audio_submit_button = gr.Button(value="Get Aya's Response", variant="primary") | |
clear_button_microphone = gr.ClearButton() | |
gr.Examples( | |
examples=AUDIO_EXAMPLES, | |
inputs=e2e_audio_file, | |
cache_examples=False, | |
examples_per_page=25, | |
label="Load example audio for:", | |
example_labels=AUDIO_EXAMPLES_LABELS, | |
) | |
with gr.Column(): | |
e2e_audio_file_trans = gr.Textbox(lines=3,label="Your Input", autoscroll=False, show_copy_button=True, interactive=False) | |
e2e_audio_file_aya_response = gr.Textbox(lines=3,label="Aya's Response", show_copy_button=True, container=True, interactive=False) | |
e2e_aya_audio_response = gr.Audio(type="filepath", label="Aya's Audio Response") | |
with gr.Accordion("See Details", open=False): | |
gr.Markdown("To enable voice interaction with Aya Expanse, this space uses [Whisper large-v3-turbo](https://huggingface.co/openai/whisper-large-v3-turbo) and [Groq](https://groq.com/) for STT and [neets.ai](http://neets.ai/) for TTS.") | |
# Generate Images | |
with gr.TabItem("Visualize with Aya") as visualize_with_aya: | |
with gr.Row(): | |
with gr.Column(): | |
input_img_prompt = gr.Textbox(placeholder="Ask anything in our 23 languages ...", label="Describe an image", lines=3) | |
# generated_img_desc = gr.Textbox(label="Image Description generated by Aya", interactive=False, lines=3, visible=False) | |
submit_button_img = gr.Button(value="Submit", variant="primary") | |
clear_button_img = gr.ClearButton() | |
with gr.Column(): | |
generated_img = gr.Image(label="Generated Image", interactive=False) | |
input_prompt_lang = gr.Textbox(visible=False) | |
with gr.Row(): | |
gr.Examples( | |
examples=[[lang] for lang in IMG_GEN_PROMPT_EXAMPLES.keys()], | |
inputs=input_prompt_lang, | |
outputs=input_img_prompt, | |
fn=choose_img_prompt_examples, | |
label="Load example prompt for:", | |
examples_per_page=25, | |
run_on_click=True | |
) | |
generated_img_desc = gr.Textbox(label="Image Description generated by Aya", interactive=False, lines=3, visible=False) | |
# increase spacing between examples and Accordion components | |
with gr.Row(): | |
pass | |
with gr.Row(): | |
pass | |
with gr.Row(): | |
pass | |
with gr.Row(): | |
with gr.Accordion("See Details", open=False): | |
gr.Markdown("This space uses Aya Expanse for translating multilingual prompts and generating detailed image descriptions and [Flux Schnell](https://huggingface.co/black-forest-labs/FLUX.1-schnell) for Image Generation.") | |
# Image Generation | |
clear_button_img.click(lambda: None, None, input_img_prompt) | |
clear_button_img.click(lambda: None, None, generated_img_desc) | |
clear_button_img.click(lambda: None, None, generated_img) | |
submit_button_img.click( | |
generate_img_prompt, | |
inputs=[input_img_prompt], | |
outputs=[generated_img_desc], | |
) | |
generated_img_desc.change( | |
generate_image, #run_flux, | |
inputs=[generated_img_desc], | |
outputs=[generated_img], | |
show_progress="full", | |
) | |
# Audio Pipeline | |
clear_button_microphone.click(lambda: None, None, e2e_audio_file) | |
clear_button_microphone.click(lambda: None, None, e2e_aya_audio_response) | |
clear_button_microphone.click(lambda: None, None, e2e_audio_file_aya_response) | |
clear_button_microphone.click(lambda: None, None, e2e_audio_file_trans) | |
#e2e_audio_file.change( | |
e2_audio_submit_button.click( | |
transcribe_and_stream, | |
inputs=[e2e_audio_file], | |
outputs=[e2e_audio_file_trans], | |
show_progress="full", | |
).then( | |
aya_speech_text_response, | |
inputs=[e2e_audio_file_trans], | |
outputs=[e2e_audio_file_aya_response], | |
show_progress="full", | |
).then( | |
convert_text_to_speech, | |
inputs=[e2e_audio_file_aya_response], | |
outputs=[e2e_aya_audio_response], | |
show_progress="full", | |
) | |
demo.load(lambda: secrets.token_hex(16), None, token) | |
demo.queue(api_open=False, max_size=20, default_concurrency_limit=4).launch(show_api=False, allowed_paths=['/home/user/app']) |