Spaces:
Runtime error
Runtime error
feat: gradient accumulation
Browse files- seq2seq/run_seq2seq_flax.py +37 -10
seq2seq/run_seq2seq_flax.py
CHANGED
@@ -239,6 +239,8 @@ class DataTrainingArguments:
|
|
239 |
|
240 |
class TrainState(train_state.TrainState):
|
241 |
dropout_rng: jnp.ndarray
|
|
|
|
|
242 |
|
243 |
def replicate(self):
|
244 |
return jax_utils.replicate(self).replace(dropout_rng=shard_prng_key(self.dropout_rng))
|
@@ -590,14 +592,16 @@ def main():
|
|
590 |
# Store some constant
|
591 |
num_epochs = int(training_args.num_train_epochs)
|
592 |
train_batch_size = int(training_args.per_device_train_batch_size) * jax.device_count()
|
|
|
593 |
eval_batch_size = int(training_args.per_device_eval_batch_size) * jax.device_count()
|
594 |
steps_per_epoch = len(train_dataset) // train_batch_size
|
595 |
-
|
|
|
596 |
|
597 |
# Create learning rate schedule
|
598 |
linear_decay_lr_schedule_fn = create_learning_rate_fn(
|
599 |
len(train_dataset),
|
600 |
-
|
601 |
training_args.num_train_epochs,
|
602 |
training_args.warmup_steps,
|
603 |
training_args.learning_rate,
|
@@ -636,7 +640,14 @@ def main():
|
|
636 |
)
|
637 |
|
638 |
# Setup train state
|
639 |
-
state = TrainState.create(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
640 |
|
641 |
# label smoothed cross entropy
|
642 |
def loss_fn(logits, labels):
|
@@ -655,15 +666,28 @@ def main():
|
|
655 |
return loss
|
656 |
|
657 |
grad_fn = jax.value_and_grad(compute_loss)
|
658 |
-
loss,
|
659 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
660 |
|
661 |
-
new_state =
|
|
|
|
|
|
|
|
|
|
|
662 |
|
663 |
-
metrics = {"loss": loss, "learning_rate": linear_decay_lr_schedule_fn(state.
|
664 |
metrics = jax.lax.pmean(metrics, axis_name="batch")
|
665 |
|
666 |
-
return new_state, metrics
|
667 |
|
668 |
# Define eval fn
|
669 |
def eval_step(params, batch):
|
@@ -702,8 +726,11 @@ def main():
|
|
702 |
logger.info(f" Num examples = {len(train_dataset)}")
|
703 |
logger.info(f" Num Epochs = {num_epochs}")
|
704 |
logger.info(f" Instantaneous batch size per device = {training_args.per_device_train_batch_size}")
|
705 |
-
logger.info(
|
706 |
-
|
|
|
|
|
|
|
707 |
|
708 |
train_time = 0
|
709 |
epochs = tqdm(range(num_epochs), desc=f"Epoch ... (1/{num_epochs})", position=0)
|
|
|
239 |
|
240 |
class TrainState(train_state.TrainState):
|
241 |
dropout_rng: jnp.ndarray
|
242 |
+
grad_accum: jnp.ndarray
|
243 |
+
optimizer_step: int
|
244 |
|
245 |
def replicate(self):
|
246 |
return jax_utils.replicate(self).replace(dropout_rng=shard_prng_key(self.dropout_rng))
|
|
|
592 |
# Store some constant
|
593 |
num_epochs = int(training_args.num_train_epochs)
|
594 |
train_batch_size = int(training_args.per_device_train_batch_size) * jax.device_count()
|
595 |
+
total_batch_size = int(train_batch_size) * training_args.gradient_accumulation_steps
|
596 |
eval_batch_size = int(training_args.per_device_eval_batch_size) * jax.device_count()
|
597 |
steps_per_epoch = len(train_dataset) // train_batch_size
|
598 |
+
total_steps = steps_per_epoch * num_epochs
|
599 |
+
total_optimization_steps = (len(train_dataset) // total_batch_size) * num_epochs
|
600 |
|
601 |
# Create learning rate schedule
|
602 |
linear_decay_lr_schedule_fn = create_learning_rate_fn(
|
603 |
len(train_dataset),
|
604 |
+
total_batch_size,
|
605 |
training_args.num_train_epochs,
|
606 |
training_args.warmup_steps,
|
607 |
training_args.learning_rate,
|
|
|
640 |
)
|
641 |
|
642 |
# Setup train state
|
643 |
+
state = TrainState.create(
|
644 |
+
apply_fn=model.__call__,
|
645 |
+
params=model.params,
|
646 |
+
tx=adamw,
|
647 |
+
dropout_rng=dropout_rng,
|
648 |
+
grad_accum=jax.tree_map(jnp.zeros_like, model.params),
|
649 |
+
optimizer_step=0,
|
650 |
+
)
|
651 |
|
652 |
# label smoothed cross entropy
|
653 |
def loss_fn(logits, labels):
|
|
|
666 |
return loss
|
667 |
|
668 |
grad_fn = jax.value_and_grad(compute_loss)
|
669 |
+
loss, grads = grad_fn(state.params)
|
670 |
+
grad_accum = jax.tree_multimap(lambda x, y: x + y, grads, state.grad_accum)
|
671 |
+
|
672 |
+
def update_fn():
|
673 |
+
grads = jax.tree_map(lambda x: x / training_args.gradient_accumulation_steps, grad_accum)
|
674 |
+
grads = jax.lax.pmean(grads, "batch")
|
675 |
+
new_state = state.apply_gradients(
|
676 |
+
grads=grads, grad_accum=jax.tree_map(jnp.zeros_like, grads), optimizer_step=state.optimizer_step
|
677 |
+
)
|
678 |
+
return new_state
|
679 |
|
680 |
+
new_state = jax.lax.cond(
|
681 |
+
state.step % training_args.gradient_accumulation_steps == 0,
|
682 |
+
lambda _: update_fn(),
|
683 |
+
lambda _: state.replace(grad_accum=grad_accum, step=state.step + 1),
|
684 |
+
None,
|
685 |
+
)
|
686 |
|
687 |
+
metrics = {"loss": loss, "learning_rate": linear_decay_lr_schedule_fn(state.optimizer_step)}
|
688 |
metrics = jax.lax.pmean(metrics, axis_name="batch")
|
689 |
|
690 |
+
return new_state.replace(dropout_rng=new_dropout_rng), metrics
|
691 |
|
692 |
# Define eval fn
|
693 |
def eval_step(params, batch):
|
|
|
726 |
logger.info(f" Num examples = {len(train_dataset)}")
|
727 |
logger.info(f" Num Epochs = {num_epochs}")
|
728 |
logger.info(f" Instantaneous batch size per device = {training_args.per_device_train_batch_size}")
|
729 |
+
logger.info(
|
730 |
+
f" Total train batch size (w. parallel & distributed) = {train_batch_size * training_args.gradient_accumulation_steps}"
|
731 |
+
)
|
732 |
+
logger.info(f" Total global steps = {total_steps}")
|
733 |
+
logger.info(f" Total optimization steps = {total_optimization_steps}")
|
734 |
|
735 |
train_time = 0
|
736 |
epochs = tqdm(range(num_epochs), desc=f"Epoch ... (1/{num_epochs})", position=0)
|