from transformers import AutoTokenizer, AutoModelForSeq2SeqLM import torch import json class EndpointHandler: def __init__(self, model_dir): # Cargar el modelo y el tokenizador desde el directorio del modelo self.tokenizer = AutoTokenizer.from_pretrained(model_dir) self.model = AutoModelForSeq2SeqLM.from_pretrained(model_dir) self.model.eval() # Configurar el modelo en modo de evaluaciĆ³n def preprocess(self, data): # Preprocesamiento de la entrada if isinstance(data, dict) and "inputs" in data: input_text = "Generate a valid JSON capturing data from this text: " + data["inputs"] else: raise ValueError("Esperando un diccionario con la clave 'inputs'") # TokenizaciĆ³n de la entrada tokens = self.tokenizer(input_text, return_tensors="pt", truncation=True, padding=True) return tokens def inference(self, tokens): # Realizar la inferencia with torch.no_grad(): outputs = self.model.generate(**tokens) return outputs def postprocess(self, outputs): # Decodificar la salida del modelo decoded_output = self.tokenizer.decode(outputs[0], skip_special_tokens=True) return {"generated_text": decoded_output} def __call__(self, data): # Llamada principal del handler para procesamiento completo tokens = self.preprocess(data) outputs = self.inference(tokens) result = self.postprocess(outputs) return result