kaido-2.7b / handler.py
MrD05's picture
Upload handler.py
f546677
raw
history blame
2.23 kB
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
from langchain.llms import HuggingFacePipeline
from langchain import PromptTemplate, LLMChain
from typing import Dict, List, Any
import holidays
template = """{char_name}'s Persona: {char_persona}
<START>
{chat_history}
{char_name}: {char_greeting}
<END>
{user_name}: {user_input}
{char_name}: """
class EndpointHandler():
def __init__(self, path=""):
self.tokenizer = AutoTokenizer.from_pretrained(path)
self.model = AutoModelForCausalLM.from_pretrained(path, load_in_8bit = True, device_map = "auto")
self.local_llm = HuggingFacePipeline(
pipeline = pipeline(
"text-generation",
model = self.model,
tokenizer = self.tokenizer,
max_length = 2048,
temperature = 0.5,
top_p = 0.9,
top_k = 0,
repetition_penalty = 1.1,
pad_token_id = 50256,
num_return_sequences = 1
)
)
self.prompt_template = PromptTemplate(
template = template,
input_variables = [
"user_input",
"user_name",
"char_name",
"char_persona",
"char_greeting",
"chat_history"
],
validate_template = True
)
self.llm_engine = LLMChain(
llm = self.local_llm,
prompt = self.prompt_template,
verbose = True
)
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
"""
data args:
inputs (:obj: `str`)
date (:obj: `str`)
Return:
A :obj:`list` | `dict`: will be serialized and returned
"""
inputs = data.pop("inputs", data)
return self.llm_engine.predict(
user_input = inputs["user_input"],
user_name = inputs["user_name"],
char_name = inputs["char_name"],
char_persona = inputs["char_persona"],
char_greeting = inputs["char_greeting"],
chat_history = inputs["chat_history"]
).split("\n",1)[0]