|
import logging |
|
from transformers import AutoTokenizer, AutoModelForCausalLM |
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
class Chat: |
|
def __init__( |
|
self, |
|
path="mathewhe/Llama-3.1-8B-Chat", |
|
device="cuda", |
|
): |
|
self.tokenizer = AutoTokenizer.from_pretrained(path) |
|
self.model = AutoModelForCausalLM.from_pretrained(path, device_map=device) |
|
|
|
self.messages = list() |
|
self.device = device |
|
self.gen_kwargs = { |
|
"min_new_tokens": 1, |
|
"max_new_tokens": 2048, |
|
"top_p": 0.8, |
|
"temperature": 0.8, |
|
"do_sample": True, |
|
"repetition_penalty": 1.1, |
|
} |
|
|
|
def reset(self): |
|
r"""Reset the chat message history.""" |
|
self.messages = list() |
|
|
|
def _inference(self, messages): |
|
chat = self.tokenizer.apply_chat_template(messages, tokenize=False) |
|
inputs = { |
|
k: v.to(self.device) |
|
for k, v in self.tokenizer(chat, return_tensors="pt", add_special_tokens=False).items() |
|
} |
|
input_length = len(inputs["input_ids"][0]) |
|
output = self.model.generate(**inputs, **self.gen_kwargs) |
|
response = self.tokenizer.decode( |
|
output[0].tolist()[input_length:], |
|
skip_special_tokens=True, |
|
) |
|
return response |
|
|
|
def message(self, message): |
|
r""" |
|
Add the message to the chat history and return a response. |
|
""" |
|
self.messages.append({"role": "user", "content": message}) |
|
|
|
response = self._inference(self.messages) |
|
self.messages.append({"role": "assistant", "content": response}) |
|
return response |
|
|
|
def cli_chat(self): |
|
r""" |
|
For CLI-based chatting (with history). |
|
""" |
|
asst_prompt = "Assistant: " |
|
user_prompt = "---> User: " |
|
|
|
print(f"{asst_prompt}Hi! How can I help you?\n") |
|
message = input(user_prompt) |
|
while not (message is None or message == ""): |
|
response = self.message(message) |
|
print(f"\n{asst_prompt}{response}\n") |
|
message = input(user_prompt) |
|
|
|
def instruct(self, message): |
|
r""" |
|
For single instruction-response interactions (without history). |
|
""" |
|
messages = [{"role": "user", "content": message}] |
|
response = self._inference(messages) |
|
return response |
|
|
|
|
|
if __name__ == "__main__": |
|
chat = Chat() |
|
chat.cli_chat() |
|
|