add support for v3-32
Browse files- run_mlm_flax_stream.py +5 -25
run_mlm_flax_stream.py
CHANGED
@@ -262,29 +262,6 @@ class FlaxDataCollatorForLanguageModeling:
|
|
262 |
return inputs, labels
|
263 |
|
264 |
|
265 |
-
@dataclass
|
266 |
-
class SamplingArguments:
|
267 |
-
"""
|
268 |
-
Arguments pertaining to how to perform sampling of the dataset.
|
269 |
-
"""
|
270 |
-
|
271 |
-
perplexity_model: Optional[str] = field(
|
272 |
-
default="./es.arpa.bin", metadata={"help": "Path to KenLM model to use to get perplexity values."}
|
273 |
-
)
|
274 |
-
sampling_method: Optional[str] = field(
|
275 |
-
default=None, metadata={"help": "Sample using a 'step' or 'gaussian' perplexity function per document, or 'random'."}
|
276 |
-
)
|
277 |
-
sampling_factor: Optional[float] = field(
|
278 |
-
default=None, metadata={"help": "Sampling factor. Integers for step function, decimals for gaussian."}
|
279 |
-
)
|
280 |
-
boundaries: Optional[str] = field(
|
281 |
-
default="536394.99320948,662247.50212365,919250.87225178", metadata={"help": "Quartile boundaries"}
|
282 |
-
)
|
283 |
-
|
284 |
-
def __post_init__(self):
|
285 |
-
self.boundaries = [float(q.strip()) for q in self.boundaries.split(",")]
|
286 |
-
|
287 |
-
|
288 |
def generate_batch_splits(samples_idx: jnp.ndarray, batch_size: int) -> jnp.ndarray:
|
289 |
num_samples = len(samples_idx)
|
290 |
samples_to_remove = num_samples % batch_size
|
@@ -310,7 +287,9 @@ def advance_iter_and_group_samples(train_iterator, num_samples, max_seq_length):
|
|
310 |
i += len(tokenized_samples["input_ids"])
|
311 |
|
312 |
# concatenate tokenized samples to list
|
313 |
-
samples = {
|
|
|
|
|
314 |
|
315 |
# Concatenated tokens are split to lists of length `max_seq_length`.
|
316 |
# Note that remainedr of % max_seq_length are thrown away.
|
@@ -404,7 +383,7 @@ if __name__ == "__main__":
|
|
404 |
# or by passing the --help flag to this script.
|
405 |
# We now keep distinct sets of args, for a cleaner separation of concerns.
|
406 |
|
407 |
-
parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments
|
408 |
if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
|
409 |
# If we pass only one argument to the script and it's the path to a json file,
|
410 |
# let's parse it to get our arguments.
|
@@ -528,6 +507,7 @@ if __name__ == "__main__":
|
|
528 |
|
529 |
# Data collator
|
530 |
# This one will take care of randomly masking the tokens.
|
|
|
531 |
data_collator = FlaxDataCollatorForLanguageModeling(tokenizer=tokenizer, mlm_probability=data_args.mlm_probability)
|
532 |
|
533 |
# Initialize our training
|
|
|
262 |
return inputs, labels
|
263 |
|
264 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
265 |
def generate_batch_splits(samples_idx: jnp.ndarray, batch_size: int) -> jnp.ndarray:
|
266 |
num_samples = len(samples_idx)
|
267 |
samples_to_remove = num_samples % batch_size
|
|
|
287 |
i += len(tokenized_samples["input_ids"])
|
288 |
|
289 |
# concatenate tokenized samples to list
|
290 |
+
samples = {
|
291 |
+
k: samples[k] + tokenized_samples[k] for k in ["input_ids", "attention_mask", "special_tokens_mask"]
|
292 |
+
}
|
293 |
|
294 |
# Concatenated tokens are split to lists of length `max_seq_length`.
|
295 |
# Note that remainedr of % max_seq_length are thrown away.
|
|
|
383 |
# or by passing the --help flag to this script.
|
384 |
# We now keep distinct sets of args, for a cleaner separation of concerns.
|
385 |
|
386 |
+
parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments))
|
387 |
if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
|
388 |
# If we pass only one argument to the script and it's the path to a json file,
|
389 |
# let's parse it to get our arguments.
|
|
|
507 |
|
508 |
# Data collator
|
509 |
# This one will take care of randomly masking the tokens.
|
510 |
+
print("DATA COLLATOR")
|
511 |
data_collator = FlaxDataCollatorForLanguageModeling(tokenizer=tokenizer, mlm_probability=data_args.mlm_probability)
|
512 |
|
513 |
# Initialize our training
|