Spaces:
Runtime error
Runtime error
Merge pull request #25 from borisdayma/fix-config
Browse files- seq2seq/do_big_run.sh +5 -5
- seq2seq/do_small_run.sh +3 -3
- seq2seq/run_seq2seq_flax.py +21 -11
seq2seq/do_big_run.sh
CHANGED
@@ -1,16 +1,16 @@
|
|
1 |
python run_seq2seq_flax.py \
|
2 |
--max_source_length 128 \
|
3 |
-
--train_file /data/CC12M/encoded-small-train.tsv \
|
4 |
-
--validation_file /data/CC12M/encoded-small-valid.tsv \
|
5 |
--output_dir output \
|
6 |
--per_device_train_batch_size 56 \
|
7 |
--per_device_eval_batch_size 56 \
|
8 |
--preprocessing_num_workers 80 \
|
9 |
-
--warmup_steps
|
10 |
--gradient_accumulation_steps 8 \
|
11 |
--do_train \
|
12 |
--do_eval \
|
13 |
--adafactor \
|
14 |
-
--num_train_epochs
|
15 |
--log_model \
|
16 |
-
--learning_rate 0.
|
|
|
1 |
python run_seq2seq_flax.py \
|
2 |
--max_source_length 128 \
|
3 |
+
--train_file /data/CC12M/encoded-small-train.tsv \
|
4 |
+
--validation_file /data/CC12M/encoded-small-valid.tsv \
|
5 |
--output_dir output \
|
6 |
--per_device_train_batch_size 56 \
|
7 |
--per_device_eval_batch_size 56 \
|
8 |
--preprocessing_num_workers 80 \
|
9 |
+
--warmup_steps 250 \
|
10 |
--gradient_accumulation_steps 8 \
|
11 |
--do_train \
|
12 |
--do_eval \
|
13 |
--adafactor \
|
14 |
+
--num_train_epochs 6 \
|
15 |
--log_model \
|
16 |
+
--learning_rate 0.005
|
seq2seq/do_small_run.sh
CHANGED
@@ -1,7 +1,7 @@
|
|
1 |
python run_seq2seq_flax.py \
|
2 |
--max_source_length 128 \
|
3 |
-
--train_file /data/CC12M/encoded-small-train.tsv \
|
4 |
-
--validation_file /data/CC12M/encoded-small-valid.tsv \
|
5 |
--output_dir output \
|
6 |
--per_device_train_batch_size 56 \
|
7 |
--per_device_eval_batch_size 56 \
|
@@ -13,4 +13,4 @@ python run_seq2seq_flax.py \
|
|
13 |
--adafactor \
|
14 |
--num_train_epochs 1 \
|
15 |
--max_train_samples 20000 \
|
16 |
-
--learning_rate 0.
|
|
|
1 |
python run_seq2seq_flax.py \
|
2 |
--max_source_length 128 \
|
3 |
+
--train_file /data/CC12M/encoded-small-train.tsv \
|
4 |
+
--validation_file /data/CC12M/encoded-small-valid.tsv \
|
5 |
--output_dir output \
|
6 |
--per_device_train_batch_size 56 \
|
7 |
--per_device_eval_batch_size 56 \
|
|
|
13 |
--adafactor \
|
14 |
--num_train_epochs 1 \
|
15 |
--max_train_samples 20000 \
|
16 |
+
--learning_rate 0.005
|
seq2seq/run_seq2seq_flax.py
CHANGED
@@ -31,6 +31,7 @@ from dataclasses import dataclass, field
|
|
31 |
from functools import partial
|
32 |
from pathlib import Path
|
33 |
from typing import Callable, Optional
|
|
|
34 |
|
35 |
import datasets
|
36 |
import nltk # Here to have a nice missing dependency error message early on
|
@@ -44,6 +45,7 @@ import optax
|
|
44 |
import transformers
|
45 |
from filelock import FileLock
|
46 |
from flax import jax_utils, traverse_util
|
|
|
47 |
import flax.linen as nn
|
48 |
from flax.jax_utils import unreplicate
|
49 |
from flax.training import train_state
|
@@ -282,8 +284,6 @@ class CustomFlaxBartModule(FlaxBartModule):
|
|
282 |
# the decoder has a different config
|
283 |
decoder_config = BartConfig(self.config.to_dict())
|
284 |
decoder_config.max_position_embeddings = OUTPUT_LENGTH
|
285 |
-
decoder_config.min_length = OUTPUT_LENGTH
|
286 |
-
decoder_config.max_length = OUTPUT_LENGTH
|
287 |
decoder_config.vocab_size = OUTPUT_VOCAB_SIZE
|
288 |
self.decoder = FlaxBartDecoder(decoder_config, dtype=self.dtype, embed_tokens=self.decoder_embed)
|
289 |
|
@@ -363,7 +363,7 @@ def main():
|
|
363 |
else:
|
364 |
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
|
365 |
|
366 |
-
logger.warning(f"eval_steps has been manually hardcoded") # TODO: remove it later, convenient for now
|
367 |
training_args.eval_steps = 400
|
368 |
|
369 |
if (
|
@@ -412,7 +412,7 @@ def main():
|
|
412 |
# (the dataset will be downloaded automatically from the datasets Hub).
|
413 |
#
|
414 |
data_files = {}
|
415 |
-
logger.warning(f"Datasets path have been manually hardcoded") # TODO: remove it later, convenient for now
|
416 |
if data_args.train_file is not None:
|
417 |
data_files["train"] = ["/data/CC3M/training-encoded.tsv", "/data/CC12M/encoded-train.tsv"]
|
418 |
if data_args.validation_file is not None:
|
@@ -434,14 +434,15 @@ def main():
|
|
434 |
# Set up our new model config
|
435 |
config = BartConfig.from_pretrained(model_args.model_name_or_path)
|
436 |
config.tie_word_embeddings = False
|
437 |
-
config.decoder_start_token_id = BOS_TOKEN_ID
|
438 |
-
config.bos_token_id = BOS_TOKEN_ID # should not be used
|
439 |
config.pos_token_id = BOS_TOKEN_ID # should not be needed (as we generate until max_length)
|
440 |
config.eos_token_id = BOS_TOKEN_ID + 1 # unreachable
|
441 |
config.forced_bos_token_id = None # we don't need this token
|
442 |
config.forced_eos_token_id = None # we don't need this token
|
443 |
-
|
444 |
-
|
|
|
445 |
|
446 |
print(f"TPUs: {jax.device_count()}")
|
447 |
assert jax.device_count() == 8, "TPUs in use, please check running processes"
|
@@ -779,7 +780,7 @@ def main():
|
|
779 |
|
780 |
return eval_metrics
|
781 |
|
782 |
-
def run_save_model(step, epoch, eval_metrics=None):
|
783 |
if jax.process_index() == 0:
|
784 |
params = jax.device_get(jax.tree_map(lambda x: x[0], state.params))
|
785 |
|
@@ -789,6 +790,13 @@ def main():
|
|
789 |
params=params,
|
790 |
)
|
791 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
792 |
# save to W&B
|
793 |
if data_args.log_model:
|
794 |
metadata = {'step': step, 'epoch': epoch}
|
@@ -799,6 +807,8 @@ def main():
|
|
799 |
)
|
800 |
artifact.add_file(str(Path(training_args.output_dir) / 'flax_model.msgpack'))
|
801 |
artifact.add_file(str(Path(training_args.output_dir) / 'config.json'))
|
|
|
|
|
802 |
wandb.run.log_artifact(artifact)
|
803 |
|
804 |
# save to the hub
|
@@ -835,7 +845,7 @@ def main():
|
|
835 |
run_evaluation()
|
836 |
|
837 |
if global_step % data_args.save_model_steps == 0:
|
838 |
-
run_save_model(global_step, epoch)
|
839 |
|
840 |
# log final train metrics
|
841 |
wandb_log(unreplicate(train_metric), step=global_step, prefix='train')
|
@@ -850,7 +860,7 @@ def main():
|
|
850 |
eval_metrics = run_evaluation()
|
851 |
|
852 |
# save checkpoint after each epoch and push checkpoint to the hub
|
853 |
-
run_save_model(global_step, epoch, eval_metrics)
|
854 |
|
855 |
|
856 |
# ======================== Prediction loop ==============================
|
|
|
31 |
from functools import partial
|
32 |
from pathlib import Path
|
33 |
from typing import Callable, Optional
|
34 |
+
import json
|
35 |
|
36 |
import datasets
|
37 |
import nltk # Here to have a nice missing dependency error message early on
|
|
|
45 |
import transformers
|
46 |
from filelock import FileLock
|
47 |
from flax import jax_utils, traverse_util
|
48 |
+
from flax.serialization import from_bytes, to_bytes
|
49 |
import flax.linen as nn
|
50 |
from flax.jax_utils import unreplicate
|
51 |
from flax.training import train_state
|
|
|
284 |
# the decoder has a different config
|
285 |
decoder_config = BartConfig(self.config.to_dict())
|
286 |
decoder_config.max_position_embeddings = OUTPUT_LENGTH
|
|
|
|
|
287 |
decoder_config.vocab_size = OUTPUT_VOCAB_SIZE
|
288 |
self.decoder = FlaxBartDecoder(decoder_config, dtype=self.dtype, embed_tokens=self.decoder_embed)
|
289 |
|
|
|
363 |
else:
|
364 |
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
|
365 |
|
366 |
+
logger.warning(f"WARNING: eval_steps has been manually hardcoded") # TODO: remove it later, convenient for now
|
367 |
training_args.eval_steps = 400
|
368 |
|
369 |
if (
|
|
|
412 |
# (the dataset will be downloaded automatically from the datasets Hub).
|
413 |
#
|
414 |
data_files = {}
|
415 |
+
logger.warning(f"WARNING: Datasets path have been manually hardcoded") # TODO: remove it later, convenient for now
|
416 |
if data_args.train_file is not None:
|
417 |
data_files["train"] = ["/data/CC3M/training-encoded.tsv", "/data/CC12M/encoded-train.tsv"]
|
418 |
if data_args.validation_file is not None:
|
|
|
434 |
# Set up our new model config
|
435 |
config = BartConfig.from_pretrained(model_args.model_name_or_path)
|
436 |
config.tie_word_embeddings = False
|
437 |
+
config.decoder_start_token_id = BOS_TOKEN_ID # for first token
|
438 |
+
config.bos_token_id = BOS_TOKEN_ID # should not be used (due to forced_bos_token_id)
|
439 |
config.pos_token_id = BOS_TOKEN_ID # should not be needed (as we generate until max_length)
|
440 |
config.eos_token_id = BOS_TOKEN_ID + 1 # unreachable
|
441 |
config.forced_bos_token_id = None # we don't need this token
|
442 |
config.forced_eos_token_id = None # we don't need this token
|
443 |
+
config.force_bos_token_to_be_generated = False # otherwise it sets bos_token_id at loading
|
444 |
+
config.min_length = data_args.max_target_length
|
445 |
+
config.max_length = data_args.max_target_length
|
446 |
|
447 |
print(f"TPUs: {jax.device_count()}")
|
448 |
assert jax.device_count() == 8, "TPUs in use, please check running processes"
|
|
|
780 |
|
781 |
return eval_metrics
|
782 |
|
783 |
+
def run_save_model(state, step, epoch, eval_metrics=None):
|
784 |
if jax.process_index() == 0:
|
785 |
params = jax.device_get(jax.tree_map(lambda x: x[0], state.params))
|
786 |
|
|
|
790 |
params=params,
|
791 |
)
|
792 |
|
793 |
+
# save state
|
794 |
+
state = unreplicate(state)
|
795 |
+
with (Path(training_args.output_dir) / 'opt_state.msgpack').open('wb') as f:
|
796 |
+
f.write(to_bytes(state.opt_state))
|
797 |
+
with (Path(training_args.output_dir) / 'training_state.json').open('w') as f:
|
798 |
+
json.dump({'step': state.step.item()}, f)
|
799 |
+
|
800 |
# save to W&B
|
801 |
if data_args.log_model:
|
802 |
metadata = {'step': step, 'epoch': epoch}
|
|
|
807 |
)
|
808 |
artifact.add_file(str(Path(training_args.output_dir) / 'flax_model.msgpack'))
|
809 |
artifact.add_file(str(Path(training_args.output_dir) / 'config.json'))
|
810 |
+
artifact.add_file(str(Path(training_args.output_dir) / 'opt_state.msgpack'))
|
811 |
+
artifact.add_file(str(Path(training_args.output_dir) / 'training_state.json'))
|
812 |
wandb.run.log_artifact(artifact)
|
813 |
|
814 |
# save to the hub
|
|
|
845 |
run_evaluation()
|
846 |
|
847 |
if global_step % data_args.save_model_steps == 0:
|
848 |
+
run_save_model(state, global_step, epoch)
|
849 |
|
850 |
# log final train metrics
|
851 |
wandb_log(unreplicate(train_metric), step=global_step, prefix='train')
|
|
|
860 |
eval_metrics = run_evaluation()
|
861 |
|
862 |
# save checkpoint after each epoch and push checkpoint to the hub
|
863 |
+
run_save_model(state, global_step, epoch, eval_metrics)
|
864 |
|
865 |
|
866 |
# ======================== Prediction loop ==============================
|