File size: 4,980 Bytes
5bc67ee
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
from typing import List
from queue import Queue

import torch


# def build_chat_input(model, tokenizer, messages: List[dict], max_new_tokens: int=0):
#     def _parse_messages(messages, split_role="user"):
#         system, rounds = "", []
#         round = []
#         for i, message in enumerate(messages):
#             if message["role"] == "system":
#                 assert i == 0
#                 system = message["content"]
#                 continue
#             if message["role"] == split_role and round:
#                 rounds.append(round)
#                 round = []
#             round.append(message)
#         if round:
#             rounds.append(round)
#         return system, rounds

#     max_new_tokens = max_new_tokens or model.generation_config.max_new_tokens
#     max_input_tokens = model.config.model_max_length - max_new_tokens
#     system, rounds = _parse_messages(messages, split_role="user")
#     system_tokens = tokenizer.encode(system)
#     max_history_tokens = max_input_tokens - len(system_tokens)

#     history_tokens = []
#     for round in rounds[::-1]:
#         round_tokens = []
#         for message in round:
#             if message["role"] == "user":
#                 round_tokens.append(model.generation_config.user_token_id)
#             else:
#                 round_tokens.append(model.generation_config.assistant_token_id)
#             round_tokens.extend(tokenizer.encode(message["content"]))
#         if len(history_tokens) == 0 or len(history_tokens) + len(round_tokens) <= max_history_tokens:
#             history_tokens = round_tokens + history_tokens  # concat left
#             if len(history_tokens) < max_history_tokens:
#                 continue
#         break

#     input_tokens = system_tokens + history_tokens
#     if messages[-1]["role"] != "assistant":
#         input_tokens.append(model.generation_config.assistant_token_id)
#     input_tokens = input_tokens[-max_input_tokens:]  # truncate left
#     return torch.LongTensor([input_tokens]).to(model.device)

# for HuatuoGPT2
def build_chat_input(model, tokenizer, messages: List[dict], max_new_tokens: int=0):
    def _parse_messages(messages, split_role="user"):
        system, rounds = "", []
        round = []
        for i, message in enumerate(messages):
            # if message["role"] == "system":
            #     assert i == 0
            #     system = message["content"]
            #     continue
            if message["role"] == split_role and round:
                rounds.append(round)
                round = []
            round.append(message)
        if round:
            rounds.append(round)
        return system, rounds
    
    max_new_tokens = max_new_tokens or model.generation_config.max_new_tokens
    max_input_tokens = model.config.model_max_length - max_new_tokens
    system, rounds = _parse_messages(messages, split_role="user")
    max_history_tokens = max_input_tokens
    roles = ('<问>:','<答>:')
    sep = '\n'

    history_tokens = []
    for round in rounds[::-1]:
        round_tokens = []
        for message in round:
            message["content"]
            if message["role"] == "user":
                round_tokens.extend(tokenizer.encode(roles[0]+message["content"]+sep))
            else:
                round_tokens.extend(tokenizer.encode(roles[1]+message["content"]+sep))
        if len(history_tokens) == 0 or len(history_tokens) + len(round_tokens) <= max_history_tokens:
            history_tokens = round_tokens + history_tokens  # concat left
            if len(history_tokens) < max_history_tokens:
                continue
        break

    input_tokens = history_tokens
    if messages[-1]["role"] != "assistant":
        input_tokens.extend(tokenizer.encode(roles[1]))
    # debug
    input_tokens = input_tokens[-max_input_tokens:]  # truncate left
    # print(tokenizer.decode(input_tokens),flush=True)
    return torch.LongTensor([input_tokens]).to(model.device)


class TextIterStreamer:
    def __init__(self, tokenizer, skip_prompt=False, skip_special_tokens=False):
        self.tokenizer = tokenizer
        self.skip_prompt = skip_prompt
        self.skip_special_tokens = skip_special_tokens
        self.tokens = []
        self.text_queue = Queue()
        self.next_tokens_are_prompt = True

    def put(self, value):
        if self.skip_prompt and self.next_tokens_are_prompt:
            self.next_tokens_are_prompt = False
        else:
            if len(value.shape) > 1:
                value = value[0]
            self.tokens.extend(value.tolist())
            self.text_queue.put(
                self.tokenizer.decode(self.tokens, skip_special_tokens=self.skip_special_tokens))

    def end(self):
        self.text_queue.put(None)

    def __iter__(self):
        return self

    def __next__(self):
        value = self.text_queue.get()
        if value is None:
            raise StopIteration()
        else:
            return value