boris commited on
Commit
4aced93
·
1 Parent(s): 28f08be

feat: restore state from checkpoint

Browse files

Former-commit-id: 1cf3567d3cbc24fb1b78684c35ca9a89e4fafedf

Files changed (1) hide show
  1. seq2seq/run_seq2seq_flax.py +33 -5
seq2seq/run_seq2seq_flax.py CHANGED
@@ -436,11 +436,24 @@ def main():
436
  # See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at
437
  # https://huggingface.co/docs/datasets/loading_datasets.html.
438
 
439
- # Load pretrained model and tokenizer
440
- tokenizer = AutoTokenizer.from_pretrained(
441
- model_args.model_name_or_path, cache_dir=model_args.cache_dir, use_fast=model_args.use_fast_tokenizer
442
- )
443
-
 
 
 
 
 
 
 
 
 
 
 
 
 
444
  if model_args.from_checkpoint is not None:
445
  artifact = wandb.run.use_artifact(model_args.from_checkpoint)
446
  artifact_dir = artifact.download()
@@ -455,6 +468,12 @@ def main():
455
  # used in the preprocessing function
456
  config = model.config
457
 
 
 
 
 
 
 
458
  else:
459
  base_model = FlaxAutoModelForSeq2SeqLM.from_pretrained(
460
  model_args.model_name_or_path, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype)
@@ -480,6 +499,12 @@ def main():
480
  model.params['model']['shared'] = base_model.params['model']['shared']
481
  del base_model
482
 
 
 
 
 
 
 
483
  print(f"TPUs: {jax.device_count()}")
484
  assert jax.device_count() == 8, "TPUs in use, please check running processes"
485
 
@@ -676,6 +701,9 @@ def main():
676
  grad_accum=jax.tree_map(jnp.zeros_like, model.params),
677
  optimizer_step=0,
678
  )
 
 
 
679
 
680
  # label smoothed cross entropy
681
  def loss_fn(logits, labels):
 
436
  # See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at
437
  # https://huggingface.co/docs/datasets/loading_datasets.html.
438
 
439
+ # Set up items to load or create
440
+ tokenizer = None
441
+ artifact_dir = None
442
+
443
+ def restore_state(state, artifact_dir):
444
+ # restore optimizer state
445
+ if (Path(artifact_dir) / 'opt_state.msgpack').exists():
446
+ with (Path(artifact_dir) / 'opt_state.msgpack').open('rb') as f:
447
+ opt_state = from_bytes(state.opt_state, f.read())
448
+
449
+ # restore steps
450
+ if (Path(artifact_dir) / 'training_state.json').exists():
451
+ with (Path(artifact_dir) / 'opt_state.msgpack').open('r') as f:
452
+ training_state = json.load(f)
453
+ step = training_state['step']
454
+ optimizer_step = step // training_args.gradient_accumulation_steps
455
+ state.replace(step=step, optimizer_step=optimizer_step)
456
+
457
  if model_args.from_checkpoint is not None:
458
  artifact = wandb.run.use_artifact(model_args.from_checkpoint)
459
  artifact_dir = artifact.download()
 
468
  # used in the preprocessing function
469
  config = model.config
470
 
471
+ # load tokenizer if present
472
+ if (Path(artifact_dir) / 'tokenizer_config.json').exists():
473
+ tokenizer = AutoTokenizer.from_pretrained(
474
+ model_args.model_name_or_path, cache_dir=model_args.cache_dir, use_fast=model_args.use_fast_tokenizer
475
+ )
476
+
477
  else:
478
  base_model = FlaxAutoModelForSeq2SeqLM.from_pretrained(
479
  model_args.model_name_or_path, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype)
 
499
  model.params['model']['shared'] = base_model.params['model']['shared']
500
  del base_model
501
 
502
+ # Load tokenizer if it has not been set
503
+ if tokenizer is None:
504
+ tokenizer = AutoTokenizer.from_pretrained(
505
+ model_args.model_name_or_path, cache_dir=model_args.cache_dir, use_fast=model_args.use_fast_tokenizer
506
+ )
507
+
508
  print(f"TPUs: {jax.device_count()}")
509
  assert jax.device_count() == 8, "TPUs in use, please check running processes"
510
 
 
701
  grad_accum=jax.tree_map(jnp.zeros_like, model.params),
702
  optimizer_step=0,
703
  )
704
+ if model_args.from_checkpoint is not None:
705
+ # restore optimizer state, step and optimizer_step
706
+ restore_state(state)
707
 
708
  # label smoothed cross entropy
709
  def loss_fn(logits, labels):