from typing import Dict, List, Any # from transformers import AutoProcessor, MusicgenForConditionalGeneration # import torch # import torchaudio from audiocraft.models import AudioGen from audiocraft.data.audio import audio_write class EndpointHandler: def __init__(self, path=""): # load model and processor from path # path = "jamesdon/audiogen-medium-endpoint" # self.processor = AutoProcessor.from_pretrained(path) # self.model = MusicgenForConditionalGeneration.from_pretrained(path).to("cuda") self.model = AudioGen.get_pretrained(path) def __call__(self, data: Dict[str, Any]) -> Dict[str, str]: """ Args: data (:dict:): The payload with the text prompt and generation parameters. """ # process input inputs = data.pop("inputs", data) # list of string duration = data.pop("duration", 5) # seconds to generate self.model.set_generation_params(duration=duration) outputs = self.model.generate(inputs) # Save each generated audio file with loudness normalization and encode in base64 output_files = [] for idx, one_wav in enumerate(outputs): output_file_path = f"generated_audio_{idx}.wav" audio_write(output_file_path, one_wav.cpu(), self.model.sample_rate, strategy="loudness", loudness_compressor=True) # Read the file and encode it in base64 with open(output_file_path, "rb") as audio_file: encoded_string = base64.b64encode(audio_file.read()).decode('utf-8') output_files.append(encoded_string) return {"generated_audio_files": output_files}