Spaces:
Runtime error
Runtime error
fix: accumulation vs lr
Browse files
seq2seq/run_seq2seq_flax.py
CHANGED
@@ -673,12 +673,12 @@ def main():
|
|
673 |
grads = jax.tree_map(lambda x: x / training_args.gradient_accumulation_steps, grad_accum)
|
674 |
grads = jax.lax.pmean(grads, "batch")
|
675 |
new_state = state.apply_gradients(
|
676 |
-
grads=grads, grad_accum=jax.tree_map(jnp.zeros_like, grads), optimizer_step=state.optimizer_step
|
677 |
)
|
678 |
return new_state
|
679 |
|
680 |
new_state = jax.lax.cond(
|
681 |
-
state.step % training_args.gradient_accumulation_steps == 0,
|
682 |
lambda _: update_fn(),
|
683 |
lambda _: state.replace(grad_accum=grad_accum, step=state.step + 1),
|
684 |
None,
|
|
|
673 |
grads = jax.tree_map(lambda x: x / training_args.gradient_accumulation_steps, grad_accum)
|
674 |
grads = jax.lax.pmean(grads, "batch")
|
675 |
new_state = state.apply_gradients(
|
676 |
+
grads=grads, grad_accum=jax.tree_map(jnp.zeros_like, grads), optimizer_step=state.optimizer_step + 1
|
677 |
)
|
678 |
return new_state
|
679 |
|
680 |
new_state = jax.lax.cond(
|
681 |
+
(state.step + 1) % training_args.gradient_accumulation_steps == 0,
|
682 |
lambda _: update_fn(),
|
683 |
lambda _: state.replace(grad_accum=grad_accum, step=state.step + 1),
|
684 |
None,
|