from typing import Dict, Any from transformers import AutoProcessor, MusicgenForConditionalGeneration import torch class EndpointHandler: def __init__(self, path=""): # Load the processor and model from the specified path self.processor = AutoProcessor.from_pretrained(path) self.model = MusicgenForConditionalGeneration.from_pretrained( path, torch_dtype=torch.float16 ).to("cuda") self.sampling_rate = self.model.config.audio_encoder.sampling_rate def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: """ Args: data (dict): The payload with the text prompt and generation parameters. """ # Extract inputs and parameters from the payload inputs = data.get("inputs", {}) prompt = inputs.get("prompt", "") duration = inputs.get("duration", 10) parameters = data.get("parameters", {}) # Validate the prompt if not prompt: return {"error": "No prompt provided."} # Preprocess the prompt input_ids = self.processor( text=[prompt], padding=True, return_tensors="pt", ).to("cuda") # Set generation parameters gen_kwargs = { "max_new_tokens": int(duration * 50), # MusicGen uses 50 tokens per second **parameters, } # Generate audio with torch.autocast("cuda"): outputs = self.model.generate(**input_ids, **gen_kwargs) # Convert the output audio tensor to a list of lists (channel-wise) audio_tensor = outputs[0].cpu() # Shape: [num_channels, seq_len] audio_list = audio_tensor.numpy().tolist() # [[channel1_data], [channel2_data]] return [ { "generated_audio": audio_list, "sample_rate": self.sampling_rate, } ]