Jose Benitez commited on
Commit
e244774
·
1 Parent(s): ed5138d

add endpoint handler

Browse files
Files changed (1) hide show
  1. handler.py +150 -0
handler.py ADDED
@@ -0,0 +1,150 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from typing import Dict, Any, List, Union
3
+ from transformers import VitsModel, AutoTokenizer
4
+ import numpy as np
5
+
6
+ class EndpointHandler:
7
+ def __init__(self, path="joselobenitezg/mms-grn-tts", device=None):
8
+ """Initialize the VITS TTS model and tokenizer.
9
+
10
+ Args:
11
+ path (str): HuggingFace model path
12
+ device (str, optional): Device to run the model on ('cuda', 'cpu', or specific cuda device)
13
+ """
14
+ # Device management
15
+ self.device = device or ('cuda' if torch.cuda.is_available() else 'cpu')
16
+
17
+ try:
18
+ self.tokenizer = AutoTokenizer.from_pretrained(path)
19
+ self.model = VitsModel.from_pretrained(path).to(self.device)
20
+ self.sampling_rate = self.model.config.sampling_rate
21
+ except Exception as e:
22
+ raise RuntimeError(f"Failed to load model and tokenizer: {str(e)}")
23
+
24
+ # Set maximum input length
25
+ self.max_input_length = 200
26
+
27
+ print(f"Model loaded on {self.device}")
28
+
29
+ def validate_input(self, text: Union[str, List[str]]) -> List[str]:
30
+ """Validate and preprocess input text.
31
+
32
+ Args:
33
+ text: Input text or list of texts
34
+
35
+ Returns:
36
+ List[str]: Validated and processed text list
37
+
38
+ Raises:
39
+ ValueError: If input validation fails
40
+ """
41
+ # Convert single string to list
42
+ if isinstance(text, str):
43
+ text = [text]
44
+ elif isinstance(text, list):
45
+ if not all(isinstance(t, str) for t in text):
46
+ raise ValueError("All elements in the input list must be strings")
47
+ else:
48
+ raise ValueError("Input must be a string or list of strings")
49
+
50
+ # Validate each text
51
+ for t in text:
52
+ if not t.strip():
53
+ raise ValueError("Empty text is not allowed")
54
+ if len(t) > self.max_input_length:
55
+ raise ValueError(f"Input text exceeds maximum length of {self.max_input_length}")
56
+
57
+ return text
58
+
59
+ def batch_process(self, texts: List[str], batch_size: int = 8) -> List[Dict[str, Any]]:
60
+ """Process multiple texts in batches.
61
+
62
+ Args:
63
+ texts (List[str]): List of texts to process
64
+ batch_size (int): Size of each batch
65
+
66
+ Returns:
67
+ List[Dict[str, Any]]: List of results for each text
68
+ """
69
+ results = []
70
+
71
+ for i in range(0, len(texts), batch_size):
72
+ batch_texts = texts[i:i + batch_size]
73
+ # Tokenize batch
74
+ inputs = self.tokenizer(batch_texts, padding=True, return_tensors="pt")
75
+ inputs = {k: v.to(self.device) for k, v in inputs.items()}
76
+
77
+ try:
78
+ with torch.no_grad():
79
+ outputs = self.model(**inputs).waveform
80
+
81
+ for waveform in outputs:
82
+ # Move to CPU and convert to numpy
83
+ waveform_np = waveform.cpu().numpy()
84
+ results.append({
85
+ "waveform": waveform_np.tolist(),
86
+ "sampling_rate": self.sampling_rate
87
+ })
88
+ except Exception as e:
89
+ raise RuntimeError(f"Error during batch processing: {str(e)}")
90
+
91
+ return results
92
+
93
+ def __call__(self, data: Union[Dict[str, Any], str, List[str]]) -> Union[Dict[str, Any], List[Dict[str, Any]]]:
94
+ """Process the input text and generate audio.
95
+
96
+ Args:
97
+ data: Input data in one of these formats:
98
+ - Dict[str, Any]: {"inputs": "text" or ["text1", "text2"], "batch_size": int}
99
+ - str: Direct text input
100
+ - List[str]: List of texts to process
101
+
102
+ Returns:
103
+ Union[Dict[str, Any], List[Dict[str, Any]]]: Dictionary or list of dictionaries
104
+ containing the audio waveform(s) and sampling rate
105
+ """
106
+ try:
107
+ # Handle different input types
108
+ if isinstance(data, dict):
109
+ text = data.get("inputs", "")
110
+ batch_size = data.get("batch_size", 8)
111
+ elif isinstance(data, (str, list)):
112
+ text = data
113
+ batch_size = 8
114
+ else:
115
+ raise ValueError(f"Unsupported input type: {type(data)}")
116
+
117
+ # Validate input
118
+ texts = self.validate_input(text)
119
+
120
+ # Single input case
121
+ if len(texts) == 1:
122
+ inputs = self.tokenizer(texts[0], return_tensors="pt")
123
+ inputs = {k: v.to(self.device) for k, v in inputs.items()}
124
+
125
+ with torch.no_grad():
126
+ output = self.model(**inputs).waveform
127
+ waveform = output.cpu().squeeze().numpy()
128
+
129
+ return {
130
+ "waveform": waveform.tolist(),
131
+ "sampling_rate": self.sampling_rate
132
+ }
133
+
134
+ # Multiple inputs case
135
+ else:
136
+ return self.batch_process(texts, batch_size)
137
+
138
+ except Exception as e:
139
+ error_msg = f"Error processing input: {str(e)}"
140
+ print(error_msg) # Log the error
141
+ raise RuntimeError(error_msg)
142
+
143
+ def cleanup(self):
144
+ """Cleanup resources when shutting down."""
145
+ try:
146
+ # Clear CUDA cache if using GPU
147
+ if 'cuda' in self.device:
148
+ torch.cuda.empty_cache()
149
+ except Exception as e:
150
+ print(f"Error during cleanup: {str(e)}")