from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline from langchain.llms import HuggingFacePipeline from langchain import PromptTemplate, LLMChain from typing import Dict, List, Any template = """{char_name}'s Persona: {char_persona} {chat_history} {char_name}: {char_greeting} {user_name}: {user_input} {char_name}: """ class EndpointHandler(): def __init__(self, path=""): self.tokenizer = AutoTokenizer.from_pretrained(path) self.model = AutoModelForCausalLM.from_pretrained(path, load_in_8bit = True, device_map = "auto") self.local_llm = HuggingFacePipeline( pipeline = pipeline( "text-generation", model = self.model, tokenizer = self.tokenizer, max_length = 2048, temperature = 0.5, top_p = 0.9, top_k = 0, repetition_penalty = 1.1, pad_token_id = 50256, num_return_sequences = 1 ) ) self.prompt_template = PromptTemplate( template = template, input_variables = [ "user_input", "user_name", "char_name", "char_persona", "char_greeting", "chat_history" ], validate_template = True ) self.llm_engine = LLMChain( llm = self.local_llm, prompt = self.prompt_template, verbose = True ) def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]: """ data args: inputs (:obj: `str`) date (:obj: `str`) Return: A :obj:`list` | `dict`: will be serialized and returned """ inputs = data.pop("inputs", data) return self.llm_engine.predict( user_input = inputs["user_input"], user_name = inputs["user_name"], char_name = inputs["char_name"], char_persona = inputs["char_persona"], char_greeting = inputs["char_greeting"], chat_history = inputs["chat_history"] ).split("\n",1)[0]