ReLoRA implementation (with quantization) (#322)
Browse files* Experimental ReLoRA (+qlora) implementation
* Add CPU offload
* Remove local config
* Fix saving logic
* Remove redundant assert
* Fix logic errors
* Move ReLoRA into its own trainer class with a method override to create the proper scheduler
* Formatting & typing fixes
* Use safe_serialization
* Don't allow fsdp/deepspeed with ReLoRA
* Fix cpu-offload logic, enable multi gpu
* Document parameters and add comment
* Fix merge issue
* Smooth over some sharp edges
* Implement resume from checkpoint for relora
* Address review comments
* Fix saving logic
* Add necessary metadata to safetensors
---------
Co-authored-by: Wing Lian <[email protected]>
- README.md +6 -0
- scripts/finetune.py +23 -14
- src/axolotl/monkeypatch/relora.py +393 -0
- src/axolotl/utils/callbacks.py +3 -1
- src/axolotl/utils/config.py +13 -0
- src/axolotl/utils/trainer.py +53 -5
README.md
CHANGED
@@ -493,6 +493,12 @@ lora_modules_to_save:
|
|
493 |
lora_out_dir:
|
494 |
lora_fan_in_fan_out: false
|
495 |
|
|
|
|
|
|
|
|
|
|
|
|
|
496 |
# wandb configuration if you're using it
|
497 |
wandb_mode: # "offline" to save run metadata locally and not sync to the server, "disabled" to turn off wandb
|
498 |
wandb_project: # your wandb project name
|
|
|
493 |
lora_out_dir:
|
494 |
lora_fan_in_fan_out: false
|
495 |
|
496 |
+
# ReLoRA configuration
|
497 |
+
# must use either 'lora' or 'qlora' adapter, and does not support fsdp or deepspeed
|
498 |
+
relora_steps: # number of steps per ReLoRA restart
|
499 |
+
relora_warmup_steps: # number of per-restart warmup steps
|
500 |
+
relora_cpu_offload: # true to perform lora weight merges on cpu during restarts, for modest gpu memory savings
|
501 |
+
|
502 |
# wandb configuration if you're using it
|
503 |
wandb_mode: # "offline" to save run metadata locally and not sync to the server, "disabled" to turn off wandb
|
504 |
wandb_project: # your wandb project name
|
scripts/finetune.py
CHANGED
@@ -242,6 +242,21 @@ def train(
|
|
242 |
model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization)
|
243 |
return
|
244 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
245 |
trainer = setup_trainer(
|
246 |
cfg, train_dataset, eval_dataset, model, tokenizer, total_num_steps
|
247 |
)
|
@@ -273,20 +288,6 @@ def train(
|
|
273 |
LOG.info("Starting trainer...")
|
274 |
if cfg.group_by_length:
|
275 |
LOG.info("hang tight... sorting dataset for group_by_length")
|
276 |
-
resume_from_checkpoint = cfg.resume_from_checkpoint
|
277 |
-
if cfg.resume_from_checkpoint is None and cfg.auto_resume_from_checkpoints:
|
278 |
-
possible_checkpoints = [
|
279 |
-
str(cp) for cp in Path(cfg.output_dir).glob("checkpoint-*")
|
280 |
-
]
|
281 |
-
if len(possible_checkpoints) > 0:
|
282 |
-
sorted_paths = sorted(
|
283 |
-
possible_checkpoints,
|
284 |
-
key=lambda path: int(path.split("-")[-1]),
|
285 |
-
)
|
286 |
-
resume_from_checkpoint = sorted_paths[-1]
|
287 |
-
LOG.info(
|
288 |
-
f"Using Auto-resume functionality to start with checkpoint at {resume_from_checkpoint}"
|
289 |
-
)
|
290 |
|
291 |
if not Path(cfg.output_dir).is_dir():
|
292 |
os.makedirs(cfg.output_dir, exist_ok=True)
|
@@ -301,6 +302,13 @@ def train(
|
|
301 |
|
302 |
LOG.info(f"Training Completed!!! Saving pre-trained model to {cfg.output_dir}")
|
303 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
304 |
# TODO do we need this fix? https://huggingface.co/docs/accelerate/usage_guides/fsdp#saving-and-loading
|
305 |
# only save on rank 0, otherwise it corrupts output on multi-GPU when multiple processes attempt to write the same file
|
306 |
if cfg.fsdp:
|
@@ -308,6 +316,7 @@ def train(
|
|
308 |
elif cfg.local_rank == 0:
|
309 |
if cfg.flash_optimum:
|
310 |
model = BetterTransformer.reverse(model)
|
|
|
311 |
model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization)
|
312 |
|
313 |
|
|
|
242 |
model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization)
|
243 |
return
|
244 |
|
245 |
+
if cfg.resume_from_checkpoint is None and cfg.auto_resume_from_checkpoints:
|
246 |
+
possible_checkpoints = [
|
247 |
+
str(cp) for cp in Path(cfg.output_dir).glob("checkpoint-*")
|
248 |
+
]
|
249 |
+
if len(possible_checkpoints) > 0:
|
250 |
+
sorted_paths = sorted(
|
251 |
+
possible_checkpoints,
|
252 |
+
key=lambda path: int(path.split("-")[-1]),
|
253 |
+
)
|
254 |
+
cfg.resume_from_checkpoint = sorted_paths[-1]
|
255 |
+
LOG.info(
|
256 |
+
f"Using Auto-resume functionality to start with checkpoint at {cfg.resume_from_checkpoint}"
|
257 |
+
)
|
258 |
+
resume_from_checkpoint = cfg.resume_from_checkpoint
|
259 |
+
|
260 |
trainer = setup_trainer(
|
261 |
cfg, train_dataset, eval_dataset, model, tokenizer, total_num_steps
|
262 |
)
|
|
|
288 |
LOG.info("Starting trainer...")
|
289 |
if cfg.group_by_length:
|
290 |
LOG.info("hang tight... sorting dataset for group_by_length")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
291 |
|
292 |
if not Path(cfg.output_dir).is_dir():
|
293 |
os.makedirs(cfg.output_dir, exist_ok=True)
|
|
|
302 |
|
303 |
LOG.info(f"Training Completed!!! Saving pre-trained model to {cfg.output_dir}")
|
304 |
|
305 |
+
if cfg.relora_steps:
|
306 |
+
if cfg.adapter == "lora" and not (cfg.load_in_4bit or cfg.load_in_8bit):
|
307 |
+
model = model.merge_and_unload()
|
308 |
+
else:
|
309 |
+
# final model weights have already been saved by `ReLoRACallback.on_train_end`
|
310 |
+
return
|
311 |
+
|
312 |
# TODO do we need this fix? https://huggingface.co/docs/accelerate/usage_guides/fsdp#saving-and-loading
|
313 |
# only save on rank 0, otherwise it corrupts output on multi-GPU when multiple processes attempt to write the same file
|
314 |
if cfg.fsdp:
|
|
|
316 |
elif cfg.local_rank == 0:
|
317 |
if cfg.flash_optimum:
|
318 |
model = BetterTransformer.reverse(model)
|
319 |
+
|
320 |
model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization)
|
321 |
|
322 |
|
src/axolotl/monkeypatch/relora.py
ADDED
@@ -0,0 +1,393 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Implements the ReLoRA training procedure from https://arxiv.org/abs/2307.05695, minus the initial full fine-tune."""
|
2 |
+
import glob
|
3 |
+
import json
|
4 |
+
import logging
|
5 |
+
import os.path
|
6 |
+
import shutil
|
7 |
+
from pathlib import Path
|
8 |
+
from typing import Dict, List, Sequence
|
9 |
+
|
10 |
+
import bitsandbytes as bnb
|
11 |
+
import peft
|
12 |
+
import safetensors.torch as st
|
13 |
+
import torch
|
14 |
+
from huggingface_hub import snapshot_download
|
15 |
+
from torch.optim.lr_scheduler import LRScheduler
|
16 |
+
from torch.optim.optimizer import Optimizer
|
17 |
+
from transformers import (
|
18 |
+
TrainerCallback,
|
19 |
+
TrainerControl,
|
20 |
+
TrainerState,
|
21 |
+
TrainingArguments,
|
22 |
+
)
|
23 |
+
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR
|
24 |
+
|
25 |
+
from axolotl.utils.dict import DictDefault
|
26 |
+
from axolotl.utils.distributed import is_main_process
|
27 |
+
|
28 |
+
LOG = logging.getLogger("axolotl.relora")
|
29 |
+
|
30 |
+
|
31 |
+
def reset_optimizer(optimizer: torch.optim.Optimizer):
|
32 |
+
for group in optimizer.param_groups:
|
33 |
+
for param in group["params"]:
|
34 |
+
param_state = optimizer.state[param]
|
35 |
+
for key in param_state:
|
36 |
+
if "qmap" in key:
|
37 |
+
continue
|
38 |
+
|
39 |
+
if key == "step" and isinstance(param_state[key], int):
|
40 |
+
param_state[key] = 0
|
41 |
+
else:
|
42 |
+
param_state[key] = torch.zeros_like(param_state[key])
|
43 |
+
|
44 |
+
|
45 |
+
class ReLoRACallback(TrainerCallback):
|
46 |
+
"""Callback to merge LoRA weights into the base model and save full-weight checkpoints"""
|
47 |
+
|
48 |
+
def __init__(self, cfg: DictDefault):
|
49 |
+
self.relora_steps = cfg.relora_steps
|
50 |
+
self.cpu_offload = cfg.relora_cpu_offload
|
51 |
+
self.quantized = cfg.load_in_4bit or cfg.load_in_8bit
|
52 |
+
self.last_full_model = cfg.base_model
|
53 |
+
self.resume_from_checkpoint = cfg.resume_from_checkpoint
|
54 |
+
|
55 |
+
if not os.path.exists(self.last_full_model):
|
56 |
+
self.last_full_model = str(Path(snapshot_download(cfg.base_model)))
|
57 |
+
|
58 |
+
assert os.path.exists(
|
59 |
+
self.last_full_model
|
60 |
+
), "for ReLORA base_model must be a local path"
|
61 |
+
|
62 |
+
self.num_lora_restarts = 0
|
63 |
+
self.need_full_save = False
|
64 |
+
|
65 |
+
def on_train_begin(
|
66 |
+
self,
|
67 |
+
_args: TrainingArguments,
|
68 |
+
_state: TrainerState,
|
69 |
+
control: TrainerControl,
|
70 |
+
model: peft.LoraModel,
|
71 |
+
**_kwargs,
|
72 |
+
):
|
73 |
+
if self.resume_from_checkpoint:
|
74 |
+
weight_path = os.path.join(self.resume_from_checkpoint, "relora")
|
75 |
+
if not os.path.exists(weight_path):
|
76 |
+
LOG.warning(
|
77 |
+
"Resuming ReLoRA from checkpoint, but no full-weight save found"
|
78 |
+
)
|
79 |
+
else:
|
80 |
+
LOG.info(f"Loading adjusted base weights from {weight_path}")
|
81 |
+
load_weight_checkpoint(model, weight_path)
|
82 |
+
return control
|
83 |
+
|
84 |
+
def on_step_begin(
|
85 |
+
self,
|
86 |
+
args: TrainingArguments,
|
87 |
+
state: TrainerState,
|
88 |
+
control: TrainerControl,
|
89 |
+
model: peft.LoraModel,
|
90 |
+
optimizer: torch.optim.Optimizer,
|
91 |
+
**_kwargs,
|
92 |
+
):
|
93 |
+
if state.global_step > 0 and state.global_step % self.relora_steps == 0:
|
94 |
+
checkpoint_folder = os.path.join(
|
95 |
+
args.output_dir,
|
96 |
+
f"{PREFIX_CHECKPOINT_DIR}-{state.global_step}",
|
97 |
+
"relora",
|
98 |
+
)
|
99 |
+
|
100 |
+
with torch.no_grad():
|
101 |
+
merge_and_save(
|
102 |
+
model,
|
103 |
+
self.last_full_model,
|
104 |
+
checkpoint_folder,
|
105 |
+
reinit=True,
|
106 |
+
quantized=self.quantized,
|
107 |
+
actually_save=is_main_process(),
|
108 |
+
cpu_offload=self.cpu_offload,
|
109 |
+
)
|
110 |
+
reset_optimizer(optimizer)
|
111 |
+
|
112 |
+
if self.quantized:
|
113 |
+
self.last_full_model = checkpoint_folder
|
114 |
+
self.num_lora_restarts += 1
|
115 |
+
|
116 |
+
return control
|
117 |
+
|
118 |
+
def on_save(
|
119 |
+
self,
|
120 |
+
args: TrainingArguments,
|
121 |
+
state: TrainerState,
|
122 |
+
control: TrainerControl,
|
123 |
+
model: peft.LoraModel,
|
124 |
+
**_kwargs,
|
125 |
+
):
|
126 |
+
checkpoint_folder = os.path.join(
|
127 |
+
args.output_dir, f"{PREFIX_CHECKPOINT_DIR}-{state.global_step}", "relora"
|
128 |
+
)
|
129 |
+
if (
|
130 |
+
state.global_step >= self.relora_steps
|
131 |
+
and state.global_step % self.relora_steps != 0
|
132 |
+
):
|
133 |
+
if self.quantized:
|
134 |
+
if self.last_full_model != checkpoint_folder:
|
135 |
+
# ensure the latest full parameter save is in the latest checkpoint
|
136 |
+
# folder, so that automatic pruning of checkpoints does not remove it
|
137 |
+
LOG.info(f"moving last full parameter save to {checkpoint_folder}")
|
138 |
+
os.makedirs(checkpoint_folder, exist_ok=True)
|
139 |
+
chunks = glob.glob(
|
140 |
+
f"{self.last_full_model}/model*.safetensors"
|
141 |
+
) + glob.glob(f"{self.last_full_model}/model*.index.json")
|
142 |
+
for path in chunks:
|
143 |
+
new_path = os.path.abspath(shutil.move(path, checkpoint_folder))
|
144 |
+
try:
|
145 |
+
os.symlink(new_path, path)
|
146 |
+
except OSError:
|
147 |
+
# probably on windows without permission to symlink
|
148 |
+
pass
|
149 |
+
|
150 |
+
self.last_full_model = checkpoint_folder
|
151 |
+
else:
|
152 |
+
model.model.save_pretrained(checkpoint_folder, safe_serialization=True)
|
153 |
+
|
154 |
+
return control
|
155 |
+
|
156 |
+
def on_log(
|
157 |
+
self,
|
158 |
+
_args: TrainingArguments,
|
159 |
+
_state: TrainerState,
|
160 |
+
control: TrainerControl,
|
161 |
+
logs: Dict[str, float],
|
162 |
+
**_kwargs,
|
163 |
+
):
|
164 |
+
logs["num_lora_restarts"] = self.num_lora_restarts
|
165 |
+
return control
|
166 |
+
|
167 |
+
def on_train_end(
|
168 |
+
self,
|
169 |
+
args: TrainingArguments,
|
170 |
+
_state: TrainerState,
|
171 |
+
control: TrainerControl,
|
172 |
+
model: peft.LoraModel,
|
173 |
+
**_kwargs,
|
174 |
+
):
|
175 |
+
if self.quantized:
|
176 |
+
# perform final merge and save
|
177 |
+
with torch.no_grad():
|
178 |
+
merge_and_save(
|
179 |
+
model,
|
180 |
+
self.last_full_model,
|
181 |
+
args.output_dir,
|
182 |
+
reinit=False,
|
183 |
+
quantized=self.quantized,
|
184 |
+
actually_save=is_main_process(),
|
185 |
+
cpu_offload=self.cpu_offload,
|
186 |
+
)
|
187 |
+
# no need to save if unquantized, as finetune.py will call merge_and_unload()
|
188 |
+
return control
|
189 |
+
|
190 |
+
|
191 |
+
class ReLoRAScheduler(LRScheduler):
|
192 |
+
"""Wraps another scheduler to apply per-lora-restart learning rate warmups."""
|
193 |
+
|
194 |
+
def __init__(
|
195 |
+
self,
|
196 |
+
optimizer: Optimizer,
|
197 |
+
inner_schedule: LRScheduler,
|
198 |
+
relora_steps: int,
|
199 |
+
warmup_steps: int,
|
200 |
+
min_lr_scale: float = 0.001,
|
201 |
+
) -> None:
|
202 |
+
self.inner_schedule = inner_schedule
|
203 |
+
self.relora_steps = relora_steps
|
204 |
+
self.warmup_steps = warmup_steps
|
205 |
+
self.min_lr_scale = min_lr_scale
|
206 |
+
super().__init__(optimizer, inner_schedule.last_epoch, inner_schedule.verbose)
|
207 |
+
|
208 |
+
def get_lr(self) -> float:
|
209 |
+
self.inner_schedule.last_epoch = self.last_epoch
|
210 |
+
|
211 |
+
original = self.inner_schedule.get_lr()
|
212 |
+
step = self.last_epoch
|
213 |
+
if step < self.relora_steps:
|
214 |
+
scale = 1
|
215 |
+
else:
|
216 |
+
cycle_t = min(1.0, (step % self.relora_steps) / self.warmup_steps)
|
217 |
+
scale = cycle_t * (1 - self.min_lr_scale) + self.min_lr_scale
|
218 |
+
|
219 |
+
if isinstance(original, Sequence):
|
220 |
+
return [lr * scale for lr in original]
|
221 |
+
return original * scale
|
222 |
+
|
223 |
+
|
224 |
+
def sharded_paths(path: str, module_names: List[str]) -> Dict[str, str]:
|
225 |
+
model_name = "model.safetensors"
|
226 |
+
if not os.path.exists(str(Path(path) / model_name)) and not os.path.exists(
|
227 |
+
str(Path(path) / f"{model_name}.index.json")
|
228 |
+
):
|
229 |
+
model_name = "pytorch_model.bin"
|
230 |
+
|
231 |
+
index_path = str(Path(path) / f"{model_name}.index.json")
|
232 |
+
if os.path.exists(index_path):
|
233 |
+
with open(index_path, "r", encoding="utf-8") as file:
|
234 |
+
data = json.load(file)
|
235 |
+
return data["weight_map"]
|
236 |
+
return {(module_name + ".weight"): model_name for module_name in module_names}
|
237 |
+
|
238 |
+
|
239 |
+
def lora_delta_weight(layer: peft.tuners.lora.LoraLayer, device) -> torch.Tensor:
|
240 |
+
if isinstance(layer, (peft.tuners.lora.Linear8bitLt, peft.tuners.lora.Linear4bit)):
|
241 |
+
adapter = layer.active_adapter
|
242 |
+
return (
|
243 |
+
peft.utils.transpose(
|
244 |
+
layer.lora_B[adapter].weight.detach().to(device)
|
245 |
+
@ layer.lora_A[adapter].weight.detach().to(device),
|
246 |
+
getattr(layer, "fan_in_fan_out", False),
|
247 |
+
)
|
248 |
+
* layer.scaling[adapter]
|
249 |
+
)
|
250 |
+
|
251 |
+
return layer.get_delta_weight().to(device)
|
252 |
+
|
253 |
+
|
254 |
+
def find_lora_modules(model: peft.LoraModel) -> Dict[str, peft.tuners.lora.LoraLayer]:
|
255 |
+
modules: Dict[str, peft.tuners.lora.LoraLayer] = {}
|
256 |
+
|
257 |
+
key_list = [key for key, _ in model.model.named_modules() if "lora" not in key]
|
258 |
+
for key in key_list:
|
259 |
+
try:
|
260 |
+
# pylint: disable=protected-access
|
261 |
+
_parent, target, _target_name = peft.utils._get_submodules(model.model, key)
|
262 |
+
except AttributeError:
|
263 |
+
continue
|
264 |
+
|
265 |
+
if isinstance(target, peft.tuners.lora.LoraLayer):
|
266 |
+
modules[key] = target
|
267 |
+
|
268 |
+
return modules
|
269 |
+
|
270 |
+
|
271 |
+
def update_weights(
|
272 |
+
target: peft.tuners.lora.LoraLayer, new_weight: torch.Tensor, reinit: bool, device
|
273 |
+
):
|
274 |
+
if reinit:
|
275 |
+
for adapter_name in target.lora_A:
|
276 |
+
target.reset_lora_parameters(adapter_name)
|
277 |
+
for adapter_name in target.lora_embedding_A:
|
278 |
+
target.reset_lora_parameters(adapter_name)
|
279 |
+
|
280 |
+
if isinstance(target, peft.tuners.lora.Linear4bit):
|
281 |
+
# This could be faster, but the quantization of Linear4bit weights occurs
|
282 |
+
# when the module is moved from cpu to gpu. Without meddling *too* deeply in
|
283 |
+
# PEFT's innards or maintaining a duplicate of that codepath, this is good
|
284 |
+
# enough for now.
|
285 |
+
target.weight.quant_state = None
|
286 |
+
target.weight.data = new_weight.cpu()
|
287 |
+
target.to(device)
|
288 |
+
elif isinstance(target, peft.tuners.lora.Linear8bitLt):
|
289 |
+
target.weight = bnb.nn.Int8Params(new_weight, requires_grad=False).to(device)
|
290 |
+
else:
|
291 |
+
target.weight.data = new_weight.to(device)
|
292 |
+
|
293 |
+
|
294 |
+
def merge_and_save(
|
295 |
+
model: peft.LoraModel,
|
296 |
+
model_src: str,
|
297 |
+
model_dst: str,
|
298 |
+
reinit: bool = False,
|
299 |
+
quantized: bool = False,
|
300 |
+
cpu_offload: bool = False,
|
301 |
+
actually_save: bool = True,
|
302 |
+
):
|
303 |
+
modules = find_lora_modules(model)
|
304 |
+
|
305 |
+
if not quantized:
|
306 |
+
for module_name, target in modules.items():
|
307 |
+
update = target.get_delta_weight(target.active_adapter).detach()
|
308 |
+
target.weight.data += update
|
309 |
+
|
310 |
+
if reinit:
|
311 |
+
for adapter_name in target.lora_A:
|
312 |
+
target.reset_lora_parameters(adapter_name)
|
313 |
+
for adapter_name in target.lora_embedding_A:
|
314 |
+
target.reset_lora_parameters(adapter_name)
|
315 |
+
return
|
316 |
+
|
317 |
+
os.makedirs(model_dst, exist_ok=True)
|
318 |
+
shard_paths = sharded_paths(model_src, modules.keys())
|
319 |
+
out_shard_paths = {}
|
320 |
+
|
321 |
+
unique_shards = list(set(shard_paths.values()))
|
322 |
+
for shard_path in unique_shards:
|
323 |
+
out_tensors = {}
|
324 |
+
if shard_path.endswith(".safetensors"):
|
325 |
+
in_tensors = st.load_file(str(Path(model_src) / shard_path))
|
326 |
+
else:
|
327 |
+
in_tensors = torch.load(Path(model_src) / shard_path)
|
328 |
+
if "state_dict" in in_tensors:
|
329 |
+
in_tensors = in_tensors["state_dict"]
|
330 |
+
|
331 |
+
for module_name, target in modules.items():
|
332 |
+
key = module_name + ".weight"
|
333 |
+
if key not in shard_paths or shard_paths[key] != shard_path:
|
334 |
+
continue
|
335 |
+
|
336 |
+
orig_weight = in_tensors[key]
|
337 |
+
old_dev = target.weight.device
|
338 |
+
math_dev = "cpu" if cpu_offload else old_dev
|
339 |
+
|
340 |
+
delta_weight = lora_delta_weight(target, math_dev)
|
341 |
+
new_weight = orig_weight.to(math_dev) + delta_weight
|
342 |
+
del delta_weight
|
343 |
+
|
344 |
+
if actually_save:
|
345 |
+
out_tensors[key] = new_weight.half().cpu()
|
346 |
+
|
347 |
+
update_weights(target, new_weight, reinit=reinit, device=old_dev)
|
348 |
+
|
349 |
+
if actually_save:
|
350 |
+
out_shard_name = shard_path
|
351 |
+
if out_shard_name.startswith("pytorch_model"):
|
352 |
+
out_shard_name = (
|
353 |
+
out_shard_name.replace("pytorch_model", "model").rstrip(".bin")
|
354 |
+
+ ".safetensors"
|
355 |
+
)
|
356 |
+
|
357 |
+
for module_name in in_tensors:
|
358 |
+
if module_name not in out_tensors:
|
359 |
+
out_tensors[module_name] = in_tensors[module_name].half()
|
360 |
+
out_shard_paths[module_name] = out_shard_name
|
361 |
+
|
362 |
+
shard_fn = str(Path(model_dst) / out_shard_name)
|
363 |
+
LOG.info(f"saving tensors to {shard_fn}")
|
364 |
+
st.save_file(out_tensors, shard_fn, metadata={"format": "pt"})
|
365 |
+
|
366 |
+
del in_tensors
|
367 |
+
del out_tensors
|
368 |
+
torch.cuda.empty_cache()
|
369 |
+
|
370 |
+
if actually_save and len(unique_shards) > 1:
|
371 |
+
with open(
|
372 |
+
str(Path(model_dst, "model.safetensors.index.json")), "w", encoding="utf-8"
|
373 |
+
) as file:
|
374 |
+
json.dump({"metadata": {}, "weight_map": out_shard_paths}, file)
|
375 |
+
|
376 |
+
|
377 |
+
def load_weight_checkpoint(model: peft.LoraModel, checkpoint_path: str):
|
378 |
+
modules = find_lora_modules(model)
|
379 |
+
shard_paths = sharded_paths(checkpoint_path, modules.keys())
|
380 |
+
unique_shards = list(set(shard_paths.values()))
|
381 |
+
|
382 |
+
for shard_path in unique_shards:
|
383 |
+
tensors = st.load_file(os.path.join(checkpoint_path, shard_path))
|
384 |
+
|
385 |
+
for module_name, target in modules.items():
|
386 |
+
key = module_name + ".weight"
|
387 |
+
if key not in shard_paths or shard_paths[key] != shard_path:
|
388 |
+
continue
|
389 |
+
|
390 |
+
new_weight = tensors[key]
|
391 |
+
update_weights(
|
392 |
+
target, new_weight, reinit=False, device=target.weight.device
|
393 |
+
)
|
src/axolotl/utils/callbacks.py
CHANGED
@@ -33,7 +33,9 @@ class SavePeftModelCallback(TrainerCallback): # pylint: disable=too-few-public-
|
|
33 |
)
|
34 |
|
35 |
peft_model_path = os.path.join(checkpoint_folder, "adapter_model")
|
36 |
-
kwargs["model"].save_pretrained(
|
|
|
|
|
37 |
|
38 |
return control
|
39 |
|
|
|
33 |
)
|
34 |
|
35 |
peft_model_path = os.path.join(checkpoint_folder, "adapter_model")
|
36 |
+
kwargs["model"].save_pretrained(
|
37 |
+
peft_model_path, save_safetensors=args.save_safetensors
|
38 |
+
)
|
39 |
|
40 |
return control
|
41 |
|
src/axolotl/utils/config.py
CHANGED
@@ -126,6 +126,19 @@ def validate_config(cfg):
|
|
126 |
if not cfg.load_in_8bit and cfg.adapter == "lora":
|
127 |
LOG.warning("We recommend setting `load_in_8bit: true` for LORA finetuning")
|
128 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
129 |
if cfg.trust_remote_code:
|
130 |
LOG.warning(
|
131 |
"`trust_remote_code` is set to true. Please make sure that you reviewed the remote code/model."
|
|
|
126 |
if not cfg.load_in_8bit and cfg.adapter == "lora":
|
127 |
LOG.warning("We recommend setting `load_in_8bit: true` for LORA finetuning")
|
128 |
|
129 |
+
if cfg.relora_steps:
|
130 |
+
if cfg.adapter not in ("lora", "qlora"):
|
131 |
+
raise ValueError("cfg.adapter must be lora or qlora to use ReLoRA")
|
132 |
+
|
133 |
+
if cfg.fsdp:
|
134 |
+
raise ValueError("fsdp not supported with ReLoRA")
|
135 |
+
|
136 |
+
if cfg.deepspeed:
|
137 |
+
raise ValueError("deepspeed not supported with ReLoRA")
|
138 |
+
|
139 |
+
if cfg.lr_scheduler == "one_cycle":
|
140 |
+
raise ValueError("ReLoRA is not compatible with the one_cycle scheduler")
|
141 |
+
|
142 |
if cfg.trust_remote_code:
|
143 |
LOG.warning(
|
144 |
"`trust_remote_code` is set to true. Please make sure that you reviewed the remote code/model."
|
src/axolotl/utils/trainer.py
CHANGED
@@ -24,6 +24,7 @@ from transformers.trainer_pt_utils import (
|
|
24 |
get_parameter_names,
|
25 |
)
|
26 |
|
|
|
27 |
from axolotl.utils.callbacks import (
|
28 |
GPUStatsCallback,
|
29 |
SaveBetterTransformerModelCallback,
|
@@ -127,6 +128,14 @@ class AxolotlTrainingArguments(TrainingArguments):
|
|
127 |
default=1,
|
128 |
metadata={"help": "the multiplier for the max len for packed sequences"},
|
129 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
130 |
|
131 |
|
132 |
class AxolotlTrainer(Trainer):
|
@@ -265,6 +274,39 @@ class OneCycleLRSchedulerTrainer(AxolotlTrainer):
|
|
265 |
return self.lr_scheduler
|
266 |
|
267 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
268 |
def add_position_ids(sample):
|
269 |
sample["position_ids"] = torch.arange(len(sample["input_ids"]))
|
270 |
return sample
|
@@ -517,6 +559,8 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_
|
|
517 |
weight_decay=cfg.weight_decay if cfg.weight_decay is not None else 0.0,
|
518 |
sample_packing=cfg.sample_packing if cfg.sample_packing else False,
|
519 |
sample_packing_seq_len_multiplier=cfg.micro_batch_size,
|
|
|
|
|
520 |
**training_arguments_kwargs,
|
521 |
)
|
522 |
|
@@ -589,6 +633,10 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_
|
|
589 |
|
590 |
callbacks = []
|
591 |
callbacks.append(GPUStatsCallback(cfg))
|
|
|
|
|
|
|
|
|
592 |
# TODO on_save callback to sync checkpoints to GCP/AWS in background
|
593 |
if cfg.early_stopping_patience:
|
594 |
early_stop_cb = EarlyStoppingCallback(
|
@@ -633,11 +681,11 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_
|
|
633 |
num_proc=32,
|
634 |
)
|
635 |
|
636 |
-
trainer_cls =
|
637 |
-
|
638 |
-
|
639 |
-
|
640 |
-
|
641 |
trainer = trainer_cls(
|
642 |
model=model,
|
643 |
train_dataset=train_dataset,
|
|
|
24 |
get_parameter_names,
|
25 |
)
|
26 |
|
27 |
+
from axolotl.monkeypatch.relora import ReLoRACallback, ReLoRAScheduler
|
28 |
from axolotl.utils.callbacks import (
|
29 |
GPUStatsCallback,
|
30 |
SaveBetterTransformerModelCallback,
|
|
|
128 |
default=1,
|
129 |
metadata={"help": "the multiplier for the max len for packed sequences"},
|
130 |
)
|
131 |
+
relora_steps: Optional[int] = field(
|
132 |
+
default=None,
|
133 |
+
metadata={"help": "how often to reset for ReLoRA"},
|
134 |
+
)
|
135 |
+
relora_warmup_steps: Optional[int] = field(
|
136 |
+
default=None,
|
137 |
+
metadata={"help": "how many warmup steps to take after reset for ReLoRA"},
|
138 |
+
)
|
139 |
|
140 |
|
141 |
class AxolotlTrainer(Trainer):
|
|
|
274 |
return self.lr_scheduler
|
275 |
|
276 |
|
277 |
+
class ReLoRATrainer(AxolotlTrainer):
|
278 |
+
"""
|
279 |
+
Trainer subclass that uses the OneCycleLR scheduler
|
280 |
+
"""
|
281 |
+
|
282 |
+
def __init__(self, *args, **kwargs):
|
283 |
+
super().__init__(*args, **kwargs)
|
284 |
+
self.lr_scheduler = None
|
285 |
+
|
286 |
+
def create_scheduler(
|
287 |
+
self,
|
288 |
+
num_training_steps: int,
|
289 |
+
optimizer: Optional[torch.optim.Optimizer] = None,
|
290 |
+
):
|
291 |
+
optimizer = self.optimizer if optimizer is None else optimizer
|
292 |
+
lr_scheduler = super().create_scheduler(num_training_steps, optimizer)
|
293 |
+
|
294 |
+
if self.args.relora_steps:
|
295 |
+
warmup_steps = (
|
296 |
+
self.args.relora_warmup_steps if self.args.relora_warmup_steps else 10
|
297 |
+
)
|
298 |
+
self.lr_scheduler = ReLoRAScheduler(
|
299 |
+
optimizer,
|
300 |
+
lr_scheduler,
|
301 |
+
self.args.relora_steps,
|
302 |
+
warmup_steps,
|
303 |
+
)
|
304 |
+
else:
|
305 |
+
self.lr_scheduler = lr_scheduler
|
306 |
+
|
307 |
+
return self.lr_scheduler
|
308 |
+
|
309 |
+
|
310 |
def add_position_ids(sample):
|
311 |
sample["position_ids"] = torch.arange(len(sample["input_ids"]))
|
312 |
return sample
|
|
|
559 |
weight_decay=cfg.weight_decay if cfg.weight_decay is not None else 0.0,
|
560 |
sample_packing=cfg.sample_packing if cfg.sample_packing else False,
|
561 |
sample_packing_seq_len_multiplier=cfg.micro_batch_size,
|
562 |
+
relora_steps=cfg.relora_steps,
|
563 |
+
relora_warmup_steps=cfg.relora_warmup_steps,
|
564 |
**training_arguments_kwargs,
|
565 |
)
|
566 |
|
|
|
633 |
|
634 |
callbacks = []
|
635 |
callbacks.append(GPUStatsCallback(cfg))
|
636 |
+
|
637 |
+
if cfg.relora_steps:
|
638 |
+
callbacks.append(ReLoRACallback(cfg))
|
639 |
+
|
640 |
# TODO on_save callback to sync checkpoints to GCP/AWS in background
|
641 |
if cfg.early_stopping_patience:
|
642 |
early_stop_cb = EarlyStoppingCallback(
|
|
|
681 |
num_proc=32,
|
682 |
)
|
683 |
|
684 |
+
trainer_cls = AxolotlTrainer
|
685 |
+
if cfg.lr_scheduler == "one_cycle" and (cfg.fsdp or cfg.adapter == "qlora"):
|
686 |
+
trainer_cls = OneCycleLRSchedulerTrainer
|
687 |
+
elif cfg.relora_steps:
|
688 |
+
trainer_cls = ReLoRATrainer
|
689 |
trainer = trainer_cls(
|
690 |
model=model,
|
691 |
train_dataset=train_dataset,
|