boris commited on
Commit
754f876
·
1 Parent(s): 5e244d0

feat: save model frequently

Browse files
Files changed (1) hide show
  1. seq2seq/run_seq2seq_flax.py +15 -1
seq2seq/run_seq2seq_flax.py CHANGED
@@ -231,6 +231,12 @@ class DataTrainingArguments:
231
  log_model: bool = field(
232
  default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
233
  )
 
 
 
 
 
 
234
 
235
  def __post_init__(self):
236
  if self.dataset_name is None and self.train_file is None and self.validation_file is None:
@@ -340,7 +346,7 @@ def wandb_log(metrics, step=None, prefix=None):
340
  if jax.process_index() == 0:
341
  log_metrics = {f'{prefix}/{k}' if prefix is not None else k: jax.device_get(v) for k,v in metrics.items()}
342
  if step is not None:
343
- log_metrics = {**log_metrics, 'train/step': step}
344
  wandb.log(log_metrics)
345
 
346
 
@@ -795,6 +801,9 @@ def main():
795
 
796
  if global_step % training_args.eval_steps == 0:
797
  run_evaluation()
 
 
 
798
 
799
  # log final train metrics
800
  wandb_log(unreplicate(train_metric), step=global_step, prefix='train')
@@ -809,6 +818,9 @@ def main():
809
  eval_metrics = run_evaluation()
810
 
811
  # save checkpoint after each epoch and push checkpoint to the hub
 
 
 
812
  if jax.process_index() == 0:
813
  params = jax.device_get(jax.tree_map(lambda x: x[0], state.params))
814
 
@@ -821,6 +833,8 @@ def main():
821
  # save to W&B
822
  if data_args.log_model:
823
  metadata = {'epoch': epoch+1, 'eval/loss': eval_metrics['loss']}
 
 
824
  artifact = wandb.Artifact(
825
  name=f"model-{wandb.run.id}", type="bart_model", metadata=metadata
826
  )
 
231
  log_model: bool = field(
232
  default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
233
  )
234
+ save_model_steps: Optional[int] = field(
235
+ default=3000, # about once every hour in our experiments
236
+ metadata={
237
+ "help": "For logging the model more frequently. Used only when `log_model` is set."
238
+ },
239
+ )
240
 
241
  def __post_init__(self):
242
  if self.dataset_name is None and self.train_file is None and self.validation_file is None:
 
346
  if jax.process_index() == 0:
347
  log_metrics = {f'{prefix}/{k}' if prefix is not None else k: jax.device_get(v) for k,v in metrics.items()}
348
  if step is not None:
349
+ log_metrics['train/step'] = step
350
  wandb.log(log_metrics)
351
 
352
 
 
801
 
802
  if global_step % training_args.eval_steps == 0:
803
  run_evaluation()
804
+
805
+ if global_step % training_args.save_model_steps == 0:
806
+ run_save_model(global_step, epoch)
807
 
808
  # log final train metrics
809
  wandb_log(unreplicate(train_metric), step=global_step, prefix='train')
 
818
  eval_metrics = run_evaluation()
819
 
820
  # save checkpoint after each epoch and push checkpoint to the hub
821
+ run_save_model(global_step, epoch, eval_metrics)
822
+
823
+ def run_save_model(step, epoch, eval_metrics=None):
824
  if jax.process_index() == 0:
825
  params = jax.device_get(jax.tree_map(lambda x: x[0], state.params))
826
 
 
833
  # save to W&B
834
  if data_args.log_model:
835
  metadata = {'epoch': epoch+1, 'eval/loss': eval_metrics['loss']}
836
+ if eval_metrics is not None:
837
+ metadata['eval/loss'] = eval_metrics['loss']
838
  artifact = wandb.Artifact(
839
  name=f"model-{wandb.run.id}", type="bart_model", metadata=metadata
840
  )