CoachCasey / handler.py
Dantinob's picture
Update handler.py
cd6f314 verified
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
class EndpointHandler:
def __init__(self, model_dir):
# Load tokenizer and model during initialization
self.tokenizer = AutoTokenizer.from_pretrained(model_dir)
self.model = AutoModelForCausalLM.from_pretrained(model_dir)
def __call__(self, data):
"""
This method processes input data and generates output.
:param data: Input data, usually a dictionary with 'inputs' key.
"""
# Extract input prompt
inputs = data.get("inputs", "")
if not inputs:
return {"error": "No input provided"}
# Preprocess input
encoded_inputs = self.tokenizer(inputs, return_tensors="pt", padding=True)
# Generate output
with torch.no_grad():
outputs = self.model.generate(
**encoded_inputs,
max_length=200,
temperature=0.7,
do_sample=True
)
# Decode and return response
response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
return {"generated_text": response}