winglian commited on
Commit
0d4a7f4
·
unverified ·
2 Parent(s): af3aacb cc67862

Merge pull request #67 from OpenAccess-AI-Collective/refactor-tokenizer-load

Browse files
Files changed (2) hide show
  1. scripts/finetune.py +26 -13
  2. src/axolotl/utils/models.py +41 -42
scripts/finetune.py CHANGED
@@ -5,7 +5,7 @@ import random
5
  import signal
6
  import sys
7
  from pathlib import Path
8
- from typing import Optional
9
 
10
  import fire
11
  import torch
@@ -21,7 +21,7 @@ src_dir = os.path.join(project_root, "src")
21
  sys.path.insert(0, src_dir)
22
 
23
  from axolotl.utils.data import load_prepare_datasets
24
- from axolotl.utils.models import load_model
25
  from axolotl.utils.trainer import setup_trainer
26
  from axolotl.utils.wandb import setup_wandb_env_vars
27
 
@@ -117,6 +117,10 @@ def choose_config(path: Path):
117
  return chosen_file
118
 
119
 
 
 
 
 
120
  def train(
121
  config: Path = Path("configs/"),
122
  prepare_ds_only: bool = False,
@@ -161,13 +165,30 @@ def train(
161
 
162
  validate_config(cfg)
163
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
164
  # Load the model and tokenizer
165
- logging.info("loading model, tokenizer, and peft_config...")
166
- model, tokenizer, peft_config = load_model(
167
  cfg.base_model,
168
  cfg.base_model_config,
169
  cfg.model_type,
170
- cfg.tokenizer_type,
171
  cfg,
172
  adapter=cfg.adapter,
173
  inference=("inference" in kwargs),
@@ -192,10 +213,6 @@ def train(
192
  model.save_pretrained(cfg.output_dir)
193
  return
194
 
195
- train_dataset, eval_dataset = load_prepare_datasets(
196
- tokenizer, cfg, DEFAULT_DATASET_PREPARED_PATH
197
- )
198
-
199
  if cfg.debug:
200
  logging.info("check_dataset_labels...")
201
  check_dataset_labels(
@@ -205,10 +222,6 @@ def train(
205
  tokenizer,
206
  )
207
 
208
- if prepare_ds_only:
209
- logging.info("Finished preparing dataset. Exiting...")
210
- return
211
-
212
  trainer = setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer)
213
 
214
  model.config.use_cache = False
 
5
  import signal
6
  import sys
7
  from pathlib import Path
8
+ from typing import Optional, List, Dict, Any, Union
9
 
10
  import fire
11
  import torch
 
21
  sys.path.insert(0, src_dir)
22
 
23
  from axolotl.utils.data import load_prepare_datasets
24
+ from axolotl.utils.models import load_model, load_tokenizer
25
  from axolotl.utils.trainer import setup_trainer
26
  from axolotl.utils.wandb import setup_wandb_env_vars
27
 
 
117
  return chosen_file
118
 
119
 
120
+ def check_not_in(list1: List[str], list2: Union[Dict[str, Any], List[str]]) -> bool:
121
+ return not any(el in list2 for el in list1)
122
+
123
+
124
  def train(
125
  config: Path = Path("configs/"),
126
  prepare_ds_only: bool = False,
 
165
 
166
  validate_config(cfg)
167
 
168
+ # load the tokenizer first
169
+ logging.info("loading tokenizer...")
170
+ tokenizer = load_tokenizer(
171
+ cfg.base_model_config,
172
+ cfg.tokenizer_type,
173
+ cfg
174
+ )
175
+
176
+ if check_not_in(["inference", "shard", "merge_lora"], kwargs): # don't need to load dataset for these
177
+ train_dataset, eval_dataset = load_prepare_datasets(
178
+ tokenizer, cfg, DEFAULT_DATASET_PREPARED_PATH
179
+ )
180
+
181
+ if prepare_ds_only:
182
+ logging.info("Finished preparing dataset. Exiting...")
183
+ return
184
+
185
  # Load the model and tokenizer
186
+ logging.info("loading model and peft_config...")
187
+ model, peft_config = load_model(
188
  cfg.base_model,
189
  cfg.base_model_config,
190
  cfg.model_type,
191
+ tokenizer,
192
  cfg,
193
  adapter=cfg.adapter,
194
  inference=("inference" in kwargs),
 
213
  model.save_pretrained(cfg.output_dir)
214
  return
215
 
 
 
 
 
216
  if cfg.debug:
217
  logging.info("check_dataset_labels...")
218
  check_dataset_labels(
 
222
  tokenizer,
223
  )
224
 
 
 
 
 
225
  trainer = setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer)
226
 
227
  model.config.use_cache = False
src/axolotl/utils/models.py CHANGED
@@ -7,7 +7,6 @@ from typing import Optional, Tuple, TYPE_CHECKING
7
  import bitsandbytes as bnb
8
  import torch
9
  import transformers
10
- from torch import nn
11
  from transformers import (
12
  AutoModelForCausalLM,
13
  AutoTokenizer,
@@ -34,20 +33,56 @@ if TYPE_CHECKING:
34
  from transformers import PreTrainedTokenizer
35
 
36
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37
  def load_model(
38
  base_model,
39
  base_model_config,
40
  model_type,
41
- tokenizer_type,
42
  cfg,
43
  adapter="lora",
44
  inference=False,
45
  ):
46
- # type: (str, str, str, str, AttrDefault, Optional[str], bool) -> Tuple[PreTrainedModel, PreTrainedTokenizer, Optional[PeftConfig]]
47
 
48
  # TODO refactor as a kwarg
49
  load_in_8bit = cfg.load_in_8bit
50
- tokenizer = None
51
  is_llama_derived_model = "llama" in base_model or (
52
  cfg.model_type and "llama" in cfg.model_type.lower()
53
  )
@@ -122,7 +157,7 @@ def load_model(
122
  model_path = str(cache_model_path)
123
  except:
124
  model_path = cfg.base_model
125
- model, tokenizer = load_llama_model_4bit_low_ram(
126
  base_model_config if base_model_config else base_model,
127
  model_path,
128
  device_map=cfg.device_map,
@@ -207,42 +242,6 @@ def load_model(
207
  **model_kwargs,
208
  )
209
 
210
- if not tokenizer:
211
- try:
212
- if is_llama_derived_model and "LlamaTokenizer" in globals():
213
- tokenizer = LlamaTokenizer.from_pretrained(
214
- base_model_config,
215
- trust_remote_code=True if cfg.trust_remote_code is True else False,
216
- )
217
- else:
218
- tokenizer = getattr(transformers, tokenizer_type).from_pretrained(
219
- base_model_config,
220
- trust_remote_code=True if cfg.trust_remote_code is True else False,
221
- )
222
- except:
223
- tokenizer = AutoTokenizer.from_pretrained(
224
- base_model_config,
225
- trust_remote_code=True if cfg.trust_remote_code is True else False,
226
- )
227
-
228
- logging.debug(f"EOS: {tokenizer.eos_token_id} / {tokenizer.eos_token}")
229
- logging.debug(f"BOS: {tokenizer.bos_token_id} / {tokenizer.bos_token}")
230
- logging.debug(f"PAD: {tokenizer.pad_token_id} / {tokenizer.pad_token}")
231
- logging.debug(f"UNK: {tokenizer.unk_token_id} / {tokenizer.unk_token}")
232
-
233
- if tokenizer.__class__.__name__ in ["LlamaTokenizer", "LlamaTokenizerFast"]:
234
- tokenizer.pad_token = LLAMA_DEFAULT_PAD_TOKEN
235
-
236
- if tokenizer.__class__.__name__ == "GPTNeoXTokenizerFast":
237
- tokenizer.add_special_tokens({"pad_token": "[PAD]"})
238
- os.environ["TOKENIZERS_PARALLELISM"] = "false"
239
-
240
- if cfg.special_tokens:
241
- for k, v in cfg.special_tokens.items():
242
- tokenizer.add_special_tokens({k: v})
243
- if cfg.tokens:
244
- tokenizer.add_tokens(list(cfg.tokens))
245
-
246
  embeddings_len = math.ceil(len(tokenizer) / 32) * 32
247
  model.resize_token_embeddings(embeddings_len)
248
 
@@ -291,7 +290,7 @@ def load_model(
291
  model.config.use_cache = False
292
 
293
  # TODO resume_from_checkpoint handling
294
- return model, tokenizer, lora_config
295
 
296
 
297
  def load_adapter(model, cfg, adapter):
 
7
  import bitsandbytes as bnb
8
  import torch
9
  import transformers
 
10
  from transformers import (
11
  AutoModelForCausalLM,
12
  AutoTokenizer,
 
33
  from transformers import PreTrainedTokenizer
34
 
35
 
36
+ def load_tokenizer(
37
+ base_model_config,
38
+ tokenizer_type,
39
+ cfg,
40
+ ):
41
+ if tokenizer_type:
42
+ tokenizer = getattr(transformers, tokenizer_type).from_pretrained(
43
+ base_model_config,
44
+ trust_remote_code=cfg.trust_remote_code or False,
45
+ )
46
+ else:
47
+ tokenizer = AutoTokenizer.from_pretrained(
48
+ base_model_config,
49
+ trust_remote_code=cfg.trust_remote_code or False,
50
+ )
51
+
52
+ logging.debug(f"EOS: {tokenizer.eos_token_id} / {tokenizer.eos_token}")
53
+ logging.debug(f"BOS: {tokenizer.bos_token_id} / {tokenizer.bos_token}")
54
+ logging.debug(f"PAD: {tokenizer.pad_token_id} / {tokenizer.pad_token}")
55
+ logging.debug(f"UNK: {tokenizer.unk_token_id} / {tokenizer.unk_token}")
56
+
57
+ if tokenizer.__class__.__name__ in ["LlamaTokenizer", "LlamaTokenizerFast"]:
58
+ tokenizer.pad_token = LLAMA_DEFAULT_PAD_TOKEN
59
+
60
+ if tokenizer.__class__.__name__ == "GPTNeoXTokenizerFast":
61
+ tokenizer.add_special_tokens({"pad_token": "[PAD]"})
62
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
63
+
64
+ if cfg.special_tokens:
65
+ for k, v in cfg.special_tokens.items():
66
+ tokenizer.add_special_tokens({k: v})
67
+ if cfg.tokens:
68
+ tokenizer.add_tokens(list(cfg.tokens))
69
+
70
+ return tokenizer
71
+
72
+
73
  def load_model(
74
  base_model,
75
  base_model_config,
76
  model_type,
77
+ tokenizer,
78
  cfg,
79
  adapter="lora",
80
  inference=False,
81
  ):
82
+ # type: (str, str, str, str, AttrDefault, Optional[str], bool) -> Tuple[PreTrainedModel, Optional[PeftConfig]]
83
 
84
  # TODO refactor as a kwarg
85
  load_in_8bit = cfg.load_in_8bit
 
86
  is_llama_derived_model = "llama" in base_model or (
87
  cfg.model_type and "llama" in cfg.model_type.lower()
88
  )
 
157
  model_path = str(cache_model_path)
158
  except:
159
  model_path = cfg.base_model
160
+ model, _ = load_llama_model_4bit_low_ram(
161
  base_model_config if base_model_config else base_model,
162
  model_path,
163
  device_map=cfg.device_map,
 
242
  **model_kwargs,
243
  )
244
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
245
  embeddings_len = math.ceil(len(tokenizer) / 32) * 32
246
  model.resize_token_embeddings(embeddings_len)
247
 
 
290
  model.config.use_cache = False
291
 
292
  # TODO resume_from_checkpoint handling
293
+ return model, lora_config
294
 
295
 
296
  def load_adapter(model, cfg, adapter):