File size: 1,746 Bytes
f77503f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a240179
42c394e
 
 
 
a240179
 
 
 
 
def1f31
a240179
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
from typing import Dict, List, Any
# from transformers import AutoProcessor, MusicgenForConditionalGeneration
# import torch

# import torchaudio
from audiocraft.models import AudioGen
from audiocraft.data.audio import audio_write

class EndpointHandler:
    def __init__(self, path=""):
        # load model and processor from path
        # path = "jamesdon/audiogen-medium-endpoint"
        # self.processor = AutoProcessor.from_pretrained(path)
        # self.model = MusicgenForConditionalGeneration.from_pretrained(path).to("cuda")
        self.model = AudioGen.get_pretrained(path)
        

    def __call__(self, data: Dict[str, Any]) -> Dict[str, str]:
        """
        Args:
            data (:dict:):
                The payload with the text prompt and generation parameters.
        """
        # process input
        inputs = data.pop("inputs", data) # list of string
        duration = data.pop("duration", 5) # seconds to generate
        
        self.model.set_generation_params(duration=duration)
        outputs = self.model.generate(inputs)

        # Save each generated audio file with loudness normalization and encode in base64
        output_files = []
        for idx, one_wav in enumerate(outputs):
            output_file_path = f"generated_audio_{idx}.wav"
            audio_write(output_file_path, one_wav.cpu(), self.model.sample_rate, strategy="loudness", loudness_compressor=True)
            
            # Read the file and encode it in base64
            with open(output_file_path, "rb") as audio_file:
                encoded_string = base64.b64encode(audio_file.read()).decode('utf-8')
                output_files.append(encoded_string)
        
        return {"generated_audio_files": output_files}