kaido-2.7b / handler.py
MrD05's picture
Update handler.py
f2899a8
raw
history blame
2.06 kB
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
from langchain.llms import HuggingFacePipeline
from langchain import PromptTemplate, LLMChain
from typing import Dict, List, Any
template = """{char_name}'s Persona: {char_persona}
<START>
{chat_history}
{char_name}: {char_greeting}
<END>
{user_name}: {user_input}
{char_name}: """
class EndpointHandler():
def __init__(self, path=""):
pass
# tokenizer = AutoTokenizer.from_pretrained(path)
# model = AutoModelForCausalLM.from_pretrained(path, load_in_8bit = True, device_map = "auto")
# local_llm = HuggingFacePipeline(
# pipeline = pipeline(
# "text-generation",
# model = model,
# tokenizer = tokenizer,
# max_length = 2048,
# temperature = 0.5,
# top_p = 0.9,
# top_k = 0,
# repetition_penalty = 1.1,
# pad_token_id = 50256,
# num_return_sequences = 1
# )
# )
# prompt_template = PromptTemplate(
# template = template,
# input_variables = [
# "user_input",
# "user_name",
# "char_name",
# "char_persona",
# "char_greeting",
# "chat_history"
# ],
# validate_template = True
# )
# self.llm_engine = LLMChain(
# llm = local_llm,
# prompt = prompt_template
# )
def __call__(self, data: Any) -> Any:
return data, type(data)
# inputs = data.pop("inputs", data)
# return self.llm_engine.predict(
# user_input = inputs["user_input"],
# user_name = inputs["user_name"],
# char_name = inputs["char_name"],
# char_persona = inputs["char_persona"],
# char_greeting = inputs["char_greeting"],
# chat_history = inputs["chat_history"]
# ).split("\n",1)[0]