boris commited on
Commit
d61405b
·
1 Parent(s): 833a2d5

feat: padding mask not required

Browse files
Files changed (1) hide show
  1. seq2seq/run_seq2seq_flax.py +4 -15
seq2seq/run_seq2seq_flax.py CHANGED
@@ -487,10 +487,6 @@ def main():
487
 
488
  model_inputs["decoder_input_ids"] = labels
489
 
490
- # We need decoder_attention_mask so we can ignore pad tokens from loss
491
- # TODO: I don't believe we need "decoder_attention_mask" in this case because all labels have same length
492
- #model_inputs["decoder_attention_mask"] = labels["attention_mask"]
493
-
494
  return model_inputs
495
 
496
  if training_args.do_train:
@@ -643,7 +639,7 @@ def main():
643
  state = TrainState.create(apply_fn=model.__call__, params=model.params, tx=optimizer, dropout_rng=dropout_rng)
644
 
645
  # label smoothed cross entropy
646
- def loss_fn(logits, labels, padding_mask, label_smoothing_factor=0.0):
647
  """
648
  The label smoothing implementation is adapted from Flax's official example:
649
  https://github.com/google/flax/blob/87a211135c6a377c8f29048a1cac3840e38b9da4/examples/wmt/train.py#L104
@@ -659,12 +655,7 @@ def main():
659
  loss = optax.softmax_cross_entropy(logits, soft_labels)
660
  loss = loss - normalizing_constant
661
 
662
- if padding_mask is None:
663
- padding_mask = np.ones(loss.shape)
664
-
665
- # ignore padded tokens from loss
666
- loss = loss * padding_mask
667
- loss = loss.sum() / padding_mask.sum()
668
  return loss
669
 
670
  # Define gradient update step fn
@@ -674,8 +665,7 @@ def main():
674
  def compute_loss(params):
675
  labels = batch.pop("labels")
676
  logits = state.apply_fn(**batch, params=params, dropout_rng=dropout_rng, train=True)[0]
677
- padding_mask = batch.get("decoder_attention_mask", None)
678
- loss = loss_fn(logits, labels, padding_mask, label_smoothing_factor)
679
  return loss
680
 
681
  grad_fn = jax.value_and_grad(compute_loss)
@@ -693,8 +683,7 @@ def main():
693
  def eval_step(params, batch, label_smoothing_factor=0.0):
694
  labels = batch.pop("labels")
695
  logits = model(**batch, params=params, train=False)[0]
696
- padding_mask = batch.get("decoder_attention_mask", None)
697
- loss = loss_fn(logits, labels, padding_mask, label_smoothing_factor)
698
 
699
  # summarize metrics
700
  metrics = {"loss": loss}
 
487
 
488
  model_inputs["decoder_input_ids"] = labels
489
 
 
 
 
 
490
  return model_inputs
491
 
492
  if training_args.do_train:
 
639
  state = TrainState.create(apply_fn=model.__call__, params=model.params, tx=optimizer, dropout_rng=dropout_rng)
640
 
641
  # label smoothed cross entropy
642
+ def loss_fn(logits, labels, label_smoothing_factor=0.0):
643
  """
644
  The label smoothing implementation is adapted from Flax's official example:
645
  https://github.com/google/flax/blob/87a211135c6a377c8f29048a1cac3840e38b9da4/examples/wmt/train.py#L104
 
655
  loss = optax.softmax_cross_entropy(logits, soft_labels)
656
  loss = loss - normalizing_constant
657
 
658
+ loss = loss.mean()
 
 
 
 
 
659
  return loss
660
 
661
  # Define gradient update step fn
 
665
  def compute_loss(params):
666
  labels = batch.pop("labels")
667
  logits = state.apply_fn(**batch, params=params, dropout_rng=dropout_rng, train=True)[0]
668
+ loss = loss_fn(logits, labels, label_smoothing_factor)
 
669
  return loss
670
 
671
  grad_fn = jax.value_and_grad(compute_loss)
 
683
  def eval_step(params, batch, label_smoothing_factor=0.0):
684
  labels = batch.pop("labels")
685
  logits = model(**batch, params=params, train=False)[0]
686
+ loss = loss_fn(logits, labels, label_smoothing_factor)
 
687
 
688
  # summarize metrics
689
  metrics = {"loss": loss}