|
from typing import Dict, List, Any |
|
|
|
|
|
|
|
|
|
from audiocraft.models import AudioGen |
|
from audiocraft.data.audio import audio_write |
|
|
|
class EndpointHandler: |
|
def __init__(self, path=""): |
|
|
|
|
|
|
|
|
|
self.model = AudioGen.get_pretrained(path) |
|
|
|
|
|
def __call__(self, data: Dict[str, Any]) -> Dict[str, str]: |
|
""" |
|
Args: |
|
data (:dict:): |
|
The payload with the text prompt and generation parameters. |
|
""" |
|
|
|
inputs = data.pop("inputs", data) |
|
duration = data.pop("duration", 5) |
|
|
|
self.model.set_generation_params(duration=duration) |
|
outputs = self.model.generate(inputs) |
|
|
|
|
|
output_files = [] |
|
for idx, one_wav in enumerate(outputs): |
|
output_file_path = f"generated_audio_{idx}.wav" |
|
audio_write(output_file_path, one_wav.cpu(), self.model.sample_rate, strategy="loudness", loudness_compressor=True) |
|
|
|
|
|
with open(output_file_path, "rb") as audio_file: |
|
encoded_string = base64.b64encode(audio_file.read()).decode('utf-8') |
|
output_files.append(encoded_string) |
|
|
|
return {"generated_audio_files": output_files} |