Hansimov commited on
Commit
4a9e46f
·
1 Parent(s): e491840

:recycle: [Refactor] Use official chat_template to replace mannual parses for several models, and remove unused message split functions

Browse files
Files changed (2) hide show
  1. messagers/message_composer.py +18 -157
  2. requirements.txt +1 -0
messagers/message_composer.py CHANGED
@@ -3,7 +3,7 @@ from pprint import pprint
3
 
4
  from transformers import AutoTokenizer
5
 
6
- from constants.models import AVAILABLE_MODELS
7
  from utils.logger import logger
8
 
9
 
@@ -13,6 +13,7 @@ class MessageComposer:
13
  self.model = model
14
  else:
15
  self.model = "mixtral-8x7b"
 
16
  self.system_roles = ["system"]
17
  self.inst_roles = ["user", "system", "inst"]
18
  self.answer_roles = ["assistant", "bot", "answer", "model"]
@@ -46,12 +47,16 @@ class MessageComposer:
46
  return concat_messages
47
 
48
  def merge(self, messages) -> str:
 
 
 
 
 
 
 
49
  # Mistral and Mixtral:
50
  # <s> [INST] Instruction [/INST] Model answer </s> [INST] Follow-up instruction [/INST]
51
 
52
- # OpenChat:
53
- # GPT4 Correct User: Hello<|end_of_turn|>GPT4 Correct Assistant: Hi<|end_of_turn|>GPT4 Correct User: How are you today?<|end_of_turn|>GPT4 Correct Assistant:
54
-
55
  # Nous Mixtral:
56
  # <|im_start|>system
57
  # You are "Hermes 2".<|im_end|>
@@ -59,6 +64,9 @@ class MessageComposer:
59
  # Hello, who are you?<|im_end|>
60
  # <|im_start|>assistant
61
 
 
 
 
62
  # Google Gemma-it
63
  # <start_of_turn>user
64
  # How does the brain work?<end_of_turn>
@@ -83,44 +91,6 @@ class MessageComposer:
83
  self.cached_str = f"[INST] {content} [/INST]"
84
  if self.cached_str:
85
  self.merged_str += f"{self.cached_str}"
86
- # https://huggingface.co/NousResearch/Nous-Hermes-2-Mixtral-8x7B-DPO#prompt-format
87
- elif self.model in ["nous-mixtral-8x7b"]:
88
- self.merged_str_list = []
89
- for message in self.messages:
90
- role = message["role"]
91
- content = message["content"]
92
- if role not in ["system", "user", "assistant"]:
93
- role = self.default_role
94
- message_line = f"<|im_start|>{role}\n{content}<|im_end|>"
95
- self.merged_str_list.append(message_line)
96
- self.merged_str_list.append("<|im_start|>assistant")
97
- self.merged_str = "\n".join(self.merged_str_list)
98
- # https://huggingface.co/openchat/openchat-3.5-0106
99
- elif self.model in ["openchat-3.5"]:
100
- tokenizer = AutoTokenizer.from_pretrained("openchat/openchat-3.5-0106")
101
- self.merged_str = tokenizer.apply_chat_template(
102
- messages, tokenize=False, add_generation_prompt=True
103
- )
104
- # self.messages = self.concat_messages_by_role(messages)
105
- # self.merged_str_list = []
106
- # self.end_of_turn = "<|end_of_turn|>"
107
- # for message in self.messages:
108
- # role = message["role"]
109
- # content = message["content"]
110
- # if role in self.inst_roles:
111
- # self.merged_str_list.append(
112
- # f"GPT4 Correct User:\n{content}{self.end_of_turn}"
113
- # )
114
- # elif role in self.answer_roles:
115
- # self.merged_str_list.append(
116
- # f"GPT4 Correct Assistant:\n{content}{self.end_of_turn}"
117
- # )
118
- # else:
119
- # self.merged_str_list.append(
120
- # f"GPT4 Correct User: {content}{self.end_of_turn}"
121
- # )
122
- # self.merged_str_list.append(f"GPT4 Correct Assistant:\n")
123
- # self.merged_str = "\n".join(self.merged_str_list)
124
  # https://huggingface.co/google/gemma-7b-it#chat-template
125
  elif self.model in ["gemma-7b"]:
126
  self.messages = self.concat_messages_by_role(messages)
@@ -144,122 +114,17 @@ class MessageComposer:
144
  )
145
  self.merged_str_list.append(f"{self.start_of_turn}model\n")
146
  self.merged_str = "\n".join(self.merged_str_list)
 
 
 
147
  else:
148
- self.merged_str = "\n".join(
149
- [
150
- f'`{message["role"]}`:\n{message["content"]}\n'
151
- for message in self.messages
152
- ]
153
  )
154
 
155
  return self.merged_str
156
 
157
- def convert_pair_matches_to_messages(self, pair_matches_list):
158
- messages = []
159
- if len(pair_matches_list) <= 0:
160
- messages = [
161
- {
162
- "role": "user",
163
- "content": self.merged_str,
164
- }
165
- ]
166
- else:
167
- for match in pair_matches_list:
168
- inst = match.group("inst")
169
- answer = match.group("answer")
170
- messages.extend(
171
- [
172
- {"role": "user", "content": inst.strip()},
173
- {"role": "assistant", "content": answer.strip()},
174
- ]
175
- )
176
- return messages
177
-
178
- def append_last_instruction_to_messages(self, inst_matches_list, pair_matches_list):
179
- if len(inst_matches_list) > len(pair_matches_list):
180
- self.messages.extend(
181
- [
182
- {
183
- "role": "user",
184
- "content": inst_matches_list[-1].group("inst").strip(),
185
- }
186
- ]
187
- )
188
-
189
- def split(self, merged_str) -> list:
190
- self.merged_str = merged_str
191
- self.messages = []
192
-
193
- if self.model in ["mixtral-8x7b", "mistral-7b"]:
194
- pair_pattern = (
195
- r"<s>\s*\[INST\](?P<inst>[\s\S]*?)\[/INST\](?P<answer>[\s\S]*?)</s>"
196
- )
197
- pair_matches = re.finditer(pair_pattern, self.merged_str, re.MULTILINE)
198
- pair_matches_list = list(pair_matches)
199
-
200
- self.messages = self.convert_pair_matches_to_messages(pair_matches_list)
201
-
202
- inst_pattern = r"\[INST\](?P<inst>[\s\S]*?)\[/INST\]"
203
- inst_matches = re.finditer(inst_pattern, self.merged_str, re.MULTILINE)
204
- inst_matches_list = list(inst_matches)
205
-
206
- self.append_last_instruction_to_messages(
207
- inst_matches_list, pair_matches_list
208
- )
209
- elif self.model in ["nous-mixtral-8x7b"]:
210
- # https://huggingface.co/NousResearch/Nous-Hermes-2-Mixtral-8x7B-DPO#prompt-format
211
- # message_pattern = r"<\|im_start\|>(?P<role>system|user|assistant)[\s\n]*(?P<content>[\s\S]*?)<\|im_end\|>"
212
- message_pattern = r"<\|im_start\|>(?P<role>system|user|assistant)[\s\n]*(?P<content>[\s\S]*?)<\|im_end\|>"
213
- message_matches = re.finditer(
214
- message_pattern, self.merged_str, flags=re.MULTILINE | re.IGNORECASE
215
- )
216
- message_matches_list = list(message_matches)
217
- logger.note(f"message_matches_list: {message_matches_list}")
218
- for match in message_matches_list:
219
- role = match.group("role")
220
- content = match.group("content")
221
- self.messages.append({"role": role, "content": content.strip()})
222
- elif self.model in ["openchat-3.5"]:
223
- pair_pattern = r"GPT4 Correct User:(?P<inst>[\s\S]*?)<\|end_of_turn\|>\s*GPT4 Correct Assistant:(?P<answer>[\s\S]*?)<\|end_of_turn\|>"
224
- pair_matches = re.finditer(
225
- pair_pattern, self.merged_str, flags=re.MULTILINE | re.IGNORECASE
226
- )
227
- pair_matches_list = list(pair_matches)
228
- self.messages = self.convert_pair_matches_to_messages(pair_matches_list)
229
- inst_pattern = r"GPT4 Correct User:(?P<inst>[\s\S]*?)<\|end_of_turn\|>"
230
- inst_matches = re.finditer(
231
- inst_pattern, self.merged_str, flags=re.MULTILINE | re.IGNORECASE
232
- )
233
- inst_matches_list = list(inst_matches)
234
- self.append_last_instruction_to_messages(
235
- inst_matches_list, pair_matches_list
236
- )
237
- # https://huggingface.co/google/gemma-7b-it#chat-template
238
- elif self.model in ["gemma-7b"]:
239
- pair_pattern = r"<start_of_turn>user[\s\n]*(?P<inst>[\s\S]*?)<end_of_turn>[\s\n]*<start_of_turn>model(?P<answer>[\s\S]*?)<end_of_turn>"
240
- pair_matches = re.finditer(
241
- pair_pattern, self.merged_str, flags=re.MULTILINE | re.IGNORECASE
242
- )
243
- pair_matches_list = list(pair_matches)
244
- self.messages = self.convert_pair_matches_to_messages(pair_matches_list)
245
- inst_pattern = r"<start_of_turn>user\n(?P<inst>[\s\S]*?)<end_of_turn>"
246
- inst_matches = re.finditer(
247
- inst_pattern, self.merged_str, flags=re.MULTILINE | re.IGNORECASE
248
- )
249
- inst_matches_list = list(inst_matches)
250
- self.append_last_instruction_to_messages(
251
- inst_matches_list, pair_matches_list
252
- )
253
- else:
254
- self.messages = [
255
- {
256
- "role": "user",
257
- "content": self.merged_str,
258
- }
259
- ]
260
-
261
- return self.messages
262
-
263
 
264
  if __name__ == "__main__":
265
  # model = "mixtral-8x7b"
@@ -287,9 +152,5 @@ if __name__ == "__main__":
287
  merged_str = composer.merge(messages)
288
  logger.note("merged_str:")
289
  logger.mesg(merged_str)
290
- logger.note("splitted messages:")
291
- pprint(composer.split(merged_str))
292
- # logger.note("merged merged_str:")
293
- # logger.mesg(composer.merge(composer.split(merged_str)))
294
 
295
  # python -m messagers.message_composer
 
3
 
4
  from transformers import AutoTokenizer
5
 
6
+ from constants.models import AVAILABLE_MODELS, MODEL_MAP
7
  from utils.logger import logger
8
 
9
 
 
13
  self.model = model
14
  else:
15
  self.model = "mixtral-8x7b"
16
+ self.model_fullname = MODEL_MAP[self.model]
17
  self.system_roles = ["system"]
18
  self.inst_roles = ["user", "system", "inst"]
19
  self.answer_roles = ["assistant", "bot", "answer", "model"]
 
47
  return concat_messages
48
 
49
  def merge(self, messages) -> str:
50
+ # Templates for Chat Models
51
+ # - https://huggingface.co/docs/transformers/main/en/chat_templating
52
+ # - https://huggingface.co/mistralai/Mixtral-8x7B-Instruct-v0.1#instruction-format
53
+ # - https://huggingface.co/NousResearch/Nous-Hermes-2-Mixtral-8x7B-DPO#prompt-format
54
+ # - https://huggingface.co/openchat/openchat-3.5-0106
55
+ # - https://huggingface.co/google/gemma-7b-it#chat-template
56
+
57
  # Mistral and Mixtral:
58
  # <s> [INST] Instruction [/INST] Model answer </s> [INST] Follow-up instruction [/INST]
59
 
 
 
 
60
  # Nous Mixtral:
61
  # <|im_start|>system
62
  # You are "Hermes 2".<|im_end|>
 
64
  # Hello, who are you?<|im_end|>
65
  # <|im_start|>assistant
66
 
67
+ # OpenChat:
68
+ # GPT4 Correct User: Hello<|end_of_turn|>GPT4 Correct Assistant: Hi<|end_of_turn|>GPT4 Correct User: How are you today?<|end_of_turn|>GPT4 Correct Assistant:
69
+
70
  # Google Gemma-it
71
  # <start_of_turn>user
72
  # How does the brain work?<end_of_turn>
 
91
  self.cached_str = f"[INST] {content} [/INST]"
92
  if self.cached_str:
93
  self.merged_str += f"{self.cached_str}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
94
  # https://huggingface.co/google/gemma-7b-it#chat-template
95
  elif self.model in ["gemma-7b"]:
96
  self.messages = self.concat_messages_by_role(messages)
 
114
  )
115
  self.merged_str_list.append(f"{self.start_of_turn}model\n")
116
  self.merged_str = "\n".join(self.merged_str_list)
117
+ # https://huggingface.co/NousResearch/Nous-Hermes-2-Mixtral-8x7B-DPO#prompt-format
118
+ # https://huggingface.co/openchat/openchat-3.5-0106
119
+ # elif self.model in ["openchat-3.5", "nous-mixtral-8x7b"]:
120
  else:
121
+ tokenizer = AutoTokenizer.from_pretrained("openchat/openchat-3.5-0106")
122
+ self.merged_str = tokenizer.apply_chat_template(
123
+ messages, tokenize=False, add_generation_prompt=True
 
 
124
  )
125
 
126
  return self.merged_str
127
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
128
 
129
  if __name__ == "__main__":
130
  # model = "mixtral-8x7b"
 
152
  merged_str = composer.merge(messages)
153
  logger.note("merged_str:")
154
  logger.mesg(merged_str)
 
 
 
 
155
 
156
  # python -m messagers.message_composer
requirements.txt CHANGED
@@ -1,6 +1,7 @@
1
  aiohttp
2
  fastapi
3
  httpx
 
4
  markdown2[all]
5
  openai
6
  pydantic
 
1
  aiohttp
2
  fastapi
3
  httpx
4
+ jinja2
5
  markdown2[all]
6
  openai
7
  pydantic