Spaces:
Runtime error
Runtime error
Pedro Cuenca
commited on
Commit
·
835ea55
1
Parent(s):
945d86c
Shift tokens in numpy because the built in shift function stalls.
Browse filesPossible cause is the conversion to jax arrays and then back to numpy,
we might be moving data to/from the TPU.
- seq2seq/run_seq2seq_flax.py +11 -11
seq2seq/run_seq2seq_flax.py
CHANGED
@@ -462,16 +462,19 @@ def main():
|
|
462 |
# Temporarily set max_target_length for training.
|
463 |
max_target_length = data_args.max_target_length
|
464 |
|
465 |
-
|
466 |
-
|
467 |
-
|
468 |
-
|
469 |
-
|
|
|
|
|
|
|
470 |
|
471 |
-
# Setting padding="max_length" as we need fixed length inputs for jitted functions
|
472 |
def preprocess_function(examples):
|
473 |
inputs = examples[text_column]
|
474 |
inputs = [prefix + inp for inp in inputs]
|
|
|
475 |
model_inputs = tokenizer(
|
476 |
inputs, max_length=data_args.max_source_length, padding="max_length", truncation=True, return_tensors="np"
|
477 |
)
|
@@ -486,11 +489,8 @@ def main():
|
|
486 |
model_inputs["labels"] = labels
|
487 |
|
488 |
# In our case, this prepends the bos token and removes the last one
|
489 |
-
decoder_input_ids =
|
490 |
-
|
491 |
-
)
|
492 |
-
|
493 |
-
model_inputs["decoder_input_ids"] = np.asarray(decoder_input_ids)
|
494 |
|
495 |
return model_inputs
|
496 |
|
|
|
462 |
# Temporarily set max_target_length for training.
|
463 |
max_target_length = data_args.max_target_length
|
464 |
|
465 |
+
def shift_tokens_right(input_ids: np.array, decoder_start_token_id: int):
|
466 |
+
"""
|
467 |
+
Shift input ids one token to the right.
|
468 |
+
"""
|
469 |
+
shifted_input_ids = np.zeros(input_ids.shape)
|
470 |
+
shifted_input_ids[:, 1:] = input_ids[:, :-1]
|
471 |
+
shifted_input_ids[:, 0] = decoder_start_token_id
|
472 |
+
return shifted_input_ids
|
473 |
|
|
|
474 |
def preprocess_function(examples):
|
475 |
inputs = examples[text_column]
|
476 |
inputs = [prefix + inp for inp in inputs]
|
477 |
+
# Setting padding="max_length" as we need fixed length inputs for jitted functions
|
478 |
model_inputs = tokenizer(
|
479 |
inputs, max_length=data_args.max_source_length, padding="max_length", truncation=True, return_tensors="np"
|
480 |
)
|
|
|
489 |
model_inputs["labels"] = labels
|
490 |
|
491 |
# In our case, this prepends the bos token and removes the last one
|
492 |
+
decoder_input_ids = shift_tokens_right(labels, config.decoder_start_token_id)
|
493 |
+
model_inputs["decoder_input_ids"] = decoder_input_ids
|
|
|
|
|
|
|
494 |
|
495 |
return model_inputs
|
496 |
|