from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline | |
from langchain.llms import HuggingFacePipeline | |
from langchain import PromptTemplate, LLMChain | |
from typing import Dict, List, Any | |
template = """{char_name}'s Persona: {char_persona} | |
<START> | |
{chat_history} | |
{char_name}: {char_greeting} | |
<END> | |
{user_name}: {user_input} | |
{char_name}: """ | |
class EndpointHandler(): | |
def __init__(self, path=""): | |
pass | |
# tokenizer = AutoTokenizer.from_pretrained(path) | |
# model = AutoModelForCausalLM.from_pretrained(path, load_in_8bit = True, device_map = "auto") | |
# local_llm = HuggingFacePipeline( | |
# pipeline = pipeline( | |
# "text-generation", | |
# model = model, | |
# tokenizer = tokenizer, | |
# max_length = 2048, | |
# temperature = 0.5, | |
# top_p = 0.9, | |
# top_k = 0, | |
# repetition_penalty = 1.1, | |
# pad_token_id = 50256, | |
# num_return_sequences = 1 | |
# ) | |
# ) | |
# prompt_template = PromptTemplate( | |
# template = template, | |
# input_variables = [ | |
# "user_input", | |
# "user_name", | |
# "char_name", | |
# "char_persona", | |
# "char_greeting", | |
# "chat_history" | |
# ], | |
# validate_template = True | |
# ) | |
# self.llm_engine = LLMChain( | |
# llm = local_llm, | |
# prompt = prompt_template | |
# ) | |
def __call__(self, data: Any) -> Any: | |
return data, type(data) | |
# inputs = data.pop("inputs", data) | |
# return self.llm_engine.predict( | |
# user_input = inputs["user_input"], | |
# user_name = inputs["user_name"], | |
# char_name = inputs["char_name"], | |
# char_persona = inputs["char_persona"], | |
# char_greeting = inputs["char_greeting"], | |
# chat_history = inputs["chat_history"] | |
# ).split("\n",1)[0] |