tomaarsen HF staff commited on
Commit
f1d3ae0
·
1 Parent(s): 13f4da1

Upload train.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. train.py +174 -0
train.py ADDED
@@ -0,0 +1,174 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from datasets import load_dataset
2
+ from transformers import TrainingArguments
3
+
4
+ from span_marker import SpanMarkerModel, Trainer
5
+
6
+
7
+ def main() -> None:
8
+ # Load the dataset, ensure "tokens" and "ner_tags" columns, and get a list of labels
9
+ dataset = "Babelscape/multinerd"
10
+ train_dataset = load_dataset(dataset, split="train")
11
+ eval_dataset = load_dataset(dataset, split="validation").shuffle().select(range(3000))
12
+ labels = [
13
+ "O",
14
+ "B-PER",
15
+ "I-PER",
16
+ "B-ORG",
17
+ "I-ORG",
18
+ "B-LOC",
19
+ "I-LOC",
20
+ "B-ANIM",
21
+ "I-ANIM",
22
+ "B-BIO",
23
+ "I-BIO",
24
+ "B-CEL",
25
+ "I-CEL",
26
+ "B-DIS",
27
+ "I-DIS",
28
+ "B-EVE",
29
+ "I-EVE",
30
+ "B-FOOD",
31
+ "I-FOOD",
32
+ "B-INST",
33
+ "I-INST",
34
+ "B-MEDIA",
35
+ "I-MEDIA",
36
+ "B-MYTH",
37
+ "I-MYTH",
38
+ "B-PLANT",
39
+ "I-PLANT",
40
+ "B-TIME",
41
+ "I-TIME",
42
+ "B-VEHI",
43
+ "I-VEHI",
44
+ ]
45
+
46
+ # Initialize a SpanMarker model using a pretrained BERT-style encoder
47
+ model_name = "xlm-roberta-base"
48
+ model = SpanMarkerModel.from_pretrained(
49
+ model_name,
50
+ labels=labels,
51
+ # SpanMarker hyperparameters:
52
+ model_max_length=256,
53
+ marker_max_length=128,
54
+ entity_max_length=6,
55
+ )
56
+
57
+ # Prepare the 🤗 transformers training arguments
58
+ args = TrainingArguments(
59
+ output_dir="models/span_marker_xlm_roberta_base_multinerd",
60
+ # Training Hyperparameters:
61
+ learning_rate=1e-5,
62
+ per_device_train_batch_size=32,
63
+ per_device_eval_batch_size=32,
64
+ # gradient_accumulation_steps=2,
65
+ num_train_epochs=1,
66
+ weight_decay=0.01,
67
+ warmup_ratio=0.1,
68
+ bf16=True, # Replace `bf16` with `fp16` if your hardware can't use bf16.
69
+ # Other Training parameters
70
+ logging_first_step=True,
71
+ logging_steps=50,
72
+ evaluation_strategy="steps",
73
+ save_strategy="steps",
74
+ eval_steps=1000,
75
+ save_total_limit=2,
76
+ dataloader_num_workers=2,
77
+ )
78
+
79
+ # Initialize the trainer using our model, training args & dataset, and train
80
+ trainer = Trainer(
81
+ model=model,
82
+ args=args,
83
+ train_dataset=train_dataset,
84
+ eval_dataset=eval_dataset,
85
+ )
86
+ trainer.train()
87
+ trainer.save_model("models/span_marker_xlm_roberta_base_multinerd/checkpoint-final")
88
+
89
+ test_dataset = load_dataset(dataset, split="test")
90
+ # Compute & save the metrics on the test set
91
+ metrics = trainer.evaluate(test_dataset, metric_key_prefix="test")
92
+ trainer.save_metrics("test", metrics)
93
+
94
+
95
+ if __name__ == "__main__":
96
+ main()
97
+
98
+ """
99
+ This SpanMarker model will ignore 2.239322% of all annotated entities in the train dataset. This is caused by the SpanMarkerModel maximum entity length of 6 words and the maximum model input length of 256 tokens.
100
+ These are the frequencies of the missed entities due to maximum entity length out of 4111958 total entities:
101
+ - 35814 missed entities with 7 words (0.870972%)
102
+ - 21246 missed entities with 8 words (0.516688%)
103
+ - 12680 missed entities with 9 words (0.308369%)
104
+ - 7308 missed entities with 10 words (0.177726%)
105
+ - 4414 missed entities with 11 words (0.107345%)
106
+ - 2474 missed entities with 12 words (0.060166%)
107
+ - 1894 missed entities with 13 words (0.046061%)
108
+ - 1130 missed entities with 14 words (0.027481%)
109
+ - 744 missed entities with 15 words (0.018094%)
110
+ - 582 missed entities with 16 words (0.014154%)
111
+ - 344 missed entities with 17 words (0.008366%)
112
+ - 226 missed entities with 18 words (0.005496%)
113
+ - 84 missed entities with 19 words (0.002043%)
114
+ - 46 missed entities with 20 words (0.001119%)
115
+ - 20 missed entities with 21 words (0.000486%)
116
+ - 20 missed entities with 22 words (0.000486%)
117
+ - 12 missed entities with 23 words (0.000292%)
118
+ - 18 missed entities with 24 words (0.000438%)
119
+ - 2 missed entities with 25 words (0.000049%)
120
+ - 4 missed entities with 26 words (0.000097%)
121
+ - 4 missed entities with 27 words (0.000097%)
122
+ - 2 missed entities with 31 words (0.000049%)
123
+ - 8 missed entities with 32 words (0.000195%)
124
+ - 6 missed entities with 33 words (0.000146%)
125
+ - 2 missed entities with 34 words (0.000049%)
126
+ - 4 missed entities with 36 words (0.000097%)
127
+ - 8 missed entities with 37 words (0.000195%)
128
+ - 2 missed entities with 38 words (0.000049%)
129
+ - 2 missed entities with 41 words (0.000049%)
130
+ - 2 missed entities with 72 words (0.000049%)
131
+ Additionally, a total of 2978 (0.072423%) entities were missed due to the maximum input length.
132
+
133
+ This SpanMarker model won't be able to predict 2.501087% of all annotated entities in the evaluation dataset. This is caused by the SpanMarkerModel maximum entity length of 6 words.
134
+ These are the frequencies of the missed entities due to maximum entity length out of 4598 total entities:
135
+ - 45 missed entities with 7 words (0.978686%)
136
+ - 27 missed entities with 8 words (0.587212%)
137
+ - 21 missed entities with 9 words (0.456720%)
138
+ - 9 missed entities with 10 words (0.195737%)
139
+ - 3 missed entities with 12 words (0.065246%)
140
+ - 4 missed entities with 13 words (0.086994%)
141
+ - 3 missed entities with 14 words (0.065246%)
142
+ - 1 missed entities with 15 words (0.021749%)
143
+ - 1 missed entities with 16 words (0.021749%)
144
+ - 1 missed entities with 20 words (0.021749%)
145
+ """
146
+
147
+ """
148
+ wandb: Run summary:
149
+ wandb: eval/loss 0.00594
150
+ wandb: eval/overall_accuracy 0.98181
151
+ wandb: eval/overall_f1 0.90333
152
+ wandb: eval/overall_precision 0.91259
153
+ wandb: eval/overall_recall 0.89427
154
+ wandb: eval/runtime 21.4308
155
+ wandb: eval/samples_per_second 154.171
156
+ wandb: eval/steps_per_second 4.853
157
+ wandb: test/loss 0.00559
158
+ wandb: test/overall_accuracy 0.98247
159
+ wandb: test/overall_f1 0.91314
160
+ wandb: test/overall_precision 0.91994
161
+ wandb: test/overall_recall 0.90643
162
+ wandb: test/runtime 2202.6894
163
+ wandb: test/samples_per_second 169.652
164
+ wandb: test/steps_per_second 5.302
165
+ wandb: train/epoch 1.0
166
+ wandb: train/global_step 93223
167
+ wandb: train/learning_rate 0.0
168
+ wandb: train/loss 0.0049
169
+ wandb: train/total_flos 7.851073325660897e+17
170
+ wandb: train/train_loss 0.01782
171
+ wandb: train/train_runtime 41756.9748
172
+ wandb: train/train_samples_per_second 71.44
173
+ wandb: train/train_steps_per_second 2.233
174
+ """