qgyd2021 commited on
Commit
03ee06f
·
1 Parent(s): c525f79

[update]add model

Browse files
examples/chinese_chitchat/step_1_prepare_data.py ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+ import argparse
4
+ from itertools import chain
5
+ import os
6
+ from pathlib import Path
7
+ import platform
8
+
9
+ if platform.system() == "Windows":
10
+ from project_settings import project_path
11
+ else:
12
+ project_path = os.path.abspath("./")
13
+ project_path = Path(project_path)
14
+
15
+ from datasets import load_dataset, concatenate_datasets, IterableDataset, Dataset
16
+
17
+
18
+ def get_args():
19
+ parser = argparse.ArgumentParser()
20
+ parser.add_argument("--dataset_path", default="qgyd2021/chinese_chitchat", type=str)
21
+ parser.add_argument("--dataset_split", default=None, type=str)
22
+ parser.add_argument(
23
+ "--dataset_cache_dir",
24
+ default=(project_path / "hub_datasets").as_posix(),
25
+ type=str
26
+ )
27
+ parser.add_argument("--dataset_streaming", default=False, type=bool)
28
+ parser.add_argument("--valid_dataset_size", default=10000, type=int)
29
+
30
+ parser.add_argument(
31
+ "--num_workers",
32
+ default=None if platform.system() == "Windows" else os.cpu_count() // 2,
33
+ type=str
34
+ )
35
+
36
+ parser.add_argument("--seed", default=3407, type=str, help="https://arxiv.org/abs/2109.08203")
37
+
38
+ args = parser.parse_args()
39
+ return args
40
+
41
+
42
+ def main():
43
+ args = get_args()
44
+
45
+ names = [
46
+ "qingyun", "chatterbot", "douban", "ptt", "subtitle", "tieba", "weibo", "xiaohuangji"
47
+ ]
48
+ dataset_list = list()
49
+ for name in names:
50
+ dataset_dict = load_dataset(
51
+ path=args.dataset_path,
52
+ name=name,
53
+ split=args.dataset_split,
54
+ cache_dir=args.dataset_cache_dir,
55
+ num_proc=args.num_workers if not args.dataset_streaming else None,
56
+ streaming=args.dataset_streaming,
57
+ )
58
+
59
+ dataset = dataset_dict["train"]
60
+ dataset_list.append(dataset)
61
+
62
+ dataset = concatenate_datasets(dataset_list)
63
+
64
+ if args.dataset_streaming:
65
+ valid_dataset = dataset.take(args.valid_dataset_size)
66
+ train_dataset = dataset.skip(args.valid_dataset_size)
67
+ train_dataset = train_dataset.shuffle(buffer_size=args.shuffle_buffer_size, seed=None)
68
+ else:
69
+ dataset = dataset.train_test_split(test_size=args.valid_dataset_size, seed=None)
70
+ train_dataset = dataset["train"]
71
+ valid_dataset = dataset["test"]
72
+
73
+ print(train_dataset)
74
+ print(valid_dataset)
75
+ return
76
+
77
+
78
+ if __name__ == '__main__':
79
+ main()
examples/chinese_chitchat/step_2_train_model.py ADDED
@@ -0,0 +1,260 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+ from dataclasses import dataclass, field
4
+ import os
5
+ from pathlib import Path
6
+ import platform
7
+ import re
8
+ from typing import Dict, List, Optional, Union
9
+
10
+ if platform.system() == "Windows":
11
+ from project_settings import project_path
12
+ else:
13
+ project_path = os.path.abspath("./")
14
+ project_path = Path(project_path)
15
+
16
+ hf_hub_cache = (project_path / "cache/huggingface/hub").as_posix()
17
+
18
+ os.environ["HUGGINGFACE_HUB_CACHE"] = hf_hub_cache
19
+
20
+ from datasets import concatenate_datasets, load_dataset
21
+ import huggingface_hub
22
+ import torch
23
+ import torch.multiprocessing as mp
24
+ from transformers import HfArgumentParser
25
+ from transformers.data.data_collator import DataCollatorForLanguageModeling
26
+ from transformers.models.auto import AutoModelForCausalLM, AutoTokenizer
27
+ from transformers.models.gpt2.modeling_gpt2 import GPT2LMHeadModel
28
+ from transformers.trainer import Trainer
29
+ from transformers.trainer_callback import EarlyStoppingCallback
30
+ from transformers.training_args import TrainingArguments
31
+
32
+
33
+ @dataclass
34
+ class ScriptArguments:
35
+ # dataset
36
+ dataset_path: str = field(default="qgyd2021/chinese_chitchat")
37
+ dataset_name: str = field(default=None)
38
+ dataset_split: str = field(default=None)
39
+ dataset_cache_dir: str = field(default=(project_path / "hub_datasets").as_posix())
40
+ dataset_streaming: bool = field(default=False)
41
+ num_workers: int = field(default=None if platform.system() == "Windows" else os.cpu_count() // 2)
42
+
43
+ valid_dataset_size: int = field(default=10000)
44
+ seed: int = field(default=3407)
45
+
46
+ # model
47
+ # pretrained_model_name_or_path: str = field(
48
+ # default="uer/gpt2-chinese-cluecorpussmall" if platform.system() != "Windows" else (project_path / "pretrained_models/gpt2-chinese-cluecorpussmall").as_posix()
49
+ # )
50
+ pretrained_model_name_or_path: str = field(
51
+ default="qgyd2021/chinese_chitchat"
52
+ )
53
+ hf_token: str = field(default="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")
54
+
55
+
56
+ def get_args():
57
+ parser = HfArgumentParser(ScriptArguments)
58
+ args = parser.parse_args_into_dataclasses(return_remaining_strings=True)[0]
59
+ return args
60
+
61
+
62
+ def train_model(local_rank, world_size, args):
63
+ os.environ["RANK"] = f"{local_rank}"
64
+ os.environ["LOCAL_RANK"] = f"{local_rank}"
65
+ os.environ["WORLD_SIZE"] = f"{world_size}"
66
+ os.environ["MASTER_ADDR"] = "localhost"
67
+ os.environ["MASTER_PORT"] = "12355"
68
+
69
+ huggingface_hub.login(token=args.hf_token)
70
+
71
+ # dataset
72
+ names = [
73
+ # "qingyun", "chatterbot",
74
+ # "douban", "ptt", "subtitle", "tieba", "weibo",
75
+ "xiaohuangji"
76
+ ]
77
+ dataset_list = list()
78
+ for name in names:
79
+ dataset_dict = load_dataset(
80
+ path=args.dataset_path,
81
+ name=name,
82
+ split=args.dataset_split,
83
+ cache_dir=args.dataset_cache_dir,
84
+ # num_proc=args.num_workers if not args.dataset_streaming else None,
85
+ streaming=args.dataset_streaming,
86
+ )
87
+
88
+ dataset = dataset_dict["train"]
89
+ dataset_list.append(dataset)
90
+
91
+ dataset = concatenate_datasets(dataset_list)
92
+
93
+ if args.dataset_streaming:
94
+ valid_dataset = dataset.take(args.valid_dataset_size)
95
+ train_dataset = dataset.skip(args.valid_dataset_size)
96
+ train_dataset = train_dataset.shuffle(buffer_size=args.shuffle_buffer_size, seed=args.seed)
97
+ else:
98
+ dataset = dataset.train_test_split(test_size=args.valid_dataset_size, seed=args.seed)
99
+ train_dataset = dataset["train"]
100
+ valid_dataset = dataset["test"]
101
+
102
+ # pretrained model
103
+ model: GPT2LMHeadModel = AutoModelForCausalLM.from_pretrained(args.pretrained_model_name_or_path)
104
+ tokenizer = AutoTokenizer.from_pretrained(args.pretrained_model_name_or_path)
105
+
106
+ # map
107
+ def encode(examples: dict):
108
+ conversation_ = examples.pop("conversation")
109
+
110
+ utterances = list()
111
+ for row_ in conversation_:
112
+ message_ = row_["message"]
113
+ utterance = tokenizer.sep_token.join(message_)
114
+ utterances.append(utterance)
115
+
116
+ utterances = tokenizer.__call__(
117
+ text=utterances,
118
+ truncation=True,
119
+ padding="longest",
120
+ max_length=1024,
121
+ return_special_tokens_mask=True,
122
+ )
123
+ return utterances
124
+
125
+ train_dataset = train_dataset.map(
126
+ encode,
127
+ batched=True,
128
+ drop_last_batch=True,
129
+ batch_size=10,
130
+ num_proc=args.num_workers if not args.dataset_streaming else None,
131
+ cache_file_name="train.cache"
132
+ )
133
+ valid_dataset = valid_dataset.map(
134
+ encode,
135
+ batched=True,
136
+ drop_last_batch=True,
137
+ batch_size=10,
138
+ num_proc=args.num_workers if not args.dataset_streaming else None,
139
+ cache_file_name="valid.cache"
140
+ )
141
+ dataset_info = f"""
142
+ train dataset: {len(train_dataset)}
143
+ valid dataset: {len(valid_dataset)}
144
+ """
145
+ dataset_info = re.sub(r"[\u0020]{4,}", "", dataset_info)
146
+ print(dataset_info)
147
+
148
+ data_collator = DataCollatorForLanguageModeling(
149
+ tokenizer=tokenizer, mlm=False
150
+ )
151
+
152
+ # training_args
153
+ training_args = TrainingArguments(
154
+ output_dir="output_dir",
155
+ evaluation_strategy="steps",
156
+ per_device_train_batch_size=16,
157
+ gradient_accumulation_steps=4,
158
+ learning_rate=2e-4,
159
+ weight_decay=0,
160
+ max_grad_norm=1.0,
161
+ num_train_epochs=40.0,
162
+ warmup_steps=10000,
163
+ logging_steps=1000,
164
+ save_strategy="steps",
165
+ save_steps=1000,
166
+ save_total_limit=2,
167
+ no_cuda=False,
168
+ fp16=True if torch.cuda.is_available() else False,
169
+ local_rank=local_rank,
170
+ ddp_backend="nccl",
171
+ remove_unused_columns=True,
172
+ load_best_model_at_end=True,
173
+ metric_for_best_model="loss",
174
+ greater_is_better=False,
175
+ report_to="tensorboard",
176
+ push_to_hub=True,
177
+ hub_model_id="chinese_chitchat",
178
+ hub_strategy="every_save",
179
+ gradient_checkpointing=True,
180
+ )
181
+
182
+ partial_state_str = f"""
183
+ distributed_type: {training_args.distributed_state.distributed_type}
184
+ local_process_index: {training_args.distributed_state.local_process_index}
185
+ num_processes: {training_args.distributed_state.num_processes}
186
+ process_index: {training_args.distributed_state.process_index}
187
+ device: {training_args.distributed_state.device}
188
+ """
189
+ partial_state_str = re.sub(r"[\u0020]{4,}", "", partial_state_str)
190
+ print(partial_state_str)
191
+
192
+ environ = f"""
193
+ RANK: {os.environ.get("RANK", -1)}
194
+ WORLD_SIZE: {os.environ.get("WORLD_SIZE", -1)}
195
+ LOCAL_RANK: {os.environ.get("LOCAL_RANK", -1)}
196
+ """
197
+ environ = re.sub(r"[\u0020]{4,}", "", environ)
198
+ print(environ)
199
+
200
+ callbacks = [
201
+ EarlyStoppingCallback(early_stopping_patience=5)
202
+ ]
203
+
204
+ trainer = Trainer(
205
+ model=model,
206
+ args=training_args,
207
+ data_collator=data_collator,
208
+ train_dataset=train_dataset,
209
+ eval_dataset=valid_dataset,
210
+ tokenizer=tokenizer,
211
+ callbacks=callbacks
212
+ )
213
+ train_result = trainer.train()
214
+
215
+ # 保存最好的 checkpoint
216
+ final_save_path = os.path.join(training_args.output_dir, "final")
217
+ trainer.save_model(final_save_path) # Saves the tokenizer too
218
+ # 保存训练指标
219
+ metrics = train_result.metrics
220
+ trainer.log_metrics("train", metrics)
221
+ trainer.save_metrics("train", metrics)
222
+ trainer.save_state()
223
+
224
+ tokenizer.save_pretrained(final_save_path)
225
+ return
226
+
227
+
228
+ def train_on_cpu():
229
+ args = get_args()
230
+
231
+ train_model(0, 1, args)
232
+ return
233
+
234
+
235
+ def train_on_kaggle_notebook():
236
+ """
237
+ train on kaggle notebook with GPU T4 x2
238
+
239
+ from shutil import copyfile
240
+ copyfile(src = "../input/tempdataset/step_2_train_model.py", dst = "../working/step_2_train_model.py")
241
+
242
+ import step_2_train_model
243
+ step_2_train_model.train_on_kaggle_notebook()
244
+
245
+ """
246
+ args = get_args()
247
+
248
+ world_size = torch.cuda.device_count()
249
+ print("world_size: {}".format(world_size))
250
+
251
+ mp.spawn(train_model,
252
+ args=(world_size, args),
253
+ nprocs=world_size,
254
+ join=True)
255
+
256
+ return
257
+
258
+
259
+ if __name__ == '__main__':
260
+ train_on_cpu()
examples/{exercises/chinese_porn_novel → chinese_porn_novel}/1.prepare_data.py RENAMED
File without changes
examples/{exercises/chinese_porn_novel → chinese_porn_novel}/2.train_model.py RENAMED
File without changes
examples/{exercises/chinese_porn_novel → chinese_porn_novel}/3.test_model.py RENAMED
File without changes
examples/{exercises/chinese_porn_novel → chinese_porn_novel}/README.md RENAMED
File without changes
examples/{exercises/chinese_porn_novel → chinese_porn_novel}/run.sh RENAMED
File without changes
examples/{exercises/chinese_porn_novel → chinese_porn_novel}/stop.sh RENAMED
File without changes
examples/lib_service_4chan/step_1_prepare_data.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+ import argparse
4
+ import os
5
+ import platform
6
+
7
+ from datasets import load_dataset
8
+
9
+ from project_settings import project_path
10
+
11
+
12
+ def get_args():
13
+ parser = argparse.ArgumentParser()
14
+ parser.add_argument("--dataset_path", default="qgyd2021/lip_service_4chan", type=str)
15
+ parser.add_argument("--dataset_name", default="moss_003_sft_data_10", type=str)
16
+ parser.add_argument("--dataset_split", default=None, type=str)
17
+ parser.add_argument(
18
+ "--dataset_cache_dir",
19
+ default=(project_path / "hub_datasets").as_posix(),
20
+ type=str
21
+ )
22
+ parser.add_argument("--dataset_streaming", default=False, type=bool)
23
+ parser.add_argument(
24
+ "--num_workers",
25
+ default=None if platform.system() == "Windows" else os.cpu_count() // 2,
26
+ type=str
27
+ )
28
+
29
+ args = parser.parse_args()
30
+ return args
31
+
32
+
33
+ def main():
34
+ args = get_args()
35
+
36
+ dataset_dict = load_dataset(
37
+ path=args.dataset_path,
38
+ name=args.dataset_name,
39
+ split=args.dataset_split,
40
+ cache_dir=args.dataset_cache_dir,
41
+ num_proc=args.num_workers if not args.dataset_streaming else None,
42
+ streaming=args.dataset_streaming,
43
+ )
44
+ print(dataset_dict)
45
+
46
+ dataset = dataset_dict["train"]
47
+
48
+ if args.dataset_streaming:
49
+ valid_dataset = dataset.take(args.valid_dataset_size)
50
+ train_dataset = dataset.skip(args.valid_dataset_size)
51
+ train_dataset = train_dataset.shuffle(buffer_size=args.shuffle_buffer_size, seed=None)
52
+ else:
53
+ dataset = dataset.train_test_split(test_size=10000, seed=None)
54
+ train_dataset = dataset["train"]
55
+ valid_dataset = dataset["test"]
56
+
57
+ print(train_dataset)
58
+ print(valid_dataset)
59
+ return
60
+
61
+
62
+ if __name__ == '__main__':
63
+ main()
examples/lib_service_4chan/step_2_train_model.py ADDED
@@ -0,0 +1,263 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+ from dataclasses import dataclass, field
4
+ import os
5
+ from pathlib import Path
6
+ import platform
7
+ import re
8
+ from typing import Dict, List, Optional, Union
9
+
10
+ if platform.system() == "Windows":
11
+ from project_settings import project_path
12
+ else:
13
+ project_path = os.path.abspath("./")
14
+ project_path = Path(project_path)
15
+
16
+ hf_hub_cache = (project_path / "cache/huggingface/hub").as_posix()
17
+
18
+ os.environ["HUGGINGFACE_HUB_CACHE"] = hf_hub_cache
19
+
20
+ from datasets import load_dataset
21
+ import huggingface_hub
22
+ import torch
23
+ import torch.multiprocessing as mp
24
+ from transformers import HfArgumentParser
25
+ from transformers.data.data_collator import DataCollatorForLanguageModeling
26
+ from transformers.models.auto import AutoModelForCausalLM, AutoTokenizer
27
+ from transformers.models.gpt2.modeling_gpt2 import GPT2LMHeadModel
28
+ from transformers.trainer import Trainer
29
+ from transformers.trainer_callback import EarlyStoppingCallback
30
+ from transformers.training_args import TrainingArguments
31
+
32
+
33
+ @dataclass
34
+ class ScriptArguments:
35
+ # dataset
36
+ dataset_path: str = field(default="qgyd2021/lip_service_4chan")
37
+ dataset_name: str = field(default=None)
38
+ dataset_split: str = field(default=None)
39
+ dataset_cache_dir: str = field(default=(project_path / "hub_datasets").as_posix())
40
+ dataset_streaming: bool = field(default=False)
41
+ num_workers: int = field(default=None if platform.system() == "Windows" else os.cpu_count() // 2)
42
+
43
+ # model
44
+ pretrained_model_name_or_path: str = field(
45
+ default="uer/gpt2-chinese-cluecorpussmall"
46
+ )
47
+ # pretrained_model_name_or_path: str = field(
48
+ # default=(project_path / "pretrained_models/gpt2-chinese-cluecorpussmall").as_posix()
49
+ # )
50
+
51
+ hf_token: str = field(default=None)
52
+
53
+
54
+ def get_args():
55
+ parser = HfArgumentParser(ScriptArguments)
56
+ args = parser.parse_args_into_dataclasses(return_remaining_strings=True)[0]
57
+ return args
58
+
59
+
60
+ def train_model(local_rank, world_size, args):
61
+ os.environ["RANK"] = f"{local_rank}"
62
+ os.environ["LOCAL_RANK"] = f"{local_rank}"
63
+ os.environ["WORLD_SIZE"] = f"{world_size}"
64
+ os.environ["MASTER_ADDR"] = "localhost"
65
+ os.environ["MASTER_PORT"] = "12355"
66
+
67
+ huggingface_hub.login(token=args.hf_token)
68
+
69
+ # dataset
70
+ dataset_dict = load_dataset(
71
+ path=args.dataset_path,
72
+ name=args.dataset_name,
73
+ split=args.dataset_split,
74
+ cache_dir=args.dataset_cache_dir,
75
+ # num_proc=args.num_workers if not args.dataset_streaming else None,
76
+ streaming=args.dataset_streaming,
77
+ )
78
+ print(dataset_dict)
79
+
80
+ dataset = dataset_dict["train"]
81
+
82
+ if args.dataset_streaming:
83
+ valid_dataset = dataset.take(args.valid_dataset_size)
84
+ train_dataset = dataset.skip(args.valid_dataset_size)
85
+ train_dataset = train_dataset.shuffle(buffer_size=args.shuffle_buffer_size, seed=None)
86
+ else:
87
+ dataset = dataset.train_test_split(test_size=4000, seed=None)
88
+ train_dataset = dataset["train"]
89
+ valid_dataset = dataset["test"]
90
+
91
+ # pretrained model
92
+ model: GPT2LMHeadModel = AutoModelForCausalLM.from_pretrained(args.pretrained_model_name_or_path)
93
+ tokenizer = AutoTokenizer.from_pretrained(args.pretrained_model_name_or_path)
94
+
95
+ # map
96
+ def encode(examples: dict):
97
+ questions_ = examples.pop("question")
98
+ answers_ = examples.pop("answer")
99
+
100
+ utterances = list()
101
+ for question, answer in zip(questions_, answers_):
102
+ if not isinstance(question, str):
103
+ continue
104
+ if not isinstance(answer, str):
105
+ continue
106
+ utterance = question + tokenizer.sep_token + answer
107
+ utterances.append(utterance)
108
+
109
+ utterances = tokenizer.__call__(
110
+ text=utterances,
111
+ truncation=True,
112
+ padding="longest",
113
+ max_length=512,
114
+ return_special_tokens_mask=True,
115
+ )
116
+ return utterances
117
+
118
+ train_dataset = train_dataset.map(
119
+ encode,
120
+ batched=True,
121
+ drop_last_batch=True,
122
+ batch_size=10,
123
+ num_proc=None,
124
+ cache_file_name="train.cache"
125
+ )
126
+ valid_dataset = valid_dataset.map(
127
+ encode,
128
+ batched=True,
129
+ drop_last_batch=True,
130
+ batch_size=10,
131
+ num_proc=None,
132
+ cache_file_name="valid.cache"
133
+ )
134
+ dataset_info = f"""
135
+ train dataset: {len(train_dataset)}
136
+ valid dataset: {len(valid_dataset)}
137
+ """
138
+ dataset_info = re.sub(r"[\u0020]{4,}", "", dataset_info)
139
+ print(dataset_info)
140
+
141
+ # for k, v in model.named_parameters():
142
+ # if k.__contains__(".bias"):
143
+ # v.requires_grad = True
144
+ # else:
145
+ # v.requires_grad = False
146
+
147
+ # for k, v in model.named_parameters():
148
+ # if v.requires_grad is True:
149
+ # print(k)
150
+
151
+ data_collator = DataCollatorForLanguageModeling(
152
+ tokenizer=tokenizer, mlm=False
153
+ )
154
+
155
+ # training_args
156
+ training_args = TrainingArguments(
157
+ output_dir="output_dir",
158
+ evaluation_strategy="steps",
159
+ per_device_train_batch_size=8,
160
+ gradient_accumulation_steps=4,
161
+ learning_rate=2e-4,
162
+ weight_decay=0,
163
+ max_grad_norm=1.0,
164
+ num_train_epochs=1.0,
165
+ warmup_steps=1000,
166
+ logging_steps=100,
167
+ save_strategy="steps",
168
+ save_steps=100,
169
+ save_total_limit=2,
170
+ no_cuda=False,
171
+ fp16=True if torch.cuda.is_available() else False,
172
+ local_rank=local_rank,
173
+ ddp_backend="nccl",
174
+ remove_unused_columns=True,
175
+ load_best_model_at_end=True,
176
+ metric_for_best_model="loss",
177
+ greater_is_better=False,
178
+ report_to="tensorboard",
179
+ push_to_hub=True,
180
+ hub_model_id="lib_service_4chan",
181
+ hub_strategy="every_save",
182
+ gradient_checkpointing=True,
183
+ )
184
+
185
+ partial_state_str = f"""
186
+ distributed_type: {training_args.distributed_state.distributed_type}
187
+ local_process_index: {training_args.distributed_state.local_process_index}
188
+ num_processes: {training_args.distributed_state.num_processes}
189
+ process_index: {training_args.distributed_state.process_index}
190
+ device: {training_args.distributed_state.device}
191
+ """
192
+ partial_state_str = re.sub(r"[\u0020]{4,}", "", partial_state_str)
193
+ print(partial_state_str)
194
+
195
+ environ = f"""
196
+ RANK: {os.environ.get("RANK", -1)}
197
+ WORLD_SIZE: {os.environ.get("WORLD_SIZE", -1)}
198
+ LOCAL_RANK: {os.environ.get("LOCAL_RANK", -1)}
199
+ """
200
+ environ = re.sub(r"[\u0020]{4,}", "", environ)
201
+ print(environ)
202
+
203
+ callbacks = [
204
+ EarlyStoppingCallback(early_stopping_patience=5)
205
+ ]
206
+
207
+ trainer = Trainer(
208
+ model=model,
209
+ args=training_args,
210
+ data_collator=data_collator,
211
+ train_dataset=train_dataset,
212
+ eval_dataset=valid_dataset,
213
+ tokenizer=tokenizer,
214
+ callbacks=callbacks
215
+ )
216
+ train_result = trainer.train()
217
+
218
+ # 保存最好的 checkpoint
219
+ final_save_path = os.path.join(training_args.output_dir, "final")
220
+ trainer.save_model(final_save_path) # Saves the tokenizer too
221
+ # 保存训练指标
222
+ metrics = train_result.metrics
223
+ trainer.log_metrics("train", metrics)
224
+ trainer.save_metrics("train", metrics)
225
+ trainer.save_state()
226
+
227
+ tokenizer.save_pretrained(final_save_path)
228
+ return
229
+
230
+
231
+ def train_on_cpu():
232
+ args = get_args()
233
+
234
+ train_model(0, 1, args)
235
+ return
236
+
237
+
238
+ def train_on_kaggle_notebook():
239
+ """
240
+ train on kaggle notebook with GPU T4 x2
241
+
242
+ from shutil import copyfile
243
+ copyfile(src = "../input/tempdataset/step_2_train_model.py", dst = "../working/step_2_train_model.py")
244
+
245
+ import step_2_train_model
246
+ step_2_train_model.train_on_kaggle_notebook()
247
+
248
+ """
249
+ args = get_args()
250
+
251
+ world_size = torch.cuda.device_count()
252
+ print("world_size: {}".format(world_size))
253
+
254
+ mp.spawn(train_model,
255
+ args=(world_size, args),
256
+ nprocs=world_size,
257
+ join=True)
258
+
259
+ return
260
+
261
+
262
+ if __name__ == '__main__':
263
+ train_on_cpu()