Safetensors
openlm
text
File size: 2,911 Bytes
3b9d7a6
 
 
 
 
 
 
3bf2eac
3b9d7a6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
from open_lm.hf import *
from transformers import AutoTokenizer, AutoModelForCausalLM


class Chat:
    def __init__(
        self,
        path="mathewhe/DCLM-7B-Chat",
        device="cuda",
    ):
        r"""
        Construct :class:`Chat`\.

        Args:
            path (str): Model name or path.
            device (str): Model device.
        """
        self.tokenizer = AutoTokenizer.from_pretrained(path)
        self.tokenizer.add_tokens(
            ["[ASST]", "[INST]", "[/ASST]", "[/INST]"],
            special_tokens=True,
        )
        self.model = AutoModelForCausalLM.from_pretrained(path, device_map="cuda")

        self.messages = list()
        self.device = device
        self.gen_kwargs = {
            "min_new_tokens": 1,
            "max_new_tokens": 2048,
            "top_p": 0.8,
            "temperature": 0.8,
            "do_sample": True,
            "repetition_penalty": 1.1,
        }

    def reset(self):
        self.messages = list()

    def _inference(self, messages):
        chat = self.tokenizer.apply_chat_template(messages, tokenize=False)
        inputs = {
            k: v.to(self.device)
            for k, v in self.tokenizer(chat, return_tensors="pt").items()
        }
        input_length = len(inputs["input_ids"][0])
        output = self.model.generate(**inputs, **self.gen_kwargs)
        response = self.tokenizer.decode(
            output[0].tolist()[input_length:],
            skip_special_tokens=True,
        )
        if response.startswith(" "):  # fix this so it's handled correctly by the tokenizer
            response = response[1:]
        return response

    def message(self, message):
        r"""
        Add a user message to the chat history and save and return a response.

        Args:
            message (str): The user message.
        """
        self.messages.append({"role": "user", "content": message})
        response = self._inference(self.messages)
        self.messages.append({"role": "assistant", "content": response})
        return response

    def cli_chat(self):
        r"""
        For CLI-based chatting (with history).
        """
        asst_prompt = "Assistant: "
        user_prompt = "---> User: "
        print(f"{asst_prompt}Hi! How can I help you?\n")
        message = input(user_prompt)
        while not (message is None or message == ""):
            response = self.message(message)
            print(f"\n{asst_prompt}{response}\n")
            message = input(user_prompt)

    def instruct(self, message):
        r"""
        For single instruction-response interactions (without history).

        Args:
            message (str): An instruction or one-off user message.
        """
        messages = [{"role": "user", "content": message}]
        response = self._inference(messages)
        return response


if __name__ == "__main__":
    chat = Chat()
    chat.cli_chat()