versae commited on
Commit
75469bd
0 Parent(s):

Training dump

Browse files
Files changed (49) hide show
  1. .gitattributes +19 -0
  2. .gitignore +4 -0
  3. configs/base/config.json +25 -0
  4. configs/base/tokenizer.json +0 -0
  5. configs/large/config.json +25 -0
  6. configs/large/tokenizer.json +0 -0
  7. mc4/README.md +525 -0
  8. mc4/dummy/af/0.0.0/dummy_data.zip +0 -0
  9. mc4/mc4.py +426 -0
  10. mc4/mc4.py.lock +0 -0
  11. outputs/checkpoints/checkpoint-140001/config.json +25 -0
  12. outputs/checkpoints/checkpoint-140001/data_collator.joblib +3 -0
  13. outputs/checkpoints/checkpoint-140001/flax_model.msgpack +3 -0
  14. outputs/checkpoints/checkpoint-140001/optimizer_state.msgpack +3 -0
  15. outputs/checkpoints/checkpoint-140001/training_args.joblib +3 -0
  16. outputs/checkpoints/checkpoint-140001/training_state.json +1 -0
  17. outputs/checkpoints/checkpoint-150001/config.json +25 -0
  18. outputs/checkpoints/checkpoint-150001/data_collator.joblib +3 -0
  19. outputs/checkpoints/checkpoint-150001/flax_model.msgpack +3 -0
  20. outputs/checkpoints/checkpoint-150001/optimizer_state.msgpack +3 -0
  21. outputs/checkpoints/checkpoint-150001/training_args.joblib +3 -0
  22. outputs/checkpoints/checkpoint-150001/training_state.json +1 -0
  23. outputs/checkpoints/checkpoint-160001/config.json +25 -0
  24. outputs/checkpoints/checkpoint-160001/data_collator.joblib +3 -0
  25. outputs/checkpoints/checkpoint-160001/flax_model.msgpack +3 -0
  26. outputs/checkpoints/checkpoint-160001/optimizer_state.msgpack +3 -0
  27. outputs/checkpoints/checkpoint-160001/training_args.joblib +3 -0
  28. outputs/checkpoints/checkpoint-160001/training_state.json +1 -0
  29. outputs/checkpoints/checkpoint-170001/config.json +25 -0
  30. outputs/checkpoints/checkpoint-170001/data_collator.joblib +3 -0
  31. outputs/checkpoints/checkpoint-170001/flax_model.msgpack +3 -0
  32. outputs/checkpoints/checkpoint-170001/optimizer_state.msgpack +3 -0
  33. outputs/checkpoints/checkpoint-170001/training_args.joblib +3 -0
  34. outputs/checkpoints/checkpoint-170001/training_state.json +1 -0
  35. outputs/checkpoints/checkpoint-180001/config.json +25 -0
  36. outputs/checkpoints/checkpoint-180001/data_collator.joblib +3 -0
  37. outputs/checkpoints/checkpoint-180001/flax_model.msgpack +3 -0
  38. outputs/checkpoints/checkpoint-180001/optimizer_state.msgpack +3 -0
  39. outputs/checkpoints/checkpoint-180001/training_args.joblib +3 -0
  40. outputs/checkpoints/checkpoint-180001/training_state.json +1 -0
  41. outputs/config.json +25 -0
  42. outputs/data_collator.joblib +3 -0
  43. outputs/events.out.tfevents.1626172316.underestimate.4022703.3.v2 +3 -0
  44. outputs/flax_model.msgpack +3 -0
  45. outputs/optimizer_state.msgpack +3 -0
  46. outputs/training_args.joblib +3 -0
  47. outputs/training_state.json +1 -0
  48. run_mlm_flax_stream.py +722 -0
  49. run_stream.sh +27 -0
.gitattributes ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.bin.* filter=lfs diff=lfs merge=lfs -text
2
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.h5 filter=lfs diff=lfs merge=lfs -text
5
+ *.tflite filter=lfs diff=lfs merge=lfs -text
6
+ *.tar.gz filter=lfs diff=lfs merge=lfs -text
7
+ *.ot filter=lfs diff=lfs merge=lfs -text
8
+ *.onnx filter=lfs diff=lfs merge=lfs -text
9
+ *.arrow filter=lfs diff=lfs merge=lfs -text
10
+ *.ftz filter=lfs diff=lfs merge=lfs -text
11
+ *.joblib filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.pb filter=lfs diff=lfs merge=lfs -text
15
+ *.pt filter=lfs diff=lfs merge=lfs -text
16
+ *.pth filter=lfs diff=lfs merge=lfs -text
17
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
18
+ *.log filter=lfs diff=lfs merge=lfs -text
19
+ *.wandb filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ run*.log
2
+ debug*.log
3
+ run*.wandb
4
+ wandb/
configs/base/config.json ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "RobertaForMaskedLM"
4
+ ],
5
+ "attention_probs_dropout_prob": 0.1,
6
+ "bos_token_id": 0,
7
+ "eos_token_id": 2,
8
+ "gradient_checkpointing": false,
9
+ "hidden_act": "gelu",
10
+ "hidden_dropout_prob": 0.1,
11
+ "hidden_size": 768,
12
+ "initializer_range": 0.02,
13
+ "intermediate_size": 3072,
14
+ "layer_norm_eps": 1e-05,
15
+ "max_position_embeddings": 514,
16
+ "model_type": "roberta",
17
+ "num_attention_heads": 12,
18
+ "num_hidden_layers": 12,
19
+ "pad_token_id": 1,
20
+ "position_embedding_type": "absolute",
21
+ "transformers_version": "4.9.0.dev0",
22
+ "type_vocab_size": 1,
23
+ "use_cache": true,
24
+ "vocab_size": 50265
25
+ }
configs/base/tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
configs/large/config.json ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "RobertaForMaskedLM"
4
+ ],
5
+ "attention_probs_dropout_prob": 0.1,
6
+ "bos_token_id": 0,
7
+ "eos_token_id": 2,
8
+ "gradient_checkpointing": false,
9
+ "hidden_act": "gelu",
10
+ "hidden_dropout_prob": 0.1,
11
+ "hidden_size": 1024,
12
+ "initializer_range": 0.02,
13
+ "intermediate_size": 4096,
14
+ "layer_norm_eps": 1e-05,
15
+ "max_position_embeddings": 514,
16
+ "model_type": "roberta",
17
+ "num_attention_heads": 16,
18
+ "num_hidden_layers": 24,
19
+ "pad_token_id": 1,
20
+ "position_embedding_type": "absolute",
21
+ "transformers_version": "4.9.0.dev0",
22
+ "type_vocab_size": 1,
23
+ "use_cache": true,
24
+ "vocab_size": 50265
25
+ }
configs/large/tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
mc4/README.md ADDED
@@ -0,0 +1,525 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ pretty_name: mC4
3
+ annotations_creators:
4
+ - no-annotation
5
+ language_creators:
6
+ - found
7
+ languages:
8
+ - af
9
+ - am
10
+ - ar
11
+ - az
12
+ - be
13
+ - bg
14
+ - bg-Latn
15
+ - bn
16
+ - ca
17
+ - ceb
18
+ - co
19
+ - cs
20
+ - cy
21
+ - da
22
+ - de
23
+ - el
24
+ - el-Latn
25
+ - en
26
+ - eo
27
+ - es
28
+ - et
29
+ - eu
30
+ - fa
31
+ - fi
32
+ - fil
33
+ - fr
34
+ - fy
35
+ - ga
36
+ - gd
37
+ - gl
38
+ - gu
39
+ - ha
40
+ - haw
41
+ - hi
42
+ - hi-Latn
43
+ - hmn
44
+ - ht
45
+ - hu
46
+ - hy
47
+ - id
48
+ - ig
49
+ - is
50
+ - it
51
+ - iw
52
+ - ja
53
+ - ja-Latn
54
+ - jv
55
+ - ka
56
+ - kk
57
+ - km
58
+ - kn
59
+ - ko
60
+ - ku
61
+ - ky
62
+ - la
63
+ - lb
64
+ - lo
65
+ - lt
66
+ - lv
67
+ - mg
68
+ - mi
69
+ - mk
70
+ - ml
71
+ - mn
72
+ - mr
73
+ - ms
74
+ - mt
75
+ - my
76
+ - ne
77
+ - nl
78
+ - "no"
79
+ - ny
80
+ - pa
81
+ - pl
82
+ - ps
83
+ - pt
84
+ - ro
85
+ - ru
86
+ - ru-Latn
87
+ - sd
88
+ - si
89
+ - sk
90
+ - sl
91
+ - sm
92
+ - sn
93
+ - so
94
+ - sq
95
+ - sr
96
+ - st
97
+ - su
98
+ - sv
99
+ - sw
100
+ - ta
101
+ - te
102
+ - tg
103
+ - th
104
+ - tr
105
+ - uk
106
+ - und
107
+ - ur
108
+ - uz
109
+ - vi
110
+ - xh
111
+ - yi
112
+ - yo
113
+ - zh
114
+ - zh-Latn
115
+ - zu
116
+ licenses:
117
+ - odc-by-1.0
118
+ multilinguality:
119
+ - multilingual
120
+ size_categories:
121
+ - n<1K
122
+ - 1K<n<10K
123
+ - 10K<n<100K
124
+ - 100K<n<1M
125
+ - 1M<n<10M
126
+ - 10M<n<100M
127
+ - 100M<n<1B
128
+ - 1B<n<10B
129
+ source_datasets:
130
+ - original
131
+ task_categories:
132
+ - sequence-modeling
133
+ task_ids:
134
+ - language-modeling
135
+ paperswithcode_id: mc4
136
+ ---
137
+
138
+ # Dataset Card for mC4
139
+
140
+ ## Table of Contents
141
+
142
+ - [Dataset Card for mC4](#dataset-card-for-mc4)
143
+ - [Table of Contents](#table-of-contents)
144
+ - [Dataset Description](#dataset-description)
145
+ - [Dataset Summary](#dataset-summary)
146
+ - [Supported Tasks and Leaderboards](#supported-tasks-and-leaderboards)
147
+ - [Languages](#languages)
148
+ - [Dataset Structure](#dataset-structure)
149
+ - [Data Instances](#data-instances)
150
+ - [Data Fields](#data-fields)
151
+ - [Data Splits](#data-splits)
152
+ - [Dataset Creation](#dataset-creation)
153
+ - [Curation Rationale](#curation-rationale)
154
+ - [Source Data](#source-data)
155
+ - [Initial Data Collection and Normalization](#initial-data-collection-and-normalization)
156
+ - [Who are the source language producers?](#who-are-the-source-language-producers)
157
+ - [Annotations](#annotations)
158
+ - [Annotation process](#annotation-process)
159
+ - [Who are the annotators?](#who-are-the-annotators)
160
+ - [Personal and Sensitive Information](#personal-and-sensitive-information)
161
+ - [Considerations for Using the Data](#considerations-for-using-the-data)
162
+ - [Social Impact of Dataset](#social-impact-of-dataset)
163
+ - [Discussion of Biases](#discussion-of-biases)
164
+ - [Other Known Limitations](#other-known-limitations)
165
+ - [Additional Information](#additional-information)
166
+ - [Dataset Curators](#dataset-curators)
167
+ - [Licensing Information](#licensing-information)
168
+ - [Citation Information](#citation-information)
169
+ - [Contributions](#contributions)
170
+
171
+ ## Dataset Description
172
+
173
+ - **Homepage:** https://huggingface.co/datasets/allenai/c4
174
+ - **Paper:** https://arxiv.org/abs/1910.10683
175
+
176
+ ### Dataset Summary
177
+
178
+ A multilingual colossal, cleaned version of Common Crawl's web crawl corpus. Based on Common Crawl dataset: "https://commoncrawl.org".
179
+
180
+ This is the version prepared by AllenAI, hosted at this address: https://huggingface.co/datasets/allenai/c4
181
+
182
+ 108 languages are available and are reported in the table below.
183
+
184
+ Note that the languages that end with "-Latn" are simply romanized variants, i.e. written using the Latin script.
185
+
186
+ | language code | language name |
187
+ |:----------------|:---------------------|
188
+ | af | Afrikaans |
189
+ | am | Amharic |
190
+ | ar | Arabic |
191
+ | az | Azerbaijani |
192
+ | be | Belarusian |
193
+ | bg | Bulgarian |
194
+ | bg-Latn | Bulgarian (Latin) |
195
+ | bn | Bangla |
196
+ | ca | Catalan |
197
+ | ceb | Cebuano |
198
+ | co | Corsican |
199
+ | cs | Czech |
200
+ | cy | Welsh |
201
+ | da | Danish |
202
+ | de | German |
203
+ | el | Greek |
204
+ | el-Latn | Greek (Latin) |
205
+ | en | English |
206
+ | eo | Esperanto |
207
+ | es | Spanish |
208
+ | et | Estonian |
209
+ | eu | Basque |
210
+ | fa | Persian |
211
+ | fi | Finnish |
212
+ | fil | Filipino |
213
+ | fr | French |
214
+ | fy | Western Frisian |
215
+ | ga | Irish |
216
+ | gd | Scottish Gaelic |
217
+ | gl | Galician |
218
+ | gu | Gujarati |
219
+ | ha | Hausa |
220
+ | haw | Hawaiian |
221
+ | hi | Hindi |
222
+ | hi-Latn | Hindi (Latin script) |
223
+ | hmn | Hmong, Mong |
224
+ | ht | Haitian |
225
+ | hu | Hungarian |
226
+ | hy | Armenian |
227
+ | id | Indonesian |
228
+ | ig | Igbo |
229
+ | is | Icelandic |
230
+ | it | Italian |
231
+ | iw | former Hebrew |
232
+ | ja | Japanese |
233
+ | ja-Latn | Japanese (Latin) |
234
+ | jv | Javanese |
235
+ | ka | Georgian |
236
+ | kk | Kazakh |
237
+ | km | Khmer |
238
+ | kn | Kannada |
239
+ | ko | Korean |
240
+ | ku | Kurdish |
241
+ | ky | Kyrgyz |
242
+ | la | Latin |
243
+ | lb | Luxembourgish |
244
+ | lo | Lao |
245
+ | lt | Lithuanian |
246
+ | lv | Latvian |
247
+ | mg | Malagasy |
248
+ | mi | Maori |
249
+ | mk | Macedonian |
250
+ | ml | Malayalam |
251
+ | mn | Mongolian |
252
+ | mr | Marathi |
253
+ | ms | Malay |
254
+ | mt | Maltese |
255
+ | my | Burmese |
256
+ | ne | Nepali |
257
+ | nl | Dutch |
258
+ | no | Norwegian |
259
+ | ny | Nyanja |
260
+ | pa | Punjabi |
261
+ | pl | Polish |
262
+ | ps | Pashto |
263
+ | pt | Portuguese |
264
+ | ro | Romanian |
265
+ | ru | Russian |
266
+ | ru-Latn | Russian (Latin) |
267
+ | sd | Sindhi |
268
+ | si | Sinhala |
269
+ | sk | Slovak |
270
+ | sl | Slovenian |
271
+ | sm | San Marino |
272
+ | sn | Shona |
273
+ | so | Somali |
274
+ | sq | Albanian |
275
+ | sr | Serbian |
276
+ | st | Southern Sotho |
277
+ | su | Sundanese |
278
+ | sv | Swedish |
279
+ | sw | Swahili |
280
+ | ta | Tamil |
281
+ | te | Telugu |
282
+ | tg | Tajik |
283
+ | th | Thai |
284
+ | tr | Turkish |
285
+ | uk | Ukrainian |
286
+ | und | Unknown language |
287
+ | ur | Urdu |
288
+ | uz | Uzbek |
289
+ | vi | Vietnamese |
290
+ | xh | Xhosa |
291
+ | yi | Yiddish |
292
+ | yo | Yoruba |
293
+ | zh | Chinese |
294
+ | zh-Latn | Chinese (Latin) |
295
+ | zu | Zulu |
296
+
297
+ You can load the mC4 subset of any language like this:
298
+
299
+ ```python
300
+ from datasets import load_dataset
301
+
302
+ en_mc4 = load_dataset("mc4", "en")
303
+ ```
304
+
305
+ And if you can even specify a list of languages:
306
+
307
+ ```python
308
+ from datasets import load_dataset
309
+
310
+ mc4_subset_with_five_languages = load_dataset("mc4", languages=["en", "fr", "es", "de", "zh"])
311
+ ```
312
+
313
+ ### Supported Tasks and Leaderboards
314
+
315
+ mC4 is mainly intended to pretrain language models and word representations.
316
+
317
+ ### Languages
318
+
319
+ The dataset supports 108 languages.
320
+
321
+ ## Dataset Structure
322
+
323
+ ### Data Instances
324
+
325
+ An example form the `en` config is:
326
+
327
+ ```
328
+ {'timestamp': '2018-06-24T01:32:39Z',
329
+ 'text': 'Farm Resources in Plumas County\nShow Beginning Farmer Organizations & Professionals (304)\nThere are 304 resources serving Plumas County in the following categories:\nMap of Beginning Farmer Organizations & Professionals serving Plumas County\nVictoria Fisher - Office Manager - Loyalton, CA\nAmy Lynn Rasband - UCCE Plumas-Sierra Administrative Assistant II - Quincy , CA\nShow Farm Income Opportunities Organizations & Professionals (353)\nThere are 353 resources serving Plumas County in the following categories:\nFarm Ranch And Forest Retailers (18)\nMap of Farm Income Opportunities Organizations & Professionals serving Plumas County\nWarner Valley Wildlife Area - Plumas County\nShow Farm Resources Organizations & Professionals (297)\nThere are 297 resources serving Plumas County in the following categories:\nMap of Farm Resources Organizations & Professionals serving Plumas County\nThere are 57 resources serving Plumas County in the following categories:\nMap of Organic Certification Organizations & Professionals serving Plumas County',
330
+ 'url': 'http://www.californialandcan.org/Plumas/Farm-Resources/'}
331
+ ```
332
+
333
+ ### Data Fields
334
+
335
+ The data have several fields:
336
+
337
+ - `url`: url of the source as a string
338
+ - `text`: text content as a string
339
+ - `timestamp`: timestamp as a string
340
+
341
+ ### Data Splits
342
+
343
+ To build mC4, the authors used [CLD3](https://github.com/google/cld3) to identify over 100 languages. The resulting mC4 subsets for each language are reported in this table:
344
+
345
+ | config | train | validation |
346
+ |:---------|:--------|:-------------|
347
+ | af | ? | ? |
348
+ | am | ? | ? |
349
+ | ar | ? | ? |
350
+ | az | ? | ? |
351
+ | be | ? | ? |
352
+ | bg | ? | ? |
353
+ | bg-Latn | ? | ? |
354
+ | bn | ? | ? |
355
+ | ca | ? | ? |
356
+ | ceb | ? | ? |
357
+ | co | ? | ? |
358
+ | cs | ? | ? |
359
+ | cy | ? | ? |
360
+ | da | ? | ? |
361
+ | de | ? | ? |
362
+ | el | ? | ? |
363
+ | el-Latn | ? | ? |
364
+ | en | ? | ? |
365
+ | eo | ? | ? |
366
+ | es | ? | ? |
367
+ | et | ? | ? |
368
+ | eu | ? | ? |
369
+ | fa | ? | ? |
370
+ | fi | ? | ? |
371
+ | fil | ? | ? |
372
+ | fr | ? | ? |
373
+ | fy | ? | ? |
374
+ | ga | ? | ? |
375
+ | gd | ? | ? |
376
+ | gl | ? | ? |
377
+ | gu | ? | ? |
378
+ | ha | ? | ? |
379
+ | haw | ? | ? |
380
+ | hi | ? | ? |
381
+ | hi-Latn | ? | ? |
382
+ | hmn | ? | ? |
383
+ | ht | ? | ? |
384
+ | hu | ? | ? |
385
+ | hy | ? | ? |
386
+ | id | ? | ? |
387
+ | ig | ? | ? |
388
+ | is | ? | ? |
389
+ | it | ? | ? |
390
+ | iw | ? | ? |
391
+ | ja | ? | ? |
392
+ | ja-Latn | ? | ? |
393
+ | jv | ? | ? |
394
+ | ka | ? | ? |
395
+ | kk | ? | ? |
396
+ | km | ? | ? |
397
+ | kn | ? | ? |
398
+ | ko | ? | ? |
399
+ | ku | ? | ? |
400
+ | ky | ? | ? |
401
+ | la | ? | ? |
402
+ | lb | ? | ? |
403
+ | lo | ? | ? |
404
+ | lt | ? | ? |
405
+ | lv | ? | ? |
406
+ | mg | ? | ? |
407
+ | mi | ? | ? |
408
+ | mk | ? | ? |
409
+ | ml | ? | ? |
410
+ | mn | ? | ? |
411
+ | mr | ? | ? |
412
+ | ms | ? | ? |
413
+ | mt | ? | ? |
414
+ | my | ? | ? |
415
+ | ne | ? | ? |
416
+ | nl | ? | ? |
417
+ | no | ? | ? |
418
+ | ny | ? | ? |
419
+ | pa | ? | ? |
420
+ | pl | ? | ? |
421
+ | ps | ? | ? |
422
+ | pt | ? | ? |
423
+ | ro | ? | ? |
424
+ | ru | ? | ? |
425
+ | ru-Latn | ? | ? |
426
+ | sd | ? | ? |
427
+ | si | ? | ? |
428
+ | sk | ? | ? |
429
+ | sl | ? | ? |
430
+ | sm | ? | ? |
431
+ | sn | ? | ? |
432
+ | so | ? | ? |
433
+ | sq | ? | ? |
434
+ | sr | ? | ? |
435
+ | st | ? | ? |
436
+ | su | ? | ? |
437
+ | sv | ? | ? |
438
+ | sw | ? | ? |
439
+ | ta | ? | ? |
440
+ | te | ? | ? |
441
+ | tg | ? | ? |
442
+ | th | ? | ? |
443
+ | tr | ? | ? |
444
+ | uk | ? | ? |
445
+ | und | ? | ? |
446
+ | ur | ? | ? |
447
+ | uz | ? | ? |
448
+ | vi | ? | ? |
449
+ | xh | ? | ? |
450
+ | yi | ? | ? |
451
+ | yo | ? | ? |
452
+ | zh | ? | ? |
453
+ | zh-Latn | ? | ? |
454
+ | zu | ? | ? |
455
+
456
+ ## Dataset Creation
457
+
458
+ ### Curation Rationale
459
+
460
+ [More Information Needed]
461
+
462
+ ### Source Data
463
+
464
+ #### Initial Data Collection and Normalization
465
+
466
+ [More Information Needed]
467
+
468
+ #### Who are the source language producers?
469
+
470
+ [More Information Needed]
471
+
472
+ ### Annotations
473
+
474
+ #### Annotation process
475
+
476
+ [More Information Needed]
477
+
478
+ #### Who are the annotators?
479
+
480
+ [More Information Needed]
481
+
482
+ ### Personal and Sensitive Information
483
+
484
+ [More Information Needed]
485
+
486
+ ## Considerations for Using the Data
487
+
488
+ ### Social Impact of Dataset
489
+
490
+ [More Information Needed]
491
+
492
+ ### Discussion of Biases
493
+
494
+ [More Information Needed]
495
+
496
+ ### Other Known Limitations
497
+
498
+ [More Information Needed]
499
+
500
+ ## Additional Information
501
+
502
+ ### Dataset Curators
503
+
504
+ [More Information Needed]
505
+
506
+ ### Licensing Information
507
+
508
+ AllenAI are releasing this dataset under the terms of ODC-BY. By using this, you are also bound by the Common Crawl terms of use in respect of the content contained in the dataset.
509
+
510
+ ### Citation Information
511
+
512
+ ```
513
+ @article{2019t5,
514
+ author = {Colin Raffel and Noam Shazeer and Adam Roberts and Katherine Lee and Sharan Narang and Michael Matena and Yanqi Zhou and Wei Li and Peter J. Liu},
515
+ title = {Exploring the Limits of Transfer Learning with a Unified Text-to-Text Transformer},
516
+ journal = {arXiv e-prints},
517
+ year = {2019},
518
+ archivePrefix = {arXiv},
519
+ eprint = {1910.10683},
520
+ }
521
+ ```
522
+
523
+ ### Contributions
524
+
525
+ Thanks to [@dirkgr](https://github.com/dirkgr) and [@lhoestq](https://github.com/lhoestq) for adding this dataset.
mc4/dummy/af/0.0.0/dummy_data.zip ADDED
Binary file (8.54 kB). View file
 
mc4/mc4.py ADDED
@@ -0,0 +1,426 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """mC4 dataset based on Common Crawl."""
2
+
3
+
4
+ import gzip
5
+ import json
6
+
7
+ import datasets
8
+ import kenlm
9
+ import numpy as np
10
+ from numpy.random import default_rng
11
+
12
+
13
+ logger = datasets.logging.get_logger(__name__)
14
+
15
+
16
+ _DESCRIPTION = """\
17
+ A colossal, cleaned version of Common Crawl's web crawl corpus.
18
+
19
+ Based on Common Crawl dataset: "https://commoncrawl.org".
20
+
21
+ This is the processed version of Google's mC4 dataset by AllenAI.
22
+ """
23
+
24
+ _CITATION = """
25
+ @article{2019t5,
26
+ author = {Colin Raffel and Noam Shazeer and Adam Roberts and Katherine Lee and Sharan Narang and Michael Matena and Yanqi Zhou and Wei Li and Peter J. Liu},
27
+ title = {Exploring the Limits of Transfer Learning with a Unified Text-to-Text Transformer},
28
+ journal = {arXiv e-prints},
29
+ year = {2019},
30
+ archivePrefix = {arXiv},
31
+ eprint = {1910.10683},
32
+ }
33
+ """
34
+
35
+ _URL = "https://github.com/allenai/allennlp/discussions/5056"
36
+
37
+ _DATA_URL = "https://huggingface.co/datasets/allenai/c4/resolve/1ddc917116b730e1859edef32896ec5c16be51d0/multilingual/c4-{language}{split_suffix}.tfrecord-{index:05d}-of-{n_shards:05d}.json.gz"
38
+
39
+ _LANGUAGES = [
40
+ "af",
41
+ "am",
42
+ "ar",
43
+ "az",
44
+ "be",
45
+ "bg",
46
+ "bg-Latn",
47
+ "bn",
48
+ "ca",
49
+ "ceb",
50
+ "co",
51
+ "cs",
52
+ "cy",
53
+ "da",
54
+ "de",
55
+ "el",
56
+ "el-Latn",
57
+ "en",
58
+ "eo",
59
+ "es",
60
+ "et",
61
+ "eu",
62
+ "fa",
63
+ "fi",
64
+ "fil",
65
+ "fr",
66
+ "fy",
67
+ "ga",
68
+ "gd",
69
+ "gl",
70
+ "gu",
71
+ "ha",
72
+ "haw",
73
+ "hi",
74
+ "hi-Latn",
75
+ "hmn",
76
+ "ht",
77
+ "hu",
78
+ "hy",
79
+ "id",
80
+ "ig",
81
+ "is",
82
+ "it",
83
+ "iw",
84
+ "ja",
85
+ "ja-Latn",
86
+ "jv",
87
+ "ka",
88
+ "kk",
89
+ "km",
90
+ "kn",
91
+ "ko",
92
+ "ku",
93
+ "ky",
94
+ "la",
95
+ "lb",
96
+ "lo",
97
+ "lt",
98
+ "lv",
99
+ "mg",
100
+ "mi",
101
+ "mk",
102
+ "ml",
103
+ "mn",
104
+ "mr",
105
+ "ms",
106
+ "mt",
107
+ "my",
108
+ "ne",
109
+ "nl",
110
+ "no",
111
+ "ny",
112
+ "pa",
113
+ "pl",
114
+ "ps",
115
+ "pt",
116
+ "ro",
117
+ "ru",
118
+ "ru-Latn",
119
+ "sd",
120
+ "si",
121
+ "sk",
122
+ "sl",
123
+ "sm",
124
+ "sn",
125
+ "so",
126
+ "sq",
127
+ "sr",
128
+ "st",
129
+ "su",
130
+ "sv",
131
+ "sw",
132
+ "ta",
133
+ "te",
134
+ "tg",
135
+ "th",
136
+ "tr",
137
+ "uk",
138
+ "und",
139
+ "ur",
140
+ "uz",
141
+ "vi",
142
+ "xh",
143
+ "yi",
144
+ "yo",
145
+ "zh",
146
+ "zh-Latn",
147
+ "zu",
148
+ ]
149
+
150
+ _N_SHARDS_PER_SPLIT = {
151
+ "af": {"train": 64, "validation": 1},
152
+ "am": {"train": 16, "validation": 1},
153
+ "ar": {"train": 1024, "validation": 4},
154
+ "az": {"train": 256, "validation": 1},
155
+ "be": {"train": 128, "validation": 1},
156
+ "bg": {"train": 1024, "validation": 1},
157
+ "bg-Latn": {"train": 4, "validation": 1},
158
+ "bn": {"train": 512, "validation": 1},
159
+ "ca": {"train": 512, "validation": 1},
160
+ "ceb": {"train": 8, "validation": 1},
161
+ "co": {"train": 8, "validation": 1},
162
+ "cs": {"train": 1024, "validation": 2},
163
+ "cy": {"train": 256, "validation": 1},
164
+ "da": {"train": 1024, "validation": 1},
165
+ "de": {"train": 2048, "validation": 16},
166
+ "el": {"train": 1024, "validation": 2},
167
+ "el-Latn": {"train": 16, "validation": 1},
168
+ "en": {"train": 11264, "validation": 128},
169
+ "eo": {"train": 32, "validation": 1},
170
+ "es": {"train": 2048, "validation": 16},
171
+ "et": {"train": 256, "validation": 1},
172
+ "eu": {"train": 64, "validation": 1},
173
+ "fa": {"train": 1024, "validation": 2},
174
+ "fi": {"train": 1024, "validation": 1},
175
+ "fil": {"train": 64, "validation": 1},
176
+ "fr": {"train": 2048, "validation": 16},
177
+ "fy": {"train": 16, "validation": 1},
178
+ "ga": {"train": 16, "validation": 1},
179
+ "gd": {"train": 16, "validation": 1},
180
+ "gl": {"train": 128, "validation": 1},
181
+ "gu": {"train": 64, "validation": 1},
182
+ "ha": {"train": 8, "validation": 1},
183
+ "haw": {"train": 2, "validation": 1},
184
+ "hi": {"train": 1024, "validation": 2},
185
+ "hi-Latn": {"train": 16, "validation": 1},
186
+ "hmn": {"train": 8, "validation": 1},
187
+ "ht": {"train": 8, "validation": 1},
188
+ "hu": {"train": 1024, "validation": 2},
189
+ "hy": {"train": 128, "validation": 1},
190
+ "id": {"train": 1024, "validation": 4},
191
+ "ig": {"train": 4, "validation": 1},
192
+ "is": {"train": 128, "validation": 1},
193
+ "it": {"train": 1024, "validation": 8},
194
+ "iw": {"train": 1024, "validation": 1},
195
+ "ja": {"train": 1024, "validation": 8},
196
+ "ja-Latn": {"train": 8, "validation": 1},
197
+ "jv": {"train": 8, "validation": 1},
198
+ "ka": {"train": 256, "validation": 1},
199
+ "kk": {"train": 256, "validation": 1},
200
+ "km": {"train": 64, "validation": 1},
201
+ "kn": {"train": 64, "validation": 1},
202
+ "ko": {"train": 1024, "validation": 1},
203
+ "ku": {"train": 16, "validation": 1},
204
+ "ky": {"train": 64, "validation": 1},
205
+ "la": {"train": 64, "validation": 1},
206
+ "lb": {"train": 32, "validation": 1},
207
+ "lo": {"train": 8, "validation": 1},
208
+ "lt": {"train": 512, "validation": 1},
209
+ "lv": {"train": 256, "validation": 1},
210
+ "mg": {"train": 8, "validation": 1},
211
+ "mi": {"train": 4, "validation": 1},
212
+ "mk": {"train": 128, "validation": 1},
213
+ "ml": {"train": 128, "validation": 1},
214
+ "mn": {"train": 128, "validation": 1},
215
+ "mr": {"train": 1024, "validation": 1},
216
+ "ms": {"train": 512, "validation": 1},
217
+ "mt": {"train": 128, "validation": 1},
218
+ "my": {"train": 64, "validation": 1},
219
+ "ne": {"train": 256, "validation": 1},
220
+ "nl": {"train": 1024, "validation": 4},
221
+ "no": {"train": 1024, "validation": 1},
222
+ "ny": {"train": 4, "validation": 1},
223
+ "pa": {"train": 32, "validation": 1},
224
+ "pl": {"train": 1024, "validation": 4},
225
+ "ps": {"train": 16, "validation": 1},
226
+ "pt": {"train": 1024, "validation": 4},
227
+ "ro": {"train": 1024, "validation": 2},
228
+ "ru": {"train": 4096, "validation": 32},
229
+ "ru-Latn": {"train": 32, "validation": 1},
230
+ "sd": {"train": 64, "validation": 1},
231
+ "si": {"train": 64, "validation": 1},
232
+ "sk": {"train": 512, "validation": 1},
233
+ "sl": {"train": 256, "validation": 1},
234
+ "sm": {"train": 4, "validation": 1},
235
+ "sn": {"train": 8, "validation": 1},
236
+ "so": {"train": 64, "validation": 1},
237
+ "sq": {"train": 128, "validation": 1},
238
+ "sr": {"train": 256, "validation": 1},
239
+ "st": {"train": 2, "validation": 1},
240
+ "su": {"train": 4, "validation": 1},
241
+ "sv": {"train": 1024, "validation": 2},
242
+ "sw": {"train": 32, "validation": 1},
243
+ "ta": {"train": 256, "validation": 1},
244
+ "te": {"train": 128, "validation": 1},
245
+ "tg": {"train": 64, "validation": 1},
246
+ "th": {"train": 1024, "validation": 1},
247
+ "tr": {"train": 1024, "validation": 4},
248
+ "uk": {"train": 1024, "validation": 2},
249
+ "und": {"train": 3072, "validation": 32},
250
+ "ur": {"train": 128, "validation": 1},
251
+ "uz": {"train": 32, "validation": 1},
252
+ "vi": {"train": 1024, "validation": 4},
253
+ "xh": {"train": 2, "validation": 1},
254
+ "yi": {"train": 16, "validation": 1},
255
+ "yo": {"train": 2, "validation": 1},
256
+ "zh": {"train": 1024, "validation": 2},
257
+ "zh-Latn": {"train": 8, "validation": 1},
258
+ "zu": {"train": 8, "validation": 1},
259
+ }
260
+
261
+
262
+ class Mc4Config(datasets.BuilderConfig):
263
+ """BuilderConfig for mC4."""
264
+
265
+ def __init__(self, *args, languages, **kwargs):
266
+ """BuilderConfig for mC4.
267
+ Args:
268
+ languages (:obj:`List[str]`): list of languages to load
269
+ **kwargs: keyword arguments forwarded to super.
270
+ """
271
+ super().__init__(
272
+ *args,
273
+ name="+".join(languages),
274
+ **kwargs,
275
+ )
276
+ self.languages = languages
277
+
278
+
279
+ class Mc4(datasets.GeneratorBasedBuilder):
280
+ """mC4, a colossal, cleaned version of Common Crawl's web crawl corpus."""
281
+
282
+ BUILDER_CONFIGS = [Mc4Config(languages=[lang]) for lang in _LANGUAGES]
283
+ BUILDER_CONFIG_CLASS = Mc4Config
284
+
285
+ def __init__(self, *args, writer_batch_size=None, **kwargs):
286
+ self.data_files = kwargs.pop("data_files", {})
287
+ self.sampling_method = kwargs.pop("sampling_method", None)
288
+ self.perplexity_model = kwargs.pop("perplexity_model", None)
289
+ self.sampling_factor = kwargs.pop("sampling_factor", None)
290
+ self.boundaries = kwargs.pop("boundaries", None)
291
+ self.seed = kwargs.pop("seed", None)
292
+ if self.sampling_method:
293
+ if self.seed is not None:
294
+ self.rng = default_rng(self.seed)
295
+ else:
296
+ self.rng = default_rng()
297
+ if self.sampling_method == "random":
298
+ self.should_keep_doc = self._should_keep_doc_random
299
+ else:
300
+ # Loading 5-gram model
301
+ # http://dl.fbaipublicfiles.com/cc_net/lm/es.arpa.bin
302
+ logger.info("loading model = %s", self.perplexity_model)
303
+ self.pp_model = kenlm.Model(self.perplexity_model)
304
+ if self.sampling_method == "gaussian":
305
+ self.should_keep_doc = self._should_keep_doc_gaussian
306
+ else:
307
+ self.should_keep_doc = self._should_keep_doc_step
308
+ super().__init__(*args, writer_batch_size=writer_batch_size, **kwargs)
309
+
310
+ def get_perplexity(self, doc):
311
+ doc_log_score, doc_length = 0, 0
312
+ for line in doc.split("\n"):
313
+ log_score = self.pp_model.score(line)
314
+ length = len(line.split()) + 1
315
+ doc_log_score += log_score
316
+ doc_length += length
317
+ return 10.0 ** (-doc_log_score / doc_length)
318
+
319
+ def _should_keep_doc_step(self, doc, factor=1.5e5, boundaries=None):
320
+ perplexity = self.get_perplexity(doc)
321
+ if boundaries is None:
322
+ boundaries = [536394.99320948, 662247.50212365, 919250.87225178]
323
+ if perplexity <= boundaries[0]:
324
+ quartile_range = boundaries[0]
325
+ elif boundaries[0] < perplexity < boundaries[1]:
326
+ quartile_range = boundaries[1] - boundaries[0]
327
+ elif boundaries[1] < perplexity < boundaries[2]:
328
+ quartile_range = boundaries[2] - boundaries[1]
329
+ elif perplexity >= boundaries[2]:
330
+ quartile_range = 10 * boundaries[2]
331
+ probability = factor / quartile_range
332
+ return self.rng.uniform() < probability
333
+
334
+ def _should_keep_doc_gaussian(self, doc, factor=0.78, boundaries=None):
335
+ perplexity = self.get_perplexity(doc)
336
+ if boundaries is not None:
337
+ m = boundaries[1]
338
+ else:
339
+ m = 662247.50212365
340
+ exponential = np.exp(-9/2 * ((perplexity - m) / m) ** 2)
341
+ weighted_perplexity = factor * exponential
342
+ return self.rng.uniform() < weighted_perplexity
343
+
344
+ def _should_keep_doc_random(self, doc, factor=None, boundaries=None):
345
+ if factor is None:
346
+ factor = 0.5
347
+ return self.rng.uniform() <= factor
348
+
349
+ def _info(self):
350
+ return datasets.DatasetInfo(
351
+ description=_DESCRIPTION,
352
+ features=datasets.Features(
353
+ {
354
+ "text": datasets.Value("string"),
355
+ "timestamp": datasets.Value("string"),
356
+ "url": datasets.Value("string"),
357
+ }
358
+ ),
359
+ supervised_keys=None,
360
+ homepage=_URL,
361
+ citation=_CITATION,
362
+ )
363
+
364
+ def _split_generators(self, dl_manager):
365
+ data_urls = {}
366
+ for split in ["train", "validation"]:
367
+ data_urls[split] = [
368
+ _DATA_URL.format(
369
+ language=self.config.name,
370
+ split_suffix="-validation" if split == "validation" else "",
371
+ index=index,
372
+ n_shards=_N_SHARDS_PER_SPLIT[lang][split],
373
+ )
374
+ for lang in self.config.languages
375
+ for index in range(_N_SHARDS_PER_SPLIT[lang][split])
376
+ ]
377
+ if "train" in self.data_files:
378
+ train_downloaded_files = self.data_files["train"]
379
+ if not isinstance(train_downloaded_files, (tuple, list)):
380
+ train_downloaded_files = [train_downloaded_files]
381
+ else:
382
+ train_downloaded_files = dl_manager.download(data_urls["train"])
383
+ if "validation" in self.data_files:
384
+ validation_downloaded_files = self.data_files["validation"]
385
+ if not isinstance(validation_downloaded_files, (tuple, list)):
386
+ validation_downloaded_files = [validation_downloaded_files]
387
+ else:
388
+ validation_downloaded_files = dl_manager.download(data_urls["validation"])
389
+ return [
390
+ datasets.SplitGenerator(name=datasets.Split.TRAIN, gen_kwargs={"filepaths": train_downloaded_files}),
391
+ datasets.SplitGenerator(
392
+ name=datasets.Split.VALIDATION, gen_kwargs={"filepaths": validation_downloaded_files}
393
+ ),
394
+ ]
395
+
396
+ def _generate_examples(self, filepaths):
397
+ """This function returns the examples in the raw (text) form by iterating on all the files."""
398
+ id_ = 0
399
+ for filepath in filepaths:
400
+ logger.info("generating examples from = %s", filepath)
401
+ if filepath.endswith("jsonl"):
402
+ with open(filepath, "r", encoding="utf-8") as f:
403
+ for line in f:
404
+ if line:
405
+ example = json.loads(line)
406
+ yield id_, example
407
+ id_ += 1
408
+ else:
409
+ with gzip.open(open(filepath, "rb"), "rt", encoding="utf-8") as f:
410
+ if self.sampling_method:
411
+ logger.info("sampling method = %s", self.sampling_method)
412
+ for line in f:
413
+ if line:
414
+ example = json.loads(line)
415
+ if self.should_keep_doc(
416
+ example["text"],
417
+ factor=self.sampling_factor,
418
+ boundaries=self.boundaries):
419
+ yield id_, example
420
+ id_ += 1
421
+ else:
422
+ for line in f:
423
+ if line:
424
+ example = json.loads(line)
425
+ yield id_, example
426
+ id_ += 1
mc4/mc4.py.lock ADDED
File without changes
outputs/checkpoints/checkpoint-140001/config.json ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "RobertaForMaskedLM"
4
+ ],
5
+ "attention_probs_dropout_prob": 0.1,
6
+ "bos_token_id": 0,
7
+ "eos_token_id": 2,
8
+ "gradient_checkpointing": false,
9
+ "hidden_act": "gelu",
10
+ "hidden_dropout_prob": 0.1,
11
+ "hidden_size": 768,
12
+ "initializer_range": 0.02,
13
+ "intermediate_size": 3072,
14
+ "layer_norm_eps": 1e-05,
15
+ "max_position_embeddings": 514,
16
+ "model_type": "roberta",
17
+ "num_attention_heads": 12,
18
+ "num_hidden_layers": 12,
19
+ "pad_token_id": 1,
20
+ "position_embedding_type": "absolute",
21
+ "transformers_version": "4.9.0.dev0",
22
+ "type_vocab_size": 1,
23
+ "use_cache": true,
24
+ "vocab_size": 50265
25
+ }
outputs/checkpoints/checkpoint-140001/data_collator.joblib ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e02a6e9cfa63cb321cac9402efd29841b652999fcbf787800ae050e747b161ee
3
+ size 1471394
outputs/checkpoints/checkpoint-140001/flax_model.msgpack ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:bb3b6443b0b4e0fd6b95f7409525ddde51fb73dd99318041f2fecda9f547f5a6
3
+ size 249750019
outputs/checkpoints/checkpoint-140001/optimizer_state.msgpack ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:73ce4d1287008fdfac801ca7df44a0debe3e41f901970f3132f0cd49d2ad6bd0
3
+ size 499500278
outputs/checkpoints/checkpoint-140001/training_args.joblib ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8bc14fe16573d318dd510c7cfb42ebb7cc87b4dcf77e99247a2d1605cffd772b
3
+ size 1876
outputs/checkpoints/checkpoint-140001/training_state.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"step": 140001}
outputs/checkpoints/checkpoint-150001/config.json ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "RobertaForMaskedLM"
4
+ ],
5
+ "attention_probs_dropout_prob": 0.1,
6
+ "bos_token_id": 0,
7
+ "eos_token_id": 2,
8
+ "gradient_checkpointing": false,
9
+ "hidden_act": "gelu",
10
+ "hidden_dropout_prob": 0.1,
11
+ "hidden_size": 768,
12
+ "initializer_range": 0.02,
13
+ "intermediate_size": 3072,
14
+ "layer_norm_eps": 1e-05,
15
+ "max_position_embeddings": 514,
16
+ "model_type": "roberta",
17
+ "num_attention_heads": 12,
18
+ "num_hidden_layers": 12,
19
+ "pad_token_id": 1,
20
+ "position_embedding_type": "absolute",
21
+ "transformers_version": "4.9.0.dev0",
22
+ "type_vocab_size": 1,
23
+ "use_cache": true,
24
+ "vocab_size": 50265
25
+ }
outputs/checkpoints/checkpoint-150001/data_collator.joblib ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e02a6e9cfa63cb321cac9402efd29841b652999fcbf787800ae050e747b161ee
3
+ size 1471394
outputs/checkpoints/checkpoint-150001/flax_model.msgpack ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d9f2a38ac6c111d01809dd28ae9078aab932064126a7de753ce0d88bd60421e4
3
+ size 249750019
outputs/checkpoints/checkpoint-150001/optimizer_state.msgpack ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:84f53f9b574ccfb97696f637d71903b9762ef2718c656bea201e5aeb9078c328
3
+ size 499500278
outputs/checkpoints/checkpoint-150001/training_args.joblib ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8bc14fe16573d318dd510c7cfb42ebb7cc87b4dcf77e99247a2d1605cffd772b
3
+ size 1876
outputs/checkpoints/checkpoint-150001/training_state.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"step": 150001}
outputs/checkpoints/checkpoint-160001/config.json ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "RobertaForMaskedLM"
4
+ ],
5
+ "attention_probs_dropout_prob": 0.1,
6
+ "bos_token_id": 0,
7
+ "eos_token_id": 2,
8
+ "gradient_checkpointing": false,
9
+ "hidden_act": "gelu",
10
+ "hidden_dropout_prob": 0.1,
11
+ "hidden_size": 768,
12
+ "initializer_range": 0.02,
13
+ "intermediate_size": 3072,
14
+ "layer_norm_eps": 1e-05,
15
+ "max_position_embeddings": 514,
16
+ "model_type": "roberta",
17
+ "num_attention_heads": 12,
18
+ "num_hidden_layers": 12,
19
+ "pad_token_id": 1,
20
+ "position_embedding_type": "absolute",
21
+ "transformers_version": "4.9.0.dev0",
22
+ "type_vocab_size": 1,
23
+ "use_cache": true,
24
+ "vocab_size": 50265
25
+ }
outputs/checkpoints/checkpoint-160001/data_collator.joblib ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e02a6e9cfa63cb321cac9402efd29841b652999fcbf787800ae050e747b161ee
3
+ size 1471394
outputs/checkpoints/checkpoint-160001/flax_model.msgpack ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b86d26169d8fb7bb58ae7fecd67ca557a0affc93bf2d5b5947af0070ee894ab9
3
+ size 249750019
outputs/checkpoints/checkpoint-160001/optimizer_state.msgpack ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ea3a8f65ea9c3c6c3606f1167c4e54049784fa8b2a5ee3f4936563ecd4f811b6
3
+ size 499500278
outputs/checkpoints/checkpoint-160001/training_args.joblib ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8bc14fe16573d318dd510c7cfb42ebb7cc87b4dcf77e99247a2d1605cffd772b
3
+ size 1876
outputs/checkpoints/checkpoint-160001/training_state.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"step": 160001}
outputs/checkpoints/checkpoint-170001/config.json ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "RobertaForMaskedLM"
4
+ ],
5
+ "attention_probs_dropout_prob": 0.1,
6
+ "bos_token_id": 0,
7
+ "eos_token_id": 2,
8
+ "gradient_checkpointing": false,
9
+ "hidden_act": "gelu",
10
+ "hidden_dropout_prob": 0.1,
11
+ "hidden_size": 768,
12
+ "initializer_range": 0.02,
13
+ "intermediate_size": 3072,
14
+ "layer_norm_eps": 1e-05,
15
+ "max_position_embeddings": 514,
16
+ "model_type": "roberta",
17
+ "num_attention_heads": 12,
18
+ "num_hidden_layers": 12,
19
+ "pad_token_id": 1,
20
+ "position_embedding_type": "absolute",
21
+ "transformers_version": "4.9.0.dev0",
22
+ "type_vocab_size": 1,
23
+ "use_cache": true,
24
+ "vocab_size": 50265
25
+ }
outputs/checkpoints/checkpoint-170001/data_collator.joblib ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e02a6e9cfa63cb321cac9402efd29841b652999fcbf787800ae050e747b161ee
3
+ size 1471394
outputs/checkpoints/checkpoint-170001/flax_model.msgpack ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c40291527e2cf6e418cf78bb9cd4eec53ac716230987ad7a0a447bf0ce041d4c
3
+ size 249750019
outputs/checkpoints/checkpoint-170001/optimizer_state.msgpack ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:90dbe4fe7d7694dd86d13e9b075953620aa4dabb4fdc2023b6ede17aa720848e
3
+ size 499500278
outputs/checkpoints/checkpoint-170001/training_args.joblib ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8bc14fe16573d318dd510c7cfb42ebb7cc87b4dcf77e99247a2d1605cffd772b
3
+ size 1876
outputs/checkpoints/checkpoint-170001/training_state.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"step": 170001}
outputs/checkpoints/checkpoint-180001/config.json ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "RobertaForMaskedLM"
4
+ ],
5
+ "attention_probs_dropout_prob": 0.1,
6
+ "bos_token_id": 0,
7
+ "eos_token_id": 2,
8
+ "gradient_checkpointing": false,
9
+ "hidden_act": "gelu",
10
+ "hidden_dropout_prob": 0.1,
11
+ "hidden_size": 768,
12
+ "initializer_range": 0.02,
13
+ "intermediate_size": 3072,
14
+ "layer_norm_eps": 1e-05,
15
+ "max_position_embeddings": 514,
16
+ "model_type": "roberta",
17
+ "num_attention_heads": 12,
18
+ "num_hidden_layers": 12,
19
+ "pad_token_id": 1,
20
+ "position_embedding_type": "absolute",
21
+ "transformers_version": "4.9.0.dev0",
22
+ "type_vocab_size": 1,
23
+ "use_cache": true,
24
+ "vocab_size": 50265
25
+ }
outputs/checkpoints/checkpoint-180001/data_collator.joblib ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e02a6e9cfa63cb321cac9402efd29841b652999fcbf787800ae050e747b161ee
3
+ size 1471394
outputs/checkpoints/checkpoint-180001/flax_model.msgpack ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:393c37966461709fe51a3b3f84befb7fa7e5030025856d171308efd40dbbc7da
3
+ size 249750019
outputs/checkpoints/checkpoint-180001/optimizer_state.msgpack ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3a33cad417a7e78eaafc1c041f93fd54ad9f63869d01e1351bac4abcd58e4eeb
3
+ size 499500278
outputs/checkpoints/checkpoint-180001/training_args.joblib ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8bc14fe16573d318dd510c7cfb42ebb7cc87b4dcf77e99247a2d1605cffd772b
3
+ size 1876
outputs/checkpoints/checkpoint-180001/training_state.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"step": 180001}
outputs/config.json ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "RobertaForMaskedLM"
4
+ ],
5
+ "attention_probs_dropout_prob": 0.1,
6
+ "bos_token_id": 0,
7
+ "eos_token_id": 2,
8
+ "gradient_checkpointing": false,
9
+ "hidden_act": "gelu",
10
+ "hidden_dropout_prob": 0.1,
11
+ "hidden_size": 768,
12
+ "initializer_range": 0.02,
13
+ "intermediate_size": 3072,
14
+ "layer_norm_eps": 1e-05,
15
+ "max_position_embeddings": 514,
16
+ "model_type": "roberta",
17
+ "num_attention_heads": 12,
18
+ "num_hidden_layers": 12,
19
+ "pad_token_id": 1,
20
+ "position_embedding_type": "absolute",
21
+ "transformers_version": "4.9.0.dev0",
22
+ "type_vocab_size": 1,
23
+ "use_cache": true,
24
+ "vocab_size": 50265
25
+ }
outputs/data_collator.joblib ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e02a6e9cfa63cb321cac9402efd29841b652999fcbf787800ae050e747b161ee
3
+ size 1471394
outputs/events.out.tfevents.1626172316.underestimate.4022703.3.v2 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:54e7a88ae2dc3c9128df68ad99b735f3ae87946bc9753da8eb080eb7379dc4d3
3
+ size 26964023
outputs/flax_model.msgpack ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:393c37966461709fe51a3b3f84befb7fa7e5030025856d171308efd40dbbc7da
3
+ size 249750019
outputs/optimizer_state.msgpack ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3a33cad417a7e78eaafc1c041f93fd54ad9f63869d01e1351bac4abcd58e4eeb
3
+ size 499500278
outputs/training_args.joblib ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8bc14fe16573d318dd510c7cfb42ebb7cc87b4dcf77e99247a2d1605cffd772b
3
+ size 1876
outputs/training_state.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"step": 180001}
run_mlm_flax_stream.py ADDED
@@ -0,0 +1,722 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # coding=utf-8
3
+ # Copyright 2021 The HuggingFace Team All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ """
17
+ Fine-tuning the library models for masked language modeling (BERT, ALBERT, RoBERTa...) with whole word masking on a
18
+ text file or a dataset.
19
+
20
+ Here is the full list of checkpoints on the hub that can be fine-tuned by this script:
21
+ https://huggingface.co/models?filter=masked-lm
22
+ """
23
+ import logging
24
+ import json
25
+ import os
26
+ import shutil
27
+ import sys
28
+ import time
29
+ from collections import defaultdict
30
+ from dataclasses import dataclass, field
31
+
32
+ # You can also adapt this script on your own masked language modeling task. Pointers for this are left as comments.
33
+ import joblib
34
+ from pathlib import Path
35
+ from typing import Dict, List, Optional, Tuple
36
+
37
+ import datasets
38
+ import numpy as np
39
+ from datasets import load_dataset
40
+ from tqdm import tqdm
41
+
42
+ import flax
43
+ import jax
44
+ import jax.numpy as jnp
45
+ import kenlm # pip install https://github.com/kpu/kenlm/archive/master.zip
46
+ import optax
47
+ from flax import jax_utils, traverse_util
48
+ from flax.serialization import from_bytes, to_bytes
49
+ from flax.training import train_state
50
+ from flax.training.common_utils import get_metrics, onehot, shard
51
+ from transformers import (
52
+ CONFIG_MAPPING,
53
+ FLAX_MODEL_FOR_MASKED_LM_MAPPING,
54
+ AutoConfig,
55
+ AutoTokenizer,
56
+ FlaxAutoModelForMaskedLM,
57
+ HfArgumentParser,
58
+ PreTrainedTokenizerBase,
59
+ TensorType,
60
+ TrainingArguments,
61
+ is_tensorboard_available,
62
+ set_seed,
63
+ )
64
+
65
+
66
+ if datasets.__version__ <= "1.8.0":
67
+ raise ValueError("Make sure to upgrade `datasets` to a version >= 1.9.0 to use dataset streaming")
68
+
69
+
70
+ MODEL_CONFIG_CLASSES = list(FLAX_MODEL_FOR_MASKED_LM_MAPPING.keys())
71
+ MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)
72
+
73
+
74
+ @dataclass
75
+ class ModelArguments:
76
+ """
77
+ Arguments pertaining to which model/config/tokenizer we are going to fine-tune, or train from scratch.
78
+ """
79
+
80
+ model_name_or_path: Optional[str] = field(
81
+ default=None,
82
+ metadata={
83
+ "help": "The model checkpoint for weights initialization."
84
+ "Don't set if you want to train a model from scratch."
85
+ },
86
+ )
87
+ model_type: Optional[str] = field(
88
+ default=None,
89
+ metadata={"help": "If training from scratch, pass a model type from the list: " + ", ".join(MODEL_TYPES)},
90
+ )
91
+ config_name: Optional[str] = field(
92
+ default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
93
+ )
94
+ tokenizer_name: Optional[str] = field(
95
+ default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}
96
+ )
97
+ cache_dir: Optional[str] = field(
98
+ default=None, metadata={"help": "Where do you want to store the pretrained models downloaded from s3"}
99
+ )
100
+ use_fast_tokenizer: bool = field(
101
+ default=True,
102
+ metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."},
103
+ )
104
+ dtype: Optional[str] = field(
105
+ default="float32",
106
+ metadata={
107
+ "help": "Floating-point format in which the model weights should be initialized and trained. Choose one of `[float32, float16, bfloat16]`."
108
+ },
109
+ )
110
+
111
+ @dataclass
112
+ class DataTrainingArguments:
113
+ """
114
+ Arguments pertaining to what data we are going to input our model for training and eval.
115
+ """
116
+
117
+ dataset_name: Optional[str] = field(
118
+ default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."}
119
+ )
120
+ dataset_config_name: Optional[str] = field(
121
+ default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."}
122
+ )
123
+ train_file: Optional[str] = field(default=None, metadata={"help": "The input training data file (a text file)."})
124
+ validation_file: Optional[str] = field(
125
+ default=None,
126
+ metadata={"help": "An optional input evaluation data file to evaluate the perplexity on (a text file)."},
127
+ )
128
+ train_ref_file: Optional[str] = field(
129
+ default=None,
130
+ metadata={"help": "An optional input train ref data file for whole word masking in Chinese."},
131
+ )
132
+ validation_ref_file: Optional[str] = field(
133
+ default=None,
134
+ metadata={"help": "An optional input validation ref data file for whole word masking in Chinese."},
135
+ )
136
+ overwrite_cache: bool = field(
137
+ default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
138
+ )
139
+ validation_split_percentage: Optional[int] = field(
140
+ default=5,
141
+ metadata={
142
+ "help": "The percentage of the train set used as validation set in case there's no validation split"
143
+ },
144
+ )
145
+ max_seq_length: Optional[int] = field(
146
+ default=None,
147
+ metadata={
148
+ "help": "The maximum total input sequence length after tokenization. Sequences longer "
149
+ "than this will be truncated. Default to the max input length of the model."
150
+ },
151
+ )
152
+ preprocessing_num_workers: Optional[int] = field(
153
+ default=None,
154
+ metadata={"help": "The number of processes to use for the preprocessing."},
155
+ )
156
+ mlm_probability: float = field(
157
+ default=0.15, metadata={"help": "Ratio of tokens to mask for masked language modeling loss"}
158
+ )
159
+ pad_to_max_length: bool = field(
160
+ default=False,
161
+ metadata={
162
+ "help": "Whether to pad all samples to `max_seq_length`. "
163
+ "If False, will pad the samples dynamically when batching to the maximum length in the batch."
164
+ },
165
+ )
166
+ line_by_line: bool = field(
167
+ default=False,
168
+ metadata={"help": "Whether distinct lines of text in the dataset are to be handled as distinct sequences."},
169
+ )
170
+ text_column_name: str = field(
171
+ default="text", metadata={"help": "The name of the column to retrieve the training text."}
172
+ )
173
+ shuffle_buffer_size: int = field(
174
+ default=10000, metadata={"help": "The number of examples to pre-load for shuffling."}
175
+ )
176
+ num_train_steps: int = field(default=50000, metadata={"help": "The number of training steps."})
177
+ num_eval_samples: int = field(default=50000, metadata={"help": "The number of samples to be used for evaluation"})
178
+
179
+ def __post_init__(self):
180
+ if self.dataset_name is None and self.train_file is None and self.validation_file is None:
181
+ raise ValueError("Need either a dataset name or a training/validation file.")
182
+ else:
183
+ if self.train_file is not None:
184
+ extension = self.train_file.split(".")[-1]
185
+ assert extension in ["csv", "json", "jsonl", "txt", "gz"], "`train_file` should be a csv, a json (lines) or a txt file."
186
+ if self.validation_file is not None:
187
+ extension = self.validation_file.split(".")[-1]
188
+ assert extension in ["csv", "json", "jsonl", "txt", "gz"], "`validation_file` should be a csv, a json (lines) or a txt file."
189
+
190
+
191
+ @flax.struct.dataclass
192
+ class FlaxDataCollatorForLanguageModeling:
193
+ """
194
+ Data collator used for language modeling. Inputs are dynamically padded to the maximum length of a batch if they
195
+ are not all of the same length.
196
+
197
+ Args:
198
+ tokenizer (:class:`~transformers.PreTrainedTokenizer` or :class:`~transformers.PreTrainedTokenizerFast`):
199
+ The tokenizer used for encoding the data.
200
+ mlm_probability (:obj:`float`, `optional`, defaults to 0.15):
201
+ The probability with which to (randomly) mask tokens in the input.
202
+
203
+ .. note::
204
+
205
+ For best performance, this data collator should be used with a dataset having items that are dictionaries or
206
+ BatchEncoding, with the :obj:`"special_tokens_mask"` key, as returned by a
207
+ :class:`~transformers.PreTrainedTokenizer` or a :class:`~transformers.PreTrainedTokenizerFast` with the
208
+ argument :obj:`return_special_tokens_mask=True`.
209
+ """
210
+
211
+ tokenizer: PreTrainedTokenizerBase
212
+ mlm_probability: float = 0.15
213
+
214
+ def __post_init__(self):
215
+ if self.tokenizer.mask_token is None:
216
+ raise ValueError(
217
+ "This tokenizer does not have a mask token which is necessary for masked language modeling. "
218
+ "You should pass `mlm=False` to train on causal language modeling instead."
219
+ )
220
+
221
+ def __call__(self, examples: List[Dict[str, np.ndarray]], pad_to_multiple_of: int) -> Dict[str, np.ndarray]:
222
+ # Handle dict or lists with proper padding and conversion to tensor.
223
+ batch = self.tokenizer.pad(examples, pad_to_multiple_of=pad_to_multiple_of, return_tensors=TensorType.NUMPY)
224
+
225
+ # If special token mask has been preprocessed, pop it from the dict.
226
+ special_tokens_mask = batch.pop("special_tokens_mask", None)
227
+
228
+ batch["input_ids"], batch["labels"] = self.mask_tokens(
229
+ batch["input_ids"], special_tokens_mask=special_tokens_mask
230
+ )
231
+ return batch
232
+
233
+ def mask_tokens(
234
+ self, inputs: np.ndarray, special_tokens_mask: Optional[np.ndarray]
235
+ ) -> Tuple[jnp.ndarray, jnp.ndarray]:
236
+ """
237
+ Prepare masked tokens inputs/labels for masked language modeling: 80% MASK, 10% random, 10% original.
238
+ """
239
+ labels = inputs.copy()
240
+ # We sample a few tokens in each sequence for MLM training (with probability `self.mlm_probability`)
241
+ probability_matrix = np.full(labels.shape, self.mlm_probability)
242
+ special_tokens_mask = special_tokens_mask.astype("bool")
243
+
244
+ probability_matrix[special_tokens_mask] = 0.0
245
+ masked_indices = np.random.binomial(1, probability_matrix).astype("bool")
246
+ labels[~masked_indices] = -100 # We only compute loss on masked tokens
247
+
248
+ # 80% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK])
249
+ indices_replaced = np.random.binomial(1, np.full(labels.shape, 0.8)).astype("bool") & masked_indices
250
+ inputs[indices_replaced] = self.tokenizer.convert_tokens_to_ids(self.tokenizer.mask_token)
251
+
252
+ # 10% of the time, we replace masked input tokens with random word
253
+ indices_random = np.random.binomial(1, np.full(labels.shape, 0.5)).astype("bool")
254
+ indices_random &= masked_indices & ~indices_replaced
255
+
256
+ random_words = np.random.randint(self.tokenizer.vocab_size, size=labels.shape, dtype="i4")
257
+ inputs[indices_random] = random_words[indices_random]
258
+
259
+ # The rest of the time (10% of the time) we keep the masked input tokens unchanged
260
+ return inputs, labels
261
+
262
+
263
+ @dataclass
264
+ class SamplingArguments:
265
+ """
266
+ Arguments pertaining to how to perform sampling of the dataset.
267
+ """
268
+
269
+ perplexity_model: Optional[str] = field(
270
+ default="./es.arpa.bin", metadata={"help": "Path to KenLM model to use to get perplexity values."}
271
+ )
272
+ sampling_method: Optional[str] = field(
273
+ default=None, metadata={"help": "Sample using a 'step' or 'gaussian' perplexity function per document, or 'random'."}
274
+ )
275
+ sampling_factor: Optional[float] = field(
276
+ default=None, metadata={"help": "Sampling factor. Integers for step function, decimals for gaussian."}
277
+ )
278
+ boundaries: Optional[str] = field(
279
+ default="536394.99320948,662247.50212365,919250.87225178", metadata={"help": "Quartile boundaries"}
280
+ )
281
+
282
+ def __post_init__(self):
283
+ self.boundaries = [float(q.strip()) for q in self.boundaries.split(",")]
284
+
285
+
286
+ def generate_batch_splits(samples_idx: jnp.ndarray, batch_size: int) -> jnp.ndarray:
287
+ num_samples = len(samples_idx)
288
+ samples_to_remove = num_samples % batch_size
289
+
290
+ if samples_to_remove != 0:
291
+ samples_idx = samples_idx[:-samples_to_remove]
292
+ sections_split = num_samples // batch_size
293
+ batch_idx = np.split(samples_idx, sections_split)
294
+ return batch_idx
295
+
296
+
297
+ def advance_iter_and_group_samples(train_iterator, num_samples, max_seq_length):
298
+ """
299
+ The training iterator is advanced so that after groupifying the samples,
300
+ `num_samples` of length `max_seq_length` are returned.
301
+ """
302
+ num_total_tokens = max_seq_length * num_samples
303
+ samples = defaultdict(list)
304
+
305
+ i = 0
306
+ while i < num_total_tokens:
307
+ tokenized_samples = next(train_iterator)
308
+ i += len(tokenized_samples["input_ids"])
309
+
310
+ # concatenate tokenized samples to list
311
+ samples = {k: samples[k] + tokenized_samples[k] for k in tokenized_samples.keys()}
312
+
313
+ # Concatenated tokens are split to lists of length `max_seq_length`.
314
+ # Note that remainedr of % max_seq_length are thrown away.
315
+ def group_texts(examples):
316
+ result = {
317
+ k: [t[i : i + max_seq_length] for i in range(0, num_total_tokens, max_seq_length)]
318
+ for k, t in examples.items()
319
+ }
320
+ return result
321
+
322
+ grouped_samples = group_texts(samples)
323
+ return grouped_samples
324
+
325
+
326
+ def write_train_metric(summary_writer, train_metrics, train_time, step):
327
+ summary_writer.scalar("train_time", train_time, step)
328
+
329
+ train_metrics = get_metrics(train_metrics)
330
+ for key, vals in train_metrics.items():
331
+ tag = f"train_{key}"
332
+ for i, val in enumerate(vals):
333
+ summary_writer.scalar(tag, val, step - len(vals) + i + 1)
334
+
335
+
336
+ def write_eval_metric(summary_writer, eval_metrics, step):
337
+ for metric_name, value in eval_metrics.items():
338
+ summary_writer.scalar(f"eval_{metric_name}", value, step)
339
+
340
+
341
+ def save_checkpoint_files(state, data_collator, training_args, save_dir):
342
+ unreplicated_state = jax_utils.unreplicate(state)
343
+ with open(os.path.join(save_dir, "optimizer_state.msgpack"), "wb") as f:
344
+ f.write(to_bytes(unreplicated_state.opt_state))
345
+ joblib.dump(training_args, os.path.join(save_dir, "training_args.joblib"))
346
+ joblib.dump(data_collator, os.path.join(save_dir, "data_collator.joblib"))
347
+ with open(os.path.join(save_dir, "training_state.json"), "w") as f:
348
+ json.dump({"step": unreplicated_state.step.item()}, f)
349
+
350
+
351
+ def rotate_checkpoints(path, max_checkpoints=5):
352
+ paths = sorted(Path(path).iterdir(), key=os.path.getmtime)[::-1]
353
+ if len(paths) > max_checkpoints:
354
+ for path_to_delete in paths[max_checkpoints:]:
355
+ try:
356
+ shutil.rmtree(path_to_delete)
357
+ except OSError:
358
+ os.remove(path_to_delete)
359
+
360
+
361
+ if __name__ == "__main__":
362
+ # See all possible arguments in src/transformers/training_args.py
363
+ # or by passing the --help flag to this script.
364
+ # We now keep distinct sets of args, for a cleaner separation of concerns.
365
+
366
+ parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments, SamplingArguments))
367
+ if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
368
+ # If we pass only one argument to the script and it's the path to a json file,
369
+ # let's parse it to get our arguments.
370
+ model_args, data_args, training_args, sampling_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
371
+ else:
372
+ model_args, data_args, training_args, sampling_args = parser.parse_args_into_dataclasses()
373
+
374
+ if (
375
+ os.path.exists(training_args.output_dir)
376
+ and os.listdir(training_args.output_dir)
377
+ and training_args.do_train
378
+ and not training_args.overwrite_output_dir
379
+ ):
380
+ raise ValueError(
381
+ f"Output directory ({training_args.output_dir}) already exists and is not empty."
382
+ "Use --overwrite_output_dir to overcome."
383
+ )
384
+
385
+ # Setup logging
386
+ logging.basicConfig(
387
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
388
+ level="INFO",
389
+ datefmt="[%X]",
390
+ )
391
+
392
+ # Log on each process the small summary:
393
+ logger = logging.getLogger(__name__)
394
+ logger.warning(
395
+ f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}"
396
+ + f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}"
397
+ )
398
+
399
+ # Set the verbosity to info of the Transformers logger (on main process only):
400
+ logger.info(f"Training/evaluation parameters {training_args}")
401
+
402
+ # Set seed before initializing model.
403
+ set_seed(training_args.seed)
404
+
405
+ # Get the datasets: you can either provide your own CSV/JSON/TXT training and evaluation files (see below)
406
+ # or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/
407
+ # (the dataset will be downloaded automatically from the datasets Hub).
408
+ #
409
+ # For CSV/JSON files, this script will use the column called 'text' or the first column if no column called
410
+ # 'text' is found. You can easily tweak this behavior (see below).
411
+ if data_args.dataset_name is not None:
412
+ # Downloading and loading a dataset from the hub.
413
+ filepaths = {}
414
+ if data_args.train_file:
415
+ filepaths["train"] = data_args.train_file
416
+ if data_args.validation_file:
417
+ filepaths["validation"] = data_args.validation_file
418
+ try:
419
+ dataset = load_dataset(
420
+ data_args.dataset_name,
421
+ data_args.dataset_config_name,
422
+ cache_dir=model_args.cache_dir,
423
+ streaming=True,
424
+ split="train",
425
+ sampling_method=sampling_args.sampling_method,
426
+ sampling_factor=sampling_args.sampling_factor,
427
+ boundaries=sampling_args.boundaries,
428
+ perplexity_model=sampling_args.perplexity_model,
429
+ seed=training_args.seed,
430
+ data_files=filepaths,
431
+ )
432
+ except Exception as exc:
433
+ logger.warning(
434
+ f"Unable to load local dataset with perplexity sampling support. Using huggingface.co/datasets/{data_args.dataset_name}: {exc}"
435
+ )
436
+ dataset = load_dataset(
437
+ data_args.dataset_name,
438
+ data_args.dataset_config_name,
439
+ cache_dir=model_args.cache_dir,
440
+ streaming=True,
441
+ split="train",
442
+ )
443
+
444
+ if model_args.config_name:
445
+ config = AutoConfig.from_pretrained(model_args.config_name, cache_dir=model_args.cache_dir)
446
+ elif model_args.model_name_or_path:
447
+ config = AutoConfig.from_pretrained(model_args.model_name_or_path, cache_dir=model_args.cache_dir)
448
+ else:
449
+ config = CONFIG_MAPPING[model_args.model_type]()
450
+ logger.warning("You are instantiating a new config instance from scratch.")
451
+
452
+ if model_args.tokenizer_name:
453
+ tokenizer = AutoTokenizer.from_pretrained(
454
+ model_args.tokenizer_name, cache_dir=model_args.cache_dir, use_fast=model_args.use_fast_tokenizer
455
+ )
456
+ elif model_args.model_name_or_path:
457
+ tokenizer = AutoTokenizer.from_pretrained(
458
+ model_args.model_name_or_path, cache_dir=model_args.cache_dir, use_fast=model_args.use_fast_tokenizer
459
+ )
460
+ else:
461
+ raise ValueError(
462
+ "You are instantiating a new tokenizer from scratch. This is not supported by this script."
463
+ "You can do it from another script, save it, and load it from here, using --tokenizer_name."
464
+ )
465
+
466
+ # Otherwise, we tokenize every text, then concatenate them together before splitting them in smaller parts.
467
+ # We use `return_special_tokens_mask=True` because DataCollatorForLanguageModeling (see below) is more
468
+ # efficient when it receives the `special_tokens_mask`.
469
+ def tokenize_function(examples):
470
+ return tokenizer(
471
+ examples[data_args.text_column_name],
472
+ return_special_tokens_mask=True
473
+ )
474
+
475
+ tokenized_datasets = dataset.map(
476
+ tokenize_function,
477
+ batched=True,
478
+ )
479
+
480
+ shuffle_seed = training_args.seed
481
+ tokenized_datasets = tokenized_datasets.shuffle(buffer_size=data_args.shuffle_buffer_size, seed=shuffle_seed)
482
+
483
+ # Enable tensorboard only on the master node
484
+ has_tensorboard = is_tensorboard_available()
485
+ if has_tensorboard and jax.process_index() == 0:
486
+ try:
487
+ from flax.metrics.tensorboard import SummaryWriter
488
+ summary_writer = SummaryWriter(log_dir=Path(training_args.output_dir))
489
+ # Enable Weight&Biases
490
+ import wandb
491
+ wandb.init(
492
+ entity='wandb',
493
+ project='hf-flax-bertin-roberta-es',
494
+ sync_tensorboard=True,
495
+ )
496
+ wandb.config.update(training_args)
497
+ wandb.config.update(model_args)
498
+ wandb.config.update(data_args)
499
+ except ImportError as ie:
500
+ has_tensorboard = False
501
+ logger.warning(
502
+ f"Unable to display metrics through TensorBoard because some package are not installed: {ie}"
503
+ )
504
+ else:
505
+ logger.warning(
506
+ "Unable to display metrics through TensorBoard because the package is not installed: "
507
+ "Please run pip install tensorboard to enable."
508
+ )
509
+
510
+ # Data collator
511
+ # This one will take care of randomly masking the tokens.
512
+ data_collator = FlaxDataCollatorForLanguageModeling(tokenizer=tokenizer, mlm_probability=data_args.mlm_probability)
513
+
514
+ # Initialize our training
515
+ rng = jax.random.PRNGKey(training_args.seed)
516
+ dropout_rngs = jax.random.split(rng, jax.local_device_count())
517
+
518
+ if model_args.model_name_or_path:
519
+ model = FlaxAutoModelForMaskedLM.from_pretrained(
520
+ model_args.model_name_or_path, config=config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype)
521
+ )
522
+ else:
523
+ model = FlaxAutoModelForMaskedLM.from_config(
524
+ config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype)
525
+ )
526
+
527
+ # Store some constant
528
+ num_epochs = int(training_args.num_train_epochs)
529
+ train_batch_size = int(training_args.per_device_train_batch_size) * jax.device_count()
530
+ eval_batch_size = int(training_args.per_device_eval_batch_size) * jax.device_count()
531
+
532
+ # define number steps per stream epoch
533
+ num_train_steps = data_args.num_train_steps
534
+
535
+ # Create learning rate schedule
536
+ warmup_fn = optax.linear_schedule(
537
+ init_value=0.0, end_value=training_args.learning_rate, transition_steps=training_args.warmup_steps
538
+ )
539
+ decay_fn = optax.linear_schedule(
540
+ init_value=training_args.learning_rate,
541
+ end_value=0,
542
+ transition_steps=num_train_steps - training_args.warmup_steps,
543
+ )
544
+ linear_decay_lr_schedule_fn = optax.join_schedules(
545
+ schedules=[warmup_fn, decay_fn], boundaries=[training_args.warmup_steps]
546
+ )
547
+
548
+ # We use Optax's "masking" functionality to not apply weight decay
549
+ # to bias and LayerNorm scale parameters. decay_mask_fn returns a
550
+ # mask boolean with the same structure as the parameters.
551
+ # The mask is True for parameters that should be decayed.
552
+ # Note that this mask is specifically adapted for FlaxBERT-like models.
553
+ # For other models, one should correct the layer norm parameter naming
554
+ # accordingly.
555
+ def decay_mask_fn(params):
556
+ flat_params = traverse_util.flatten_dict(params)
557
+ flat_mask = {path: (path[-1] != "bias" and path[-2:] != ("LayerNorm", "scale")) for path in flat_params}
558
+ return traverse_util.unflatten_dict(flat_mask)
559
+
560
+ # create adam optimizer
561
+ adamw = optax.adamw(
562
+ learning_rate=linear_decay_lr_schedule_fn,
563
+ b1=training_args.adam_beta1,
564
+ b2=training_args.adam_beta2,
565
+ eps=training_args.adam_epsilon,
566
+ weight_decay=training_args.weight_decay,
567
+ mask=decay_mask_fn,
568
+ )
569
+
570
+ # Setup train state
571
+ state = train_state.TrainState.create(apply_fn=model.__call__, params=model.params, tx=adamw)
572
+
573
+ # Define gradient update step fn
574
+ def train_step(state, batch, dropout_rng):
575
+ dropout_rng, new_dropout_rng = jax.random.split(dropout_rng)
576
+
577
+ def loss_fn(params):
578
+ labels = batch.pop("labels")
579
+
580
+ logits = state.apply_fn(**batch, params=params, dropout_rng=dropout_rng, train=True)[0]
581
+
582
+ # compute loss, ignore padded input tokens
583
+ label_mask = jnp.where(labels > 0, 1.0, 0.0)
584
+ loss = optax.softmax_cross_entropy(logits, onehot(labels, logits.shape[-1])) * label_mask
585
+
586
+ # take average
587
+ loss = loss.sum() / label_mask.sum()
588
+
589
+ return loss
590
+
591
+ grad_fn = jax.value_and_grad(loss_fn)
592
+ loss, grad = grad_fn(state.params)
593
+ grad = jax.lax.pmean(grad, "batch")
594
+ new_state = state.apply_gradients(grads=grad)
595
+
596
+ metrics = jax.lax.pmean(
597
+ {"loss": loss, "learning_rate": linear_decay_lr_schedule_fn(state.step)}, axis_name="batch"
598
+ )
599
+
600
+ return new_state, metrics, new_dropout_rng
601
+
602
+ # Create parallel version of the train step
603
+ p_train_step = jax.pmap(train_step, "batch", donate_argnums=(0,))
604
+
605
+ # Define eval fn
606
+ def eval_step(params, batch):
607
+ labels = batch.pop("labels")
608
+
609
+ logits = model(**batch, params=params, train=False)[0]
610
+
611
+ # compute loss, ignore padded input tokens
612
+ label_mask = jnp.where(labels > 0, 1.0, 0.0)
613
+ loss = optax.softmax_cross_entropy(logits, onehot(labels, logits.shape[-1])) * label_mask
614
+
615
+ # compute accuracy
616
+ accuracy = jnp.equal(jnp.argmax(logits, axis=-1), labels) * label_mask
617
+
618
+ # summarize metrics
619
+ metrics = {"loss": loss.sum(), "accuracy": accuracy.sum(), "normalizer": label_mask.sum()}
620
+ metrics = jax.lax.psum(metrics, axis_name="batch")
621
+
622
+ return metrics
623
+
624
+ p_eval_step = jax.pmap(eval_step, "batch", donate_argnums=(0,))
625
+
626
+ # Replicate the train state on each device
627
+ state = jax_utils.replicate(state)
628
+
629
+ train_time = 0
630
+ train_start = time.time()
631
+ train_metrics = []
632
+ eval_metrics = []
633
+
634
+ training_iter = iter(tokenized_datasets)
635
+
636
+ max_seq_length = min(data_args.max_seq_length, tokenizer.model_max_length)
637
+ eval_samples = advance_iter_and_group_samples(training_iter, data_args.num_eval_samples, max_seq_length)
638
+
639
+ steps = tqdm(range(num_train_steps), desc="Training...", position=0)
640
+ for step in range(num_train_steps):
641
+ # ======================== Training ================================
642
+ try:
643
+ samples = advance_iter_and_group_samples(training_iter, train_batch_size, max_seq_length)
644
+ except StopIteration:
645
+ # Once the end of the dataset stream is reached, the training iterator
646
+ # is reinitialized and reshuffled and a new eval dataset is randomely chosen.
647
+ shuffle_seed += 1
648
+ tokenized_datasets.set_epoch(shuffle_seed)
649
+
650
+ training_iter = iter(tokenized_datasets)
651
+
652
+ eval_dataset = advance_iter_and_group_samples(training_iter, data_args.num_eval_samples, max_seq_length)
653
+ samples = advance_iter_and_group_samples(training_iter, train_batch_size, max_seq_length)
654
+
655
+ # process input samples
656
+ model_inputs = data_collator(samples, pad_to_multiple_of=16)
657
+
658
+ # Model forward
659
+ model_inputs = shard(model_inputs.data)
660
+ state, train_metric, dropout_rngs = p_train_step(state, model_inputs, dropout_rngs)
661
+
662
+ train_metrics.append(train_metric)
663
+
664
+ if step % training_args.logging_steps == 0 and step > 0:
665
+ steps.write(
666
+ f"Step... ({step} | Loss: {train_metric['loss'].mean()}, Learning Rate: {train_metric['learning_rate'].mean()})"
667
+ )
668
+ train_time += time.time() - train_start
669
+ if has_tensorboard and jax.process_index() == 0:
670
+ write_train_metric(summary_writer, train_metrics, train_time, step)
671
+ train_metrics = []
672
+
673
+ # ======================== Evaluating ==============================
674
+ if step % training_args.eval_steps == 0 and step > 0:
675
+ eval_samples_idx = jnp.arange(data_args.num_eval_samples)
676
+ eval_batch_idx = generate_batch_splits(eval_samples_idx, eval_batch_size)
677
+
678
+ for i, batch_idx in enumerate(tqdm(eval_batch_idx, desc="Evaluating ...", position=1)):
679
+ # process input samples
680
+ batch_eval_samples = {k: [v[idx] for idx in batch_idx] for k, v in eval_samples.items()}
681
+ model_inputs = data_collator(batch_eval_samples, pad_to_multiple_of=16)
682
+
683
+ # Model forward
684
+ model_inputs = shard(model_inputs.data)
685
+ metrics = p_eval_step(state.params, model_inputs)
686
+ eval_metrics.append(metrics)
687
+
688
+ # normalize eval metrics
689
+ eval_metrics = get_metrics(eval_metrics)
690
+ eval_metrics = jax.tree_map(jnp.sum, eval_metrics)
691
+ eval_normalizer = eval_metrics.pop("normalizer")
692
+ eval_metrics = jax.tree_map(lambda x: x / eval_normalizer, eval_metrics)
693
+
694
+ # Update progress bar
695
+ steps.desc = f"Step... ({step + 1}/{num_train_steps} | Loss: {eval_metrics['loss']}, Acc: {eval_metrics['accuracy']})"
696
+
697
+ if has_tensorboard and jax.process_index() == 0:
698
+ write_eval_metric(summary_writer, eval_metrics, step)
699
+ eval_metrics = []
700
+
701
+ # save checkpoint after eval_steps
702
+ if step % training_args.save_steps == 0 and step > 0 and jax.process_index() == 0:
703
+ logger.info(f"Saving checkpoint at {step + 1} steps")
704
+ params = jax.device_get(jax.tree_map(lambda x: x[0], state.params))
705
+ model.save_pretrained(
706
+ training_args.output_dir,
707
+ params=params,
708
+ push_to_hub=training_args.push_to_hub,
709
+ commit_message=f"Saving weights and logs of step {step + 1}",
710
+ )
711
+ save_checkpoint_files(state, data_collator, training_args, training_args.output_dir)
712
+ checkpoints_dir = Path(training_args.output_dir) / "checkpoints" / f"checkpoint-{step + 1}"
713
+ checkpoints_dir.mkdir(parents=True, exist_ok=True)
714
+ model.save_pretrained(checkpoints_dir, params=params,)
715
+ save_checkpoint_files(state, data_collator, training_args, checkpoints_dir)
716
+ rotate_checkpoints(
717
+ Path(training_args.output_dir) / "checkpoints",
718
+ max_checkpoints=training_args.save_total_limit
719
+ )
720
+
721
+ # update tqdm bar
722
+ steps.update(1)
run_stream.sh ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # From https://arxiv.org/pdf/1907.11692.pdf for base model
2
+ python -c "import jax; print('TPUs', jax.device_count())"
3
+ python ./run_mlm_flax_stream.py \
4
+ --output_dir="./outputs" \
5
+ --model_type="roberta" \
6
+ --config_name="./configs/base" \
7
+ --tokenizer_name="./configs/base" \
8
+ --dataset_name="./mc4" \
9
+ --dataset_config_name="es" \
10
+ --train_file="../mc4-es-train-50M-steps.jsonl" \
11
+ --max_seq_length="128" \
12
+ --pad_to_max_length \
13
+ --per_device_train_batch_size="256" \
14
+ --per_device_eval_batch_size="256" \
15
+ --adam_beta1="0.9" \
16
+ --adam_beta2="0.98" \
17
+ --adam_epsilon="1e-6" \
18
+ --learning_rate="6e-4" \
19
+ --weight_decay="0.01" \
20
+ --save_steps="10000" \
21
+ --save_total_limit="5" \
22
+ --warmup_steps="24000" \
23
+ --overwrite_output_dir \
24
+ --num_train_steps="250000" \
25
+ --eval_steps="10000" \
26
+ --dtype="bfloat16" \
27
+ --logging_steps="500" 2>&1 | tee run_stream.log