Spaces:
Runtime error
Runtime error
Merge pull request #32 from borisdayma/feat-model
Browse filesfeat: save and restore checkpoints
Former-commit-id: 6254697762481523764fcb4c8856e63203d2f117
- seq2seq/run_seq2seq_flax.py +55 -12
seq2seq/run_seq2seq_flax.py
CHANGED
@@ -271,6 +271,10 @@ class TrainState(train_state.TrainState):
|
|
271 |
|
272 |
class CustomFlaxBartModule(FlaxBartModule):
|
273 |
def setup(self):
|
|
|
|
|
|
|
|
|
274 |
# we keep shared to easily load pre-trained weights
|
275 |
self.shared = nn.Embed(
|
276 |
self.config.vocab_size,
|
@@ -280,7 +284,7 @@ class CustomFlaxBartModule(FlaxBartModule):
|
|
280 |
)
|
281 |
# a separate embedding is used for the decoder
|
282 |
self.decoder_embed = nn.Embed(
|
283 |
-
|
284 |
self.config.d_model,
|
285 |
embedding_init=jax.nn.initializers.normal(self.config.init_std, self.dtype),
|
286 |
dtype=self.dtype,
|
@@ -289,20 +293,23 @@ class CustomFlaxBartModule(FlaxBartModule):
|
|
289 |
|
290 |
# the decoder has a different config
|
291 |
decoder_config = BartConfig(self.config.to_dict())
|
292 |
-
decoder_config.max_position_embeddings =
|
293 |
-
decoder_config.vocab_size =
|
294 |
self.decoder = FlaxBartDecoder(decoder_config, dtype=self.dtype, embed_tokens=self.decoder_embed)
|
295 |
|
296 |
class CustomFlaxBartForConditionalGenerationModule(FlaxBartForConditionalGenerationModule):
|
297 |
def setup(self):
|
|
|
|
|
|
|
298 |
self.model = CustomFlaxBartModule(config=self.config, dtype=self.dtype)
|
299 |
self.lm_head = nn.Dense(
|
300 |
-
|
301 |
use_bias=False,
|
302 |
dtype=self.dtype,
|
303 |
kernel_init=jax.nn.initializers.normal(self.config.init_std, self.dtype),
|
304 |
)
|
305 |
-
self.final_logits_bias = self.param("final_logits_bias", self.bias_init, (1,
|
306 |
|
307 |
class CustomFlaxBartForConditionalGeneration(FlaxBartForConditionalGeneration):
|
308 |
module_class = CustomFlaxBartForConditionalGenerationModule
|
@@ -429,11 +436,24 @@ def main():
|
|
429 |
# See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at
|
430 |
# https://huggingface.co/docs/datasets/loading_datasets.html.
|
431 |
|
432 |
-
#
|
433 |
-
tokenizer =
|
434 |
-
|
435 |
-
)
|
436 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
437 |
if model_args.from_checkpoint is not None:
|
438 |
artifact = wandb.run.use_artifact(model_args.from_checkpoint)
|
439 |
artifact_dir = artifact.download()
|
@@ -448,6 +468,12 @@ def main():
|
|
448 |
# used in the preprocessing function
|
449 |
config = model.config
|
450 |
|
|
|
|
|
|
|
|
|
|
|
|
|
451 |
else:
|
452 |
base_model = FlaxAutoModelForSeq2SeqLM.from_pretrained(
|
453 |
model_args.model_name_or_path, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype)
|
@@ -473,6 +499,12 @@ def main():
|
|
473 |
model.params['model']['shared'] = base_model.params['model']['shared']
|
474 |
del base_model
|
475 |
|
|
|
|
|
|
|
|
|
|
|
|
|
476 |
print(f"TPUs: {jax.device_count()}")
|
477 |
assert jax.device_count() == 8, "TPUs in use, please check running processes"
|
478 |
|
@@ -669,6 +701,9 @@ def main():
|
|
669 |
grad_accum=jax.tree_map(jnp.zeros_like, model.params),
|
670 |
optimizer_step=0,
|
671 |
)
|
|
|
|
|
|
|
672 |
|
673 |
# label smoothed cross entropy
|
674 |
def loss_fn(logits, labels):
|
@@ -811,13 +846,16 @@ def main():
|
|
811 |
params=params,
|
812 |
)
|
813 |
|
|
|
|
|
|
|
814 |
# save state
|
815 |
state = unreplicate(state)
|
816 |
with (Path(training_args.output_dir) / 'opt_state.msgpack').open('wb') as f:
|
817 |
f.write(to_bytes(state.opt_state))
|
818 |
with (Path(training_args.output_dir) / 'training_state.json').open('w') as f:
|
819 |
json.dump({'step': state.step.item()}, f)
|
820 |
-
|
821 |
# save to W&B
|
822 |
if data_args.log_model:
|
823 |
metadata = {'step': step, 'epoch': epoch}
|
@@ -826,8 +864,13 @@ def main():
|
|
826 |
artifact = wandb.Artifact(
|
827 |
name=f"model-{wandb.run.id}", type="bart_model", metadata=metadata
|
828 |
)
|
829 |
-
artifact.add_file(str(Path(training_args.output_dir) / 'flax_model.msgpack'))
|
830 |
-
artifact.add_file(str(Path(training_args.output_dir) / 'config.json'))
|
|
|
|
|
|
|
|
|
|
|
831 |
artifact.add_file(str(Path(training_args.output_dir) / 'opt_state.msgpack'))
|
832 |
artifact.add_file(str(Path(training_args.output_dir) / 'training_state.json'))
|
833 |
wandb.run.log_artifact(artifact)
|
|
|
271 |
|
272 |
class CustomFlaxBartModule(FlaxBartModule):
|
273 |
def setup(self):
|
274 |
+
# check config is valid, otherwise set default values
|
275 |
+
self.config.vocab_size_output = getattr(self.config, 'vocab_size_output', OUTPUT_VOCAB_SIZE)
|
276 |
+
self.config.max_position_embeddings_decoder = getattr(self.config, 'max_position_embeddings_decoder', OUTPUT_LENGTH)
|
277 |
+
|
278 |
# we keep shared to easily load pre-trained weights
|
279 |
self.shared = nn.Embed(
|
280 |
self.config.vocab_size,
|
|
|
284 |
)
|
285 |
# a separate embedding is used for the decoder
|
286 |
self.decoder_embed = nn.Embed(
|
287 |
+
self.config.vocab_size_output,
|
288 |
self.config.d_model,
|
289 |
embedding_init=jax.nn.initializers.normal(self.config.init_std, self.dtype),
|
290 |
dtype=self.dtype,
|
|
|
293 |
|
294 |
# the decoder has a different config
|
295 |
decoder_config = BartConfig(self.config.to_dict())
|
296 |
+
decoder_config.max_position_embeddings = self.config.max_position_embeddings_decoder
|
297 |
+
decoder_config.vocab_size = self.config.vocab_size_output
|
298 |
self.decoder = FlaxBartDecoder(decoder_config, dtype=self.dtype, embed_tokens=self.decoder_embed)
|
299 |
|
300 |
class CustomFlaxBartForConditionalGenerationModule(FlaxBartForConditionalGenerationModule):
|
301 |
def setup(self):
|
302 |
+
# check config is valid, otherwise set default values
|
303 |
+
self.config.vocab_size_output = getattr(self.config, 'vocab_size_output', OUTPUT_VOCAB_SIZE)
|
304 |
+
|
305 |
self.model = CustomFlaxBartModule(config=self.config, dtype=self.dtype)
|
306 |
self.lm_head = nn.Dense(
|
307 |
+
self.config.vocab_size_output,
|
308 |
use_bias=False,
|
309 |
dtype=self.dtype,
|
310 |
kernel_init=jax.nn.initializers.normal(self.config.init_std, self.dtype),
|
311 |
)
|
312 |
+
self.final_logits_bias = self.param("final_logits_bias", self.bias_init, (1, self.config.vocab_size_output))
|
313 |
|
314 |
class CustomFlaxBartForConditionalGeneration(FlaxBartForConditionalGeneration):
|
315 |
module_class = CustomFlaxBartForConditionalGenerationModule
|
|
|
436 |
# See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at
|
437 |
# https://huggingface.co/docs/datasets/loading_datasets.html.
|
438 |
|
439 |
+
# Set up items to load or create
|
440 |
+
tokenizer = None
|
441 |
+
artifact_dir = None
|
|
|
442 |
|
443 |
+
def restore_state(state, artifact_dir):
|
444 |
+
# restore optimizer state
|
445 |
+
if (Path(artifact_dir) / 'opt_state.msgpack').exists():
|
446 |
+
with (Path(artifact_dir) / 'opt_state.msgpack').open('rb') as f:
|
447 |
+
opt_state = from_bytes(state.opt_state, f.read())
|
448 |
+
|
449 |
+
# restore steps
|
450 |
+
if (Path(artifact_dir) / 'training_state.json').exists():
|
451 |
+
with (Path(artifact_dir) / 'training_state.json').open('r') as f:
|
452 |
+
training_state = json.load(f)
|
453 |
+
step = training_state['step']
|
454 |
+
optimizer_step = step // training_args.gradient_accumulation_steps
|
455 |
+
state.replace(step=step, optimizer_step=optimizer_step)
|
456 |
+
|
457 |
if model_args.from_checkpoint is not None:
|
458 |
artifact = wandb.run.use_artifact(model_args.from_checkpoint)
|
459 |
artifact_dir = artifact.download()
|
|
|
468 |
# used in the preprocessing function
|
469 |
config = model.config
|
470 |
|
471 |
+
# load tokenizer if present
|
472 |
+
if (Path(artifact_dir) / 'tokenizer_config.json').exists():
|
473 |
+
tokenizer = AutoTokenizer.from_pretrained(
|
474 |
+
model_args.model_name_or_path, cache_dir=model_args.cache_dir, use_fast=model_args.use_fast_tokenizer
|
475 |
+
)
|
476 |
+
|
477 |
else:
|
478 |
base_model = FlaxAutoModelForSeq2SeqLM.from_pretrained(
|
479 |
model_args.model_name_or_path, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype)
|
|
|
499 |
model.params['model']['shared'] = base_model.params['model']['shared']
|
500 |
del base_model
|
501 |
|
502 |
+
# Load tokenizer if it has not been set
|
503 |
+
if tokenizer is None:
|
504 |
+
tokenizer = AutoTokenizer.from_pretrained(
|
505 |
+
model_args.model_name_or_path, cache_dir=model_args.cache_dir, use_fast=model_args.use_fast_tokenizer
|
506 |
+
)
|
507 |
+
|
508 |
print(f"TPUs: {jax.device_count()}")
|
509 |
assert jax.device_count() == 8, "TPUs in use, please check running processes"
|
510 |
|
|
|
701 |
grad_accum=jax.tree_map(jnp.zeros_like, model.params),
|
702 |
optimizer_step=0,
|
703 |
)
|
704 |
+
if model_args.from_checkpoint is not None:
|
705 |
+
# restore optimizer state, step and optimizer_step
|
706 |
+
restore_state(state, artifact_dir)
|
707 |
|
708 |
# label smoothed cross entropy
|
709 |
def loss_fn(logits, labels):
|
|
|
846 |
params=params,
|
847 |
)
|
848 |
|
849 |
+
# save tokenizer
|
850 |
+
tokenizer.save_pretrained(training_args.output_dir)
|
851 |
+
|
852 |
# save state
|
853 |
state = unreplicate(state)
|
854 |
with (Path(training_args.output_dir) / 'opt_state.msgpack').open('wb') as f:
|
855 |
f.write(to_bytes(state.opt_state))
|
856 |
with (Path(training_args.output_dir) / 'training_state.json').open('w') as f:
|
857 |
json.dump({'step': state.step.item()}, f)
|
858 |
+
|
859 |
# save to W&B
|
860 |
if data_args.log_model:
|
861 |
metadata = {'step': step, 'epoch': epoch}
|
|
|
864 |
artifact = wandb.Artifact(
|
865 |
name=f"model-{wandb.run.id}", type="bart_model", metadata=metadata
|
866 |
)
|
867 |
+
artifact.add_file(str(Path(training_args.output_dir) / 'flax_model.msgpack'))
|
868 |
+
artifact.add_file(str(Path(training_args.output_dir) / 'config.json'))
|
869 |
+
artifact.add_file(str(Path(training_args.output_dir) / 'tokenizer.json'))
|
870 |
+
artifact.add_file(str(Path(training_args.output_dir) / 'tokenizer_config.json'))
|
871 |
+
artifact.add_file(str(Path(training_args.output_dir) / 'vocab.json'))
|
872 |
+
artifact.add_file(str(Path(training_args.output_dir) / 'merges.txt'))
|
873 |
+
artifact.add_file(str(Path(training_args.output_dir) / 'special_tokens_map.json'))
|
874 |
artifact.add_file(str(Path(training_args.output_dir) / 'opt_state.msgpack'))
|
875 |
artifact.add_file(str(Path(training_args.output_dir) / 'training_state.json'))
|
876 |
wandb.run.log_artifact(artifact)
|