|
from transformers import MarianMTModel, MarianTokenizer |
|
from typing import Any, List, Dict |
|
|
|
class EndpointHandler: |
|
def __init__(self, path=""): |
|
|
|
self.model = MarianMTModel.from_pretrained(path) |
|
self.tokenizer = MarianTokenizer.from_pretrained(path) |
|
|
|
def __call__(self, data: Any) -> List[Dict[str, str]]: |
|
""" |
|
Args: |
|
data (dict): The request payload with an "inputs" key containing the text to translate. |
|
Returns: |
|
List[Dict]: A list containing the translated text. |
|
""" |
|
|
|
text = data.get("inputs", "") |
|
|
|
|
|
inputs = self.tokenizer(text, return_tensors="pt", padding=True) |
|
|
|
|
|
translated = self.model.generate(**inputs) |
|
|
|
|
|
translated_text = self.tokenizer.decode(translated[0], skip_special_tokens=True) |
|
|
|
|
|
return [{"translation_text": translated_text}] |