MrD05 commited on
Commit
f546677
·
1 Parent(s): 6bc40c3

Upload handler.py

Browse files
Files changed (1) hide show
  1. handler.py +69 -0
handler.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
2
+ from langchain.llms import HuggingFacePipeline
3
+ from langchain import PromptTemplate, LLMChain
4
+ from typing import Dict, List, Any
5
+ import holidays
6
+
7
+ template = """{char_name}'s Persona: {char_persona}
8
+ <START>
9
+ {chat_history}
10
+ {char_name}: {char_greeting}
11
+ <END>
12
+ {user_name}: {user_input}
13
+ {char_name}: """
14
+
15
+ class EndpointHandler():
16
+
17
+ def __init__(self, path=""):
18
+ self.tokenizer = AutoTokenizer.from_pretrained(path)
19
+ self.model = AutoModelForCausalLM.from_pretrained(path, load_in_8bit = True, device_map = "auto")
20
+ self.local_llm = HuggingFacePipeline(
21
+ pipeline = pipeline(
22
+ "text-generation",
23
+ model = self.model,
24
+ tokenizer = self.tokenizer,
25
+ max_length = 2048,
26
+ temperature = 0.5,
27
+ top_p = 0.9,
28
+ top_k = 0,
29
+ repetition_penalty = 1.1,
30
+ pad_token_id = 50256,
31
+ num_return_sequences = 1
32
+ )
33
+ )
34
+ self.prompt_template = PromptTemplate(
35
+ template = template,
36
+ input_variables = [
37
+ "user_input",
38
+ "user_name",
39
+ "char_name",
40
+ "char_persona",
41
+ "char_greeting",
42
+ "chat_history"
43
+ ],
44
+ validate_template = True
45
+ )
46
+ self.llm_engine = LLMChain(
47
+ llm = self.local_llm,
48
+ prompt = self.prompt_template,
49
+ verbose = True
50
+ )
51
+
52
+ def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
53
+ """
54
+ data args:
55
+ inputs (:obj: `str`)
56
+ date (:obj: `str`)
57
+ Return:
58
+ A :obj:`list` | `dict`: will be serialized and returned
59
+ """
60
+ inputs = data.pop("inputs", data)
61
+
62
+ return self.llm_engine.predict(
63
+ user_input = inputs["user_input"],
64
+ user_name = inputs["user_name"],
65
+ char_name = inputs["char_name"],
66
+ char_persona = inputs["char_persona"],
67
+ char_greeting = inputs["char_greeting"],
68
+ chat_history = inputs["chat_history"]
69
+ ).split("\n",1)[0]