Spaces:
Runtime error
Runtime error
feat: update model config + save optim
Browse files- seq2seq/run_seq2seq_flax.py +16 -5
seq2seq/run_seq2seq_flax.py
CHANGED
@@ -44,6 +44,7 @@ import optax
|
|
44 |
import transformers
|
45 |
from filelock import FileLock
|
46 |
from flax import jax_utils, traverse_util
|
|
|
47 |
import flax.linen as nn
|
48 |
from flax.jax_utils import unreplicate
|
49 |
from flax.training import train_state
|
@@ -432,12 +433,13 @@ def main():
|
|
432 |
# Set up our new model config
|
433 |
config = BartConfig.from_pretrained(model_args.model_name_or_path)
|
434 |
config.tie_word_embeddings = False
|
435 |
-
config.decoder_start_token_id = BOS_TOKEN_ID
|
436 |
-
config.bos_token_id = BOS_TOKEN_ID # should not be used
|
437 |
config.pos_token_id = BOS_TOKEN_ID # should not be needed (as we generate until max_length)
|
438 |
config.eos_token_id = BOS_TOKEN_ID + 1 # unreachable
|
439 |
config.forced_bos_token_id = None # we don't need this token
|
440 |
config.forced_eos_token_id = None # we don't need this token
|
|
|
441 |
config.min_length = data_args.max_target_length
|
442 |
config.max_length = data_args.max_target_length
|
443 |
|
@@ -777,10 +779,17 @@ def main():
|
|
777 |
|
778 |
return eval_metrics
|
779 |
|
780 |
-
def run_save_model(step, epoch, eval_metrics=None):
|
781 |
if jax.process_index() == 0:
|
782 |
params = jax.device_get(jax.tree_map(lambda x: x[0], state.params))
|
783 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
784 |
# save model locally
|
785 |
model.save_pretrained(
|
786 |
training_args.output_dir,
|
@@ -797,6 +806,8 @@ def main():
|
|
797 |
)
|
798 |
artifact.add_file(str(Path(training_args.output_dir) / 'flax_model.msgpack'))
|
799 |
artifact.add_file(str(Path(training_args.output_dir) / 'config.json'))
|
|
|
|
|
800 |
wandb.run.log_artifact(artifact)
|
801 |
|
802 |
# save to the hub
|
@@ -833,7 +844,7 @@ def main():
|
|
833 |
run_evaluation()
|
834 |
|
835 |
if global_step % data_args.save_model_steps == 0:
|
836 |
-
run_save_model(global_step, epoch)
|
837 |
|
838 |
# log final train metrics
|
839 |
wandb_log(unreplicate(train_metric), step=global_step, prefix='train')
|
@@ -848,7 +859,7 @@ def main():
|
|
848 |
eval_metrics = run_evaluation()
|
849 |
|
850 |
# save checkpoint after each epoch and push checkpoint to the hub
|
851 |
-
run_save_model(global_step, epoch, eval_metrics)
|
852 |
|
853 |
|
854 |
# ======================== Prediction loop ==============================
|
|
|
44 |
import transformers
|
45 |
from filelock import FileLock
|
46 |
from flax import jax_utils, traverse_util
|
47 |
+
from flax.serialization import from_bytes, to_bytes
|
48 |
import flax.linen as nn
|
49 |
from flax.jax_utils import unreplicate
|
50 |
from flax.training import train_state
|
|
|
433 |
# Set up our new model config
|
434 |
config = BartConfig.from_pretrained(model_args.model_name_or_path)
|
435 |
config.tie_word_embeddings = False
|
436 |
+
config.decoder_start_token_id = BOS_TOKEN_ID # for first token
|
437 |
+
config.bos_token_id = BOS_TOKEN_ID # should not be used (due to forced_bos_token_id)
|
438 |
config.pos_token_id = BOS_TOKEN_ID # should not be needed (as we generate until max_length)
|
439 |
config.eos_token_id = BOS_TOKEN_ID + 1 # unreachable
|
440 |
config.forced_bos_token_id = None # we don't need this token
|
441 |
config.forced_eos_token_id = None # we don't need this token
|
442 |
+
config.force_bos_token_to_be_generated = False # otherwise it sets bos_token_id at loading
|
443 |
config.min_length = data_args.max_target_length
|
444 |
config.max_length = data_args.max_target_length
|
445 |
|
|
|
779 |
|
780 |
return eval_metrics
|
781 |
|
782 |
+
def run_save_model(state, step, epoch, eval_metrics=None):
|
783 |
if jax.process_index() == 0:
|
784 |
params = jax.device_get(jax.tree_map(lambda x: x[0], state.params))
|
785 |
|
786 |
+
# save state
|
787 |
+
state = unreplicate(state)
|
788 |
+
with (Path(training_args.output_dir) / 'opt_state.msgpack').open('wb') as f:
|
789 |
+
f.write(to_bytes(state.opt_state))
|
790 |
+
with (Path(training_args.output_dir) / 'training_state.json').open('wb') as f:
|
791 |
+
json.dump({'step': state.step.item()}, f)
|
792 |
+
|
793 |
# save model locally
|
794 |
model.save_pretrained(
|
795 |
training_args.output_dir,
|
|
|
806 |
)
|
807 |
artifact.add_file(str(Path(training_args.output_dir) / 'flax_model.msgpack'))
|
808 |
artifact.add_file(str(Path(training_args.output_dir) / 'config.json'))
|
809 |
+
artifact.add_file(str(Path(training_args.output_dir) / 'opt_state.msgpack'))
|
810 |
+
artifact.add_file(str(Path(training_args.output_dir) / 'training_state.json'))
|
811 |
wandb.run.log_artifact(artifact)
|
812 |
|
813 |
# save to the hub
|
|
|
844 |
run_evaluation()
|
845 |
|
846 |
if global_step % data_args.save_model_steps == 0:
|
847 |
+
run_save_model(state, global_step, epoch)
|
848 |
|
849 |
# log final train metrics
|
850 |
wandb_log(unreplicate(train_metric), step=global_step, prefix='train')
|
|
|
859 |
eval_metrics = run_evaluation()
|
860 |
|
861 |
# save checkpoint after each epoch and push checkpoint to the hub
|
862 |
+
run_save_model(state, global_step, epoch, eval_metrics)
|
863 |
|
864 |
|
865 |
# ======================== Prediction loop ==============================
|