Spaces:
Runtime error
Runtime error
fix: output directory must exist
Browse files
seq2seq/run_seq2seq_flax.py
CHANGED
@@ -783,6 +783,12 @@ def main():
|
|
783 |
if jax.process_index() == 0:
|
784 |
params = jax.device_get(jax.tree_map(lambda x: x[0], state.params))
|
785 |
|
|
|
|
|
|
|
|
|
|
|
|
|
786 |
# save state
|
787 |
state = unreplicate(state)
|
788 |
with (Path(training_args.output_dir) / 'opt_state.msgpack').open('wb') as f:
|
@@ -790,12 +796,6 @@ def main():
|
|
790 |
with (Path(training_args.output_dir) / 'training_state.json').open('wb') as f:
|
791 |
json.dump({'step': state.step.item()}, f)
|
792 |
|
793 |
-
# save model locally
|
794 |
-
model.save_pretrained(
|
795 |
-
training_args.output_dir,
|
796 |
-
params=params,
|
797 |
-
)
|
798 |
-
|
799 |
# save to W&B
|
800 |
if data_args.log_model:
|
801 |
metadata = {'step': step, 'epoch': epoch}
|
|
|
783 |
if jax.process_index() == 0:
|
784 |
params = jax.device_get(jax.tree_map(lambda x: x[0], state.params))
|
785 |
|
786 |
+
# save model locally
|
787 |
+
model.save_pretrained(
|
788 |
+
training_args.output_dir,
|
789 |
+
params=params,
|
790 |
+
)
|
791 |
+
|
792 |
# save state
|
793 |
state = unreplicate(state)
|
794 |
with (Path(training_args.output_dir) / 'opt_state.msgpack').open('wb') as f:
|
|
|
796 |
with (Path(training_args.output_dir) / 'training_state.json').open('wb') as f:
|
797 |
json.dump({'step': state.step.item()}, f)
|
798 |
|
|
|
|
|
|
|
|
|
|
|
|
|
799 |
# save to W&B
|
800 |
if data_args.log_model:
|
801 |
metadata = {'step': step, 'epoch': epoch}
|