winglian commited on
Commit
1f5d83e
·
1 Parent(s): 7e81ca7

remove un-needed code, add validation

Browse files
Files changed (2) hide show
  1. scripts/finetune.py +3 -0
  2. src/axolotl/utils/models.py +0 -15
scripts/finetune.py CHANGED
@@ -14,6 +14,7 @@ from attrdict import AttrDefault
14
 
15
  # add src to the pythonpath so we don't need to pip install this
16
  from axolotl.utils.tokenization import check_dataset_labels
 
17
 
18
  project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
19
  src_dir = os.path.join(project_root, "src")
@@ -158,6 +159,8 @@ def train(
158
  cfg.fp16 = True
159
  cfg.bf16 = False
160
 
 
 
161
  # Load the model and tokenizer
162
  logging.info("loading model, tokenizer, and peft_config...")
163
  model, tokenizer, peft_config = load_model(
 
14
 
15
  # add src to the pythonpath so we don't need to pip install this
16
  from axolotl.utils.tokenization import check_dataset_labels
17
+ from axolotl.utils.validation import validate_config
18
 
19
  project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
20
  src_dir = os.path.join(project_root, "src")
 
159
  cfg.fp16 = True
160
  cfg.bf16 = False
161
 
162
+ validate_config(cfg)
163
+
164
  # Load the model and tokenizer
165
  logging.info("loading model, tokenizer, and peft_config...")
166
  model, tokenizer, peft_config = load_model(
src/axolotl/utils/models.py CHANGED
@@ -204,21 +204,6 @@ def load_model(
204
  **model_kwargs,
205
  )
206
 
207
- """### Post-processing on the model
208
- Finally, we need to apply some post-processing on the 8-bit model to enable training, let's freeze all our layers, and cast the layer-norm in `float32` for stability. We also cast the output of the last layer in `float32` for the same reasons.
209
- """
210
- # if cfg.adapter == "qlora":
211
- # for param in model.parameters():
212
- # param.requires_grad = False # freeze the model - train adapters later
213
- # if param.ndim == 1:
214
- # # cast the small parameters (e.g. layernorm) to fp32 for stability
215
- # param.data = param.data.to(torch.float32)
216
- # class CastOutputToFloat(nn.Linear):
217
- # def forward(self, x):
218
- # return super().forward(x).to(torch.float32)
219
- #
220
- # model.lm_head = CastOutputToFloat(model.lm_head.in_features, model.lm_head.out_features, model.lm_head.bias)
221
-
222
  if not tokenizer:
223
  try:
224
  if is_llama_derived_model and "LlamaTokenizer" in globals():
 
204
  **model_kwargs,
205
  )
206
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
207
  if not tokenizer:
208
  try:
209
  if is_llama_derived_model and "LlamaTokenizer" in globals():