Spaces:
Runtime error
Runtime error
feat: restore state from checkpoint
Browse filesFormer-commit-id: 1cf3567d3cbc24fb1b78684c35ca9a89e4fafedf
- seq2seq/run_seq2seq_flax.py +33 -5
seq2seq/run_seq2seq_flax.py
CHANGED
@@ -436,11 +436,24 @@ def main():
|
|
436 |
# See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at
|
437 |
# https://huggingface.co/docs/datasets/loading_datasets.html.
|
438 |
|
439 |
-
#
|
440 |
-
tokenizer =
|
441 |
-
|
442 |
-
|
443 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
444 |
if model_args.from_checkpoint is not None:
|
445 |
artifact = wandb.run.use_artifact(model_args.from_checkpoint)
|
446 |
artifact_dir = artifact.download()
|
@@ -455,6 +468,12 @@ def main():
|
|
455 |
# used in the preprocessing function
|
456 |
config = model.config
|
457 |
|
|
|
|
|
|
|
|
|
|
|
|
|
458 |
else:
|
459 |
base_model = FlaxAutoModelForSeq2SeqLM.from_pretrained(
|
460 |
model_args.model_name_or_path, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype)
|
@@ -480,6 +499,12 @@ def main():
|
|
480 |
model.params['model']['shared'] = base_model.params['model']['shared']
|
481 |
del base_model
|
482 |
|
|
|
|
|
|
|
|
|
|
|
|
|
483 |
print(f"TPUs: {jax.device_count()}")
|
484 |
assert jax.device_count() == 8, "TPUs in use, please check running processes"
|
485 |
|
@@ -676,6 +701,9 @@ def main():
|
|
676 |
grad_accum=jax.tree_map(jnp.zeros_like, model.params),
|
677 |
optimizer_step=0,
|
678 |
)
|
|
|
|
|
|
|
679 |
|
680 |
# label smoothed cross entropy
|
681 |
def loss_fn(logits, labels):
|
|
|
436 |
# See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at
|
437 |
# https://huggingface.co/docs/datasets/loading_datasets.html.
|
438 |
|
439 |
+
# Set up items to load or create
|
440 |
+
tokenizer = None
|
441 |
+
artifact_dir = None
|
442 |
+
|
443 |
+
def restore_state(state, artifact_dir):
|
444 |
+
# restore optimizer state
|
445 |
+
if (Path(artifact_dir) / 'opt_state.msgpack').exists():
|
446 |
+
with (Path(artifact_dir) / 'opt_state.msgpack').open('rb') as f:
|
447 |
+
opt_state = from_bytes(state.opt_state, f.read())
|
448 |
+
|
449 |
+
# restore steps
|
450 |
+
if (Path(artifact_dir) / 'training_state.json').exists():
|
451 |
+
with (Path(artifact_dir) / 'opt_state.msgpack').open('r') as f:
|
452 |
+
training_state = json.load(f)
|
453 |
+
step = training_state['step']
|
454 |
+
optimizer_step = step // training_args.gradient_accumulation_steps
|
455 |
+
state.replace(step=step, optimizer_step=optimizer_step)
|
456 |
+
|
457 |
if model_args.from_checkpoint is not None:
|
458 |
artifact = wandb.run.use_artifact(model_args.from_checkpoint)
|
459 |
artifact_dir = artifact.download()
|
|
|
468 |
# used in the preprocessing function
|
469 |
config = model.config
|
470 |
|
471 |
+
# load tokenizer if present
|
472 |
+
if (Path(artifact_dir) / 'tokenizer_config.json').exists():
|
473 |
+
tokenizer = AutoTokenizer.from_pretrained(
|
474 |
+
model_args.model_name_or_path, cache_dir=model_args.cache_dir, use_fast=model_args.use_fast_tokenizer
|
475 |
+
)
|
476 |
+
|
477 |
else:
|
478 |
base_model = FlaxAutoModelForSeq2SeqLM.from_pretrained(
|
479 |
model_args.model_name_or_path, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype)
|
|
|
499 |
model.params['model']['shared'] = base_model.params['model']['shared']
|
500 |
del base_model
|
501 |
|
502 |
+
# Load tokenizer if it has not been set
|
503 |
+
if tokenizer is None:
|
504 |
+
tokenizer = AutoTokenizer.from_pretrained(
|
505 |
+
model_args.model_name_or_path, cache_dir=model_args.cache_dir, use_fast=model_args.use_fast_tokenizer
|
506 |
+
)
|
507 |
+
|
508 |
print(f"TPUs: {jax.device_count()}")
|
509 |
assert jax.device_count() == 8, "TPUs in use, please check running processes"
|
510 |
|
|
|
701 |
grad_accum=jax.tree_map(jnp.zeros_like, model.params),
|
702 |
optimizer_step=0,
|
703 |
)
|
704 |
+
if model_args.from_checkpoint is not None:
|
705 |
+
# restore optimizer state, step and optimizer_step
|
706 |
+
restore_state(state)
|
707 |
|
708 |
# label smoothed cross entropy
|
709 |
def loss_fn(logits, labels):
|