AUD_MED1.5B / handler.py
selectmixer's picture
handler change1
a240179
from typing import Dict, List, Any
# from transformers import AutoProcessor, MusicgenForConditionalGeneration
# import torch
# import torchaudio
from audiocraft.models import AudioGen
from audiocraft.data.audio import audio_write
class EndpointHandler:
def __init__(self, path=""):
# load model and processor from path
# path = "jamesdon/audiogen-medium-endpoint"
# self.processor = AutoProcessor.from_pretrained(path)
# self.model = MusicgenForConditionalGeneration.from_pretrained(path).to("cuda")
self.model = AudioGen.get_pretrained(path)
def __call__(self, data: Dict[str, Any]) -> Dict[str, str]:
"""
Args:
data (:dict:):
The payload with the text prompt and generation parameters.
"""
# process input
inputs = data.pop("inputs", data) # list of string
duration = data.pop("duration", 5) # seconds to generate
self.model.set_generation_params(duration=duration)
outputs = self.model.generate(inputs)
# Save each generated audio file with loudness normalization and encode in base64
output_files = []
for idx, one_wav in enumerate(outputs):
output_file_path = f"generated_audio_{idx}.wav"
audio_write(output_file_path, one_wav.cpu(), self.model.sample_rate, strategy="loudness", loudness_compressor=True)
# Read the file and encode it in base64
with open(output_file_path, "rb") as audio_file:
encoded_string = base64.b64encode(audio_file.read()).decode('utf-8')
output_files.append(encoded_string)
return {"generated_audio_files": output_files}