chargoddard winglian commited on
Commit
bde3c5a
·
unverified ·
1 Parent(s): 55c23c7

ReLoRA implementation (with quantization) (#322)

Browse files

* Experimental ReLoRA (+qlora) implementation

* Add CPU offload

* Remove local config

* Fix saving logic

* Remove redundant assert

* Fix logic errors

* Move ReLoRA into its own trainer class with a method override to create the proper scheduler

* Formatting & typing fixes

* Use safe_serialization

* Don't allow fsdp/deepspeed with ReLoRA

* Fix cpu-offload logic, enable multi gpu

* Document parameters and add comment

* Fix merge issue

* Smooth over some sharp edges

* Implement resume from checkpoint for relora

* Address review comments

* Fix saving logic

* Add necessary metadata to safetensors

---------

Co-authored-by: Wing Lian <[email protected]>

README.md CHANGED
@@ -493,6 +493,12 @@ lora_modules_to_save:
493
  lora_out_dir:
494
  lora_fan_in_fan_out: false
495
 
 
 
 
 
 
 
496
  # wandb configuration if you're using it
497
  wandb_mode: # "offline" to save run metadata locally and not sync to the server, "disabled" to turn off wandb
498
  wandb_project: # your wandb project name
 
493
  lora_out_dir:
494
  lora_fan_in_fan_out: false
495
 
496
+ # ReLoRA configuration
497
+ # must use either 'lora' or 'qlora' adapter, and does not support fsdp or deepspeed
498
+ relora_steps: # number of steps per ReLoRA restart
499
+ relora_warmup_steps: # number of per-restart warmup steps
500
+ relora_cpu_offload: # true to perform lora weight merges on cpu during restarts, for modest gpu memory savings
501
+
502
  # wandb configuration if you're using it
503
  wandb_mode: # "offline" to save run metadata locally and not sync to the server, "disabled" to turn off wandb
504
  wandb_project: # your wandb project name
scripts/finetune.py CHANGED
@@ -242,6 +242,21 @@ def train(
242
  model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization)
243
  return
244
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
245
  trainer = setup_trainer(
246
  cfg, train_dataset, eval_dataset, model, tokenizer, total_num_steps
247
  )
@@ -273,20 +288,6 @@ def train(
273
  LOG.info("Starting trainer...")
274
  if cfg.group_by_length:
275
  LOG.info("hang tight... sorting dataset for group_by_length")
276
- resume_from_checkpoint = cfg.resume_from_checkpoint
277
- if cfg.resume_from_checkpoint is None and cfg.auto_resume_from_checkpoints:
278
- possible_checkpoints = [
279
- str(cp) for cp in Path(cfg.output_dir).glob("checkpoint-*")
280
- ]
281
- if len(possible_checkpoints) > 0:
282
- sorted_paths = sorted(
283
- possible_checkpoints,
284
- key=lambda path: int(path.split("-")[-1]),
285
- )
286
- resume_from_checkpoint = sorted_paths[-1]
287
- LOG.info(
288
- f"Using Auto-resume functionality to start with checkpoint at {resume_from_checkpoint}"
289
- )
290
 
291
  if not Path(cfg.output_dir).is_dir():
292
  os.makedirs(cfg.output_dir, exist_ok=True)
@@ -301,6 +302,13 @@ def train(
301
 
302
  LOG.info(f"Training Completed!!! Saving pre-trained model to {cfg.output_dir}")
303
 
 
 
 
 
 
 
 
304
  # TODO do we need this fix? https://huggingface.co/docs/accelerate/usage_guides/fsdp#saving-and-loading
305
  # only save on rank 0, otherwise it corrupts output on multi-GPU when multiple processes attempt to write the same file
306
  if cfg.fsdp:
@@ -308,6 +316,7 @@ def train(
308
  elif cfg.local_rank == 0:
309
  if cfg.flash_optimum:
310
  model = BetterTransformer.reverse(model)
 
311
  model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization)
312
 
313
 
 
242
  model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization)
243
  return
244
 
245
+ if cfg.resume_from_checkpoint is None and cfg.auto_resume_from_checkpoints:
246
+ possible_checkpoints = [
247
+ str(cp) for cp in Path(cfg.output_dir).glob("checkpoint-*")
248
+ ]
249
+ if len(possible_checkpoints) > 0:
250
+ sorted_paths = sorted(
251
+ possible_checkpoints,
252
+ key=lambda path: int(path.split("-")[-1]),
253
+ )
254
+ cfg.resume_from_checkpoint = sorted_paths[-1]
255
+ LOG.info(
256
+ f"Using Auto-resume functionality to start with checkpoint at {cfg.resume_from_checkpoint}"
257
+ )
258
+ resume_from_checkpoint = cfg.resume_from_checkpoint
259
+
260
  trainer = setup_trainer(
261
  cfg, train_dataset, eval_dataset, model, tokenizer, total_num_steps
262
  )
 
288
  LOG.info("Starting trainer...")
289
  if cfg.group_by_length:
290
  LOG.info("hang tight... sorting dataset for group_by_length")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
291
 
292
  if not Path(cfg.output_dir).is_dir():
293
  os.makedirs(cfg.output_dir, exist_ok=True)
 
302
 
303
  LOG.info(f"Training Completed!!! Saving pre-trained model to {cfg.output_dir}")
304
 
305
+ if cfg.relora_steps:
306
+ if cfg.adapter == "lora" and not (cfg.load_in_4bit or cfg.load_in_8bit):
307
+ model = model.merge_and_unload()
308
+ else:
309
+ # final model weights have already been saved by `ReLoRACallback.on_train_end`
310
+ return
311
+
312
  # TODO do we need this fix? https://huggingface.co/docs/accelerate/usage_guides/fsdp#saving-and-loading
313
  # only save on rank 0, otherwise it corrupts output on multi-GPU when multiple processes attempt to write the same file
314
  if cfg.fsdp:
 
316
  elif cfg.local_rank == 0:
317
  if cfg.flash_optimum:
318
  model = BetterTransformer.reverse(model)
319
+
320
  model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization)
321
 
322
 
src/axolotl/monkeypatch/relora.py ADDED
@@ -0,0 +1,393 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Implements the ReLoRA training procedure from https://arxiv.org/abs/2307.05695, minus the initial full fine-tune."""
2
+ import glob
3
+ import json
4
+ import logging
5
+ import os.path
6
+ import shutil
7
+ from pathlib import Path
8
+ from typing import Dict, List, Sequence
9
+
10
+ import bitsandbytes as bnb
11
+ import peft
12
+ import safetensors.torch as st
13
+ import torch
14
+ from huggingface_hub import snapshot_download
15
+ from torch.optim.lr_scheduler import LRScheduler
16
+ from torch.optim.optimizer import Optimizer
17
+ from transformers import (
18
+ TrainerCallback,
19
+ TrainerControl,
20
+ TrainerState,
21
+ TrainingArguments,
22
+ )
23
+ from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR
24
+
25
+ from axolotl.utils.dict import DictDefault
26
+ from axolotl.utils.distributed import is_main_process
27
+
28
+ LOG = logging.getLogger("axolotl.relora")
29
+
30
+
31
+ def reset_optimizer(optimizer: torch.optim.Optimizer):
32
+ for group in optimizer.param_groups:
33
+ for param in group["params"]:
34
+ param_state = optimizer.state[param]
35
+ for key in param_state:
36
+ if "qmap" in key:
37
+ continue
38
+
39
+ if key == "step" and isinstance(param_state[key], int):
40
+ param_state[key] = 0
41
+ else:
42
+ param_state[key] = torch.zeros_like(param_state[key])
43
+
44
+
45
+ class ReLoRACallback(TrainerCallback):
46
+ """Callback to merge LoRA weights into the base model and save full-weight checkpoints"""
47
+
48
+ def __init__(self, cfg: DictDefault):
49
+ self.relora_steps = cfg.relora_steps
50
+ self.cpu_offload = cfg.relora_cpu_offload
51
+ self.quantized = cfg.load_in_4bit or cfg.load_in_8bit
52
+ self.last_full_model = cfg.base_model
53
+ self.resume_from_checkpoint = cfg.resume_from_checkpoint
54
+
55
+ if not os.path.exists(self.last_full_model):
56
+ self.last_full_model = str(Path(snapshot_download(cfg.base_model)))
57
+
58
+ assert os.path.exists(
59
+ self.last_full_model
60
+ ), "for ReLORA base_model must be a local path"
61
+
62
+ self.num_lora_restarts = 0
63
+ self.need_full_save = False
64
+
65
+ def on_train_begin(
66
+ self,
67
+ _args: TrainingArguments,
68
+ _state: TrainerState,
69
+ control: TrainerControl,
70
+ model: peft.LoraModel,
71
+ **_kwargs,
72
+ ):
73
+ if self.resume_from_checkpoint:
74
+ weight_path = os.path.join(self.resume_from_checkpoint, "relora")
75
+ if not os.path.exists(weight_path):
76
+ LOG.warning(
77
+ "Resuming ReLoRA from checkpoint, but no full-weight save found"
78
+ )
79
+ else:
80
+ LOG.info(f"Loading adjusted base weights from {weight_path}")
81
+ load_weight_checkpoint(model, weight_path)
82
+ return control
83
+
84
+ def on_step_begin(
85
+ self,
86
+ args: TrainingArguments,
87
+ state: TrainerState,
88
+ control: TrainerControl,
89
+ model: peft.LoraModel,
90
+ optimizer: torch.optim.Optimizer,
91
+ **_kwargs,
92
+ ):
93
+ if state.global_step > 0 and state.global_step % self.relora_steps == 0:
94
+ checkpoint_folder = os.path.join(
95
+ args.output_dir,
96
+ f"{PREFIX_CHECKPOINT_DIR}-{state.global_step}",
97
+ "relora",
98
+ )
99
+
100
+ with torch.no_grad():
101
+ merge_and_save(
102
+ model,
103
+ self.last_full_model,
104
+ checkpoint_folder,
105
+ reinit=True,
106
+ quantized=self.quantized,
107
+ actually_save=is_main_process(),
108
+ cpu_offload=self.cpu_offload,
109
+ )
110
+ reset_optimizer(optimizer)
111
+
112
+ if self.quantized:
113
+ self.last_full_model = checkpoint_folder
114
+ self.num_lora_restarts += 1
115
+
116
+ return control
117
+
118
+ def on_save(
119
+ self,
120
+ args: TrainingArguments,
121
+ state: TrainerState,
122
+ control: TrainerControl,
123
+ model: peft.LoraModel,
124
+ **_kwargs,
125
+ ):
126
+ checkpoint_folder = os.path.join(
127
+ args.output_dir, f"{PREFIX_CHECKPOINT_DIR}-{state.global_step}", "relora"
128
+ )
129
+ if (
130
+ state.global_step >= self.relora_steps
131
+ and state.global_step % self.relora_steps != 0
132
+ ):
133
+ if self.quantized:
134
+ if self.last_full_model != checkpoint_folder:
135
+ # ensure the latest full parameter save is in the latest checkpoint
136
+ # folder, so that automatic pruning of checkpoints does not remove it
137
+ LOG.info(f"moving last full parameter save to {checkpoint_folder}")
138
+ os.makedirs(checkpoint_folder, exist_ok=True)
139
+ chunks = glob.glob(
140
+ f"{self.last_full_model}/model*.safetensors"
141
+ ) + glob.glob(f"{self.last_full_model}/model*.index.json")
142
+ for path in chunks:
143
+ new_path = os.path.abspath(shutil.move(path, checkpoint_folder))
144
+ try:
145
+ os.symlink(new_path, path)
146
+ except OSError:
147
+ # probably on windows without permission to symlink
148
+ pass
149
+
150
+ self.last_full_model = checkpoint_folder
151
+ else:
152
+ model.model.save_pretrained(checkpoint_folder, safe_serialization=True)
153
+
154
+ return control
155
+
156
+ def on_log(
157
+ self,
158
+ _args: TrainingArguments,
159
+ _state: TrainerState,
160
+ control: TrainerControl,
161
+ logs: Dict[str, float],
162
+ **_kwargs,
163
+ ):
164
+ logs["num_lora_restarts"] = self.num_lora_restarts
165
+ return control
166
+
167
+ def on_train_end(
168
+ self,
169
+ args: TrainingArguments,
170
+ _state: TrainerState,
171
+ control: TrainerControl,
172
+ model: peft.LoraModel,
173
+ **_kwargs,
174
+ ):
175
+ if self.quantized:
176
+ # perform final merge and save
177
+ with torch.no_grad():
178
+ merge_and_save(
179
+ model,
180
+ self.last_full_model,
181
+ args.output_dir,
182
+ reinit=False,
183
+ quantized=self.quantized,
184
+ actually_save=is_main_process(),
185
+ cpu_offload=self.cpu_offload,
186
+ )
187
+ # no need to save if unquantized, as finetune.py will call merge_and_unload()
188
+ return control
189
+
190
+
191
+ class ReLoRAScheduler(LRScheduler):
192
+ """Wraps another scheduler to apply per-lora-restart learning rate warmups."""
193
+
194
+ def __init__(
195
+ self,
196
+ optimizer: Optimizer,
197
+ inner_schedule: LRScheduler,
198
+ relora_steps: int,
199
+ warmup_steps: int,
200
+ min_lr_scale: float = 0.001,
201
+ ) -> None:
202
+ self.inner_schedule = inner_schedule
203
+ self.relora_steps = relora_steps
204
+ self.warmup_steps = warmup_steps
205
+ self.min_lr_scale = min_lr_scale
206
+ super().__init__(optimizer, inner_schedule.last_epoch, inner_schedule.verbose)
207
+
208
+ def get_lr(self) -> float:
209
+ self.inner_schedule.last_epoch = self.last_epoch
210
+
211
+ original = self.inner_schedule.get_lr()
212
+ step = self.last_epoch
213
+ if step < self.relora_steps:
214
+ scale = 1
215
+ else:
216
+ cycle_t = min(1.0, (step % self.relora_steps) / self.warmup_steps)
217
+ scale = cycle_t * (1 - self.min_lr_scale) + self.min_lr_scale
218
+
219
+ if isinstance(original, Sequence):
220
+ return [lr * scale for lr in original]
221
+ return original * scale
222
+
223
+
224
+ def sharded_paths(path: str, module_names: List[str]) -> Dict[str, str]:
225
+ model_name = "model.safetensors"
226
+ if not os.path.exists(str(Path(path) / model_name)) and not os.path.exists(
227
+ str(Path(path) / f"{model_name}.index.json")
228
+ ):
229
+ model_name = "pytorch_model.bin"
230
+
231
+ index_path = str(Path(path) / f"{model_name}.index.json")
232
+ if os.path.exists(index_path):
233
+ with open(index_path, "r", encoding="utf-8") as file:
234
+ data = json.load(file)
235
+ return data["weight_map"]
236
+ return {(module_name + ".weight"): model_name for module_name in module_names}
237
+
238
+
239
+ def lora_delta_weight(layer: peft.tuners.lora.LoraLayer, device) -> torch.Tensor:
240
+ if isinstance(layer, (peft.tuners.lora.Linear8bitLt, peft.tuners.lora.Linear4bit)):
241
+ adapter = layer.active_adapter
242
+ return (
243
+ peft.utils.transpose(
244
+ layer.lora_B[adapter].weight.detach().to(device)
245
+ @ layer.lora_A[adapter].weight.detach().to(device),
246
+ getattr(layer, "fan_in_fan_out", False),
247
+ )
248
+ * layer.scaling[adapter]
249
+ )
250
+
251
+ return layer.get_delta_weight().to(device)
252
+
253
+
254
+ def find_lora_modules(model: peft.LoraModel) -> Dict[str, peft.tuners.lora.LoraLayer]:
255
+ modules: Dict[str, peft.tuners.lora.LoraLayer] = {}
256
+
257
+ key_list = [key for key, _ in model.model.named_modules() if "lora" not in key]
258
+ for key in key_list:
259
+ try:
260
+ # pylint: disable=protected-access
261
+ _parent, target, _target_name = peft.utils._get_submodules(model.model, key)
262
+ except AttributeError:
263
+ continue
264
+
265
+ if isinstance(target, peft.tuners.lora.LoraLayer):
266
+ modules[key] = target
267
+
268
+ return modules
269
+
270
+
271
+ def update_weights(
272
+ target: peft.tuners.lora.LoraLayer, new_weight: torch.Tensor, reinit: bool, device
273
+ ):
274
+ if reinit:
275
+ for adapter_name in target.lora_A:
276
+ target.reset_lora_parameters(adapter_name)
277
+ for adapter_name in target.lora_embedding_A:
278
+ target.reset_lora_parameters(adapter_name)
279
+
280
+ if isinstance(target, peft.tuners.lora.Linear4bit):
281
+ # This could be faster, but the quantization of Linear4bit weights occurs
282
+ # when the module is moved from cpu to gpu. Without meddling *too* deeply in
283
+ # PEFT's innards or maintaining a duplicate of that codepath, this is good
284
+ # enough for now.
285
+ target.weight.quant_state = None
286
+ target.weight.data = new_weight.cpu()
287
+ target.to(device)
288
+ elif isinstance(target, peft.tuners.lora.Linear8bitLt):
289
+ target.weight = bnb.nn.Int8Params(new_weight, requires_grad=False).to(device)
290
+ else:
291
+ target.weight.data = new_weight.to(device)
292
+
293
+
294
+ def merge_and_save(
295
+ model: peft.LoraModel,
296
+ model_src: str,
297
+ model_dst: str,
298
+ reinit: bool = False,
299
+ quantized: bool = False,
300
+ cpu_offload: bool = False,
301
+ actually_save: bool = True,
302
+ ):
303
+ modules = find_lora_modules(model)
304
+
305
+ if not quantized:
306
+ for module_name, target in modules.items():
307
+ update = target.get_delta_weight(target.active_adapter).detach()
308
+ target.weight.data += update
309
+
310
+ if reinit:
311
+ for adapter_name in target.lora_A:
312
+ target.reset_lora_parameters(adapter_name)
313
+ for adapter_name in target.lora_embedding_A:
314
+ target.reset_lora_parameters(adapter_name)
315
+ return
316
+
317
+ os.makedirs(model_dst, exist_ok=True)
318
+ shard_paths = sharded_paths(model_src, modules.keys())
319
+ out_shard_paths = {}
320
+
321
+ unique_shards = list(set(shard_paths.values()))
322
+ for shard_path in unique_shards:
323
+ out_tensors = {}
324
+ if shard_path.endswith(".safetensors"):
325
+ in_tensors = st.load_file(str(Path(model_src) / shard_path))
326
+ else:
327
+ in_tensors = torch.load(Path(model_src) / shard_path)
328
+ if "state_dict" in in_tensors:
329
+ in_tensors = in_tensors["state_dict"]
330
+
331
+ for module_name, target in modules.items():
332
+ key = module_name + ".weight"
333
+ if key not in shard_paths or shard_paths[key] != shard_path:
334
+ continue
335
+
336
+ orig_weight = in_tensors[key]
337
+ old_dev = target.weight.device
338
+ math_dev = "cpu" if cpu_offload else old_dev
339
+
340
+ delta_weight = lora_delta_weight(target, math_dev)
341
+ new_weight = orig_weight.to(math_dev) + delta_weight
342
+ del delta_weight
343
+
344
+ if actually_save:
345
+ out_tensors[key] = new_weight.half().cpu()
346
+
347
+ update_weights(target, new_weight, reinit=reinit, device=old_dev)
348
+
349
+ if actually_save:
350
+ out_shard_name = shard_path
351
+ if out_shard_name.startswith("pytorch_model"):
352
+ out_shard_name = (
353
+ out_shard_name.replace("pytorch_model", "model").rstrip(".bin")
354
+ + ".safetensors"
355
+ )
356
+
357
+ for module_name in in_tensors:
358
+ if module_name not in out_tensors:
359
+ out_tensors[module_name] = in_tensors[module_name].half()
360
+ out_shard_paths[module_name] = out_shard_name
361
+
362
+ shard_fn = str(Path(model_dst) / out_shard_name)
363
+ LOG.info(f"saving tensors to {shard_fn}")
364
+ st.save_file(out_tensors, shard_fn, metadata={"format": "pt"})
365
+
366
+ del in_tensors
367
+ del out_tensors
368
+ torch.cuda.empty_cache()
369
+
370
+ if actually_save and len(unique_shards) > 1:
371
+ with open(
372
+ str(Path(model_dst, "model.safetensors.index.json")), "w", encoding="utf-8"
373
+ ) as file:
374
+ json.dump({"metadata": {}, "weight_map": out_shard_paths}, file)
375
+
376
+
377
+ def load_weight_checkpoint(model: peft.LoraModel, checkpoint_path: str):
378
+ modules = find_lora_modules(model)
379
+ shard_paths = sharded_paths(checkpoint_path, modules.keys())
380
+ unique_shards = list(set(shard_paths.values()))
381
+
382
+ for shard_path in unique_shards:
383
+ tensors = st.load_file(os.path.join(checkpoint_path, shard_path))
384
+
385
+ for module_name, target in modules.items():
386
+ key = module_name + ".weight"
387
+ if key not in shard_paths or shard_paths[key] != shard_path:
388
+ continue
389
+
390
+ new_weight = tensors[key]
391
+ update_weights(
392
+ target, new_weight, reinit=False, device=target.weight.device
393
+ )
src/axolotl/utils/callbacks.py CHANGED
@@ -33,7 +33,9 @@ class SavePeftModelCallback(TrainerCallback): # pylint: disable=too-few-public-
33
  )
34
 
35
  peft_model_path = os.path.join(checkpoint_folder, "adapter_model")
36
- kwargs["model"].save_pretrained(peft_model_path)
 
 
37
 
38
  return control
39
 
 
33
  )
34
 
35
  peft_model_path = os.path.join(checkpoint_folder, "adapter_model")
36
+ kwargs["model"].save_pretrained(
37
+ peft_model_path, save_safetensors=args.save_safetensors
38
+ )
39
 
40
  return control
41
 
src/axolotl/utils/config.py CHANGED
@@ -126,6 +126,19 @@ def validate_config(cfg):
126
  if not cfg.load_in_8bit and cfg.adapter == "lora":
127
  LOG.warning("We recommend setting `load_in_8bit: true` for LORA finetuning")
128
 
 
 
 
 
 
 
 
 
 
 
 
 
 
129
  if cfg.trust_remote_code:
130
  LOG.warning(
131
  "`trust_remote_code` is set to true. Please make sure that you reviewed the remote code/model."
 
126
  if not cfg.load_in_8bit and cfg.adapter == "lora":
127
  LOG.warning("We recommend setting `load_in_8bit: true` for LORA finetuning")
128
 
129
+ if cfg.relora_steps:
130
+ if cfg.adapter not in ("lora", "qlora"):
131
+ raise ValueError("cfg.adapter must be lora or qlora to use ReLoRA")
132
+
133
+ if cfg.fsdp:
134
+ raise ValueError("fsdp not supported with ReLoRA")
135
+
136
+ if cfg.deepspeed:
137
+ raise ValueError("deepspeed not supported with ReLoRA")
138
+
139
+ if cfg.lr_scheduler == "one_cycle":
140
+ raise ValueError("ReLoRA is not compatible with the one_cycle scheduler")
141
+
142
  if cfg.trust_remote_code:
143
  LOG.warning(
144
  "`trust_remote_code` is set to true. Please make sure that you reviewed the remote code/model."
src/axolotl/utils/trainer.py CHANGED
@@ -24,6 +24,7 @@ from transformers.trainer_pt_utils import (
24
  get_parameter_names,
25
  )
26
 
 
27
  from axolotl.utils.callbacks import (
28
  GPUStatsCallback,
29
  SaveBetterTransformerModelCallback,
@@ -127,6 +128,14 @@ class AxolotlTrainingArguments(TrainingArguments):
127
  default=1,
128
  metadata={"help": "the multiplier for the max len for packed sequences"},
129
  )
 
 
 
 
 
 
 
 
130
 
131
 
132
  class AxolotlTrainer(Trainer):
@@ -265,6 +274,39 @@ class OneCycleLRSchedulerTrainer(AxolotlTrainer):
265
  return self.lr_scheduler
266
 
267
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
268
  def add_position_ids(sample):
269
  sample["position_ids"] = torch.arange(len(sample["input_ids"]))
270
  return sample
@@ -517,6 +559,8 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_
517
  weight_decay=cfg.weight_decay if cfg.weight_decay is not None else 0.0,
518
  sample_packing=cfg.sample_packing if cfg.sample_packing else False,
519
  sample_packing_seq_len_multiplier=cfg.micro_batch_size,
 
 
520
  **training_arguments_kwargs,
521
  )
522
 
@@ -589,6 +633,10 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_
589
 
590
  callbacks = []
591
  callbacks.append(GPUStatsCallback(cfg))
 
 
 
 
592
  # TODO on_save callback to sync checkpoints to GCP/AWS in background
593
  if cfg.early_stopping_patience:
594
  early_stop_cb = EarlyStoppingCallback(
@@ -633,11 +681,11 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_
633
  num_proc=32,
634
  )
635
 
636
- trainer_cls = (
637
- OneCycleLRSchedulerTrainer
638
- if cfg.lr_scheduler == "one_cycle" and (cfg.fsdp or cfg.adapter == "qlora")
639
- else AxolotlTrainer
640
- )
641
  trainer = trainer_cls(
642
  model=model,
643
  train_dataset=train_dataset,
 
24
  get_parameter_names,
25
  )
26
 
27
+ from axolotl.monkeypatch.relora import ReLoRACallback, ReLoRAScheduler
28
  from axolotl.utils.callbacks import (
29
  GPUStatsCallback,
30
  SaveBetterTransformerModelCallback,
 
128
  default=1,
129
  metadata={"help": "the multiplier for the max len for packed sequences"},
130
  )
131
+ relora_steps: Optional[int] = field(
132
+ default=None,
133
+ metadata={"help": "how often to reset for ReLoRA"},
134
+ )
135
+ relora_warmup_steps: Optional[int] = field(
136
+ default=None,
137
+ metadata={"help": "how many warmup steps to take after reset for ReLoRA"},
138
+ )
139
 
140
 
141
  class AxolotlTrainer(Trainer):
 
274
  return self.lr_scheduler
275
 
276
 
277
+ class ReLoRATrainer(AxolotlTrainer):
278
+ """
279
+ Trainer subclass that uses the OneCycleLR scheduler
280
+ """
281
+
282
+ def __init__(self, *args, **kwargs):
283
+ super().__init__(*args, **kwargs)
284
+ self.lr_scheduler = None
285
+
286
+ def create_scheduler(
287
+ self,
288
+ num_training_steps: int,
289
+ optimizer: Optional[torch.optim.Optimizer] = None,
290
+ ):
291
+ optimizer = self.optimizer if optimizer is None else optimizer
292
+ lr_scheduler = super().create_scheduler(num_training_steps, optimizer)
293
+
294
+ if self.args.relora_steps:
295
+ warmup_steps = (
296
+ self.args.relora_warmup_steps if self.args.relora_warmup_steps else 10
297
+ )
298
+ self.lr_scheduler = ReLoRAScheduler(
299
+ optimizer,
300
+ lr_scheduler,
301
+ self.args.relora_steps,
302
+ warmup_steps,
303
+ )
304
+ else:
305
+ self.lr_scheduler = lr_scheduler
306
+
307
+ return self.lr_scheduler
308
+
309
+
310
  def add_position_ids(sample):
311
  sample["position_ids"] = torch.arange(len(sample["input_ids"]))
312
  return sample
 
559
  weight_decay=cfg.weight_decay if cfg.weight_decay is not None else 0.0,
560
  sample_packing=cfg.sample_packing if cfg.sample_packing else False,
561
  sample_packing_seq_len_multiplier=cfg.micro_batch_size,
562
+ relora_steps=cfg.relora_steps,
563
+ relora_warmup_steps=cfg.relora_warmup_steps,
564
  **training_arguments_kwargs,
565
  )
566
 
 
633
 
634
  callbacks = []
635
  callbacks.append(GPUStatsCallback(cfg))
636
+
637
+ if cfg.relora_steps:
638
+ callbacks.append(ReLoRACallback(cfg))
639
+
640
  # TODO on_save callback to sync checkpoints to GCP/AWS in background
641
  if cfg.early_stopping_patience:
642
  early_stop_cb = EarlyStoppingCallback(
 
681
  num_proc=32,
682
  )
683
 
684
+ trainer_cls = AxolotlTrainer
685
+ if cfg.lr_scheduler == "one_cycle" and (cfg.fsdp or cfg.adapter == "qlora"):
686
+ trainer_cls = OneCycleLRSchedulerTrainer
687
+ elif cfg.relora_steps:
688
+ trainer_cls = ReLoRATrainer
689
  trainer = trainer_cls(
690
  model=model,
691
  train_dataset=train_dataset,