File size: 3,959 Bytes
9c20b4e
 
 
 
 
 
 
 
 
f136260
dc06293
b9d657b
9c20b4e
 
 
 
8ddd281
 
 
d336953
dc06293
d336953
9c20b4e
 
 
 
 
8fa13bc
8ddd281
 
 
9c20b4e
 
9db5d78
9c20b4e
 
 
9db5d78
f17c34f
9c20b4e
 
9db5d78
 
f17c34f
 
 
 
 
 
 
 
 
 
 
9c20b4e
 
 
 
 
 
 
 
 
d336953
9c20b4e
 
 
 
 
 
 
 
 
 
dc06293
 
 
 
 
 
8a1ab06
8fa13bc
8ddd281
8fa13bc
 
8ddd281
8fa13bc
8ddd281
8fa13bc
 
8ddd281
8fa13bc
8ddd281
 
8fa13bc
 
 
 
 
 
 
 
 
8ddd281
8fa13bc
b9d657b
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
"""
utils.py

Functions:
- get_script: Get the dialogue from the LLM.
- call_llm: Call the LLM with the given prompt and dialogue format.
- get_audio: Get the audio from the TTS model from HF Spaces.
"""

import os
import requests
import time
from gradio_client import Client
from openai import OpenAI
from pydantic import ValidationError

from bark import SAMPLE_RATE, generate_audio, preload_models
from scipy.io.wavfile import write as write_wav

MODEL_ID = "accounts/fireworks/models/llama-v3p1-405b-instruct"
JINA_URL = "https://r.jina.ai/"

client = OpenAI(
    base_url="https://api.fireworks.ai/inference/v1",
    api_key=os.getenv("FIREWORKS_API_KEY"),
)

hf_client = Client("mrfakename/MeloTTS")

# download and load all models
preload_models()


def generate_script(system_prompt: str, input_text: str, output_model):
    """Get the dialogue from the LLM."""
    # Load as python object
    try:
        response = call_llm(system_prompt, input_text, output_model)
        dialogue = output_model.model_validate_json(response.choices[0].message.content)
    except ValidationError as e:
        error_message = f"Failed to parse dialogue JSON: {e}"
        system_prompt_with_error = f"{system_prompt}\n\nPlease return a VALID JSON object. This was the earlier error: {error_message}"
        response = call_llm(system_prompt_with_error, input_text, output_model)
        dialogue = output_model.model_validate_json(response.choices[0].message.content)

    # Call the LLM again to improve the dialogue
    system_prompt_with_dialogue = f"{system_prompt}\n\nHere is the first draft of the dialogue you provided:\n\n{dialogue}."
    response = call_llm(
        system_prompt_with_dialogue, "Please improve the dialogue.", output_model
    )
    improved_dialogue = output_model.model_validate_json(
        response.choices[0].message.content
    )
    return improved_dialogue


def call_llm(system_prompt: str, text: str, dialogue_format):
    """Call the LLM with the given prompt and dialogue format."""
    response = client.chat.completions.create(
        messages=[
            {"role": "system", "content": system_prompt},
            {"role": "user", "content": text},
        ],
        model=MODEL_ID,
        max_tokens=16_384,
        temperature=0.1,
        response_format={
            "type": "json_object",
            "schema": dialogue_format.model_json_schema(),
        },
    )
    return response


def parse_url(url: str) -> str:
    """Parse the given URL and return the text content."""
    full_url = f"{JINA_URL}{url}"
    response = requests.get(full_url, timeout=60)
    return response.text


def generate_podcast_audio(text: str, speaker: str, language: str, use_advanced_audio: bool) -> str:

    if use_advanced_audio:
        audio_array = generate_audio(text, history_prompt=f"v2/{language}_speaker_{'1' if speaker == 'Host (Jane)' else '3'}")

        file_path = f"audio_{language}_{speaker}.mp3"

        # save audio to disk
        write_wav(file_path, SAMPLE_RATE, audio_array)

        return file_path


    else:
        if speaker == "Guest":
            accent = "EN-US" if language == "EN" else language
            speed = 0.9
        else:  # host
            accent = "EN-Default" if language == "EN" else language
            speed = 1
        if language != "EN" and speaker != "Guest":
            speed = 1.1

        # Generate audio
        for attempt in range(3):
            try:
                result = hf_client.predict(
                    text=text,
                    language=language,
                    speaker=accent,
                    speed=speed,
                    api_name="/synthesize",
                )
                return result
            except Exception as e:
                if attempt == 2:  # Last attempt
                    raise  # Re-raise the last exception if all attempts fail
                time.sleep(1)  # Wait for 1 second before retrying