Spaces:
Runtime error
Runtime error
Merge pull request #10 from borisdayma/feat-loss
Browse files- seq2seq/run_seq2seq_flax.py +9 -34
seq2seq/run_seq2seq_flax.py
CHANGED
@@ -487,10 +487,6 @@ def main():
|
|
487 |
|
488 |
model_inputs["decoder_input_ids"] = labels
|
489 |
|
490 |
-
# We need decoder_attention_mask so we can ignore pad tokens from loss
|
491 |
-
# TODO: I don't believe we need "decoder_attention_mask" in this case because all labels have same length
|
492 |
-
#model_inputs["decoder_attention_mask"] = labels["attention_mask"]
|
493 |
-
|
494 |
return model_inputs
|
495 |
|
496 |
if training_args.do_train:
|
@@ -643,39 +639,19 @@ def main():
|
|
643 |
state = TrainState.create(apply_fn=model.__call__, params=model.params, tx=optimizer, dropout_rng=dropout_rng)
|
644 |
|
645 |
# label smoothed cross entropy
|
646 |
-
def loss_fn(logits, labels
|
647 |
-
|
648 |
-
|
649 |
-
https://github.com/google/flax/blob/87a211135c6a377c8f29048a1cac3840e38b9da4/examples/wmt/train.py#L104
|
650 |
-
"""
|
651 |
-
vocab_size = logits.shape[-1]
|
652 |
-
confidence = 1.0 - label_smoothing_factor
|
653 |
-
low_confidence = (1.0 - confidence) / (vocab_size - 1)
|
654 |
-
normalizing_constant = -(
|
655 |
-
confidence * jnp.log(confidence) + (vocab_size - 1) * low_confidence * jnp.log(low_confidence + 1e-20)
|
656 |
-
)
|
657 |
-
soft_labels = onehot(labels, vocab_size, on_value=confidence, off_value=low_confidence)
|
658 |
-
|
659 |
-
loss = optax.softmax_cross_entropy(logits, soft_labels)
|
660 |
-
loss = loss - normalizing_constant
|
661 |
-
|
662 |
-
if padding_mask is None:
|
663 |
-
padding_mask = np.ones(loss.shape)
|
664 |
-
|
665 |
-
# ignore padded tokens from loss
|
666 |
-
loss = loss * padding_mask
|
667 |
-
loss = loss.sum() / padding_mask.sum()
|
668 |
return loss
|
669 |
|
670 |
# Define gradient update step fn
|
671 |
-
def train_step(state, batch
|
672 |
dropout_rng, new_dropout_rng = jax.random.split(state.dropout_rng)
|
673 |
|
674 |
def compute_loss(params):
|
675 |
labels = batch.pop("labels")
|
676 |
logits = state.apply_fn(**batch, params=params, dropout_rng=dropout_rng, train=True)[0]
|
677 |
-
|
678 |
-
loss = loss_fn(logits, labels, padding_mask, label_smoothing_factor)
|
679 |
return loss
|
680 |
|
681 |
grad_fn = jax.value_and_grad(compute_loss)
|
@@ -690,11 +666,10 @@ def main():
|
|
690 |
return new_state, metrics
|
691 |
|
692 |
# Define eval fn
|
693 |
-
def eval_step(params, batch
|
694 |
labels = batch.pop("labels")
|
695 |
logits = model(**batch, params=params, train=False)[0]
|
696 |
-
|
697 |
-
loss = loss_fn(logits, labels, padding_mask, label_smoothing_factor)
|
698 |
|
699 |
# summarize metrics
|
700 |
metrics = {"loss": loss}
|
@@ -715,9 +690,9 @@ def main():
|
|
715 |
|
716 |
# Create parallel version of the train and eval step
|
717 |
p_train_step = jax.pmap(
|
718 |
-
|
719 |
)
|
720 |
-
p_eval_step = jax.pmap(
|
721 |
p_generate_step = jax.pmap(generate_step, "batch")
|
722 |
|
723 |
# Replicate the train state on each device
|
|
|
487 |
|
488 |
model_inputs["decoder_input_ids"] = labels
|
489 |
|
|
|
|
|
|
|
|
|
490 |
return model_inputs
|
491 |
|
492 |
if training_args.do_train:
|
|
|
639 |
state = TrainState.create(apply_fn=model.__call__, params=model.params, tx=optimizer, dropout_rng=dropout_rng)
|
640 |
|
641 |
# label smoothed cross entropy
|
642 |
+
def loss_fn(logits, labels):
|
643 |
+
loss = optax.softmax_cross_entropy(logits, onehot(labels, logits.shape[-1]))
|
644 |
+
loss = loss.mean()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
645 |
return loss
|
646 |
|
647 |
# Define gradient update step fn
|
648 |
+
def train_step(state, batch):
|
649 |
dropout_rng, new_dropout_rng = jax.random.split(state.dropout_rng)
|
650 |
|
651 |
def compute_loss(params):
|
652 |
labels = batch.pop("labels")
|
653 |
logits = state.apply_fn(**batch, params=params, dropout_rng=dropout_rng, train=True)[0]
|
654 |
+
loss = loss_fn(logits, labels)
|
|
|
655 |
return loss
|
656 |
|
657 |
grad_fn = jax.value_and_grad(compute_loss)
|
|
|
666 |
return new_state, metrics
|
667 |
|
668 |
# Define eval fn
|
669 |
+
def eval_step(params, batch):
|
670 |
labels = batch.pop("labels")
|
671 |
logits = model(**batch, params=params, train=False)[0]
|
672 |
+
loss = loss_fn(logits, labels)
|
|
|
673 |
|
674 |
# summarize metrics
|
675 |
metrics = {"loss": loss}
|
|
|
690 |
|
691 |
# Create parallel version of the train and eval step
|
692 |
p_train_step = jax.pmap(
|
693 |
+
train_step, "batch", donate_argnums=(0,)
|
694 |
)
|
695 |
+
p_eval_step = jax.pmap(eval_step, "batch")
|
696 |
p_generate_step = jax.pmap(generate_step, "batch")
|
697 |
|
698 |
# Replicate the train state on each device
|