File size: 1,162 Bytes
f77503f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
from typing import Dict, List, Any
# from transformers import AutoProcessor, MusicgenForConditionalGeneration
# import torch

# import torchaudio
from audiocraft.models import AudioGen
from audiocraft.data.audio import audio_write

class EndpointHandler:
    def __init__(self, path=""):
        # load model and processor from path
        # path = "jamesdon/audiogen-medium-endpoint"
        # self.processor = AutoProcessor.from_pretrained(path)
        # self.model = MusicgenForConditionalGeneration.from_pretrained(path).to("cuda")
        self.model = AudioGen.get_pretrained(path)
        

    def __call__(self, data: Dict[str, Any]) -> Dict[str, str]:
        """
        Args:
            data (:dict:):
                The payload with the text prompt and generation parameters.
        """
        # process input
        inputs = data.pop("inputs", data) # list of string
        duration = data.pop("duration", 5) # seconds to generate
        
        self.model.set_generation_params(duration=duration)
        outputs = self.model.generate(inputs)
        prediction = outputs[0].cpu().numpy()

        return [{"generated_audio": prediction}]