Spaces:
Running
on
T4
Running
on
T4
File size: 3,959 Bytes
9c20b4e f136260 dc06293 b9d657b 9c20b4e 8ddd281 d336953 dc06293 d336953 9c20b4e 8fa13bc 8ddd281 9c20b4e 9db5d78 9c20b4e 9db5d78 f17c34f 9c20b4e 9db5d78 f17c34f 9c20b4e d336953 9c20b4e dc06293 8a1ab06 8fa13bc 8ddd281 8fa13bc 8ddd281 8fa13bc 8ddd281 8fa13bc 8ddd281 8fa13bc 8ddd281 8fa13bc 8ddd281 8fa13bc b9d657b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 |
"""
utils.py
Functions:
- get_script: Get the dialogue from the LLM.
- call_llm: Call the LLM with the given prompt and dialogue format.
- get_audio: Get the audio from the TTS model from HF Spaces.
"""
import os
import requests
import time
from gradio_client import Client
from openai import OpenAI
from pydantic import ValidationError
from bark import SAMPLE_RATE, generate_audio, preload_models
from scipy.io.wavfile import write as write_wav
MODEL_ID = "accounts/fireworks/models/llama-v3p1-405b-instruct"
JINA_URL = "https://r.jina.ai/"
client = OpenAI(
base_url="https://api.fireworks.ai/inference/v1",
api_key=os.getenv("FIREWORKS_API_KEY"),
)
hf_client = Client("mrfakename/MeloTTS")
# download and load all models
preload_models()
def generate_script(system_prompt: str, input_text: str, output_model):
"""Get the dialogue from the LLM."""
# Load as python object
try:
response = call_llm(system_prompt, input_text, output_model)
dialogue = output_model.model_validate_json(response.choices[0].message.content)
except ValidationError as e:
error_message = f"Failed to parse dialogue JSON: {e}"
system_prompt_with_error = f"{system_prompt}\n\nPlease return a VALID JSON object. This was the earlier error: {error_message}"
response = call_llm(system_prompt_with_error, input_text, output_model)
dialogue = output_model.model_validate_json(response.choices[0].message.content)
# Call the LLM again to improve the dialogue
system_prompt_with_dialogue = f"{system_prompt}\n\nHere is the first draft of the dialogue you provided:\n\n{dialogue}."
response = call_llm(
system_prompt_with_dialogue, "Please improve the dialogue.", output_model
)
improved_dialogue = output_model.model_validate_json(
response.choices[0].message.content
)
return improved_dialogue
def call_llm(system_prompt: str, text: str, dialogue_format):
"""Call the LLM with the given prompt and dialogue format."""
response = client.chat.completions.create(
messages=[
{"role": "system", "content": system_prompt},
{"role": "user", "content": text},
],
model=MODEL_ID,
max_tokens=16_384,
temperature=0.1,
response_format={
"type": "json_object",
"schema": dialogue_format.model_json_schema(),
},
)
return response
def parse_url(url: str) -> str:
"""Parse the given URL and return the text content."""
full_url = f"{JINA_URL}{url}"
response = requests.get(full_url, timeout=60)
return response.text
def generate_podcast_audio(text: str, speaker: str, language: str, use_advanced_audio: bool) -> str:
if use_advanced_audio:
audio_array = generate_audio(text, history_prompt=f"v2/{language}_speaker_{'1' if speaker == 'Host (Jane)' else '3'}")
file_path = f"audio_{language}_{speaker}.mp3"
# save audio to disk
write_wav(file_path, SAMPLE_RATE, audio_array)
return file_path
else:
if speaker == "Guest":
accent = "EN-US" if language == "EN" else language
speed = 0.9
else: # host
accent = "EN-Default" if language == "EN" else language
speed = 1
if language != "EN" and speaker != "Guest":
speed = 1.1
# Generate audio
for attempt in range(3):
try:
result = hf_client.predict(
text=text,
language=language,
speaker=accent,
speed=speed,
api_name="/synthesize",
)
return result
except Exception as e:
if attempt == 2: # Last attempt
raise # Re-raise the last exception if all attempts fail
time.sleep(1) # Wait for 1 second before retrying
|