winglian commited on
Commit
b21e4a2
·
unverified ·
1 Parent(s): 42f9642

split train from other cli options (#503)

Browse files
scripts/finetune.py CHANGED
@@ -4,9 +4,7 @@ import importlib
4
  import logging
5
  import os
6
  import random
7
- import signal
8
  import sys
9
- from dataclasses import dataclass, field
10
  from pathlib import Path
11
  from typing import Any, Dict, List, Optional, Union
12
 
@@ -17,17 +15,17 @@ import yaml
17
 
18
  # add src to the pythonpath so we don't need to pip install this
19
  from art import text2art
20
- from optimum.bettertransformer import BetterTransformer
21
  from transformers import GenerationConfig, TextStreamer
22
 
 
23
  from axolotl.logging_config import configure_logging
 
24
  from axolotl.utils.config import normalize_config, validate_config
25
  from axolotl.utils.data import prepare_dataset
26
  from axolotl.utils.dict import DictDefault
27
  from axolotl.utils.distributed import is_main_process
28
- from axolotl.utils.models import load_model, load_model_config, load_tokenizer
29
  from axolotl.utils.tokenization import check_dataset_labels
30
- from axolotl.utils.trainer import setup_trainer
31
  from axolotl.utils.wandb import setup_wandb_env_vars
32
 
33
  project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
@@ -40,26 +38,13 @@ LOG = logging.getLogger("axolotl.scripts")
40
  os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
41
 
42
 
43
- @dataclass
44
- class TrainerCliArgs:
45
- """
46
- dataclass representing the various non-training arguments
47
- """
48
-
49
- debug: bool = field(default=False)
50
- inference: bool = field(default=False)
51
- merge_lora: bool = field(default=False)
52
- prepare_ds_only: bool = field(default=False)
53
- prompter: Optional[str] = field(default=None)
54
- shard: bool = field(default=False)
55
-
56
-
57
  def print_axolotl_text_art(suffix=None):
58
  font = "nancyj"
59
  ascii_text = " axolotl"
60
  if suffix:
61
  ascii_text += f" x {suffix}"
62
  ascii_art = text2art(" axolotl", font=font)
 
63
  if is_main_process():
64
  print(ascii_art)
65
 
@@ -73,9 +58,45 @@ def get_multi_line_input() -> Optional[str]:
73
  return instruction
74
 
75
 
76
- def do_inference(cfg, model, tokenizer, prompter: Optional[str]):
77
- if prompter == "None":
78
- prompter = None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
79
  default_tokens = {"unk_token": "<unk>", "bos_token": "<s>", "eos_token": "</s>"}
80
 
81
  for token, symbol in default_tokens.items():
@@ -176,141 +197,6 @@ def check_not_in(list1: List[str], list2: Union[Dict[str, Any], List[str]]) -> b
176
  return not any(el in list2 for el in list1)
177
 
178
 
179
- def train(
180
- *,
181
- cfg: DictDefault,
182
- cli_args: TrainerCliArgs,
183
- ):
184
- # load the tokenizer first
185
- LOG.info(f"loading tokenizer... {cfg.tokenizer_config or cfg.base_model_config}")
186
- tokenizer = load_tokenizer(cfg)
187
-
188
- if not (
189
- cli_args.shard or cli_args.merge_lora or cli_args.inference
190
- ): # don't need to load dataset for these
191
- train_dataset, eval_dataset, total_num_steps = prepare_dataset(cfg, tokenizer)
192
-
193
- if cli_args.debug or cfg.debug:
194
- LOG.info("check_dataset_labels...")
195
- check_dataset_labels(
196
- train_dataset.select(
197
- [random.randrange(0, len(train_dataset) - 1) for _ in range(5)] # nosec
198
- ),
199
- tokenizer,
200
- )
201
-
202
- if cli_args.prepare_ds_only:
203
- LOG.info("Finished preparing dataset. Exiting...")
204
- return
205
-
206
- # Load the model and tokenizer
207
- LOG.info("loading model and (optionally) peft_config...")
208
- model, peft_config = load_model(cfg, tokenizer, inference=cli_args.inference)
209
-
210
- safe_serialization = cfg.save_safetensors is True
211
-
212
- if cli_args.merge_lora and cfg.adapter is not None:
213
- LOG.info("running merge of LoRA with base model")
214
- model = model.merge_and_unload()
215
- model.to(dtype=torch.float16)
216
-
217
- if cfg.local_rank == 0:
218
- LOG.info("saving merged model")
219
- model.save_pretrained(
220
- str(Path(cfg.output_dir) / "merged"),
221
- safe_serialization=safe_serialization,
222
- )
223
- tokenizer.save_pretrained(str(Path(cfg.output_dir) / "merged"))
224
- return
225
-
226
- if cli_args.inference:
227
- LOG.debug("Running inference on model")
228
- do_inference(cfg, model, tokenizer, prompter=cli_args.prompter)
229
- return
230
-
231
- if cli_args.shard:
232
- LOG.debug("Re-saving model w/ sharding")
233
- model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization)
234
- return
235
-
236
- if cfg.resume_from_checkpoint is None and cfg.auto_resume_from_checkpoints:
237
- possible_checkpoints = [
238
- str(cp) for cp in Path(cfg.output_dir).glob("checkpoint-*")
239
- ]
240
- if len(possible_checkpoints) > 0:
241
- sorted_paths = sorted(
242
- possible_checkpoints,
243
- key=lambda path: int(path.split("-")[-1]),
244
- )
245
- cfg.resume_from_checkpoint = sorted_paths[-1]
246
- LOG.info(
247
- f"Using Auto-resume functionality to start with checkpoint at {cfg.resume_from_checkpoint}"
248
- )
249
- resume_from_checkpoint = cfg.resume_from_checkpoint
250
-
251
- trainer = setup_trainer(
252
- cfg, train_dataset, eval_dataset, model, tokenizer, total_num_steps
253
- )
254
-
255
- model.config.use_cache = False
256
-
257
- if torch.__version__ >= "2" and sys.platform != "win32":
258
- LOG.info("Compiling torch model")
259
- model = torch.compile(model)
260
-
261
- # go ahead and presave, so we have the adapter config available to inspect
262
- if peft_config:
263
- LOG.info(f"Pre-saving adapter config to {cfg.output_dir}")
264
- peft_config.save_pretrained(cfg.output_dir)
265
-
266
- # In case we want to stop early with ctrl+c, this is a nice to have to save the pretrained model
267
- if cfg.local_rank == 0:
268
-
269
- def terminate_handler(_, __, model):
270
- if cfg.flash_optimum:
271
- model = BetterTransformer.reverse(model)
272
- model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization)
273
- sys.exit(0)
274
-
275
- signal.signal(
276
- signal.SIGINT, lambda signum, frame: terminate_handler(signum, frame, model)
277
- )
278
-
279
- LOG.info("Starting trainer...")
280
- if cfg.group_by_length:
281
- LOG.info("hang tight... sorting dataset for group_by_length")
282
-
283
- if not Path(cfg.output_dir).is_dir():
284
- os.makedirs(cfg.output_dir, exist_ok=True)
285
- tokenizer.save_pretrained(cfg.output_dir)
286
- if cfg.flash_optimum:
287
- with torch.backends.cuda.sdp_kernel(
288
- enable_flash=True, enable_math=True, enable_mem_efficient=True
289
- ):
290
- trainer.train(resume_from_checkpoint=resume_from_checkpoint)
291
- else:
292
- trainer.train(resume_from_checkpoint=resume_from_checkpoint)
293
-
294
- LOG.info(f"Training Completed!!! Saving pre-trained model to {cfg.output_dir}")
295
-
296
- if cfg.relora_steps:
297
- if cfg.adapter == "lora" and not (cfg.load_in_4bit or cfg.load_in_8bit):
298
- model = model.merge_and_unload()
299
- else:
300
- # final model weights have already been saved by `ReLoRACallback.on_train_end`
301
- return
302
-
303
- # TODO do we need this fix? https://huggingface.co/docs/accelerate/usage_guides/fsdp#saving-and-loading
304
- # only save on rank 0, otherwise it corrupts output on multi-GPU when multiple processes attempt to write the same file
305
- if cfg.fsdp:
306
- trainer.save_model(cfg.output_dir)
307
- elif cfg.local_rank == 0:
308
- if cfg.flash_optimum:
309
- model = BetterTransformer.reverse(model)
310
-
311
- model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization)
312
-
313
-
314
  def load_cfg(config: Path = Path("examples/"), **kwargs):
315
  if Path(config).is_dir():
316
  config = choose_config(config)
@@ -347,15 +233,50 @@ def load_cfg(config: Path = Path("examples/"), **kwargs):
347
  return cfg
348
 
349
 
350
- def do_train(config: Path = Path("examples/"), **kwargs):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
351
  print_axolotl_text_art()
352
  parsed_cfg = load_cfg(config, **kwargs)
353
  parser = transformers.HfArgumentParser((TrainerCliArgs))
354
  parsed_cli_args, _ = parser.parse_args_into_dataclasses(
355
  return_remaining_strings=True
356
  )
357
- train(cfg=parsed_cfg, cli_args=parsed_cli_args)
 
 
 
 
 
 
 
 
 
 
358
 
359
 
360
  if __name__ == "__main__":
361
- fire.Fire(do_train)
 
4
  import logging
5
  import os
6
  import random
 
7
  import sys
 
8
  from pathlib import Path
9
  from typing import Any, Dict, List, Optional, Union
10
 
 
15
 
16
  # add src to the pythonpath so we don't need to pip install this
17
  from art import text2art
 
18
  from transformers import GenerationConfig, TextStreamer
19
 
20
+ from axolotl.common.cli import TrainerCliArgs, load_model_and_tokenizer
21
  from axolotl.logging_config import configure_logging
22
+ from axolotl.train import TrainDatasetMeta, train
23
  from axolotl.utils.config import normalize_config, validate_config
24
  from axolotl.utils.data import prepare_dataset
25
  from axolotl.utils.dict import DictDefault
26
  from axolotl.utils.distributed import is_main_process
27
+ from axolotl.utils.models import load_model_config, load_tokenizer
28
  from axolotl.utils.tokenization import check_dataset_labels
 
29
  from axolotl.utils.wandb import setup_wandb_env_vars
30
 
31
  project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
 
38
  os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
39
 
40
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
  def print_axolotl_text_art(suffix=None):
42
  font = "nancyj"
43
  ascii_text = " axolotl"
44
  if suffix:
45
  ascii_text += f" x {suffix}"
46
  ascii_art = text2art(" axolotl", font=font)
47
+
48
  if is_main_process():
49
  print(ascii_art)
50
 
 
58
  return instruction
59
 
60
 
61
+ def do_merge_lora(
62
+ *,
63
+ cfg: DictDefault,
64
+ cli_args: TrainerCliArgs,
65
+ ):
66
+ model, tokenizer = load_model_and_tokenizer(cfg=cfg, cli_args=cli_args)
67
+ safe_serialization = cfg.save_safetensors is True
68
+
69
+ LOG.info("running merge of LoRA with base model")
70
+ model = model.merge_and_unload()
71
+ model.to(dtype=torch.float16)
72
+
73
+ if cfg.local_rank == 0:
74
+ LOG.info("saving merged model")
75
+ model.save_pretrained(
76
+ str(Path(cfg.output_dir) / "merged"),
77
+ safe_serialization=safe_serialization,
78
+ )
79
+ tokenizer.save_pretrained(str(Path(cfg.output_dir) / "merged"))
80
+
81
+
82
+ def shard(
83
+ *,
84
+ cfg: DictDefault,
85
+ cli_args: TrainerCliArgs,
86
+ ):
87
+ model, _ = load_model_and_tokenizer(cfg=cfg, cli_args=cli_args)
88
+ safe_serialization = cfg.save_safetensors is True
89
+ LOG.debug("Re-saving model w/ sharding")
90
+ model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization)
91
+
92
+
93
+ def do_inference(
94
+ *,
95
+ cfg: DictDefault,
96
+ cli_args: TrainerCliArgs,
97
+ ):
98
+ model, tokenizer = load_model_and_tokenizer(cfg=cfg, cli_args=cli_args)
99
+ prompter = cli_args.prompter
100
  default_tokens = {"unk_token": "<unk>", "bos_token": "<s>", "eos_token": "</s>"}
101
 
102
  for token, symbol in default_tokens.items():
 
197
  return not any(el in list2 for el in list1)
198
 
199
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
200
  def load_cfg(config: Path = Path("examples/"), **kwargs):
201
  if Path(config).is_dir():
202
  config = choose_config(config)
 
233
  return cfg
234
 
235
 
236
+ def load_datasets(
237
+ *,
238
+ cfg: DictDefault,
239
+ cli_args: TrainerCliArgs,
240
+ ) -> TrainDatasetMeta:
241
+ tokenizer = load_tokenizer(cfg)
242
+
243
+ train_dataset, eval_dataset, total_num_steps = prepare_dataset(cfg, tokenizer)
244
+
245
+ if cli_args.debug or cfg.debug:
246
+ LOG.info("check_dataset_labels...")
247
+ check_dataset_labels(
248
+ train_dataset.select(
249
+ [random.randrange(0, len(train_dataset) - 1) for _ in range(5)] # nosec
250
+ ),
251
+ tokenizer,
252
+ )
253
+
254
+ return TrainDatasetMeta(
255
+ train_dataset=train_dataset,
256
+ eval_dataset=eval_dataset,
257
+ total_num_steps=total_num_steps,
258
+ )
259
+
260
+
261
+ def do_cli(config: Path = Path("examples/"), **kwargs):
262
  print_axolotl_text_art()
263
  parsed_cfg = load_cfg(config, **kwargs)
264
  parser = transformers.HfArgumentParser((TrainerCliArgs))
265
  parsed_cli_args, _ = parser.parse_args_into_dataclasses(
266
  return_remaining_strings=True
267
  )
268
+ if parsed_cli_args.inference:
269
+ do_inference(cfg=parsed_cfg, cli_args=parsed_cli_args)
270
+ elif parsed_cli_args.merge_lora:
271
+ do_merge_lora(cfg=parsed_cfg, cli_args=parsed_cli_args)
272
+ elif parsed_cli_args.shard:
273
+ shard(cfg=parsed_cfg, cli_args=parsed_cli_args)
274
+ else:
275
+ dataset_meta = load_datasets(cfg=parsed_cfg, cli_args=parsed_cli_args)
276
+ if parsed_cli_args.prepare_ds_only:
277
+ return
278
+ train(cfg=parsed_cfg, cli_args=parsed_cli_args, dataset_meta=dataset_meta)
279
 
280
 
281
  if __name__ == "__main__":
282
+ fire.Fire(do_cli)
src/axolotl/common/__init__.py ADDED
File without changes
src/axolotl/common/cli.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ shared module for cli specific things
3
+ """
4
+
5
+ import logging
6
+ from dataclasses import dataclass, field
7
+ from typing import Optional
8
+
9
+ from axolotl.logging_config import configure_logging
10
+ from axolotl.utils.dict import DictDefault
11
+ from axolotl.utils.models import load_model, load_tokenizer
12
+
13
+ configure_logging()
14
+ LOG = logging.getLogger("axolotl.common.cli")
15
+
16
+
17
+ @dataclass
18
+ class TrainerCliArgs:
19
+ """
20
+ dataclass representing the various non-training arguments
21
+ """
22
+
23
+ debug: bool = field(default=False)
24
+ inference: bool = field(default=False)
25
+ merge_lora: bool = field(default=False)
26
+ prepare_ds_only: bool = field(default=False)
27
+ prompter: Optional[str] = field(default=None)
28
+ shard: bool = field(default=False)
29
+
30
+
31
+ def load_model_and_tokenizer(
32
+ *,
33
+ cfg: DictDefault,
34
+ cli_args: TrainerCliArgs,
35
+ ):
36
+ LOG.info(f"loading tokenizer... {cfg.tokenizer_config or cfg.base_model_config}")
37
+ tokenizer = load_tokenizer(cfg)
38
+ LOG.info("loading model and (optionally) peft_config...")
39
+ model, _ = load_model(cfg, tokenizer, inference=cli_args.inference)
40
+
41
+ return model, tokenizer
src/axolotl/train.py ADDED
@@ -0,0 +1,139 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Prepare and train a model on a dataset. Can also infer from a model or merge lora"""
2
+
3
+ import logging
4
+ import os
5
+ import signal
6
+ import sys
7
+ from dataclasses import dataclass
8
+ from pathlib import Path
9
+ from typing import Optional
10
+
11
+ import torch
12
+
13
+ # add src to the pythonpath so we don't need to pip install this
14
+ from datasets import Dataset
15
+ from optimum.bettertransformer import BetterTransformer
16
+
17
+ from axolotl.common.cli import TrainerCliArgs
18
+ from axolotl.logging_config import configure_logging
19
+ from axolotl.utils.dict import DictDefault
20
+ from axolotl.utils.models import load_model, load_tokenizer
21
+ from axolotl.utils.trainer import setup_trainer
22
+
23
+ project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
24
+ src_dir = os.path.join(project_root, "src")
25
+ sys.path.insert(0, src_dir)
26
+
27
+ configure_logging()
28
+ LOG = logging.getLogger("axolotl.train")
29
+
30
+
31
+ @dataclass
32
+ class TrainDatasetMeta:
33
+ """
34
+ dataclass to capture the dataset specific options for training
35
+ """
36
+
37
+ train_dataset: Dataset
38
+ eval_dataset: Optional[Dataset] = None
39
+ total_num_steps: Optional[int] = None
40
+
41
+
42
+ def train(
43
+ *,
44
+ cfg: DictDefault,
45
+ cli_args: TrainerCliArgs,
46
+ dataset_meta: TrainDatasetMeta,
47
+ ):
48
+ # load the tokenizer first
49
+ LOG.info(f"loading tokenizer... {cfg.tokenizer_config or cfg.base_model_config}")
50
+ tokenizer = load_tokenizer(cfg)
51
+
52
+ train_dataset = dataset_meta.train_dataset
53
+ eval_dataset = dataset_meta.eval_dataset
54
+ total_num_steps = dataset_meta.total_num_steps
55
+
56
+ # Load the model and tokenizer
57
+ LOG.info("loading model and (optionally) peft_config...")
58
+ model, peft_config = load_model(cfg, tokenizer, inference=cli_args.inference)
59
+
60
+ safe_serialization = cfg.save_safetensors is True
61
+
62
+ if cfg.resume_from_checkpoint is None and cfg.auto_resume_from_checkpoints:
63
+ possible_checkpoints = [
64
+ str(cp) for cp in Path(cfg.output_dir).glob("checkpoint-*")
65
+ ]
66
+ if len(possible_checkpoints) > 0:
67
+ sorted_paths = sorted(
68
+ possible_checkpoints,
69
+ key=lambda path: int(path.split("-")[-1]),
70
+ )
71
+ cfg.resume_from_checkpoint = sorted_paths[-1]
72
+ LOG.info(
73
+ f"Using Auto-resume functionality to start with checkpoint at {cfg.resume_from_checkpoint}"
74
+ )
75
+ resume_from_checkpoint = cfg.resume_from_checkpoint
76
+
77
+ trainer = setup_trainer(
78
+ cfg, train_dataset, eval_dataset, model, tokenizer, total_num_steps
79
+ )
80
+
81
+ model.config.use_cache = False
82
+
83
+ if torch.__version__ >= "2" and sys.platform != "win32":
84
+ LOG.info("Compiling torch model")
85
+ model = torch.compile(model)
86
+
87
+ # go ahead and presave, so we have the adapter config available to inspect
88
+ if peft_config:
89
+ LOG.info(f"Pre-saving adapter config to {cfg.output_dir}")
90
+ peft_config.save_pretrained(cfg.output_dir)
91
+
92
+ # In case we want to stop early with ctrl+c, this is a nice to have to save the pretrained model
93
+ if cfg.local_rank == 0:
94
+
95
+ def terminate_handler(_, __, model):
96
+ if cfg.flash_optimum:
97
+ model = BetterTransformer.reverse(model)
98
+ model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization)
99
+ sys.exit(0)
100
+
101
+ signal.signal(
102
+ signal.SIGINT, lambda signum, frame: terminate_handler(signum, frame, model)
103
+ )
104
+
105
+ LOG.info("Starting trainer...")
106
+ if cfg.group_by_length:
107
+ LOG.info("hang tight... sorting dataset for group_by_length")
108
+
109
+ if not Path(cfg.output_dir).is_dir():
110
+ os.makedirs(cfg.output_dir, exist_ok=True)
111
+ tokenizer.save_pretrained(cfg.output_dir)
112
+ if cfg.flash_optimum:
113
+ with torch.backends.cuda.sdp_kernel(
114
+ enable_flash=True, enable_math=True, enable_mem_efficient=True
115
+ ):
116
+ trainer.train(resume_from_checkpoint=resume_from_checkpoint)
117
+ else:
118
+ trainer.train(resume_from_checkpoint=resume_from_checkpoint)
119
+
120
+ LOG.info(f"Training Completed!!! Saving pre-trained model to {cfg.output_dir}")
121
+
122
+ if cfg.relora_steps:
123
+ if cfg.adapter == "lora" and not (cfg.load_in_4bit or cfg.load_in_8bit):
124
+ model = model.merge_and_unload()
125
+ else:
126
+ # final model weights have already been saved by `ReLoRACallback.on_train_end`
127
+ return model, tokenizer
128
+
129
+ # TODO do we need this fix? https://huggingface.co/docs/accelerate/usage_guides/fsdp#saving-and-loading
130
+ # only save on rank 0, otherwise it corrupts output on multi-GPU when multiple processes attempt to write the same file
131
+ if cfg.fsdp:
132
+ trainer.save_model(cfg.output_dir)
133
+ elif cfg.local_rank == 0:
134
+ if cfg.flash_optimum:
135
+ model = BetterTransformer.reverse(model)
136
+
137
+ model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization)
138
+
139
+ return model, tokenizer