boris commited on
Commit
85c1b8e
·
1 Parent(s): e6c2573

feat: handle data in separate file

Browse files
Files changed (2) hide show
  1. dalle_mini/data.py +269 -0
  2. dev/seq2seq/run_seq2seq_flax.py +32 -220
dalle_mini/data.py ADDED
@@ -0,0 +1,269 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass, field
2
+ from datasets import load_dataset
3
+ import numpy as np
4
+ import jax
5
+ import jax.numpy as jnp
6
+ from flax.training.common_utils import shard
7
+ from .text import TextNormalizer
8
+
9
+
10
+ @dataclass
11
+ class Dataset:
12
+ dataset_repo_or_path: str
13
+ train_file: str = None
14
+ validation_file: str = None
15
+ dataset_type: str = "dataset"
16
+ streaming: bool = True
17
+ use_auth_token: bool = False
18
+ text_column: str = "caption"
19
+ encoding_column: str = "encoding"
20
+ max_source_length: int = 128
21
+ max_train_samples: int = None
22
+ max_eval_samples: int = None
23
+ preprocessing_num_workers: int = None
24
+ overwrite_cache: bool = False
25
+ do_train: bool = False
26
+ do_eval: bool = True
27
+ seed_dataset: int = None
28
+ train_dataset = field(init=False)
29
+ eval_dataset = field(init=False)
30
+ rng_dataset = field(init=False)
31
+
32
+ def __post_init__(self):
33
+ # define data_files
34
+ if self.train_file is not None or self.validation_file is not None:
35
+ data_files = {
36
+ "train": self.train_file,
37
+ "validation": self.validation_file,
38
+ }
39
+ else:
40
+ data_files = None
41
+
42
+ # load dataset
43
+ dataset = load_dataset(
44
+ self.dataset_repo_or_path,
45
+ data_files=data_files,
46
+ streaming=self.streaming,
47
+ use_auth_token=self.use_auth_token,
48
+ )
49
+ if self.do_train:
50
+ if "train" not in dataset:
51
+ raise ValueError("Training requires a training dataset")
52
+ self.train_dataset = dataset["train"]
53
+ if self.max_train_samples is not None:
54
+ self.train_dataset = (
55
+ self.train_dataset.take(self.max_train_samples)
56
+ if self.streaming
57
+ else self.train_dataset.select(range(self.max_train_samples))
58
+ )
59
+ if self.do_eval:
60
+ if "validation" not in dataset:
61
+ raise ValueError("Evaluating requires a validation dataset")
62
+ self.eval_dataset = dataset["validation"]
63
+ if self.max_eval_samples is not None:
64
+ self.eval_dataset = (
65
+ self.eval_dataset.take(self.max_eval_samples)
66
+ if self.streaming
67
+ else self.eval_dataset.select(range(self.max_eval_samples))
68
+ )
69
+
70
+ def preprocess(self, tokenizer, decoder_start_token_id, normalize_text):
71
+ if self.streaming:
72
+ # we need to shuffle early in streaming mode
73
+ if hasattr(self, "train_dataset"):
74
+ self.train_dataset = self.train_dataset.shuffle(1000, self.seed_dataset)
75
+ else:
76
+ # prepare rng for later shuffling
77
+ if self.seed_dataset is None:
78
+ self.seed_dataset = np.random.get_state()[1][0]
79
+ self.rng_dataset = jax.random.PRNGKey(self.seed_dataset)
80
+
81
+ # normalize text
82
+ if normalize_text:
83
+ text_normalizer = TextNormalizer()
84
+ for ds in ["train_dataset", "eval_dataset"]:
85
+ if hasattr(self, ds):
86
+ setattr(
87
+ self,
88
+ ds,
89
+ (
90
+ getattr(self, ds).map(
91
+ normalize_text,
92
+ fn_kwargs={
93
+ "text_column": self.text_column,
94
+ "text_normalizer": text_normalizer,
95
+ },
96
+ )
97
+ if self.streaming
98
+ else getattr(self, ds).map(
99
+ normalize_text,
100
+ fn_kwargs={
101
+ "text_column": self.text_column,
102
+ "text_normalizer": text_normalizer,
103
+ },
104
+ num_proc=self.preprocessing_num_workers,
105
+ load_from_cache_file=not self.overwrite_cache,
106
+ desc="Normalizing datasets",
107
+ )
108
+ ),
109
+ )
110
+
111
+ # preprocess
112
+ for ds in ["train_dataset", "eval_dataset"]:
113
+ if hasattr(self, ds):
114
+ setattr(
115
+ self,
116
+ ds,
117
+ (
118
+ getattr(self, ds).map(
119
+ preprocess_function,
120
+ batched=True,
121
+ fn_kwargs={
122
+ "tokenizer": tokenizer,
123
+ "text_column": self.text_column,
124
+ "encoding_column": self.encoding_column,
125
+ "max_source_length": self.max_source_length,
126
+ "decoder_start_token_id": decoder_start_token_id,
127
+ },
128
+ )
129
+ if self.streaming
130
+ else getattr(self, ds).map(
131
+ preprocess_function,
132
+ batched=True,
133
+ fn_kwargs={
134
+ "tokenizer": tokenizer,
135
+ "text_column": self.text_column,
136
+ "encoding_column": self.encoding_column,
137
+ "max_source_length": self.max_source_length,
138
+ "decoder_start_token_id": decoder_start_token_id,
139
+ },
140
+ remove_columns=getattr(ds, "column_names"),
141
+ num_proc=self.preprocessing_num_workers,
142
+ load_from_cache_file=not self.overwrite_cache,
143
+ desc="Preprocessing datasets",
144
+ )
145
+ ),
146
+ )
147
+
148
+ def dataloader(self, split, batch_size, epoch=None):
149
+ def _dataloader_datasets_non_streaming(
150
+ dataset: Dataset,
151
+ batch_size: int,
152
+ rng: jax.random.PRNGKey = None,
153
+ ):
154
+ """
155
+ Returns batches of size `batch_size` from truncated `dataset`, sharded over all local devices.
156
+ Shuffle batches if `shuffle` is `True`.
157
+ """
158
+ steps_per_epoch = len(dataset) // batch_size
159
+
160
+ if rng is not None:
161
+ batch_idx = jax.random.permutation(rng, len(dataset))
162
+ else:
163
+ batch_idx = jnp.arange(len(dataset))
164
+
165
+ batch_idx = batch_idx[
166
+ : steps_per_epoch * batch_size
167
+ ] # Skip incomplete batch.
168
+ batch_idx = batch_idx.reshape((steps_per_epoch, batch_size))
169
+
170
+ for idx in batch_idx:
171
+ batch = dataset[idx]
172
+ batch = {k: jnp.array(v) for k, v in batch.items()}
173
+ batch = shard(batch)
174
+ yield batch
175
+
176
+ def _dataloader_datasets_streaming(dataset: Dataset, batch_size: int):
177
+ keys = ["input_ids", "attention_mask", "labels", "decoder_input_ids"]
178
+ batch = {k: [] for k in keys}
179
+ for item in dataset:
180
+ for k, v in item.items():
181
+ batch[k].append(v)
182
+ if len(batch[keys[0]]) == batch_size:
183
+ batch = {k: jnp.array(v) for k, v in batch.items()}
184
+ batch = shard(batch)
185
+ yield batch
186
+ batch = {k: [] for k in keys}
187
+
188
+ if split == "train":
189
+ ds = self.train_dataset
190
+ elif split == "eval":
191
+ ds = self.eval_dataset
192
+ else:
193
+ raise ValueError(f'split must be "train" or "eval", got {split}')
194
+
195
+ if self.streaming:
196
+ if split == "train":
197
+ ds.set_epoch(epoch)
198
+ return _dataloader_datasets_streaming(ds, batch_size)
199
+ else:
200
+ if split == "train":
201
+ self.rng_dataset, input_rng = jax.random.split(self.rng_dataset)
202
+ return _dataloader_datasets_non_streaming(ds, batch_size, input_rng)
203
+
204
+ @property
205
+ def length(self):
206
+ len_train_dataset, len_eval_dataset = None, None
207
+ if self.streaming:
208
+ # we don't know the length, let's just assume max_samples if defined
209
+ if self.max_train_samples is not None:
210
+ len_train_dataset = self.max_train_samples
211
+ if self.max_eval_samples is not None:
212
+ len_eval_dataset = self.max_eval_samples
213
+ else:
214
+ len_train_dataset = (
215
+ len(self.train_dataset) if hasattr(self, "train_dataset") else None
216
+ )
217
+ len_eval_dataset = (
218
+ len(self.eval_dataset) if hasattr(self, "eval_dataset") else None
219
+ )
220
+ return len_train_dataset, len_eval_dataset
221
+
222
+
223
+ def shift_tokens_right(input_ids: np.array, decoder_start_token_id: int):
224
+ """
225
+ Shift input ids one token to the right.
226
+ """
227
+ shifted_input_ids = np.zeros(input_ids.shape)
228
+ shifted_input_ids[:, 1:] = input_ids[:, :-1]
229
+ shifted_input_ids[:, 0] = decoder_start_token_id
230
+ return shifted_input_ids
231
+
232
+
233
+ def normalize_text(example, text_column, text_normalizer):
234
+ example[text_column] = text_normalizer(example[text_column])
235
+ return example
236
+
237
+
238
+ def preprocess_function(
239
+ examples,
240
+ tokenizer,
241
+ text_column,
242
+ encoding_column,
243
+ max_source_length,
244
+ decoder_start_token_id,
245
+ ):
246
+ inputs = examples[text_column]
247
+ # Setting padding="max_length" as we need fixed length inputs for jitted functions
248
+ model_inputs = tokenizer(
249
+ inputs,
250
+ max_length=max_source_length,
251
+ padding="max_length",
252
+ truncation=True,
253
+ return_tensors="np",
254
+ )
255
+
256
+ # set up targets
257
+ # Note: labels correspond to our target indices
258
+ # decoder input ids are the same but shifted to the right with bos at the beginning (and without last token)
259
+ labels = examples[encoding_column]
260
+ labels = np.asarray(labels)
261
+
262
+ # We need the labels, in addition to the decoder_input_ids, for the compute_loss function
263
+ model_inputs["labels"] = labels
264
+
265
+ # In our case, this prepends the bos token and removes the last one
266
+ decoder_input_ids = shift_tokens_right(labels, decoder_start_token_id)
267
+ model_inputs["decoder_input_ids"] = decoder_input_ids
268
+
269
+ return model_inputs
dev/seq2seq/run_seq2seq_flax.py CHANGED
@@ -28,8 +28,7 @@ from typing import Callable, Optional
28
  import json
29
 
30
  import datasets
31
- import numpy as np
32
- from datasets import Dataset, load_dataset
33
  from tqdm import tqdm
34
 
35
  import jax
@@ -40,7 +39,7 @@ from flax import jax_utils, traverse_util
40
  from flax.serialization import from_bytes, to_bytes
41
  from flax.jax_utils import unreplicate
42
  from flax.training import train_state
43
- from flax.training.common_utils import get_metrics, onehot, shard, shard_prng_key
44
  from transformers import (
45
  AutoTokenizer,
46
  HfArgumentParser,
@@ -49,7 +48,7 @@ from transformers.models.bart.modeling_flax_bart import BartConfig
49
 
50
  import wandb
51
 
52
- from dalle_mini.text import TextNormalizer
53
  from dalle_mini.model import CustomFlaxBartForConditionalGeneration
54
 
55
  logger = logging.getLogger(__name__)
@@ -120,18 +119,21 @@ class DataTrainingArguments:
120
  "help": "The name of the column in the datasets containing the image encodings."
121
  },
122
  )
123
- dataset_repo_or_path: Optional[str] = field(
124
  default=None,
125
  metadata={"help": "The dataset repository containing encoded files."},
126
  )
127
  train_file: Optional[str] = field(
128
- default=None, metadata={"help": "The input training data file (a text file)."}
 
129
  )
130
  validation_file: Optional[str] = field(
131
  default=None,
132
- metadata={
133
- "help": "An optional input evaluation data file to evaluate the perplexity on (a text file)."
134
- },
 
 
135
  )
136
  # data loading should not be a bottleneck so we use "streaming" mode by default
137
  streaming: bool = field(
@@ -177,6 +179,13 @@ class DataTrainingArguments:
177
  "help": "Overwrite the cached training and evaluation sets. Not used in streaming mode."
178
  },
179
  )
 
 
 
 
 
 
 
180
 
181
  def __post_init__(self):
182
  if self.dataset_repo_or_path is None:
@@ -277,13 +286,6 @@ class TrainingArguments:
277
  "help": "Random seed for the model that will be set at the beginning of training."
278
  },
279
  )
280
- # default seed of None ensures we don't repeat the same items if script was interrupted during an epoch
281
- seed_dataset: int = field(
282
- default=None,
283
- metadata={
284
- "help": "Random seed for the dataset that will be set at the beginning of training."
285
- },
286
- )
287
 
288
  push_to_hub: bool = field(
289
  default=False,
@@ -327,45 +329,6 @@ class TrainState(train_state.TrainState):
327
  )
328
 
329
 
330
- def data_loader(
331
- dataset: Dataset,
332
- batch_size: int,
333
- rng: jax.random.PRNGKey = None,
334
- ):
335
- """
336
- Returns batches of size `batch_size` from truncated `dataset`, sharded over all local devices.
337
- Shuffle batches if `shuffle` is `True`.
338
- """
339
- steps_per_epoch = len(dataset) // batch_size
340
-
341
- if rng is not None:
342
- batch_idx = jax.random.permutation(rng, len(dataset))
343
- else:
344
- batch_idx = jnp.arange(len(dataset))
345
-
346
- batch_idx = batch_idx[: steps_per_epoch * batch_size] # Skip incomplete batch.
347
- batch_idx = batch_idx.reshape((steps_per_epoch, batch_size))
348
-
349
- for idx in batch_idx:
350
- batch = dataset[idx]
351
- batch = {k: jnp.array(v) for k, v in batch.items()}
352
- batch = shard(batch)
353
- yield batch
354
-
355
-
356
- def data_loader_streaming(dataset: Dataset, batch_size: int):
357
- keys = ["input_ids", "attention_mask", "labels", "decoder_input_ids"]
358
- batch = {k: [] for k in keys}
359
- for item in dataset:
360
- for k, v in item.items():
361
- batch[k].append(v)
362
- if len(batch[keys[0]]) == batch_size:
363
- batch = {k: jnp.array(v) for k, v in batch.items()}
364
- batch = shard(batch)
365
- yield batch
366
- batch = {k: [] for k in keys}
367
-
368
-
369
  def create_learning_rate_fn(
370
  num_warmup_steps: int,
371
  learning_rate: float,
@@ -447,18 +410,8 @@ def main():
447
  logger.info(f"Training/evaluation parameters {training_args}")
448
 
449
  # Load dataset
450
- if data_args.train_file is not None or data_args.validation_file is not None:
451
- data_files = {
452
- "train": data_args.train_file,
453
- "validation": data_args.validation_file,
454
- }
455
- else:
456
- data_files = None
457
- dataset = load_dataset(
458
- data_args.dataset_repo_or_path,
459
- data_files=data_files,
460
- streaming=data_args.streaming,
461
- use_auth_token=data_args.use_auth_token,
462
  )
463
 
464
  # Set up wandb run
@@ -552,141 +505,17 @@ def main():
552
  use_fast=True,
553
  )
554
 
555
- print(f"TPUs: {jax.device_count()}")
556
  assert jax.device_count() == 8, "TPUs in use, please check running processes"
557
 
558
  # Preprocessing the datasets.
559
- # We need to tokenize inputs and targets.
560
-
561
- # Get the column names for input/target.
562
- text_column = data_args.text_column
563
- encoding_column = data_args.encoding_column
564
-
565
- def shift_tokens_right(input_ids: np.array, decoder_start_token_id: int):
566
- """
567
- Shift input ids one token to the right.
568
- """
569
- shifted_input_ids = np.zeros(input_ids.shape)
570
- shifted_input_ids[:, 1:] = input_ids[:, :-1]
571
- shifted_input_ids[:, 0] = decoder_start_token_id
572
- return shifted_input_ids
573
-
574
- text_normalizer = TextNormalizer() if model.config.normalize_text else None
575
-
576
- def normalize_text(example):
577
- example[text_column] = text_normalizer(example[text_column])
578
- return example
579
-
580
- def preprocess_function(examples):
581
- inputs = examples[text_column]
582
- # Setting padding="max_length" as we need fixed length inputs for jitted functions
583
- model_inputs = tokenizer(
584
- inputs,
585
- max_length=data_args.max_source_length,
586
- padding="max_length",
587
- truncation=True,
588
- return_tensors="np",
589
- )
590
-
591
- # set up targets
592
- # Note: labels correspond to our target indices
593
- # decoder input ids are the same but shifted to the right with bos at the beginning (and without last token)
594
- labels = examples[encoding_column]
595
- labels = np.asarray(labels)
596
-
597
- # We need the labels, in addition to the decoder_input_ids, for the compute_loss function
598
- model_inputs["labels"] = labels
599
-
600
- # In our case, this prepends the bos token and removes the last one
601
- decoder_input_ids = shift_tokens_right(
602
- labels, model.config.decoder_start_token_id
603
- )
604
- model_inputs["decoder_input_ids"] = decoder_input_ids
605
-
606
- return model_inputs
607
-
608
- if training_args.do_train:
609
- if "train" not in dataset:
610
- raise ValueError("--do_train requires a train dataset")
611
- train_dataset = dataset["train"]
612
- if data_args.max_train_samples is not None:
613
- train_dataset = (
614
- train_dataset.take(data_args.max_train_samples)
615
- if data_args.streaming
616
- else train_dataset.select(range(data_args.max_train_samples))
617
- )
618
- if data_args.streaming:
619
- train_dataset = train_dataset.shuffle(1000, training_args.seed_dataset)
620
- else:
621
- seed_dataset = (
622
- training_args.seed_dataset
623
- if training_args.seed_dataset is not None
624
- else np.random.get_state()[1][0]
625
- )
626
- rng_dataset = jax.random.PRNGKey(seed_dataset)
627
- if model.config.normalize_text:
628
- train_dataset = (
629
- train_dataset.map(normalize_text)
630
- if data_args.streaming
631
- else train_dataset.map(
632
- normalize_text,
633
- num_proc=data_args.preprocessing_num_workers,
634
- load_from_cache_file=not data_args.overwrite_cache,
635
- desc="Normalizing the validation dataset",
636
- )
637
- )
638
- train_dataset = (
639
- train_dataset.map(
640
- preprocess_function,
641
- batched=True,
642
- )
643
- if data_args.streaming
644
- else train_dataset.map(
645
- preprocess_function,
646
- batched=True,
647
- num_proc=data_args.preprocessing_num_workers,
648
- remove_columns=train_dataset.column_names,
649
- load_from_cache_file=not data_args.overwrite_cache,
650
- desc="Running tokenizer on validation dataset",
651
- )
652
- )
653
 
654
- if training_args.do_eval:
655
- if "validation" not in dataset:
656
- raise ValueError("--do_eval requires a validation dataset")
657
- eval_dataset = dataset["validation"]
658
- if data_args.max_eval_samples is not None:
659
- eval_dataset = (
660
- eval_dataset.take(data_args.max_train_samples)
661
- if data_args.streaming
662
- else eval_dataset.select(range(data_args.max_train_samples))
663
- )
664
- if model.config.normalize_text:
665
- eval_dataset = (
666
- eval_dataset.map(normalize_text)
667
- if data_args.streaming
668
- else eval_dataset.map(
669
- normalize_text,
670
- num_proc=data_args.preprocessing_num_workers,
671
- load_from_cache_file=not data_args.overwrite_cache,
672
- desc="Normalizing the validation dataset",
673
- )
674
- )
675
- eval_dataset = (
676
- eval_dataset.map(
677
- preprocess_function,
678
- batched=True,
679
- )
680
- if data_args.streaming
681
- else eval_dataset.map(
682
- preprocess_function,
683
- batched=True,
684
- num_proc=data_args.preprocessing_num_workers,
685
- remove_columns=eval_dataset.column_names,
686
- load_from_cache_file=not data_args.overwrite_cache,
687
- desc="Running tokenizer on validation dataset",
688
- )
689
- )
690
 
691
  # Initialize our training
692
  rng = jax.random.PRNGKey(training_args.seed_model)
@@ -699,16 +528,7 @@ def main():
699
  )
700
  batch_size_per_update = train_batch_size * training_args.gradient_accumulation_steps
701
  eval_batch_size = int(training_args.per_device_eval_batch_size) * jax.device_count()
702
- len_train_dataset, len_eval_dataset = None, None
703
- if data_args.streaming:
704
- # we don't know the length, let's just assume max_samples if defined
705
- if data_args.max_train_samples is not None:
706
- len_train_dataset = data_args.max_train_samples
707
- if data_args.max_eval_samples is not None:
708
- len_eval_dataset = data_args.max_eval_samples
709
- else:
710
- len_train_dataset = len(train_dataset)
711
- len_eval_dataset = len(eval_dataset)
712
  steps_per_epoch = (
713
  len_train_dataset // train_batch_size if len_train_dataset is not None else None
714
  )
@@ -854,8 +674,8 @@ def main():
854
  # add interesting config parameters
855
  wandb.config.update(
856
  {
857
- "len_train": len_train_dataset,
858
- "len_eval": len_eval_dataset,
859
  "batch_size_per_update": batch_size_per_update,
860
  }
861
  )
@@ -867,10 +687,7 @@ def main():
867
  # ======================== Evaluating ==============================
868
  eval_metrics = []
869
  if training_args.do_eval:
870
- if data_args.streaming:
871
- eval_loader = data_loader_streaming(eval_dataset, eval_batch_size)
872
- else:
873
- eval_loader = data_loader(eval_dataset, eval_batch_size)
874
  eval_steps = (
875
  len_eval_dataset // eval_batch_size
876
  if len_eval_dataset is not None
@@ -985,12 +802,7 @@ def main():
985
  wandb_log({"train/epoch": epoch}, step=unreplicate(state.step))
986
 
987
  # Generate an epoch by shuffling sampling indices from the train dataset
988
- if data_args.streaming:
989
- train_dataset.set_epoch(epoch) # shuffle dataset
990
- train_loader = data_loader_streaming(train_dataset, train_batch_size)
991
- else:
992
- rng_dataset, input_rng = jax.random.split(rng_dataset)
993
- train_loader = data_loader(train_dataset, train_batch_size, rng=input_rng)
994
  # train
995
  for batch in tqdm(
996
  train_loader,
 
28
  import json
29
 
30
  import datasets
31
+ from datasets import Dataset
 
32
  from tqdm import tqdm
33
 
34
  import jax
 
39
  from flax.serialization import from_bytes, to_bytes
40
  from flax.jax_utils import unreplicate
41
  from flax.training import train_state
42
+ from flax.training.common_utils import get_metrics, onehot, shard_prng_key
43
  from transformers import (
44
  AutoTokenizer,
45
  HfArgumentParser,
 
48
 
49
  import wandb
50
 
51
+ from dalle_mini.data import Dataset
52
  from dalle_mini.model import CustomFlaxBartForConditionalGeneration
53
 
54
  logger = logging.getLogger(__name__)
 
119
  "help": "The name of the column in the datasets containing the image encodings."
120
  },
121
  )
122
+ dataset_repo_or_path: str = field(
123
  default=None,
124
  metadata={"help": "The dataset repository containing encoded files."},
125
  )
126
  train_file: Optional[str] = field(
127
+ default=None,
128
+ metadata={"help": "The input training data file (glob acceptable)."},
129
  )
130
  validation_file: Optional[str] = field(
131
  default=None,
132
+ metadata={"help": "An optional input evaluation data file (glob acceptable)."},
133
+ )
134
+ dataset_type: str = field(
135
+ default="datasets",
136
+ metadata={"help": "Either 🤗 'dataset' (default) or 'webdataset'."},
137
  )
138
  # data loading should not be a bottleneck so we use "streaming" mode by default
139
  streaming: bool = field(
 
179
  "help": "Overwrite the cached training and evaluation sets. Not used in streaming mode."
180
  },
181
  )
182
+ # default seed of None ensures we don't repeat the same items if script was interrupted during an epoch
183
+ seed_dataset: int = field(
184
+ default=None,
185
+ metadata={
186
+ "help": "Random seed for the dataset that will be set at the beginning of training."
187
+ },
188
+ )
189
 
190
  def __post_init__(self):
191
  if self.dataset_repo_or_path is None:
 
286
  "help": "Random seed for the model that will be set at the beginning of training."
287
  },
288
  )
 
 
 
 
 
 
 
289
 
290
  push_to_hub: bool = field(
291
  default=False,
 
329
  )
330
 
331
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
332
  def create_learning_rate_fn(
333
  num_warmup_steps: int,
334
  learning_rate: float,
 
410
  logger.info(f"Training/evaluation parameters {training_args}")
411
 
412
  # Load dataset
413
+ dataset = Dataset(
414
+ **data_args, do_train=training_args.do_train, do_eval=training_args.do_eval
 
 
 
 
 
 
 
 
 
 
415
  )
416
 
417
  # Set up wandb run
 
505
  use_fast=True,
506
  )
507
 
508
+ logger.info(f"TPUs: {jax.device_count()}")
509
  assert jax.device_count() == 8, "TPUs in use, please check running processes"
510
 
511
  # Preprocessing the datasets.
512
+ # We need to normalize and tokenize inputs and targets.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
513
 
514
+ dataset = dataset.preprocess(
515
+ tokenizer=tokenizer,
516
+ decoder_start_token_id=model.config.decoder_start_token_id,
517
+ normalize_text=model.config.normalize_text,
518
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
519
 
520
  # Initialize our training
521
  rng = jax.random.PRNGKey(training_args.seed_model)
 
528
  )
529
  batch_size_per_update = train_batch_size * training_args.gradient_accumulation_steps
530
  eval_batch_size = int(training_args.per_device_eval_batch_size) * jax.device_count()
531
+ len_train_dataset, len_eval_dataset = dataset.length
 
 
 
 
 
 
 
 
 
532
  steps_per_epoch = (
533
  len_train_dataset // train_batch_size if len_train_dataset is not None else None
534
  )
 
674
  # add interesting config parameters
675
  wandb.config.update(
676
  {
677
+ "len_train_dataset": len_train_dataset,
678
+ "len_eval_dataset": len_eval_dataset,
679
  "batch_size_per_update": batch_size_per_update,
680
  }
681
  )
 
687
  # ======================== Evaluating ==============================
688
  eval_metrics = []
689
  if training_args.do_eval:
690
+ eval_loader = dataset.dataloader("eval", eval_batch_size)
 
 
 
691
  eval_steps = (
692
  len_eval_dataset // eval_batch_size
693
  if len_eval_dataset is not None
 
802
  wandb_log({"train/epoch": epoch}, step=unreplicate(state.step))
803
 
804
  # Generate an epoch by shuffling sampling indices from the train dataset
805
+ train_loader = dataset.dataloader("train", train_batch_size)
 
 
 
 
 
806
  # train
807
  for batch in tqdm(
808
  train_loader,