Update handler.py
Browse files- handler.py +25 -24
handler.py
CHANGED
@@ -1,14 +1,14 @@
|
|
1 |
-
from typing import Dict,
|
2 |
-
from
|
3 |
-
import
|
4 |
|
5 |
class EndpointHandler:
|
6 |
-
def __init__(self
|
7 |
-
#
|
8 |
-
self.
|
9 |
-
self.model
|
10 |
|
11 |
-
def __call__(self, data: Dict[str, Any]) -> Dict[str,
|
12 |
"""
|
13 |
Args:
|
14 |
data (:dict:):
|
@@ -16,23 +16,24 @@ class EndpointHandler:
|
|
16 |
"""
|
17 |
# process input
|
18 |
inputs = data.pop("inputs", data)
|
19 |
-
parameters = data.pop("parameters",
|
20 |
|
21 |
-
#
|
22 |
-
|
23 |
-
|
24 |
-
padding=True,
|
25 |
-
return_tensors="pt",).to("cuda")
|
26 |
|
27 |
-
#
|
28 |
-
|
29 |
-
|
30 |
-
outputs = self.model.generate(**inputs, **parameters)
|
31 |
-
else:
|
32 |
-
with torch.autocast("cuda"):
|
33 |
-
outputs = self.model.generate(**inputs,)
|
34 |
|
35 |
-
#
|
36 |
-
|
|
|
|
|
|
|
37 |
|
38 |
-
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Dict, Any
|
2 |
+
from audiocraft.models import AudioGen
|
3 |
+
# from audiocraft.data.audio import audio_write
|
4 |
|
5 |
class EndpointHandler:
|
6 |
+
def __init__(self):
|
7 |
+
# Load the AudioGen model
|
8 |
+
self.model = AudioGen.get_pretrained('facebook/audiogen-medium')
|
9 |
+
self.model.set_generation_params(duration=5) # Set default duration to 5 seconds
|
10 |
|
11 |
+
def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
|
12 |
"""
|
13 |
Args:
|
14 |
data (:dict:):
|
|
|
16 |
"""
|
17 |
# process input
|
18 |
inputs = data.pop("inputs", data)
|
19 |
+
parameters = data.pop("parameters", {})
|
20 |
|
21 |
+
# Update generation parameters if provided
|
22 |
+
if 'duration' in parameters:
|
23 |
+
self.model.set_generation_params(duration=parameters['duration'])
|
|
|
|
|
24 |
|
25 |
+
# Generate audio from descriptions
|
26 |
+
descriptions = [inputs]
|
27 |
+
wav = self.model.generate(descriptions)
|
|
|
|
|
|
|
|
|
28 |
|
29 |
+
# Convert the generated audio to a list format for JSON serialization
|
30 |
+
predictions = []
|
31 |
+
for idx, one_wav in enumerate(wav):
|
32 |
+
# Save the audio to a file (optional)
|
33 |
+
# audio_write(f'{idx}', one_wav.cpu(), self.model.sample_rate, strategy="loudness", loudness_compressor=True)
|
34 |
|
35 |
+
# Convert the tensor to a list
|
36 |
+
prediction = one_wav.cpu().numpy().tolist()
|
37 |
+
predictions.append(prediction)
|
38 |
+
|
39 |
+
return {"generated_audio": predictions}
|