add max_length
Browse files- run_mlm_flax_stream.py +2 -1
run_mlm_flax_stream.py
CHANGED
@@ -308,7 +308,7 @@ def advance_iter_and_group_samples(train_iterator, num_samples, max_seq_length):
|
|
308 |
while i < num_total_tokens:
|
309 |
tokenized_samples = next(train_iterator)
|
310 |
i += len(tokenized_samples["input_ids"])
|
311 |
-
|
312 |
# concatenate tokenized samples to list
|
313 |
samples = {k: samples[k] + tokenized_samples[k] for k in tokenized_samples.keys()}
|
314 |
|
@@ -505,6 +505,7 @@ if __name__ == "__main__":
|
|
505 |
return tokenizer(
|
506 |
examples[data_args.text_column_name],
|
507 |
max_length=512,
|
|
|
508 |
return_special_tokens_mask=True
|
509 |
)
|
510 |
|
|
|
308 |
while i < num_total_tokens:
|
309 |
tokenized_samples = next(train_iterator)
|
310 |
i += len(tokenized_samples["input_ids"])
|
311 |
+
print(tokenized_samples)
|
312 |
# concatenate tokenized samples to list
|
313 |
samples = {k: samples[k] + tokenized_samples[k] for k in tokenized_samples.keys()}
|
314 |
|
|
|
505 |
return tokenizer(
|
506 |
examples[data_args.text_column_name],
|
507 |
max_length=512,
|
508 |
+
truncation=True,
|
509 |
return_special_tokens_mask=True
|
510 |
)
|
511 |
|