boris commited on
Commit
a30dbd3
·
1 Parent(s): 5aaf9df

feat: update model config + save optim

Browse files
Files changed (1) hide show
  1. seq2seq/run_seq2seq_flax.py +16 -5
seq2seq/run_seq2seq_flax.py CHANGED
@@ -44,6 +44,7 @@ import optax
44
  import transformers
45
  from filelock import FileLock
46
  from flax import jax_utils, traverse_util
 
47
  import flax.linen as nn
48
  from flax.jax_utils import unreplicate
49
  from flax.training import train_state
@@ -432,12 +433,13 @@ def main():
432
  # Set up our new model config
433
  config = BartConfig.from_pretrained(model_args.model_name_or_path)
434
  config.tie_word_embeddings = False
435
- config.decoder_start_token_id = BOS_TOKEN_ID
436
- config.bos_token_id = BOS_TOKEN_ID # should not be used
437
  config.pos_token_id = BOS_TOKEN_ID # should not be needed (as we generate until max_length)
438
  config.eos_token_id = BOS_TOKEN_ID + 1 # unreachable
439
  config.forced_bos_token_id = None # we don't need this token
440
  config.forced_eos_token_id = None # we don't need this token
 
441
  config.min_length = data_args.max_target_length
442
  config.max_length = data_args.max_target_length
443
 
@@ -777,10 +779,17 @@ def main():
777
 
778
  return eval_metrics
779
 
780
- def run_save_model(step, epoch, eval_metrics=None):
781
  if jax.process_index() == 0:
782
  params = jax.device_get(jax.tree_map(lambda x: x[0], state.params))
783
 
 
 
 
 
 
 
 
784
  # save model locally
785
  model.save_pretrained(
786
  training_args.output_dir,
@@ -797,6 +806,8 @@ def main():
797
  )
798
  artifact.add_file(str(Path(training_args.output_dir) / 'flax_model.msgpack'))
799
  artifact.add_file(str(Path(training_args.output_dir) / 'config.json'))
 
 
800
  wandb.run.log_artifact(artifact)
801
 
802
  # save to the hub
@@ -833,7 +844,7 @@ def main():
833
  run_evaluation()
834
 
835
  if global_step % data_args.save_model_steps == 0:
836
- run_save_model(global_step, epoch)
837
 
838
  # log final train metrics
839
  wandb_log(unreplicate(train_metric), step=global_step, prefix='train')
@@ -848,7 +859,7 @@ def main():
848
  eval_metrics = run_evaluation()
849
 
850
  # save checkpoint after each epoch and push checkpoint to the hub
851
- run_save_model(global_step, epoch, eval_metrics)
852
 
853
 
854
  # ======================== Prediction loop ==============================
 
44
  import transformers
45
  from filelock import FileLock
46
  from flax import jax_utils, traverse_util
47
+ from flax.serialization import from_bytes, to_bytes
48
  import flax.linen as nn
49
  from flax.jax_utils import unreplicate
50
  from flax.training import train_state
 
433
  # Set up our new model config
434
  config = BartConfig.from_pretrained(model_args.model_name_or_path)
435
  config.tie_word_embeddings = False
436
+ config.decoder_start_token_id = BOS_TOKEN_ID # for first token
437
+ config.bos_token_id = BOS_TOKEN_ID # should not be used (due to forced_bos_token_id)
438
  config.pos_token_id = BOS_TOKEN_ID # should not be needed (as we generate until max_length)
439
  config.eos_token_id = BOS_TOKEN_ID + 1 # unreachable
440
  config.forced_bos_token_id = None # we don't need this token
441
  config.forced_eos_token_id = None # we don't need this token
442
+ config.force_bos_token_to_be_generated = False # otherwise it sets bos_token_id at loading
443
  config.min_length = data_args.max_target_length
444
  config.max_length = data_args.max_target_length
445
 
 
779
 
780
  return eval_metrics
781
 
782
+ def run_save_model(state, step, epoch, eval_metrics=None):
783
  if jax.process_index() == 0:
784
  params = jax.device_get(jax.tree_map(lambda x: x[0], state.params))
785
 
786
+ # save state
787
+ state = unreplicate(state)
788
+ with (Path(training_args.output_dir) / 'opt_state.msgpack').open('wb') as f:
789
+ f.write(to_bytes(state.opt_state))
790
+ with (Path(training_args.output_dir) / 'training_state.json').open('wb') as f:
791
+ json.dump({'step': state.step.item()}, f)
792
+
793
  # save model locally
794
  model.save_pretrained(
795
  training_args.output_dir,
 
806
  )
807
  artifact.add_file(str(Path(training_args.output_dir) / 'flax_model.msgpack'))
808
  artifact.add_file(str(Path(training_args.output_dir) / 'config.json'))
809
+ artifact.add_file(str(Path(training_args.output_dir) / 'opt_state.msgpack'))
810
+ artifact.add_file(str(Path(training_args.output_dir) / 'training_state.json'))
811
  wandb.run.log_artifact(artifact)
812
 
813
  # save to the hub
 
844
  run_evaluation()
845
 
846
  if global_step % data_args.save_model_steps == 0:
847
+ run_save_model(state, global_step, epoch)
848
 
849
  # log final train metrics
850
  wandb_log(unreplicate(train_metric), step=global_step, prefix='train')
 
859
  eval_metrics = run_evaluation()
860
 
861
  # save checkpoint after each epoch and push checkpoint to the hub
862
+ run_save_model(state, global_step, epoch, eval_metrics)
863
 
864
 
865
  # ======================== Prediction loop ==============================