File size: 5,378 Bytes
0b69648 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 |
import copy
from transformers.utils import logging
from transformers.configuration_utils import PretrainedConfig
from transformers import AutoConfig, T5Config
from model.encoders import VAE_ENCODER_MODELS
from model.decoders import VAE_DECODER_MODELS
from model.utils import assertEqual, assertIn
logger = logging.get_logger(__name__)
class T5VaeConfig(PretrainedConfig):
r"""
This is the configuration class to store the configuration of :class:`FlaxT5VAE`.
It is used to instantiate a T5-VAE model according to the specified arguments, defining the model architecture.
Instantiating a configuration with the defaults will yield a similar configuration to that of the T5 `t5-vae-base architecture.
To be able to use `transformer.trainer.Trainer` we need some specific training logic & config in the model.
Configuration objects inherit from :class:`~transformers.PretrainedConfig` and can be used to control the model
outputs. Read the documentation from :class:`~transformers.PretrainedConfig` for more information.
Arguments:
n_latent_tokens (:obj:`int`, `optional`, defaults to 6):
Number of latent tokens (must be less than seq length).
latent_token_size (:obj:`int`, `optional`, defaults to 32):
Number of dimensions to use for each latent token.
t5_name (:obj:`str`, `optional`, defaults to t5-base):
Name of the Transformer model to use as a decoder.
block_size (:obj:`int`, `optional`, defaults to 60):
NOTE: Every input sequence must be padded to be equal to this length.
"""
model_type = "transformer_vae"
is_composition = True
def __init__(
self,
t5_model_name_or_path=None,
n_latent_tokens=6, # set to -1 for full sequence
latent_token_size=32,
vae_encoder_model='',
vae_decoder_model='',
block_size=60,
decoder_start_token_id=0,
cache_dir=None,
tie_word_embeddings=True,
# T5 config
t5=dict(),
vocab_size=32128,
d_model=512,
d_kv=64,
d_ff=2048,
num_layers=6,
num_decoder_layers=None,
num_heads=8,
relative_attention_num_buckets=32,
dropout_rate=0.1,
layer_norm_epsilon=1e-6,
initializer_factor=1.0,
feed_forward_proj="relu",
is_encoder_decoder=True,
use_cache=True,
pad_token_id=0,
eos_token_id=1,
gradient_checkpointing=False,
# end
**kwargs,
):
assertIn(vae_encoder_model, VAE_ENCODER_MODELS.keys(), "Unexpected VAE encoder.")
assertIn(vae_decoder_model, VAE_DECODER_MODELS.keys(), "Unexpected VAE decoder.")
super().__init__(**kwargs)
self.set_seq_size = block_size
# VAE
self.vae_encoder_model = vae_encoder_model
self.vae_decoder_model = vae_decoder_model
self.latent_token_size = latent_token_size
assert(n_latent_tokens <= self.set_seq_size, 'Cannot use more latent tokens than input tokens.')
self.n_latent_tokens = n_latent_tokens
self.use_cache = use_cache
# T5
if t5_model_name_or_path:
self.t5 = AutoConfig.from_pretrained(t5_model_name_or_path, cache_dir=cache_dir)
assertEqual(self.t5.model_type, "t5", "Need t5 model type for transformer_decoder.")
self.t5.decoder_start_token_id = decoder_start_token_id
elif t5:
# use for loading a config
self.t5 = T5Config(**t5)
else:
self.t5 = T5Config(
vocab_size=vocab_size,
d_model=d_model,
d_kv=d_kv,
d_ff=d_ff,
num_layers=num_layers,
num_decoder_layers=num_decoder_layers,
num_heads=num_heads,
relative_attention_num_buckets=relative_attention_num_buckets,
dropout_rate=dropout_rate,
layer_norm_epsilon=layer_norm_epsilon,
initializer_factor=initializer_factor,
feed_forward_proj=feed_forward_proj,
is_encoder_decoder=is_encoder_decoder,
use_cache=use_cache,
pad_token_id=pad_token_id,
eos_token_id=eos_token_id,
gradient_checkpointing=gradient_checkpointing,
**kwargs
)
if self.t5.d_model < self.latent_token_size:
raise Exception('Using larger latent token dimension then T5 hidden dimension.')
# Add t5 config options
self.tie_word_embeddings = tie_word_embeddings
self.t5.tie_word_embeddings = self.tie_word_embeddings
self.t5.use_cache = self.use_cache
self.pad_token_id = pad_token_id
self.eos_token_id = eos_token_id
self.decoder_start_token_id = self.t5.decoder_start_token_id
def to_dict(self):
"""
Serializes this instance to a Python dictionary. Override the default `to_dict()` from `PretrainedConfig`.
Returns:
:obj:`Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance,
"""
output = copy.deepcopy(self.__dict__)
output["model_type"] = self.__class__.model_type
output['t5'] = self.t5.to_dict()
return output
|