Crystalcareai commited on
Commit
d507e0c
·
verified ·
1 Parent(s): 705c81f

Create train.py

Browse files
Files changed (1) hide show
  1. train.py +142 -0
train.py ADDED
@@ -0,0 +1,142 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ torch.backends.cuda.matmul.allow_tf32 = True
3
+ import random
4
+ from transformers import AutoTokenizer, AutoModelForCausalLM
5
+ from datasets import load_dataset
6
+ from transformers import TrainingArguments
7
+ from trl import SFTTrainer
8
+ from peft import LoraConfig
9
+
10
+ import time
11
+ random_seed = 42
12
+ torch.manual_seed(random_seed)
13
+ random.seed(random_seed)
14
+
15
+ dataset = load_dataset("HuggingFaceH4/deita-10k-v0-sft", split="train_sft")
16
+
17
+ n_ahead_talk_global = 4
18
+ n_passes_global = 2
19
+ n_ahead_global = 12
20
+ n_examples = 1_000
21
+ full_batch_size = 8
22
+ eval_and_logging_steps = 2
23
+ save_steps = 100
24
+
25
+
26
+ def model_init(params):
27
+ original = False
28
+ if params is None:
29
+ params = {}
30
+ else:
31
+ params = params.params
32
+ # save params to file
33
+ n_ahead = params.get("n_ahead", n_ahead_global if not original else 1)
34
+ n_ahead_talk = params.get("n_ahead_talk", n_ahead_talk_global if not original else 1)
35
+ n_passes = params.get("n_passes", n_passes_global if not original else 1)
36
+ gumbel_temperature = params.get("gumbel_temperature", 1)
37
+ use_start_thought_token = params.get("use_start_thought_token", True)
38
+ use_end_thought_token = params.get("use_end_thought_token", True)
39
+ include_policy_loss = params.get("include_policy_loss", True)
40
+ gumbel_detach = params.get("gumbel_detach", True)
41
+ merged_talk_heads = params.get("merged_talk_heads", True)
42
+ gradient_accumulation_steps = params.get("gradient_accumulation_steps", global_gradient_accumulation_steps)
43
+ residual_think_head = params.get("residual_think_head", False)
44
+ optimize_lm_head_only_at_start = params.get("optimize_lm_head_only_at_start", False)
45
+
46
+ model_id = "Crystalcareai/Quiet-Star-Custom"
47
+ tokenizer_id = "Crystalcareai/Quiet-Star-Custom"
48
+ print("Loading model")
49
+ model = AutoModelForCausalLM.from_pretrained(
50
+ model_id,
51
+ torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32,
52
+ max_thoughts=n_ahead + n_ahead_talk + 1,
53
+ merged_talk_heads=merged_talk_heads,
54
+ merged_lm_and_talk_heads=False,
55
+ merged_lm_and_think_heads=True,
56
+ use_concat_talk_head=True,
57
+ use_shallow_think=True,
58
+ use_shallow_talk=False,
59
+ use_complex_think_head=False,
60
+ use_complex_talk_head=True,
61
+ use_weighted_talk_head=True,
62
+ trust_remote_code=True,
63
+ load_in_4bit=True,
64
+ )
65
+ print("Loaded model")
66
+
67
+ tokenizer = AutoTokenizer.from_pretrained(tokenizer_id,padding=False,truncation=True)
68
+ tokenizer.pad_token_id = tokenizer.eos_token_id
69
+
70
+ special_tokens_to_add = []
71
+ if model.use_start_thought_token:
72
+ special_tokens_to_add.append("<|startthought|>")
73
+ if model.use_end_thought_token:
74
+ special_tokens_to_add.append("<|endthought|>")
75
+ if special_tokens_to_add:
76
+ tokenizer.add_special_tokens({"additional_special_tokens": special_tokens_to_add})
77
+ model.resize_token_embeddings(len(tokenizer))
78
+ model.tokenizer = tokenizer
79
+ model.gumbel_detach = gumbel_detach
80
+ model.include_policy_loss = include_policy_loss
81
+ model.use_end_thought_token = use_end_thought_token
82
+ model.use_start_thought_token = use_start_thought_token
83
+ model.n_ahead = n_ahead
84
+ model.n_ahead_talk = n_ahead_talk
85
+ model.n_passes = n_passes
86
+ model.n_tokens_print = gradient_accumulation_steps
87
+ model.gradient_accumulation_steps = gradient_accumulation_steps
88
+ model.residual_think_head = residual_think_head
89
+ model.optimize_lm_head_only_at_start = optimize_lm_head_only_at_start
90
+ model.gumbel_temperature = gumbel_temperature
91
+ model.original_mode = original
92
+ model.config_params = params
93
+ model.run_start = int(time.time())
94
+ model.kill_after = 100
95
+ model.train()
96
+ return model
97
+
98
+
99
+ batch_size = full_batch_size // n_passes_global
100
+ global_gradient_accumulation_steps = full_batch_size // batch_size
101
+ run_id = int(time.time())
102
+ training_args = TrainingArguments(
103
+ output_dir="./out",
104
+ num_train_epochs=3,
105
+ per_device_train_batch_size=1,
106
+ gradient_checkpointing=False,
107
+ optim="adamw_bnb_8bit",
108
+ logging_steps=2,
109
+ save_strategy="steps",
110
+ save_steps=300,
111
+
112
+ bf16=True,
113
+ tf32=True,
114
+ learning_rate=2e-4,
115
+ max_grad_norm=0.3,
116
+ warmup_ratio=0.00,
117
+ lr_scheduler_type="constant",
118
+ push_to_hub=False,
119
+ )
120
+
121
+ peft_config = LoraConfig(
122
+ lora_alpha=16,
123
+ lora_dropout=0.05,
124
+ r=32,
125
+ bias="none",
126
+ target_modules = ["q_proj", "k_proj", "v_proj", "o_proj","gate_proj", "up_proj", "down_proj",],
127
+ task_type="CAUSAL_LM",
128
+ use_dora=False, # Enable Dora method
129
+ )
130
+
131
+ model = model_init(None) # Initialize the model
132
+ tokenizer = model.tokenizer
133
+
134
+ trainer = SFTTrainer(
135
+ args=training_args,
136
+ train_dataset=dataset,
137
+ model=model,
138
+ peft_config=peft_config,
139
+ tokenizer=tokenizer,
140
+ )
141
+
142
+ trainer.train()