Spaces:
Runtime error
Runtime error
feat: save model frequently
Browse files- seq2seq/run_seq2seq_flax.py +15 -1
seq2seq/run_seq2seq_flax.py
CHANGED
@@ -231,6 +231,12 @@ class DataTrainingArguments:
|
|
231 |
log_model: bool = field(
|
232 |
default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
|
233 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
234 |
|
235 |
def __post_init__(self):
|
236 |
if self.dataset_name is None and self.train_file is None and self.validation_file is None:
|
@@ -340,7 +346,7 @@ def wandb_log(metrics, step=None, prefix=None):
|
|
340 |
if jax.process_index() == 0:
|
341 |
log_metrics = {f'{prefix}/{k}' if prefix is not None else k: jax.device_get(v) for k,v in metrics.items()}
|
342 |
if step is not None:
|
343 |
-
log_metrics
|
344 |
wandb.log(log_metrics)
|
345 |
|
346 |
|
@@ -795,6 +801,9 @@ def main():
|
|
795 |
|
796 |
if global_step % training_args.eval_steps == 0:
|
797 |
run_evaluation()
|
|
|
|
|
|
|
798 |
|
799 |
# log final train metrics
|
800 |
wandb_log(unreplicate(train_metric), step=global_step, prefix='train')
|
@@ -809,6 +818,9 @@ def main():
|
|
809 |
eval_metrics = run_evaluation()
|
810 |
|
811 |
# save checkpoint after each epoch and push checkpoint to the hub
|
|
|
|
|
|
|
812 |
if jax.process_index() == 0:
|
813 |
params = jax.device_get(jax.tree_map(lambda x: x[0], state.params))
|
814 |
|
@@ -821,6 +833,8 @@ def main():
|
|
821 |
# save to W&B
|
822 |
if data_args.log_model:
|
823 |
metadata = {'epoch': epoch+1, 'eval/loss': eval_metrics['loss']}
|
|
|
|
|
824 |
artifact = wandb.Artifact(
|
825 |
name=f"model-{wandb.run.id}", type="bart_model", metadata=metadata
|
826 |
)
|
|
|
231 |
log_model: bool = field(
|
232 |
default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
|
233 |
)
|
234 |
+
save_model_steps: Optional[int] = field(
|
235 |
+
default=3000, # about once every hour in our experiments
|
236 |
+
metadata={
|
237 |
+
"help": "For logging the model more frequently. Used only when `log_model` is set."
|
238 |
+
},
|
239 |
+
)
|
240 |
|
241 |
def __post_init__(self):
|
242 |
if self.dataset_name is None and self.train_file is None and self.validation_file is None:
|
|
|
346 |
if jax.process_index() == 0:
|
347 |
log_metrics = {f'{prefix}/{k}' if prefix is not None else k: jax.device_get(v) for k,v in metrics.items()}
|
348 |
if step is not None:
|
349 |
+
log_metrics['train/step'] = step
|
350 |
wandb.log(log_metrics)
|
351 |
|
352 |
|
|
|
801 |
|
802 |
if global_step % training_args.eval_steps == 0:
|
803 |
run_evaluation()
|
804 |
+
|
805 |
+
if global_step % training_args.save_model_steps == 0:
|
806 |
+
run_save_model(global_step, epoch)
|
807 |
|
808 |
# log final train metrics
|
809 |
wandb_log(unreplicate(train_metric), step=global_step, prefix='train')
|
|
|
818 |
eval_metrics = run_evaluation()
|
819 |
|
820 |
# save checkpoint after each epoch and push checkpoint to the hub
|
821 |
+
run_save_model(global_step, epoch, eval_metrics)
|
822 |
+
|
823 |
+
def run_save_model(step, epoch, eval_metrics=None):
|
824 |
if jax.process_index() == 0:
|
825 |
params = jax.device_get(jax.tree_map(lambda x: x[0], state.params))
|
826 |
|
|
|
833 |
# save to W&B
|
834 |
if data_args.log_model:
|
835 |
metadata = {'epoch': epoch+1, 'eval/loss': eval_metrics['loss']}
|
836 |
+
if eval_metrics is not None:
|
837 |
+
metadata['eval/loss'] = eval_metrics['loss']
|
838 |
artifact = wandb.Artifact(
|
839 |
name=f"model-{wandb.run.id}", type="bart_model", metadata=metadata
|
840 |
)
|