winglian commited on
Commit
9105935
·
1 Parent(s): 7748f3d

support for multi line inference input, log sweep over learning rates

Browse files
scripts/finetune.py CHANGED
@@ -1,5 +1,7 @@
 
1
  import logging
2
  import os
 
3
  import random
4
  import signal
5
  import sys
@@ -44,18 +46,20 @@ def choose_device(cfg):
44
  cfg.device_map = {"": cfg.device}
45
 
46
 
47
- def do_inference(cfg, model, tokenizer):
48
  tokenizer.add_special_tokens({"unk_token": "<unk>"})
49
  tokenizer.add_special_tokens({"bos_token": "<s>"})
50
  tokenizer.add_special_tokens({"eos_token": "</s>"})
51
 
52
- from axolotl.prompters import ReflectAlpacaPrompter
53
 
54
  while True:
55
- instruction = str(input("Give me an instruction: "))
 
 
56
  if not instruction:
57
  return
58
- prompt = ReflectAlpacaPrompter().build_prompt(instruction=instruction)
59
  batch = tokenizer(prompt, return_tensors="pt", add_special_tokens=True)
60
 
61
  model.eval()
@@ -162,6 +166,10 @@ def train(
162
  do_inference(cfg, model, tokenizer)
163
  return
164
 
 
 
 
 
165
  train_dataset, eval_dataset = load_prepare_datasets(
166
  tokenizer, cfg, DEFAULT_DATASET_PREPARED_PATH
167
  )
@@ -207,12 +215,11 @@ def train(
207
  logging.info(f"Using Auto-resume functionality to start with checkpoint at {resume_from_checkpoint}")
208
  trainer.train(resume_from_checkpoint=resume_from_checkpoint)
209
 
210
- if cfg.local_rank == 0:
211
- # TODO do we need this fix? https://huggingface.co/docs/accelerate/usage_guides/fsdp#saving-and-loading
212
- logging.info(
213
- f"Training Completed!!! Saving pre-trained model to {cfg.output_dir}"
214
- )
215
- model.save_pretrained(cfg.output_dir)
216
 
217
 
218
  if __name__ == "__main__":
 
1
+ import importlib
2
  import logging
3
  import os
4
+ import pathlib
5
  import random
6
  import signal
7
  import sys
 
46
  cfg.device_map = {"": cfg.device}
47
 
48
 
49
+ def do_inference(cfg, model, tokenizer, prompter="AlpacaPrompter"):
50
  tokenizer.add_special_tokens({"unk_token": "<unk>"})
51
  tokenizer.add_special_tokens({"bos_token": "<s>"})
52
  tokenizer.add_special_tokens({"eos_token": "</s>"})
53
 
54
+ prompter_module = getattr(importlib.import_module("axolotl.prompters"), prompter)
55
 
56
  while True:
57
+ # support for multiline inputs
58
+ print("Give me an instruction (Ctrl + D to finish): ")
59
+ instruction = pathlib.Path("/proc/self/fd/0").read_text()
60
  if not instruction:
61
  return
62
+ prompt = prompter_module().build_prompt(instruction=instruction)
63
  batch = tokenizer(prompt, return_tensors="pt", add_special_tokens=True)
64
 
65
  model.eval()
 
166
  do_inference(cfg, model, tokenizer)
167
  return
168
 
169
+ if "shard" in kwargs:
170
+ model.save_pretrained(cfg.output_dir)
171
+ return
172
+
173
  train_dataset, eval_dataset = load_prepare_datasets(
174
  tokenizer, cfg, DEFAULT_DATASET_PREPARED_PATH
175
  )
 
215
  logging.info(f"Using Auto-resume functionality to start with checkpoint at {resume_from_checkpoint}")
216
  trainer.train(resume_from_checkpoint=resume_from_checkpoint)
217
 
218
+ logging.info(
219
+ f"Training Completed!!! Saving pre-trained model to {cfg.output_dir}"
220
+ )
221
+ # TODO do we need this fix? https://huggingface.co/docs/accelerate/usage_guides/fsdp#saving-and-loading
222
+ trainer.save_model(cfg.output_dir)
 
223
 
224
 
225
  if __name__ == "__main__":
src/axolotl/utils/schedulers.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch.optim.lr_scheduler import LRScheduler
2
+
3
+
4
+ class InterpolatingLogScheduler(LRScheduler):
5
+ def __init__(self, optimizer, num_steps, min_lr, max_lr, last_epoch=-1):
6
+ """A scheduler that interpolates learning rates in a logarithmic fashion
7
+
8
+ Args:
9
+ - optimizer: pytorch optimizer
10
+ - num_steps: int, the number of steps over which to increase from the min_lr to the max_lr
11
+ - min_lr: float, the minimum learning rate
12
+ - max_lr: float, the maximum learning rate
13
+
14
+ Usage:
15
+ fc = nn.Linear(1,1)
16
+ optimizer = optim.Adam(fc.parameters())
17
+ lr_scheduler = InterpolatingLogScheduler(optimizer, num_steps=400, min_lr=1e-6, max_lr=1e-4)
18
+ """
19
+ self.num_steps = num_steps
20
+ self.min_lr = min_lr
21
+ self.max_lr = max_lr
22
+ self.q = (max_lr / min_lr) ** (1 / num_steps - 1)
23
+ super().__init__(optimizer, last_epoch)
24
+
25
+ def get_lr(self):
26
+ if self.last_epoch == 0:
27
+ lr = self.min_lr
28
+ elif self.last_epoch < self.num_steps:
29
+ # FIXME, not perfect as we need to account for number of steps are in an epoch, etc
30
+ lr = self.min_lr * (self.q ** self.last_epoch)
31
+ else:
32
+ lr = self.max_lr
33
+
34
+ return [lr for _ in self.base_lrs]
src/axolotl/utils/trainer.py CHANGED
@@ -12,6 +12,8 @@ from torch.optim.lr_scheduler import OneCycleLR
12
  from transformers import EarlyStoppingCallback
13
  from transformers.trainer_pt_utils import get_parameter_names
14
 
 
 
15
 
16
  def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
17
  total_num_steps = int(
@@ -27,11 +29,16 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
27
  if cfg.logging_steps is not None
28
  else max(min(int(0.005 * total_num_steps), 10), 1)
29
  )
30
- save_steps = eval_steps = (
31
  cfg.save_steps
32
  if cfg.save_steps is not None
33
  else min(int(0.05 * total_num_steps), 200)
34
  )
 
 
 
 
 
35
 
36
  training_arguments_kwargs = {}
37
  if cfg.bf16 == "full":
@@ -95,7 +102,7 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
95
  report_to="wandb" if cfg.use_wandb else None,
96
  run_name=cfg.wandb_run_id if cfg.use_wandb else None,
97
  optim=cfg.optimizer if cfg.optimizer else None,
98
- lr_scheduler_type=cfg.lr_scheduler if cfg.lr_scheduler else None,
99
  weight_decay=cfg.weight_decay if cfg.weight_decay else 0.0,
100
  **training_arguments_kwargs,
101
  )
@@ -147,8 +154,16 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
147
  optimizer,
148
  cfg.learning_rate,
149
  total_steps=total_num_steps,
 
150
  **lr_scheduler_kwargs,
151
  )
 
 
 
 
 
 
 
152
  else:
153
  lr_scheduler = transformers.get_cosine_schedule_with_warmup(
154
  optimizer,
 
12
  from transformers import EarlyStoppingCallback
13
  from transformers.trainer_pt_utils import get_parameter_names
14
 
15
+ from axolotl.utils.schedulers import InterpolatingLogScheduler
16
+
17
 
18
  def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
19
  total_num_steps = int(
 
29
  if cfg.logging_steps is not None
30
  else max(min(int(0.005 * total_num_steps), 10), 1)
31
  )
32
+ save_steps = (
33
  cfg.save_steps
34
  if cfg.save_steps is not None
35
  else min(int(0.05 * total_num_steps), 200)
36
  )
37
+ eval_steps = (
38
+ cfg.eval_steps
39
+ if cfg.eval_steps is not None and save_steps % cfg.eval_steps == 0
40
+ else save_steps
41
+ )
42
 
43
  training_arguments_kwargs = {}
44
  if cfg.bf16 == "full":
 
102
  report_to="wandb" if cfg.use_wandb else None,
103
  run_name=cfg.wandb_run_id if cfg.use_wandb else None,
104
  optim=cfg.optimizer if cfg.optimizer else None,
105
+ lr_scheduler_type=cfg.lr_scheduler if cfg.lr_scheduler not in ("one_cycle", "log_sweep") else "cosine",
106
  weight_decay=cfg.weight_decay if cfg.weight_decay else 0.0,
107
  **training_arguments_kwargs,
108
  )
 
154
  optimizer,
155
  cfg.learning_rate,
156
  total_steps=total_num_steps,
157
+ epochs=cfg.num_epochs,
158
  **lr_scheduler_kwargs,
159
  )
160
+ elif cfg.lr_scheduler == "log_sweep":
161
+ lr_scheduler = InterpolatingLogScheduler(
162
+ optimizer,
163
+ cfg.warmup_steps,
164
+ cfg.log_sweep_min_lr if cfg.log_sweep_min_lr else 1e-10,
165
+ cfg.log_sweep_max_lr if cfg.log_sweep_max_lr else 10,
166
+ )
167
  else:
168
  lr_scheduler = transformers.get_cosine_schedule_with_warmup(
169
  optimizer,