Spaces:
Runtime error
Runtime error
fix: correct decoder_input_ids and labels
Browse files- seq2seq/run_seq2seq_flax.py +7 -12
seq2seq/run_seq2seq_flax.py
CHANGED
@@ -475,25 +475,20 @@ def main():
|
|
475 |
)
|
476 |
|
477 |
# set up targets
|
478 |
-
# Note:
|
479 |
-
#
|
480 |
-
|
481 |
-
labels = [[config.decoder_start_token_id] + eval(indices) for indices in examples['encoding']]
|
482 |
labels = np.asarray(labels)
|
483 |
|
484 |
# We need the labels, in addition to the decoder_input_ids, for the compute_loss function
|
485 |
-
# In our case, they are the same as decoder_input_ids. Is that correct?
|
486 |
model_inputs["labels"] = labels
|
487 |
|
488 |
-
# TODO: if data processing prevents correct compilation, we will:
|
489 |
-
# - have data saved in JSONL (to avoid `eval` which is needed here to convert string "[2]" to list[int])
|
490 |
-
# - use below `shift_tokens_right_fn`
|
491 |
# In our case, this prepends the bos token and removes the last one
|
492 |
-
|
493 |
-
|
494 |
-
|
495 |
|
496 |
-
model_inputs["decoder_input_ids"] =
|
497 |
|
498 |
return model_inputs
|
499 |
|
|
|
475 |
)
|
476 |
|
477 |
# set up targets
|
478 |
+
# Note: labels correspond to our target indices
|
479 |
+
# decoder input ids are the same but shifted to the right with bos at the beginning (and without last token)
|
480 |
+
labels = [[eval(indices) for indices in examples['encoding']]
|
|
|
481 |
labels = np.asarray(labels)
|
482 |
|
483 |
# We need the labels, in addition to the decoder_input_ids, for the compute_loss function
|
|
|
484 |
model_inputs["labels"] = labels
|
485 |
|
|
|
|
|
|
|
486 |
# In our case, this prepends the bos token and removes the last one
|
487 |
+
decoder_input_ids = shift_tokens_right_fn(
|
488 |
+
jnp.array(labels), config.pad_token_id, config.decoder_start_token_id
|
489 |
+
)
|
490 |
|
491 |
+
model_inputs["decoder_input_ids"] = decoder_input_ids
|
492 |
|
493 |
return model_inputs
|
494 |
|