Spaces:
Runtime error
Runtime error
File size: 2,463 Bytes
7f962d6 d209547 7f962d6 803c7df d209547 7f962d6 b7d8724 7f962d6 803c7df 7f962d6 b7d8724 7f962d6 803c7df 7f962d6 b7d8724 7f962d6 803c7df 7f962d6 803c7df 7f962d6 92ccf4c 7f962d6 803c7df 7f962d6 b7d8724 7f962d6 803c7df 7f962d6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 |
import flax.linen as nn
import jax
from transformers import BartConfig
from transformers.models.bart.modeling_flax_bart import (
FlaxBartDecoder,
FlaxBartEncoder,
FlaxBartForConditionalGeneration,
FlaxBartForConditionalGenerationModule,
FlaxBartModule,
)
class CustomFlaxBartModule(FlaxBartModule):
def setup(self):
# we keep shared to easily load pre-trained weights
self.shared = nn.Embed(
self.config.vocab_size,
self.config.d_model,
embedding_init=jax.nn.initializers.normal(self.config.init_std),
)
# a separate embedding is used for the decoder
self.decoder_embed = nn.Embed(
self.config.image_vocab_size + 1,
self.config.d_model,
embedding_init=jax.nn.initializers.normal(self.config.init_std),
)
self.encoder = FlaxBartEncoder(
self.config, dtype=self.dtype, embed_tokens=self.shared
)
# the decoder has a different config
# TODO: should not be needed once we have custom config/module
decoder_config = BartConfig(self.config.to_dict())
decoder_config.max_position_embeddings = (
self.config.image_length + 1 # image tokens + BOS
)
decoder_config.vocab_size = self.config.image_vocab_size + 1
self.decoder = FlaxBartDecoder(
decoder_config, dtype=self.dtype, embed_tokens=self.decoder_embed
)
class CustomFlaxBartForConditionalGenerationModule(
FlaxBartForConditionalGenerationModule
):
def setup(self):
# set default config
self.config.normalize_text = getattr(self.config, "normalize_text", False)
self.config.image_length = getattr(self.config, "image_length", 256)
self.config.image_vocab_size = getattr(self.config, "image_vocab_size", 16384)
self.model = CustomFlaxBartModule(config=self.config, dtype=self.dtype)
self.lm_head = nn.Dense(
self.config.image_vocab_size + 1, # encoded image token space + 1 for bos
use_bias=False,
kernel_init=jax.nn.initializers.normal(self.config.init_std),
)
self.final_logits_bias = self.param(
"final_logits_bias", self.bias_init, (1, self.config.image_vocab_size + 1)
)
class CustomFlaxBartForConditionalGeneration(FlaxBartForConditionalGeneration):
module_class = CustomFlaxBartForConditionalGenerationModule
|