Vipitis commited on
Commit
6bf0547
·
1 Parent(s): db86d40

add finetuning script

Browse files
Files changed (1) hide show
  1. train.py +297 -0
train.py ADDED
@@ -0,0 +1,297 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Fine-Tune SantaCoder on code/text dataset
3
+ """
4
+ # copied from https://github.com/loubnabnl/santacoder-finetuning
5
+ # removed all parts related to FIM
6
+ # set --subset to default to None instead of "data" to avoid issues with my own datasets.
7
+ # added --resume_from_checkpoint to resume training from a checkpoint (untested)
8
+
9
+
10
+ import argparse
11
+ import os
12
+ import random
13
+ import sys
14
+
15
+ import numpy as np
16
+ import torch
17
+ from datasets import load_dataset
18
+ from torch.utils.data import IterableDataset
19
+ from torch.utils.data.dataloader import DataLoader
20
+ from tqdm import tqdm
21
+ from transformers import (
22
+ AutoModelForCausalLM,
23
+ AutoTokenizer,
24
+ Trainer,
25
+ TrainingArguments,
26
+ logging,
27
+ set_seed,
28
+ )
29
+
30
+ # import fim
31
+
32
+
33
+ def get_args():
34
+ parser = argparse.ArgumentParser()
35
+ parser.add_argument("--resume_from_checkpoint", type=str, default=None) #can pass a checkpoint dir to resume training
36
+ parser.add_argument("--model_path", type=str, default="bigcode/santacoder")
37
+ parser.add_argument("--dataset_name", type=str, default="bigcode/the-stack-dedup")
38
+ parser.add_argument("--subset", type=str, default=None) #None a bodge but not the solution
39
+ parser.add_argument("--split", type=str, default="train")
40
+ parser.add_argument("--size_valid_set", type=int, default=4000)
41
+ parser.add_argument("--streaming", action="store_true")
42
+ parser.add_argument("--shuffle_buffer", type=int, default=5000)
43
+ parser.add_argument("--data_column", type=str, default="content")
44
+
45
+ parser.add_argument("--seq_length", type=int, default=1024)
46
+ parser.add_argument("--max_steps", type=int, default=10000)
47
+ parser.add_argument("--batch_size", type=int, default=2)
48
+ parser.add_argument("--gradient_accumulation_steps", type=int, default=8)
49
+ parser.add_argument("--eos_token_id", type=int, default=49152)
50
+
51
+ parser.add_argument("--learning_rate", type=float, default=5e-5)
52
+ parser.add_argument("--lr_scheduler_type", type=str, default="cosine")
53
+ parser.add_argument("--num_warmup_steps", type=int, default=100)
54
+ parser.add_argument("--weight_decay", type=float, default=0.05)
55
+
56
+ parser.add_argument("--local_rank", type=int, default=0)
57
+ parser.add_argument("--no_fp16", action="store_false")
58
+ parser.add_argument("--bf16", action="store_true")
59
+ parser.add_argument("--no_gradient_checkpointing", action="store_false")
60
+ parser.add_argument("--seed", type=int, default=0)
61
+ parser.add_argument("--num_workers", type=int, default=None)
62
+ parser.add_argument("--output_dir", type=str, default="./checkpoints")
63
+ parser.add_argument("--log_freq", default=1, type=int)
64
+ parser.add_argument("--eval_freq", default=1000, type=int)
65
+ parser.add_argument("--save_freq", default=1000, type=int)
66
+
67
+ # parser.add_argument("--fim_rate", type=float, default=0)
68
+ # parser.add_argument("--fim_spm_rate", type=float, default=0)
69
+ return parser.parse_args()
70
+
71
+
72
+ def chars_token_ratio(dataset, tokenizer, data_column, nb_examples=400):
73
+ """
74
+ Estimate the average number of characters per token in the dataset.
75
+ """
76
+ total_characters, total_tokens = 0, 0
77
+ for _, example in tqdm(zip(range(nb_examples), iter(dataset)), total=nb_examples):
78
+ total_characters += len(example[data_column])
79
+ total_tokens += len(tokenizer(example[data_column]).tokens())
80
+
81
+ return total_characters / total_tokens
82
+
83
+
84
+ class ConstantLengthDataset(IterableDataset):
85
+ """
86
+ Iterable dataset that returns constant length chunks of tokens from stream of text files.
87
+ Args:
88
+ tokenizer (Tokenizer): The processor used for proccessing the data.
89
+ dataset (dataset.Dataset): Dataset with text files.
90
+ infinite (bool): If True the iterator is reset after dataset reaches end else stops.
91
+ seq_length (int): Length of token sequences to return.
92
+ num_of_sequences (int): Number of token sequences to keep in buffer.
93
+ chars_per_token (int): Number of characters per token used to estimate number of tokens in text buffer.
94
+ # fim_rate (float): Rate (0.0 to 1.0) that sample will be permuted with FIM.
95
+ # fim_spm_rate (float): Rate (0.0 to 1.0) of FIM permuations that will use SPM.
96
+ seed (int): Seed for random number generator.
97
+ """
98
+
99
+ def __init__(
100
+ self,
101
+ tokenizer,
102
+ dataset,
103
+ infinite=False,
104
+ seq_length=1024,
105
+ num_of_sequences=1024,
106
+ chars_per_token=3.6,
107
+ content_field="content",
108
+ # fim_rate=0.5,
109
+ # fim_spm_rate=0.5,
110
+ seed=0,
111
+ ):
112
+ self.tokenizer = tokenizer
113
+ self.concat_token_id = (
114
+ tokenizer.eos_token_id if tokenizer.eos_token_id else args.eos_token_id
115
+ )
116
+ self.dataset = dataset
117
+ self.seq_length = seq_length
118
+ self.infinite = infinite
119
+ self.current_size = 0
120
+ self.max_buffer_size = seq_length * chars_per_token * num_of_sequences
121
+ self.content_field = content_field
122
+ # self.fim_rate = fim_rate
123
+ # self.fim_spm_rate = fim_spm_rate
124
+ self.seed = seed
125
+
126
+ # (
127
+ # self.suffix_tok_id,
128
+ # self.prefix_tok_id,
129
+ # self.middle_tok_id,
130
+ # self.pad_tok_id,
131
+ # ) = fim.get_fim_token_ids(self.tokenizer)
132
+ # if not self.suffix_tok_id and self.fim_rate > 0:
133
+ # print("FIM is not supported by tokenizer, disabling FIM")
134
+ # self.fim_rate = 0
135
+
136
+ def __iter__(self):
137
+ iterator = iter(self.dataset)
138
+ more_examples = True
139
+ while more_examples:
140
+ buffer, buffer_len = [], 0
141
+ while True:
142
+ if buffer_len >= self.max_buffer_size:
143
+ break
144
+ try:
145
+ buffer.append(next(iterator)[self.content_field])
146
+ buffer_len += len(buffer[-1])
147
+ except StopIteration:
148
+ if self.infinite:
149
+ iterator = iter(self.dataset)
150
+ else:
151
+ more_examples = False
152
+ break
153
+ tokenized_inputs = self.tokenizer(buffer, truncation=False)["input_ids"]
154
+ all_token_ids = []
155
+
156
+ np_rng = np.random.RandomState(seed=self.seed)
157
+ for tokenized_input in tokenized_inputs:
158
+ # optionally do FIM permutations
159
+ # if self.fim_rate > 0:
160
+ # tokenized_input, np_rng = fim.permute(
161
+ # tokenized_input,
162
+ # np_rng,
163
+ # self.suffix_tok_id,
164
+ # self.prefix_tok_id,
165
+ # self.middle_tok_id,
166
+ # self.pad_tok_id,
167
+ # fim_rate=self.fim_rate,
168
+ # fim_spm_rate=self.fim_spm_rate,
169
+ # truncate_or_pad=False,
170
+ # )
171
+
172
+ all_token_ids.extend(tokenized_input + [self.concat_token_id])
173
+ examples = []
174
+ for i in range(0, len(all_token_ids), self.seq_length):
175
+ input_ids = all_token_ids[i : i + self.seq_length]
176
+ if len(input_ids) == self.seq_length:
177
+ examples.append(input_ids)
178
+ random.shuffle(examples)
179
+ for example in examples:
180
+ self.current_size += 1
181
+ yield {
182
+ "input_ids": torch.LongTensor(example),
183
+ "labels": torch.LongTensor(example),
184
+ }
185
+
186
+ def create_datasets(tokenizer, args):
187
+ dataset = load_dataset(
188
+ args.dataset_name,
189
+ data_dir=args.subset,
190
+ split=args.split,
191
+ use_auth_token=True,
192
+ num_proc=args.num_workers if not args.streaming else None,
193
+ streaming=args.streaming,
194
+ )
195
+ if args.streaming:
196
+ print("Loading the dataset in streaming mode")
197
+ valid_data = dataset.take(args.size_valid_set)
198
+ train_data = dataset.skip(args.size_valid_set)
199
+ train_data = train_data.shuffle(buffer_size=args.shuffle_buffer, seed=args.seed)
200
+ else:
201
+ dataset = dataset.train_test_split(test_size=0.005, seed=args.seed)
202
+ train_data = dataset["train"]
203
+ valid_data = dataset["test"]
204
+ print(
205
+ f"Size of the train set: {len(train_data)}. Size of the validation set: {len(valid_data)}"
206
+ )
207
+ chars_per_token = chars_token_ratio(train_data, tokenizer, args.data_column)
208
+ print(f"The character to token ratio of the dataset is: {chars_per_token:.2f}")
209
+ train_dataset = ConstantLengthDataset(
210
+ tokenizer,
211
+ train_data,
212
+ infinite=True,
213
+ seq_length=args.seq_length,
214
+ chars_per_token=chars_per_token,
215
+ content_field=args.data_column,
216
+ # fim_rate=args.fim_rate,
217
+ # fim_spm_rate=args.fim_spm_rate,
218
+ seed=args.seed,
219
+ )
220
+ valid_dataset = ConstantLengthDataset(
221
+ tokenizer,
222
+ valid_data,
223
+ infinite=False,
224
+ seq_length=args.seq_length,
225
+ chars_per_token=chars_per_token,
226
+ content_field=args.data_column,
227
+ # fim_rate=args.fim_rate,
228
+ # fim_spm_rate=args.fim_spm_rate,
229
+ seed=args.seed,
230
+ )
231
+
232
+ return train_dataset, valid_dataset
233
+
234
+
235
+ def run_training(args, train_data, val_data):
236
+ print("Loading the model")
237
+ # disable caching mechanism when using gradient checkpointing
238
+ model = AutoModelForCausalLM.from_pretrained(
239
+ args.model_path,
240
+ trust_remote_code=True,
241
+ use_cache=not args.no_gradient_checkpointing,
242
+ )
243
+ train_data.start_iteration = 0
244
+
245
+ print(f"Starting main loop")
246
+
247
+ training_args = TrainingArguments(
248
+ output_dir=args.output_dir,
249
+ dataloader_drop_last=True,
250
+ evaluation_strategy="steps",
251
+ max_steps=args.max_steps,
252
+ eval_steps=args.eval_freq,
253
+ save_steps=args.save_freq,
254
+ logging_steps=args.log_freq,
255
+ per_device_train_batch_size=args.batch_size,
256
+ per_device_eval_batch_size=args.batch_size,
257
+ learning_rate=args.learning_rate,
258
+ lr_scheduler_type=args.lr_scheduler_type,
259
+ warmup_steps=args.num_warmup_steps,
260
+ gradient_accumulation_steps=args.gradient_accumulation_steps,
261
+ gradient_checkpointing=args.no_gradient_checkpointing,
262
+ fp16=args.no_fp16,
263
+ bf16=args.bf16,
264
+ weight_decay=args.weight_decay,
265
+ run_name=f"santacoder-{args.subset}",
266
+ # report_to="wandb", #I am not using that, so I just comment it out to avoid errors?
267
+ )
268
+
269
+ trainer = Trainer(
270
+ model=model, args=training_args, train_dataset=train_data, eval_dataset=val_data
271
+ )
272
+
273
+ print("Training...")
274
+ trainer.train(args.resume_from_checkpoint) #can resume here
275
+
276
+ print("Saving last checkpoint of the model")
277
+ model.save_pretrained(os.path.join(args.output_dir, "final_checkpoint/"))
278
+
279
+
280
+ def main(args):
281
+ tokenizer = AutoTokenizer.from_pretrained(args.model_path, use_auth_token=True)
282
+
283
+ train_dataset, eval_dataset = create_datasets(tokenizer, args)
284
+
285
+ run_training(args, train_dataset, eval_dataset)
286
+
287
+
288
+ if __name__ == "__main__":
289
+ print(sys.argv) #to abort early
290
+ args = get_args()
291
+ print(args) #see if the file actually red?
292
+ set_seed(args.seed)
293
+ os.makedirs(args.output_dir, exist_ok=True)
294
+
295
+ logging.set_verbosity_info() #lower verbosity
296
+
297
+ main(args)