from typing import Dict, List, Any # from transformers import AutoProcessor, MusicgenForConditionalGeneration # import torch # import torchaudio from audiocraft.models import AudioGen from audiocraft.data.audio import audio_write class EndpointHandler: def __init__(self, path=""): # load model and processor from path # path = "jamesdon/audiogen-medium-endpoint" # self.processor = AutoProcessor.from_pretrained(path) # self.model = MusicgenForConditionalGeneration.from_pretrained(path).to("cuda") self.model = AudioGen.get_pretrained(path) def __call__(self, data: Dict[str, Any]) -> Dict[str, str]: """ Args: data (:dict:): The payload with the text prompt and generation parameters. """ # process input inputs = data.pop("inputs", data) # list of string duration = data.pop("duration", 5) # seconds to generate self.model.set_generation_params(duration=duration) outputs = self.model.generate(inputs) prediction = outputs[0].cpu().numpy() # Define the output file path output_file_path = "generated_audio.wav" # Save the generated audio to a file audio_write(output_file_path, prediction, sample_rate=16000) # Adjust sample_rate as needed return {"generated_audio_file": output_file_path}