Spaces:
Runtime error
Runtime error
fix: model config
Browse files
seq2seq/run_seq2seq_flax.py
CHANGED
@@ -282,8 +282,6 @@ class CustomFlaxBartModule(FlaxBartModule):
|
|
282 |
# the decoder has a different config
|
283 |
decoder_config = BartConfig(self.config.to_dict())
|
284 |
decoder_config.max_position_embeddings = OUTPUT_LENGTH
|
285 |
-
decoder_config.min_length = OUTPUT_LENGTH
|
286 |
-
decoder_config.max_length = OUTPUT_LENGTH
|
287 |
decoder_config.vocab_size = OUTPUT_VOCAB_SIZE
|
288 |
self.decoder = FlaxBartDecoder(decoder_config, dtype=self.dtype, embed_tokens=self.decoder_embed)
|
289 |
|
@@ -440,8 +438,8 @@ def main():
|
|
440 |
config.eos_token_id = BOS_TOKEN_ID + 1 # unreachable
|
441 |
config.forced_bos_token_id = None # we don't need this token
|
442 |
config.forced_eos_token_id = None # we don't need this token
|
443 |
-
|
444 |
-
|
445 |
|
446 |
print(f"TPUs: {jax.device_count()}")
|
447 |
assert jax.device_count() == 8, "TPUs in use, please check running processes"
|
|
|
282 |
# the decoder has a different config
|
283 |
decoder_config = BartConfig(self.config.to_dict())
|
284 |
decoder_config.max_position_embeddings = OUTPUT_LENGTH
|
|
|
|
|
285 |
decoder_config.vocab_size = OUTPUT_VOCAB_SIZE
|
286 |
self.decoder = FlaxBartDecoder(decoder_config, dtype=self.dtype, embed_tokens=self.decoder_embed)
|
287 |
|
|
|
438 |
config.eos_token_id = BOS_TOKEN_ID + 1 # unreachable
|
439 |
config.forced_bos_token_id = None # we don't need this token
|
440 |
config.forced_eos_token_id = None # we don't need this token
|
441 |
+
config.min_length = data_args.max_target_length
|
442 |
+
config.max_length = data_args.max_target_length
|
443 |
|
444 |
print(f"TPUs: {jax.device_count()}")
|
445 |
assert jax.device_count() == 8, "TPUs in use, please check running processes"
|