File size: 1,746 Bytes
f77503f a240179 42c394e a240179 def1f31 a240179 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 |
from typing import Dict, List, Any
# from transformers import AutoProcessor, MusicgenForConditionalGeneration
# import torch
# import torchaudio
from audiocraft.models import AudioGen
from audiocraft.data.audio import audio_write
class EndpointHandler:
def __init__(self, path=""):
# load model and processor from path
# path = "jamesdon/audiogen-medium-endpoint"
# self.processor = AutoProcessor.from_pretrained(path)
# self.model = MusicgenForConditionalGeneration.from_pretrained(path).to("cuda")
self.model = AudioGen.get_pretrained(path)
def __call__(self, data: Dict[str, Any]) -> Dict[str, str]:
"""
Args:
data (:dict:):
The payload with the text prompt and generation parameters.
"""
# process input
inputs = data.pop("inputs", data) # list of string
duration = data.pop("duration", 5) # seconds to generate
self.model.set_generation_params(duration=duration)
outputs = self.model.generate(inputs)
# Save each generated audio file with loudness normalization and encode in base64
output_files = []
for idx, one_wav in enumerate(outputs):
output_file_path = f"generated_audio_{idx}.wav"
audio_write(output_file_path, one_wav.cpu(), self.model.sample_rate, strategy="loudness", loudness_compressor=True)
# Read the file and encode it in base64
with open(output_file_path, "rb") as audio_file:
encoded_string = base64.b64encode(audio_file.read()).decode('utf-8')
output_files.append(encoded_string)
return {"generated_audio_files": output_files} |