Phoenixak99's picture
Update handler.py
06c68e1 verified
raw
history blame
2.21 kB
# handler.py
from typing import Dict, Any
from transformers import AutoProcessor, MusicgenForConditionalGeneration
import torch
class EndpointHandler:
def __init__(self, path=""):
"""Initialize the model and processor."""
self.processor = AutoProcessor.from_pretrained(path)
self.model = MusicgenForConditionalGeneration.from_pretrained(
path,
torch_dtype=torch.float16,
device_map="auto" # Added for better GPU management
).to("cuda")
def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
"""Process the input data and generate audio."""
try:
# Extract inputs and parameters
inputs = data.pop("inputs", data)
parameters = data.pop("parameters", {})
# Get prompt and duration
prompt = inputs.get("prompt", "")
duration = inputs.get("duration", 30)
# Calculate max_new_tokens based on duration
samples_per_token = 1024
sampling_rate = 32000
max_new_tokens = int((duration * sampling_rate) / samples_per_token)
# Process input text
model_inputs = self.processor(
text=[prompt],
padding=True,
return_tensors="pt"
).to("cuda")
# Set default generation parameters
generation_params = {
"do_sample": True,
"guidance_scale": 3,
"max_new_tokens": max_new_tokens
}
# Update with any user-provided parameters
generation_params.update(parameters)
# Generate audio with autocast for memory efficiency
with torch.cuda.amp.autocast():
audio_values = self.model.generate(**model_inputs, **generation_params)
# Convert to list for JSON serialization
audio_data = audio_values.cpu().numpy().tolist()
return [{"generated_audio": audio_data}]
except Exception as e:
return {"error": str(e)}