|
import copy |
|
from transformers.utils import logging |
|
from transformers.configuration_utils import PretrainedConfig |
|
from transformers import AutoConfig, T5Config |
|
|
|
from model.encoders import VAE_ENCODER_MODELS |
|
from model.decoders import VAE_DECODER_MODELS |
|
from model.utils import assertEqual, assertIn |
|
|
|
logger = logging.get_logger(__name__) |
|
|
|
|
|
class T5VaeConfig(PretrainedConfig): |
|
r""" |
|
This is the configuration class to store the configuration of :class:`FlaxT5VAE`. |
|
It is used to instantiate a T5-VAE model according to the specified arguments, defining the model architecture. |
|
Instantiating a configuration with the defaults will yield a similar configuration to that of the T5 `t5-vae-base architecture. |
|
|
|
To be able to use `transformer.trainer.Trainer` we need some specific training logic & config in the model. |
|
|
|
Configuration objects inherit from :class:`~transformers.PretrainedConfig` and can be used to control the model |
|
outputs. Read the documentation from :class:`~transformers.PretrainedConfig` for more information. |
|
|
|
Arguments: |
|
n_latent_tokens (:obj:`int`, `optional`, defaults to 6): |
|
Number of latent tokens (must be less than seq length). |
|
latent_token_size (:obj:`int`, `optional`, defaults to 32): |
|
Number of dimensions to use for each latent token. |
|
t5_name (:obj:`str`, `optional`, defaults to t5-base): |
|
Name of the Transformer model to use as a decoder. |
|
block_size (:obj:`int`, `optional`, defaults to 60): |
|
NOTE: Every input sequence must be padded to be equal to this length. |
|
""" |
|
model_type = "transformer_vae" |
|
is_composition = True |
|
|
|
def __init__( |
|
self, |
|
t5_model_name_or_path=None, |
|
n_latent_tokens=6, |
|
latent_token_size=32, |
|
vae_encoder_model='', |
|
vae_decoder_model='', |
|
block_size=60, |
|
decoder_start_token_id=0, |
|
cache_dir=None, |
|
tie_word_embeddings=True, |
|
|
|
t5=dict(), |
|
vocab_size=32128, |
|
d_model=512, |
|
d_kv=64, |
|
d_ff=2048, |
|
num_layers=6, |
|
num_decoder_layers=None, |
|
num_heads=8, |
|
relative_attention_num_buckets=32, |
|
dropout_rate=0.1, |
|
layer_norm_epsilon=1e-6, |
|
initializer_factor=1.0, |
|
feed_forward_proj="relu", |
|
is_encoder_decoder=True, |
|
use_cache=True, |
|
pad_token_id=0, |
|
eos_token_id=1, |
|
gradient_checkpointing=False, |
|
|
|
**kwargs, |
|
): |
|
assertIn(vae_encoder_model, VAE_ENCODER_MODELS.keys(), "Unexpected VAE encoder.") |
|
assertIn(vae_decoder_model, VAE_DECODER_MODELS.keys(), "Unexpected VAE decoder.") |
|
|
|
super().__init__(**kwargs) |
|
|
|
self.set_seq_size = block_size |
|
|
|
|
|
self.vae_encoder_model = vae_encoder_model |
|
self.vae_decoder_model = vae_decoder_model |
|
|
|
self.latent_token_size = latent_token_size |
|
assert(n_latent_tokens <= self.set_seq_size, 'Cannot use more latent tokens than input tokens.') |
|
self.n_latent_tokens = n_latent_tokens |
|
self.use_cache = use_cache |
|
|
|
|
|
if t5_model_name_or_path: |
|
self.t5 = AutoConfig.from_pretrained(t5_model_name_or_path, cache_dir=cache_dir) |
|
assertEqual(self.t5.model_type, "t5", "Need t5 model type for transformer_decoder.") |
|
self.t5.decoder_start_token_id = decoder_start_token_id |
|
elif t5: |
|
|
|
self.t5 = T5Config(**t5) |
|
else: |
|
self.t5 = T5Config( |
|
vocab_size=vocab_size, |
|
d_model=d_model, |
|
d_kv=d_kv, |
|
d_ff=d_ff, |
|
num_layers=num_layers, |
|
num_decoder_layers=num_decoder_layers, |
|
num_heads=num_heads, |
|
relative_attention_num_buckets=relative_attention_num_buckets, |
|
dropout_rate=dropout_rate, |
|
layer_norm_epsilon=layer_norm_epsilon, |
|
initializer_factor=initializer_factor, |
|
feed_forward_proj=feed_forward_proj, |
|
is_encoder_decoder=is_encoder_decoder, |
|
use_cache=use_cache, |
|
pad_token_id=pad_token_id, |
|
eos_token_id=eos_token_id, |
|
gradient_checkpointing=gradient_checkpointing, |
|
**kwargs |
|
) |
|
|
|
if self.t5.d_model < self.latent_token_size: |
|
raise Exception('Using larger latent token dimension then T5 hidden dimension.') |
|
|
|
|
|
self.tie_word_embeddings = tie_word_embeddings |
|
self.t5.tie_word_embeddings = self.tie_word_embeddings |
|
self.t5.use_cache = self.use_cache |
|
self.pad_token_id = pad_token_id |
|
self.eos_token_id = eos_token_id |
|
self.decoder_start_token_id = self.t5.decoder_start_token_id |
|
|
|
def to_dict(self): |
|
""" |
|
Serializes this instance to a Python dictionary. Override the default `to_dict()` from `PretrainedConfig`. |
|
|
|
Returns: |
|
:obj:`Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance, |
|
""" |
|
output = copy.deepcopy(self.__dict__) |
|
output["model_type"] = self.__class__.model_type |
|
output['t5'] = self.t5.to_dict() |
|
return output |
|
|