Crystalcareai commited on
Commit
8b9bd5a
·
verified ·
1 Parent(s): 9f98307

Update train.py

Browse files
Files changed (1) hide show
  1. train.py +28 -22
train.py CHANGED
@@ -7,20 +7,19 @@ from transformers import TrainingArguments
7
  from trl import SFTTrainer
8
  from peft import LoraConfig
9
 
 
10
  import time
11
  random_seed = 42
12
  torch.manual_seed(random_seed)
13
  random.seed(random_seed)
14
 
15
- dataset = load_dataset("HuggingFaceH4/deita-10k-v0-sft", split="train_sft")
16
 
17
- n_ahead_talk_global = 2
18
  n_passes_global = 2
19
  n_ahead_global = 2
20
  n_examples = 0
21
- full_batch_size = 2
22
- eval_and_logging_steps = 2
23
- save_steps = 100
24
 
25
 
26
  def model_init(params):
@@ -39,7 +38,6 @@ def model_init(params):
39
  include_policy_loss = params.get("include_policy_loss", True)
40
  gumbel_detach = params.get("gumbel_detach", True)
41
  merged_talk_heads = params.get("merged_talk_heads", True)
42
- gradient_accumulation_steps = params.get("gradient_accumulation_steps", global_gradient_accumulation_steps)
43
  residual_think_head = params.get("residual_think_head", False)
44
  optimize_lm_head_only_at_start = params.get("optimize_lm_head_only_at_start", False)
45
 
@@ -48,7 +46,7 @@ def model_init(params):
48
  print("Loading model")
49
  model = AutoModelForCausalLM.from_pretrained(
50
  model_id,
51
- torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
52
  max_thoughts=n_ahead + n_ahead_talk + 1,
53
  merged_talk_heads=merged_talk_heads,
54
  merged_lm_and_talk_heads=False,
@@ -61,10 +59,12 @@ def model_init(params):
61
  use_weighted_talk_head=True,
62
  trust_remote_code=True,
63
  device_map="auto",
 
 
64
  )
65
  print("Loaded model")
66
 
67
- tokenizer = AutoTokenizer.from_pretrained(tokenizer_id,padding=False,truncation=True)
68
  tokenizer.pad_token_id = tokenizer.eos_token_id
69
 
70
  special_tokens_to_add = []
@@ -76,6 +76,10 @@ def model_init(params):
76
  tokenizer.add_special_tokens({"additional_special_tokens": special_tokens_to_add})
77
  model.resize_token_embeddings(len(tokenizer))
78
  model.tokenizer = tokenizer
 
 
 
 
79
  model.gumbel_detach = gumbel_detach
80
  model.include_policy_loss = include_policy_loss
81
  model.use_end_thought_token = use_end_thought_token
@@ -83,40 +87,40 @@ def model_init(params):
83
  model.n_ahead = n_ahead
84
  model.n_ahead_talk = n_ahead_talk
85
  model.n_passes = n_passes
86
- model.n_tokens_print = gradient_accumulation_steps
87
- model.gradient_accumulation_steps = gradient_accumulation_steps
88
  model.residual_think_head = residual_think_head
89
  model.optimize_lm_head_only_at_start = optimize_lm_head_only_at_start
90
  model.gumbel_temperature = gumbel_temperature
91
  model.original_mode = original
92
  model.config_params = params
93
  model.run_start = int(time.time())
94
- model.kill_after = 100
95
  model.train()
96
  return model
97
 
98
-
99
- batch_size = full_batch_size // n_passes_global
100
- global_gradient_accumulation_steps = full_batch_size // batch_size
101
  run_id = int(time.time())
102
  training_args = TrainingArguments(
103
  output_dir="./out",
104
- num_train_epochs=3,
105
  per_device_train_batch_size=1,
106
  gradient_checkpointing=False,
107
- gradient_accumulation_steps=4,
108
- optim="adamw_torch_fused",
109
  logging_steps=1,
110
  save_strategy="steps",
111
  save_steps=300,
112
  bf16=True,
113
  tf32=False,
 
 
 
114
  # auto_find_batch_size=True
115
- learning_rate=2e-07,
116
- max_grad_norm=1.0, # Gradient clipping with a maximum gradient norm of 0.3
117
- warmup_steps=100,
118
  lr_scheduler_type="cosine",
119
  push_to_hub=False,
 
 
120
  )
121
 
122
  # peft_config = LoraConfig(
@@ -131,14 +135,16 @@ training_args = TrainingArguments(
131
 
132
  torch.autograd.set_detect_anomaly(True)
133
  model = model_init(None) # Initialize the model
134
- tokenizer = model.tokenizer
135
 
 
 
136
  trainer = SFTTrainer(
137
  args=training_args,
138
  train_dataset=dataset,
139
  model=model,
140
  # peft_config=peft_config,
141
  tokenizer=tokenizer,
 
142
  )
143
 
144
- trainer.train()
 
7
  from trl import SFTTrainer
8
  from peft import LoraConfig
9
 
10
+
11
  import time
12
  random_seed = 42
13
  torch.manual_seed(random_seed)
14
  random.seed(random_seed)
15
 
16
+ dataset = load_dataset("Crystalcareai/Self-Discover-MM-Instruct-openai", split="train_sft")
17
 
18
+ n_ahead_talk_global = 3
19
  n_passes_global = 2
20
  n_ahead_global = 2
21
  n_examples = 0
22
+
 
 
23
 
24
 
25
  def model_init(params):
 
38
  include_policy_loss = params.get("include_policy_loss", True)
39
  gumbel_detach = params.get("gumbel_detach", True)
40
  merged_talk_heads = params.get("merged_talk_heads", True)
 
41
  residual_think_head = params.get("residual_think_head", False)
42
  optimize_lm_head_only_at_start = params.get("optimize_lm_head_only_at_start", False)
43
 
 
46
  print("Loading model")
47
  model = AutoModelForCausalLM.from_pretrained(
48
  model_id,
49
+ torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32,
50
  max_thoughts=n_ahead + n_ahead_talk + 1,
51
  merged_talk_heads=merged_talk_heads,
52
  merged_lm_and_talk_heads=False,
 
59
  use_weighted_talk_head=True,
60
  trust_remote_code=True,
61
  device_map="auto",
62
+ # load_in_4bit=True,
63
+ # attn_implementation="flash_attention_2",
64
  )
65
  print("Loaded model")
66
 
67
+ tokenizer = AutoTokenizer.from_pretrained(tokenizer_id,truncation=True,padding="left")
68
  tokenizer.pad_token_id = tokenizer.eos_token_id
69
 
70
  special_tokens_to_add = []
 
76
  tokenizer.add_special_tokens({"additional_special_tokens": special_tokens_to_add})
77
  model.resize_token_embeddings(len(tokenizer))
78
  model.tokenizer = tokenizer
79
+ for name, module in model.named_modules():
80
+ if "embed" in name:
81
+ print(module, flush=True)
82
+
83
  model.gumbel_detach = gumbel_detach
84
  model.include_policy_loss = include_policy_loss
85
  model.use_end_thought_token = use_end_thought_token
 
87
  model.n_ahead = n_ahead
88
  model.n_ahead_talk = n_ahead_talk
89
  model.n_passes = n_passes
 
 
90
  model.residual_think_head = residual_think_head
91
  model.optimize_lm_head_only_at_start = optimize_lm_head_only_at_start
92
  model.gumbel_temperature = gumbel_temperature
93
  model.original_mode = original
94
  model.config_params = params
95
  model.run_start = int(time.time())
 
96
  model.train()
97
  return model
98
 
99
+ max_seq_length = 1024
 
 
100
  run_id = int(time.time())
101
  training_args = TrainingArguments(
102
  output_dir="./out",
103
+ num_train_epochs=1.5,
104
  per_device_train_batch_size=1,
105
  gradient_checkpointing=False,
106
+ gradient_accumulation_steps=8,
107
+ optim="lion_32bit",
108
  logging_steps=1,
109
  save_strategy="steps",
110
  save_steps=300,
111
  bf16=True,
112
  tf32=False,
113
+ # epsilson=1e-05,
114
+ # beta1=0.9,
115
+ # beta2=0.95,
116
  # auto_find_batch_size=True
117
+ learning_rate=3e-07,
118
+ max_grad_norm=0.3, # Gradient clipping with a maximum gradient norm of 0.3
119
+ warmup_steps=10,
120
  lr_scheduler_type="cosine",
121
  push_to_hub=False,
122
+ report_to="wandb"
123
+
124
  )
125
 
126
  # peft_config = LoraConfig(
 
135
 
136
  torch.autograd.set_detect_anomaly(True)
137
  model = model_init(None) # Initialize the model
 
138
 
139
+ tokenizer = model.tokenizer
140
+
141
  trainer = SFTTrainer(
142
  args=training_args,
143
  train_dataset=dataset,
144
  model=model,
145
  # peft_config=peft_config,
146
  tokenizer=tokenizer,
147
+ max_seq_length=max_seq_length,
148
  )
149
 
150
+ trainer.train()