Crystalcareai commited on
Commit
d842ce9
·
verified ·
1 Parent(s): d507e0c

Update train.py

Browse files
Files changed (1) hide show
  1. train.py +6 -6
train.py CHANGED
@@ -17,7 +17,6 @@ dataset = load_dataset("HuggingFaceH4/deita-10k-v0-sft", split="train_sft")
17
  n_ahead_talk_global = 4
18
  n_passes_global = 2
19
  n_ahead_global = 12
20
- n_examples = 1_000
21
  full_batch_size = 8
22
  eval_and_logging_steps = 2
23
  save_steps = 100
@@ -64,7 +63,8 @@ def model_init(params):
64
  )
65
  print("Loaded model")
66
 
67
- tokenizer = AutoTokenizer.from_pretrained(tokenizer_id,padding=False,truncation=True)
 
68
  tokenizer.pad_token_id = tokenizer.eos_token_id
69
 
70
  special_tokens_to_add = []
@@ -103,14 +103,14 @@ training_args = TrainingArguments(
103
  output_dir="./out",
104
  num_train_epochs=3,
105
  per_device_train_batch_size=1,
106
- gradient_checkpointing=False,
 
107
  optim="adamw_bnb_8bit",
108
  logging_steps=2,
109
  save_strategy="steps",
110
  save_steps=300,
111
-
112
  bf16=True,
113
- tf32=True,
114
  learning_rate=2e-4,
115
  max_grad_norm=0.3,
116
  warmup_ratio=0.00,
@@ -139,4 +139,4 @@ trainer = SFTTrainer(
139
  tokenizer=tokenizer,
140
  )
141
 
142
- trainer.train()
 
17
  n_ahead_talk_global = 4
18
  n_passes_global = 2
19
  n_ahead_global = 12
 
20
  full_batch_size = 8
21
  eval_and_logging_steps = 2
22
  save_steps = 100
 
63
  )
64
  print("Loaded model")
65
 
66
+ tokenizer = AutoTokenizer.from_pretrained(tokenizer_id)
67
+ tokenizer.padding_side = "right"
68
  tokenizer.pad_token_id = tokenizer.eos_token_id
69
 
70
  special_tokens_to_add = []
 
103
  output_dir="./out",
104
  num_train_epochs=3,
105
  per_device_train_batch_size=1,
106
+ gradient_accumulation_steps=global_gradient_accumulation_steps,
107
+ gradient_checkpointing=True,
108
  optim="adamw_bnb_8bit",
109
  logging_steps=2,
110
  save_strategy="steps",
111
  save_steps=300,
 
112
  bf16=True,
113
+ tf32=False,
114
  learning_rate=2e-4,
115
  max_grad_norm=0.3,
116
  warmup_ratio=0.00,
 
139
  tokenizer=tokenizer,
140
  )
141
 
142
+ trainer.train()