remove un-needed code, add validation
Browse files- scripts/finetune.py +3 -0
- src/axolotl/utils/models.py +0 -15
scripts/finetune.py
CHANGED
@@ -14,6 +14,7 @@ from attrdict import AttrDefault
|
|
14 |
|
15 |
# add src to the pythonpath so we don't need to pip install this
|
16 |
from axolotl.utils.tokenization import check_dataset_labels
|
|
|
17 |
|
18 |
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
|
19 |
src_dir = os.path.join(project_root, "src")
|
@@ -158,6 +159,8 @@ def train(
|
|
158 |
cfg.fp16 = True
|
159 |
cfg.bf16 = False
|
160 |
|
|
|
|
|
161 |
# Load the model and tokenizer
|
162 |
logging.info("loading model, tokenizer, and peft_config...")
|
163 |
model, tokenizer, peft_config = load_model(
|
|
|
14 |
|
15 |
# add src to the pythonpath so we don't need to pip install this
|
16 |
from axolotl.utils.tokenization import check_dataset_labels
|
17 |
+
from axolotl.utils.validation import validate_config
|
18 |
|
19 |
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
|
20 |
src_dir = os.path.join(project_root, "src")
|
|
|
159 |
cfg.fp16 = True
|
160 |
cfg.bf16 = False
|
161 |
|
162 |
+
validate_config(cfg)
|
163 |
+
|
164 |
# Load the model and tokenizer
|
165 |
logging.info("loading model, tokenizer, and peft_config...")
|
166 |
model, tokenizer, peft_config = load_model(
|
src/axolotl/utils/models.py
CHANGED
@@ -204,21 +204,6 @@ def load_model(
|
|
204 |
**model_kwargs,
|
205 |
)
|
206 |
|
207 |
-
"""### Post-processing on the model
|
208 |
-
Finally, we need to apply some post-processing on the 8-bit model to enable training, let's freeze all our layers, and cast the layer-norm in `float32` for stability. We also cast the output of the last layer in `float32` for the same reasons.
|
209 |
-
"""
|
210 |
-
# if cfg.adapter == "qlora":
|
211 |
-
# for param in model.parameters():
|
212 |
-
# param.requires_grad = False # freeze the model - train adapters later
|
213 |
-
# if param.ndim == 1:
|
214 |
-
# # cast the small parameters (e.g. layernorm) to fp32 for stability
|
215 |
-
# param.data = param.data.to(torch.float32)
|
216 |
-
# class CastOutputToFloat(nn.Linear):
|
217 |
-
# def forward(self, x):
|
218 |
-
# return super().forward(x).to(torch.float32)
|
219 |
-
#
|
220 |
-
# model.lm_head = CastOutputToFloat(model.lm_head.in_features, model.lm_head.out_features, model.lm_head.bias)
|
221 |
-
|
222 |
if not tokenizer:
|
223 |
try:
|
224 |
if is_llama_derived_model and "LlamaTokenizer" in globals():
|
|
|
204 |
**model_kwargs,
|
205 |
)
|
206 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
207 |
if not tokenizer:
|
208 |
try:
|
209 |
if is_llama_derived_model and "LlamaTokenizer" in globals():
|