Safetensors
llama
text
Llama-3.1-8B-Chat / chat_class.py
mathewhe's picture
Add model and repo files
4dcaf3a
import logging
from transformers import AutoTokenizer, AutoModelForCausalLM
logger = logging.getLogger(__name__)
class Chat:
def __init__(
self,
path="mathewhe/Llama-3.1-8B-Chat",
device="cuda",
):
self.tokenizer = AutoTokenizer.from_pretrained(path)
self.model = AutoModelForCausalLM.from_pretrained(path, device_map=device)
self.messages = list()
self.device = device
self.gen_kwargs = {
"min_new_tokens": 1,
"max_new_tokens": 2048,
"top_p": 0.8,
"temperature": 0.8,
"do_sample": True,
"repetition_penalty": 1.1,
}
def reset(self):
r"""Reset the chat message history."""
self.messages = list()
def _inference(self, messages):
chat = self.tokenizer.apply_chat_template(messages, tokenize=False)
inputs = {
k: v.to(self.device)
for k, v in self.tokenizer(chat, return_tensors="pt", add_special_tokens=False).items()
}
input_length = len(inputs["input_ids"][0])
output = self.model.generate(**inputs, **self.gen_kwargs)
response = self.tokenizer.decode(
output[0].tolist()[input_length:],
skip_special_tokens=True,
)
return response
def message(self, message):
r"""
Add the message to the chat history and return a response.
"""
self.messages.append({"role": "user", "content": message})
# need to add caching of internal state!!
response = self._inference(self.messages)
self.messages.append({"role": "assistant", "content": response})
return response
def cli_chat(self):
r"""
For CLI-based chatting (with history).
"""
asst_prompt = "Assistant: "
user_prompt = "---> User: "
print(f"{asst_prompt}Hi! How can I help you?\n")
message = input(user_prompt)
while not (message is None or message == ""):
response = self.message(message)
print(f"\n{asst_prompt}{response}\n")
message = input(user_prompt)
def instruct(self, message):
r"""
For single instruction-response interactions (without history).
"""
messages = [{"role": "user", "content": message}]
response = self._inference(messages)
return response
if __name__ == "__main__":
chat = Chat()
chat.cli_chat()