File size: 2,911 Bytes
3b9d7a6 3bf2eac 3b9d7a6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 |
from open_lm.hf import *
from transformers import AutoTokenizer, AutoModelForCausalLM
class Chat:
def __init__(
self,
path="mathewhe/DCLM-7B-Chat",
device="cuda",
):
r"""
Construct :class:`Chat`\.
Args:
path (str): Model name or path.
device (str): Model device.
"""
self.tokenizer = AutoTokenizer.from_pretrained(path)
self.tokenizer.add_tokens(
["[ASST]", "[INST]", "[/ASST]", "[/INST]"],
special_tokens=True,
)
self.model = AutoModelForCausalLM.from_pretrained(path, device_map="cuda")
self.messages = list()
self.device = device
self.gen_kwargs = {
"min_new_tokens": 1,
"max_new_tokens": 2048,
"top_p": 0.8,
"temperature": 0.8,
"do_sample": True,
"repetition_penalty": 1.1,
}
def reset(self):
self.messages = list()
def _inference(self, messages):
chat = self.tokenizer.apply_chat_template(messages, tokenize=False)
inputs = {
k: v.to(self.device)
for k, v in self.tokenizer(chat, return_tensors="pt").items()
}
input_length = len(inputs["input_ids"][0])
output = self.model.generate(**inputs, **self.gen_kwargs)
response = self.tokenizer.decode(
output[0].tolist()[input_length:],
skip_special_tokens=True,
)
if response.startswith(" "): # fix this so it's handled correctly by the tokenizer
response = response[1:]
return response
def message(self, message):
r"""
Add a user message to the chat history and save and return a response.
Args:
message (str): The user message.
"""
self.messages.append({"role": "user", "content": message})
response = self._inference(self.messages)
self.messages.append({"role": "assistant", "content": response})
return response
def cli_chat(self):
r"""
For CLI-based chatting (with history).
"""
asst_prompt = "Assistant: "
user_prompt = "---> User: "
print(f"{asst_prompt}Hi! How can I help you?\n")
message = input(user_prompt)
while not (message is None or message == ""):
response = self.message(message)
print(f"\n{asst_prompt}{response}\n")
message = input(user_prompt)
def instruct(self, message):
r"""
For single instruction-response interactions (without history).
Args:
message (str): An instruction or one-off user message.
"""
messages = [{"role": "user", "content": message}]
response = self._inference(messages)
return response
if __name__ == "__main__":
chat = Chat()
chat.cli_chat()
|