winglian commited on
Commit
2809f3f
·
1 Parent(s): 4ea9a66

pygmalion dataset prompts format, cached tokenized datasets should be hashed on the tokenizer too

Browse files
src/axolotl/prompt_strategies/alpaca_instruct.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ from axolotl.prompt_tokenizers import AlpacaPromptTokenizingStrategy
2
+ from axolotl.prompters import AlpacaPrompter, PromptStyle
3
+
4
+
5
+ def load(tokenizer, cfg):
6
+ return AlpacaPromptTokenizingStrategy(
7
+ AlpacaPrompter(PromptStyle.instruct), tokenizer, cfg.train_on_inputs, cfg.sequence_len
8
+ )
src/axolotl/prompt_strategies/pygmalion.py ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import logging
3
+ from collections import defaultdict
4
+ from typing import Generator
5
+
6
+ from axolotl.prompt_tokenizers import PromptTokenizingStrategy
7
+
8
+ IGNORE_TOKEN_ID = -100
9
+
10
+
11
+ class PygmalionPromptTokenizingStrategy(PromptTokenizingStrategy):
12
+ bot_prefix_token_ids = []
13
+
14
+ def __init__(self, prompter, tokenizer, *args, **kwargs):
15
+ super().__init__(prompter, tokenizer)
16
+ res = self._tokenize("<|model|>", add_eos_token=False, strip_bos_token=True)
17
+ self.bot_prefix_token_ids = res["input_ids"]
18
+
19
+ def tokenize_prompt(self, prompt):
20
+ result = {
21
+ "input_ids": [],
22
+ "attention_mask": [],
23
+ "labels": [],
24
+ }
25
+ current_len = 0
26
+ for i, part in enumerate(self.prompter.build_prompt(prompt["conversations"])):
27
+ role, message = part
28
+ if role == "system":
29
+ prefix = "<|system|>"
30
+ # this should include a bos token, no eos token, strip trailing "\n<START>"
31
+ if message.endswith("\n<START>"):
32
+ message = message[:-8]
33
+ res = self._tokenize(prefix + "Persona: " + message.strip(), add_eos_token=False, strip_bos_token=False)
34
+ # everything from this is masked out from the labels
35
+ labels = [ IGNORE_TOKEN_ID ] * len(res["input_ids"])
36
+ elif role == "human":
37
+ prefix = "<|user|>"
38
+ res = self._tokenize(prefix + " " + message.strip(), add_eos_token=False, strip_bos_token=True)
39
+ # everything from this is masked out from the labels
40
+ labels = [ IGNORE_TOKEN_ID ] * len(res["input_ids"])
41
+ elif role == "bot":
42
+ prefix = "<|model|>"
43
+ res = self._tokenize(prefix + " " + message.strip(), add_eos_token=True, strip_bos_token=True)
44
+ res["input_ids"] = [*self.bot_prefix_token_ids, *res["input_ids"]]
45
+ # mask out the prefix token, rest is not masked out from labels
46
+ labels = [ IGNORE_TOKEN_ID ] * len(self.bot_prefix_token_ids) + [*copy.deepcopy(res["input_ids"])]
47
+ else:
48
+ logging.warning(f"unknown role in conversation: {role}")
49
+ res = defaultdict(lambda: [])
50
+ input_ids = res["input_ids"]
51
+ input_len = len(input_ids)
52
+ result["input_ids"][current_len : current_len + input_len] = input_ids
53
+ result["attention_mask"][current_len : current_len + input_len] = [
54
+ 1 if x != self.tokenizer.pad_token_id else 0
55
+ for x in input_ids
56
+ ]
57
+ result["labels"][current_len : current_len + input_len] = labels
58
+ current_len += input_len
59
+ return result
60
+
61
+ def _tokenize(self, prompt, add_eos_token=True, strip_bos_token=False):
62
+ result = self.tokenizer(
63
+ prompt,
64
+ truncation=True,
65
+ max_length=self.sequence_len,
66
+ padding=False,
67
+ return_tensors=None,
68
+ )
69
+ if (
70
+ result["input_ids"][-1] != self.tokenizer.eos_token_id
71
+ and len(result["input_ids"]) < self.sequence_len
72
+ and add_eos_token
73
+ ):
74
+ result["input_ids"].append(self.tokenizer.eos_token_id)
75
+ result["attention_mask"].append(1)
76
+
77
+ if (
78
+ result["input_ids"][0] == self.tokenizer.bos_token_id
79
+ and strip_bos_token
80
+ ):
81
+ result["input_ids"] = result["input_ids"][1:]
82
+ result["attention_mask"] = result["attention_mask"][1:]
83
+
84
+ result["labels"] = result["input_ids"].copy()
85
+ return result
86
+
87
+
88
+ class PygmalionPrompter:
89
+ def __init__(self, *args, **kwargs):
90
+ pass
91
+
92
+ def build_prompt(self, source, *args, **kwargs) -> Generator[str, None, None]:
93
+ for msg in source:
94
+ yield msg["role"], msg["value"]
95
+
96
+
97
+ def load(tokenizer, cfg):
98
+ return PygmalionPromptTokenizingStrategy(
99
+ PygmalionPrompter(), tokenizer, cfg.train_on_inputs, cfg.sequence_len
100
+ )
src/axolotl/utils/data.py CHANGED
@@ -10,6 +10,7 @@ from datasets import (
10
  concatenate_datasets,
11
  )
12
  from huggingface_hub import hf_hub_download
 
13
 
14
  from axolotl.datasets import TokenizedPromptDataset, ConstantLengthDataset
15
  from axolotl.prompt_strategies import load
@@ -37,12 +38,14 @@ from axolotl.prompters import (
37
 
38
 
39
  def load_tokenized_prepared_datasets(tokenizer, cfg, default_dataset_prepared_path):
 
40
  ds_hash = str(
41
  md5(
42
  (
43
  str(cfg.sequence_len)
44
  + "@"
45
  + "|".join(sorted([f"{d.path}:{d.type}" for d in cfg.datasets]))
 
46
  ).encode("utf-8")
47
  ).hexdigest()
48
  )
@@ -192,7 +195,7 @@ def load_tokenized_prepared_datasets(tokenizer, cfg, default_dataset_prepared_pa
192
  return dataset
193
 
194
 
195
- def load_prepare_datasets(tokenizer, cfg, default_dataset_prepared_path):
196
  max_packed_sequence_len = (
197
  cfg.max_packed_sequence_len if cfg.max_packed_sequence_len else cfg.sequence_len
198
  )
@@ -200,6 +203,7 @@ def load_prepare_datasets(tokenizer, cfg, default_dataset_prepared_path):
200
  max_packed_sequence_len, cfg.sequence_len
201
  ) # make sure we don't accidentally set it larger than sequence_len
202
 
 
203
  if cfg.max_packed_sequence_len is not None:
204
  # see if we can go ahead and load the stacked dataset
205
  seed = f"@{str(cfg.seed)}" if cfg.seed else ""
@@ -211,6 +215,7 @@ def load_prepare_datasets(tokenizer, cfg, default_dataset_prepared_path):
211
  + str(max_packed_sequence_len)
212
  + seed
213
  + "|".join(sorted([f"{d.path}:{d.type}" for d in cfg.datasets]))
 
214
  ).encode("utf-8")
215
  ).hexdigest()
216
  )
@@ -238,6 +243,11 @@ def load_prepare_datasets(tokenizer, cfg, default_dataset_prepared_path):
238
  )
239
  dataset = load_from_disk(str(prepared_ds_path))
240
  logging.info("Prepared packed dataset loaded from disk...")
 
 
 
 
 
241
  else:
242
  dataset = load_tokenized_prepared_datasets(
243
  tokenizer, cfg, default_dataset_prepared_path
 
10
  concatenate_datasets,
11
  )
12
  from huggingface_hub import hf_hub_download
13
+ from transformers import PreTrainedTokenizerBase
14
 
15
  from axolotl.datasets import TokenizedPromptDataset, ConstantLengthDataset
16
  from axolotl.prompt_strategies import load
 
38
 
39
 
40
  def load_tokenized_prepared_datasets(tokenizer, cfg, default_dataset_prepared_path):
41
+ tokenizer_name = tokenizer.__class__.__name__
42
  ds_hash = str(
43
  md5(
44
  (
45
  str(cfg.sequence_len)
46
  + "@"
47
  + "|".join(sorted([f"{d.path}:{d.type}" for d in cfg.datasets]))
48
+ + "|" + tokenizer_name
49
  ).encode("utf-8")
50
  ).hexdigest()
51
  )
 
195
  return dataset
196
 
197
 
198
+ def load_prepare_datasets(tokenizer: PreTrainedTokenizerBase, cfg, default_dataset_prepared_path):
199
  max_packed_sequence_len = (
200
  cfg.max_packed_sequence_len if cfg.max_packed_sequence_len else cfg.sequence_len
201
  )
 
203
  max_packed_sequence_len, cfg.sequence_len
204
  ) # make sure we don't accidentally set it larger than sequence_len
205
 
206
+ tokenizer_name = tokenizer.__class__.__name__
207
  if cfg.max_packed_sequence_len is not None:
208
  # see if we can go ahead and load the stacked dataset
209
  seed = f"@{str(cfg.seed)}" if cfg.seed else ""
 
215
  + str(max_packed_sequence_len)
216
  + seed
217
  + "|".join(sorted([f"{d.path}:{d.type}" for d in cfg.datasets]))
218
+ + "|" + tokenizer_name
219
  ).encode("utf-8")
220
  ).hexdigest()
221
  )
 
243
  )
244
  dataset = load_from_disk(str(prepared_ds_path))
245
  logging.info("Prepared packed dataset loaded from disk...")
246
+ if cfg.push_dataset_to_hub:
247
+ logging.info(
248
+ f"Saving packed prepared dataset with push_to_hub... {cfg.push_dataset_to_hub}/{ds_hash}"
249
+ )
250
+ dataset.push_to_hub(f"{cfg.push_dataset_to_hub}/{ds_hash}", private=True)
251
  else:
252
  dataset = load_tokenized_prepared_datasets(
253
  tokenizer, cfg, default_dataset_prepared_path