Text-to-Audio
Transformers
musicgen
Inference Endpoints
mmomeni commited on
Commit
93bba1b
·
verified ·
1 Parent(s): 91330cc

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +25 -24
handler.py CHANGED
@@ -1,14 +1,14 @@
1
- from typing import Dict, List, Any
2
- from transformers import AutoProcessor, MusicgenForConditionalGeneration
3
- import torch
4
 
5
  class EndpointHandler:
6
- def __init__(self, path=""):
7
- # load model and processor from path
8
- self.processor = AutoProcessor.from_pretrained(path)
9
- self.model = MusicgenForConditionalGeneration.from_pretrained(path, torch_dtype=torch.float16).to("cuda")
10
 
11
- def __call__(self, data: Dict[str, Any]) -> Dict[str, str]:
12
  """
13
  Args:
14
  data (:dict:):
@@ -16,23 +16,24 @@ class EndpointHandler:
16
  """
17
  # process input
18
  inputs = data.pop("inputs", data)
19
- parameters = data.pop("parameters", None)
20
 
21
- # preprocess
22
- inputs = self.processor(
23
- text=[inputs],
24
- padding=True,
25
- return_tensors="pt",).to("cuda")
26
 
27
- # pass inputs with all kwargs in data
28
- if parameters is not None:
29
- with torch.autocast("cuda"):
30
- outputs = self.model.generate(**inputs, **parameters)
31
- else:
32
- with torch.autocast("cuda"):
33
- outputs = self.model.generate(**inputs,)
34
 
35
- # postprocess the prediction
36
- prediction = outputs[0].cpu().numpy().tolist()
 
 
 
37
 
38
- return [{"generated_audio": prediction}]
 
 
 
 
 
1
+ from typing import Dict, Any
2
+ from audiocraft.models import AudioGen
3
+ # from audiocraft.data.audio import audio_write
4
 
5
  class EndpointHandler:
6
+ def __init__(self):
7
+ # Load the AudioGen model
8
+ self.model = AudioGen.get_pretrained('facebook/audiogen-medium')
9
+ self.model.set_generation_params(duration=5) # Set default duration to 5 seconds
10
 
11
+ def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
12
  """
13
  Args:
14
  data (:dict:):
 
16
  """
17
  # process input
18
  inputs = data.pop("inputs", data)
19
+ parameters = data.pop("parameters", {})
20
 
21
+ # Update generation parameters if provided
22
+ if 'duration' in parameters:
23
+ self.model.set_generation_params(duration=parameters['duration'])
 
 
24
 
25
+ # Generate audio from descriptions
26
+ descriptions = [inputs]
27
+ wav = self.model.generate(descriptions)
 
 
 
 
28
 
29
+ # Convert the generated audio to a list format for JSON serialization
30
+ predictions = []
31
+ for idx, one_wav in enumerate(wav):
32
+ # Save the audio to a file (optional)
33
+ # audio_write(f'{idx}', one_wav.cpu(), self.model.sample_rate, strategy="loudness", loudness_compressor=True)
34
 
35
+ # Convert the tensor to a list
36
+ prediction = one_wav.cpu().numpy().tolist()
37
+ predictions.append(prediction)
38
+
39
+ return {"generated_audio": predictions}