Fraser commited on
Commit
f0eb504
·
1 Parent(s): 2a2a24e

switch to using shared submodule for model code

Browse files
Files changed (12) hide show
  1. model/__init__.py +0 -0
  2. model/config.py +0 -137
  3. model/decoders.py +0 -23
  4. model/encoders.py +0 -26
  5. model/outputs.py +0 -74
  6. model/t5_vae.py +0 -522
  7. model/utils.py +0 -24
  8. model/vae.py +0 -30
  9. t5-vae-flax +0 -1
  10. t5_vae_flax +1 -1
  11. train.py +2 -2
  12. train.sh +3 -3
model/__init__.py DELETED
File without changes
model/config.py DELETED
@@ -1,137 +0,0 @@
1
- import copy
2
- from transformers.utils import logging
3
- from transformers.configuration_utils import PretrainedConfig
4
- from transformers import AutoConfig, T5Config
5
-
6
- from model.encoders import VAE_ENCODER_MODELS
7
- from model.decoders import VAE_DECODER_MODELS
8
- from model.utils import assertEqual, assertIn
9
-
10
- logger = logging.get_logger(__name__)
11
-
12
-
13
- class T5VaeConfig(PretrainedConfig):
14
- r"""
15
- This is the configuration class to store the configuration of :class:`FlaxT5VAE`.
16
- It is used to instantiate a T5-VAE model according to the specified arguments, defining the model architecture.
17
- Instantiating a configuration with the defaults will yield a similar configuration to that of the T5 `t5-vae-base architecture.
18
-
19
- To be able to use `transformer.trainer.Trainer` we need some specific training logic & config in the model.
20
-
21
- Configuration objects inherit from :class:`~transformers.PretrainedConfig` and can be used to control the model
22
- outputs. Read the documentation from :class:`~transformers.PretrainedConfig` for more information.
23
-
24
- Arguments:
25
- n_latent_tokens (:obj:`int`, `optional`, defaults to 6):
26
- Number of latent tokens (must be less than seq length).
27
- latent_token_size (:obj:`int`, `optional`, defaults to 32):
28
- Number of dimensions to use for each latent token.
29
- t5_name (:obj:`str`, `optional`, defaults to t5-base):
30
- Name of the Transformer model to use as a decoder.
31
- block_size (:obj:`int`, `optional`, defaults to 60):
32
- NOTE: Every input sequence must be padded to be equal to this length.
33
- """
34
- model_type = "transformer_vae"
35
- is_composition = True
36
-
37
- def __init__(
38
- self,
39
- t5_model_name_or_path=None,
40
- n_latent_tokens=6, # set to -1 for full sequence
41
- latent_token_size=32,
42
- vae_encoder_model='',
43
- vae_decoder_model='',
44
- block_size=60,
45
- decoder_start_token_id=0,
46
- cache_dir=None,
47
- tie_word_embeddings=True,
48
- # T5 config
49
- t5=dict(),
50
- vocab_size=32128,
51
- d_model=512,
52
- d_kv=64,
53
- d_ff=2048,
54
- num_layers=6,
55
- num_decoder_layers=None,
56
- num_heads=8,
57
- relative_attention_num_buckets=32,
58
- dropout_rate=0.1,
59
- layer_norm_epsilon=1e-6,
60
- initializer_factor=1.0,
61
- feed_forward_proj="relu",
62
- is_encoder_decoder=True,
63
- use_cache=True,
64
- pad_token_id=0,
65
- eos_token_id=1,
66
- gradient_checkpointing=False,
67
- # end
68
- **kwargs,
69
- ):
70
- assertIn(vae_encoder_model, VAE_ENCODER_MODELS.keys(), "Unexpected VAE encoder.")
71
- assertIn(vae_decoder_model, VAE_DECODER_MODELS.keys(), "Unexpected VAE decoder.")
72
-
73
- super().__init__(**kwargs)
74
-
75
- self.set_seq_size = block_size
76
-
77
- # VAE
78
- self.vae_encoder_model = vae_encoder_model
79
- self.vae_decoder_model = vae_decoder_model
80
-
81
- self.latent_token_size = latent_token_size
82
- assert(n_latent_tokens <= self.set_seq_size, 'Cannot use more latent tokens than input tokens.')
83
- self.n_latent_tokens = n_latent_tokens
84
- self.use_cache = use_cache
85
-
86
- # T5
87
- if t5_model_name_or_path:
88
- self.t5 = AutoConfig.from_pretrained(t5_model_name_or_path, cache_dir=cache_dir)
89
- assertEqual(self.t5.model_type, "t5", "Need t5 model type for transformer_decoder.")
90
- self.t5.decoder_start_token_id = decoder_start_token_id
91
- elif t5:
92
- # use for loading a config
93
- self.t5 = T5Config(**t5)
94
- else:
95
- self.t5 = T5Config(
96
- vocab_size=vocab_size,
97
- d_model=d_model,
98
- d_kv=d_kv,
99
- d_ff=d_ff,
100
- num_layers=num_layers,
101
- num_decoder_layers=num_decoder_layers,
102
- num_heads=num_heads,
103
- relative_attention_num_buckets=relative_attention_num_buckets,
104
- dropout_rate=dropout_rate,
105
- layer_norm_epsilon=layer_norm_epsilon,
106
- initializer_factor=initializer_factor,
107
- feed_forward_proj=feed_forward_proj,
108
- is_encoder_decoder=is_encoder_decoder,
109
- use_cache=use_cache,
110
- pad_token_id=pad_token_id,
111
- eos_token_id=eos_token_id,
112
- gradient_checkpointing=gradient_checkpointing,
113
- **kwargs
114
- )
115
-
116
- if self.t5.d_model < self.latent_token_size:
117
- raise Exception('Using larger latent token dimension then T5 hidden dimension.')
118
-
119
- # Add t5 config options
120
- self.tie_word_embeddings = tie_word_embeddings
121
- self.t5.tie_word_embeddings = self.tie_word_embeddings
122
- self.t5.use_cache = self.use_cache
123
- self.pad_token_id = pad_token_id
124
- self.eos_token_id = eos_token_id
125
- self.decoder_start_token_id = self.t5.decoder_start_token_id
126
-
127
- def to_dict(self):
128
- """
129
- Serializes this instance to a Python dictionary. Override the default `to_dict()` from `PretrainedConfig`.
130
-
131
- Returns:
132
- :obj:`Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance,
133
- """
134
- output = copy.deepcopy(self.__dict__)
135
- output["model_type"] = self.__class__.model_type
136
- output['t5'] = self.t5.to_dict()
137
- return output
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
model/decoders.py DELETED
@@ -1,23 +0,0 @@
1
- import logging
2
- import flax.linen as nn
3
-
4
- logger = logging.getLogger(__name__)
5
-
6
-
7
- class Decoder(nn.Module):
8
- '''
9
- Converts latent code -> transformer encoding.
10
- '''
11
- dim_model: int
12
- n_latent_tokens: int
13
-
14
- @nn.compact
15
- def __call__(self, latent_code): # (batch, latent_tokens_per_sequence, latent_token_dim)
16
- raw_latent_tokens = nn.Dense(self.dim_model)(latent_code)
17
- latent_tokens = nn.LayerNorm()(raw_latent_tokens)
18
- return latent_tokens # (batch, latent_tokens_per_sequence, dim_model)
19
-
20
-
21
- VAE_DECODER_MODELS = {
22
- '': Decoder,
23
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
model/encoders.py DELETED
@@ -1,26 +0,0 @@
1
- import logging
2
- import jax.numpy as jnp
3
- import flax.linen as nn
4
-
5
- logger = logging.getLogger(__name__)
6
-
7
-
8
- class Encoder(nn.Module):
9
- '''
10
- Converts N hidden tokens into N seperate latent codes.
11
- '''
12
- latent_token_size: int
13
- n_latent_tokens: int
14
-
15
- @nn.compact
16
- def __call__(self, encoding):
17
- latent_tokens = nn.Dense(self.latent_token_size)(encoding)
18
- raw_latent_code = latent_tokens[:, : self.n_latent_tokens, :]
19
- # TODO does this just apply tanh to each latent token? Or across the whole batch
20
- latent_code = jnp.tanh(raw_latent_code)
21
- return latent_code # (batch, latent_tokens_per_sequence, latent_token_dim)
22
-
23
-
24
- VAE_ENCODER_MODELS = {
25
- '': Encoder,
26
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
model/outputs.py DELETED
@@ -1,74 +0,0 @@
1
- from typing import Optional, Tuple
2
-
3
- import flax
4
- import jaxlib.xla_extension as jax_xla
5
-
6
- from transformers.file_utils import ModelOutput
7
-
8
-
9
- @flax.struct.dataclass
10
- class TransformerVaeOutput(ModelOutput):
11
- """
12
- Base class for a Transformer-VAE's outputs.
13
-
14
- Args:
15
- latent_codes (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, n_latent_tokens, latent_token_size)`):
16
- Latent codes representing encoded sequences.
17
- remade_encoder_hidden_state (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, n_tokens, model_dim)`):
18
- Reconstructed encoder hidden states representing sequences.
19
-
20
- (std Seq2Seq) Args:
21
- logits (:obj:`jax_xla.DeviceArray` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`):
22
- Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
23
- past_key_values (:obj:`tuple(tuple(jax_xla.DeviceArray))`, `optional`, returned when ``use_cache=True`` is passed or when ``config.use_cache=True``):
24
- Tuple of :obj:`tuple(jax_xla.DeviceArray)` of length :obj:`config.n_layers`, with each tuple having 2
25
- tensors of shape :obj:`(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional
26
- tensors of shape :obj:`(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
27
-
28
- Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
29
- blocks) that can be used (see :obj:`past_key_values` input) to speed up sequential decoding.
30
- last_hidden_state (:obj:`tuple(jax_xla.DeviceArray)`:
31
- Last model hidden state.
32
- decoder_hidden_states (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
33
- Tuple of :obj:`jax_xla.DeviceArray` (one for the output of the embeddings + one for the output of each
34
- layer) of shape :obj:`(batch_size, sequence_length, hidden_size)`.
35
-
36
- Hidden-states of the decoder at the output of each layer plus the initial embedding outputs.
37
- decoder_attentions (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
38
- Tuple of :obj:`jax_xla.DeviceArray` (one for each layer) of shape :obj:`(batch_size, num_heads,
39
- sequence_length, sequence_length)`.
40
-
41
- Attentions weights of the decoder, after the attention softmax, used to compute the weighted average in the
42
- self-attention heads.
43
- cross_attentions (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
44
- Tuple of :obj:`jax_xla.DeviceArray` (one for each layer) of shape :obj:`(batch_size, num_heads,
45
- sequence_length, sequence_length)`.
46
-
47
- Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the
48
- weighted average in the cross-attention heads.
49
- encoder_last_hidden_state (:obj:`jax_xla.DeviceArray` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
50
- Sequence of hidden-states at the output of the last layer of the encoder of the model.
51
- encoder_hidden_states (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
52
- Tuple of :obj:`jax_xla.DeviceArray` (one for the output of the embeddings + one for the output of each
53
- layer) of shape :obj:`(batch_size, sequence_length, hidden_size)`.
54
-
55
- Hidden-states of the encoder at the output of each layer plus the initial embedding outputs.
56
- encoder_attentions (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
57
- Tuple of :obj:`jax_xla.DeviceArray` (one for each layer) of shape :obj:`(batch_size, num_heads,
58
- sequence_length, sequence_length)`.
59
-
60
- Attentions weights of the encoder, after the attention softmax, used to compute the weighted average in the
61
- self-attention heads.
62
- """
63
- logits: jax_xla.DeviceArray = None
64
- latent_codes: jax_xla.DeviceArray = None
65
- remade_encoder_hidden_state: jax_xla.DeviceArray = None
66
- # seq2seq
67
- past_key_values: Optional[Tuple[Tuple[jax_xla.DeviceArray]]] = None
68
- decoder_hidden_states: Optional[Tuple[jax_xla.DeviceArray]] = None
69
- decoder_attentions: Optional[Tuple[jax_xla.DeviceArray]] = None
70
- cross_attentions: Optional[Tuple[jax_xla.DeviceArray]] = None
71
- last_hidden_state: Optional[jax_xla.DeviceArray] = None
72
- encoder_last_hidden_state: Optional[jax_xla.DeviceArray] = None
73
- encoder_hidden_states: Optional[Tuple[jax_xla.DeviceArray]] = None
74
- encoder_attentions: Optional[Tuple[jax_xla.DeviceArray]] = None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
model/t5_vae.py DELETED
@@ -1,522 +0,0 @@
1
- from typing import Optional, Tuple
2
-
3
- import jax
4
- import jax.numpy as jnp
5
- from jax.random import PRNGKey
6
- import flax.linen as nn
7
- from flax.core.frozen_dict import FrozenDict, unfreeze
8
-
9
- from transformers.modeling_flax_outputs import FlaxCausalLMOutputWithCrossAttentions
10
- from transformers.file_utils import add_start_docstrings
11
- from transformers.modeling_flax_utils import FlaxPreTrainedModel
12
- from transformers.models.t5.modeling_flax_t5 import FlaxT5ForConditionalGenerationModule
13
-
14
- from model.vae import VAE
15
- from model.outputs import TransformerVaeOutput
16
- from model.config import T5VaeConfig
17
-
18
-
19
- @add_start_docstrings("""T5 Model with a `language modeling` head on top converted into a VAE.""")
20
- class FlaxT5VaeForAutoencodingModule(nn.Module):
21
- config: T5VaeConfig
22
- dtype: jnp.dtype = jnp.float32 # the dtype of the computation
23
-
24
- def _get_encoder_module(self):
25
- return self.t5.encoder
26
-
27
- def _get_vae_encoder_module(self):
28
- return self.vae.encoder
29
-
30
- def _get_vae_decoder_module(self):
31
- return self.vae.decoder
32
-
33
- def _get_decoder_module(self):
34
- return self.t5.decoder
35
-
36
- def setup(self):
37
- self.t5 = FlaxT5ForConditionalGenerationModule(self.config.t5)
38
- self.vae = VAE(self.config)
39
-
40
- def __call__(
41
- self,
42
- input_ids=None,
43
- attention_mask=None,
44
- decoder_input_ids=None,
45
- decoder_attention_mask=None,
46
- encoder_outputs=None,
47
- latent_codes=None,
48
- output_attentions=None,
49
- output_hidden_states=None,
50
- return_dict=None,
51
- deterministic: bool = True,
52
- ):
53
- """
54
- Adapted from `FlaxT5ForConditionalGenerationModule`
55
- """
56
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
57
-
58
- # Encode
59
- encoder_outputs = self.t5.encoder(
60
- input_ids=input_ids,
61
- attention_mask=attention_mask,
62
- output_attentions=output_attentions,
63
- output_hidden_states=output_hidden_states,
64
- return_dict=return_dict,
65
- deterministic=deterministic,
66
- )
67
-
68
- hidden_states = encoder_outputs[0]
69
-
70
- # Autoencode
71
- hidden_states, latent_codes = self.vae(hidden_states, latent_codes)
72
- encoder_attention_mask = jnp.ones((hidden_states.shape[0], hidden_states.shape[1]))
73
-
74
- # Decode
75
- decoder_outputs = self.t5.decoder(
76
- input_ids=decoder_input_ids,
77
- attention_mask=decoder_attention_mask,
78
- encoder_hidden_states=hidden_states,
79
- encoder_attention_mask=encoder_attention_mask,
80
- output_attentions=output_attentions,
81
- output_hidden_states=output_hidden_states,
82
- return_dict=return_dict,
83
- deterministic=deterministic,
84
- )
85
-
86
- sequence_output = decoder_outputs[0]
87
-
88
- if self.t5.config.tie_word_embeddings:
89
- # Rescale output before projecting on vocab
90
- # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/transformer.py#L586
91
- sequence_output = sequence_output * (self.t5.config.d_model ** -0.5)
92
-
93
- if self.t5.config.tie_word_embeddings:
94
- shared_embedding = self.t5.shared.variables["params"]["embedding"]
95
- lm_logits = self.t5.lm_head.apply({"params": {"kernel": shared_embedding.T}}, sequence_output)
96
- else:
97
- lm_logits = self.t5.lm_head(sequence_output)
98
-
99
- if not return_dict:
100
- return [lm_logits, latent_codes] + decoder_outputs[1:] + encoder_outputs
101
-
102
- return TransformerVaeOutput(
103
- logits=lm_logits,
104
- latent_codes=latent_codes,
105
- last_hidden_state=decoder_outputs.last_hidden_state,
106
- past_key_values=decoder_outputs.past_key_values,
107
- decoder_hidden_states=decoder_outputs.hidden_states,
108
- decoder_attentions=decoder_outputs.attentions,
109
- cross_attentions=decoder_outputs.cross_attentions,
110
- encoder_last_hidden_state=encoder_outputs.last_hidden_state,
111
- encoder_hidden_states=encoder_outputs.hidden_states,
112
- encoder_attentions=encoder_outputs.attentions,
113
- )
114
-
115
-
116
- class FlaxT5VaePreTrainedModel(FlaxPreTrainedModel):
117
- """
118
- An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
119
- models.
120
- """
121
-
122
- config_class = T5VaeConfig
123
- base_model_prefix = "transformer"
124
- module_class: nn.Module = None
125
-
126
- def __init__(
127
- self,
128
- config: T5VaeConfig,
129
- input_shape: Tuple[int] = (1, 1),
130
- seed: int = 0,
131
- dtype: jnp.dtype = jnp.float32,
132
- **kwargs
133
- ):
134
- module = self.module_class(config=config, dtype=dtype, **kwargs)
135
- super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype)
136
-
137
- def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> FrozenDict:
138
- # init input tensors
139
- input_ids = jnp.zeros(input_shape, dtype="i4")
140
-
141
- attention_mask = jnp.ones_like(input_ids)
142
- decoder_input_ids = jnp.ones_like(input_ids)
143
- decoder_attention_mask = jnp.ones_like(input_ids)
144
-
145
- params_rng, dropout_rng = jax.random.split(rng)
146
- rngs = {"params": params_rng, "dropout": dropout_rng}
147
-
148
- return self.module.init(
149
- rngs,
150
- input_ids,
151
- attention_mask,
152
- decoder_input_ids,
153
- decoder_attention_mask,
154
- )["params"]
155
-
156
- def __call__(
157
- self,
158
- input_ids: jnp.ndarray,
159
- attention_mask: Optional[jnp.ndarray] = None,
160
- decoder_input_ids: jnp.ndarray = None,
161
- decoder_attention_mask: Optional[jnp.ndarray] = None,
162
- output_attentions: Optional[bool] = None,
163
- output_hidden_states: Optional[bool] = None,
164
- return_dict: Optional[bool] = None,
165
- train: bool = False,
166
- params: dict = None,
167
- dropout_rng: PRNGKey = None,
168
- ):
169
- output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
170
- output_hidden_states = (
171
- output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
172
- )
173
- return_dict = return_dict if return_dict is not None else self.config.return_dict
174
-
175
- if decoder_input_ids is None:
176
- raise ValueError(
177
- "Make sure to provide both `input_ids` and `decoder_input_ids`. `decoder_input_ids` is not passed here."
178
- )
179
-
180
- # prepare encoder inputs
181
- if attention_mask is None:
182
- attention_mask = jnp.ones_like(input_ids)
183
-
184
- # prepare decoder inputs
185
- if decoder_attention_mask is None:
186
- decoder_attention_mask = jnp.ones_like(decoder_input_ids)
187
-
188
- # Handle any PRNG if needed
189
- rngs = {"dropout": dropout_rng} if dropout_rng is not None else {}
190
-
191
- return self.module.apply(
192
- {"params": params or self.params},
193
- input_ids=jnp.array(input_ids, dtype="i4"),
194
- attention_mask=jnp.array(attention_mask, dtype="i4"),
195
- decoder_input_ids=jnp.array(decoder_input_ids, dtype="i4"),
196
- decoder_attention_mask=jnp.array(decoder_attention_mask, dtype="i4"),
197
- output_attentions=output_attentions,
198
- output_hidden_states=output_hidden_states,
199
- return_dict=return_dict,
200
- deterministic=not train,
201
- rngs=rngs,
202
- )
203
-
204
- def init_cache(self, batch_size, max_length, latent_codes):
205
- r"""
206
- Args:
207
- batch_size (:obj:`int`):
208
- batch_size used for fast auto-regressive decoding. Defines the batch size of the initialized cache.
209
- max_length (:obj:`int`):
210
- maximum possible length for auto-regressive decoding. Defines the sequence length of the initialized
211
- cache.
212
- latent_codes (:obj:`Union[FlaxBaseModelOutput, tuple(tuple(jnp.ndarray)]`):
213
- ``latent_codes`` consists of compressed hidden-states at the output of the last layer of the encoder.
214
- Used in the cross-attention of the decoder.
215
- """
216
- # init input variables to retrieve cache
217
- decoder_input_ids = jnp.ones((batch_size, max_length), dtype="i4")
218
- decoder_attention_mask = jnp.ones_like(decoder_input_ids)
219
-
220
- def _decoder_forward(module, decoder_input_ids, latent_codes, decoder_attention_mask, **kwargs):
221
- vae_decoder_module = module._get_vae_decoder_module()
222
- decoder_module = module._get_decoder_module()
223
- return decoder_module(
224
- decoder_input_ids,
225
- decoder_attention_mask,
226
- encoder_hidden_states=vae_decoder_module(latent_codes),
227
- **kwargs,
228
- )
229
-
230
- init_variables = self.module.init(
231
- jax.random.PRNGKey(0),
232
- decoder_input_ids=decoder_input_ids,
233
- latent_codes=latent_codes,
234
- decoder_attention_mask=decoder_attention_mask,
235
- init_cache=True,
236
- method=_decoder_forward, # we only need to call the decoder to init the cache
237
- )
238
- return unfreeze(init_variables["cache"])
239
-
240
- def encode(
241
- self,
242
- input_ids: jnp.ndarray,
243
- attention_mask: Optional[jnp.ndarray] = None,
244
- output_attentions: Optional[bool] = None,
245
- output_hidden_states: Optional[bool] = None,
246
- return_dict: Optional[bool] = None,
247
- train: bool = False,
248
- params: dict = None,
249
- dropout_rng: PRNGKey = None,
250
- ):
251
- raise NotImplementedError()
252
-
253
- def decode(
254
- self,
255
- decoder_input_ids,
256
- latent_codes,
257
- encoder_attention_mask: Optional[jnp.ndarray] = None,
258
- decoder_attention_mask: Optional[jnp.ndarray] = None,
259
- past_key_values: dict = None,
260
- output_attentions: Optional[bool] = None,
261
- output_hidden_states: Optional[bool] = None,
262
- return_dict: Optional[bool] = None,
263
- train: bool = False,
264
- params: dict = None,
265
- dropout_rng: PRNGKey = None,
266
- ):
267
- raise NotImplementedError()
268
-
269
-
270
- class FlaxT5VaeForAutoencoding(FlaxT5VaePreTrainedModel):
271
- module_class = FlaxT5VaeForAutoencodingModule
272
-
273
- def __call__(
274
- self,
275
- input_ids: jnp.ndarray,
276
- attention_mask: Optional[jnp.ndarray] = None,
277
- decoder_input_ids=None,
278
- decoder_attention_mask=None,
279
- output_attentions: Optional[bool] = None,
280
- output_hidden_states: Optional[bool] = None,
281
- return_dict: Optional[bool] = None,
282
- train: bool = False,
283
- params: dict = None,
284
- dropout_rng: PRNGKey = None,
285
- ):
286
- '''
287
- Adapted from `FlaxT5PreTrainedModel`
288
- '''
289
- output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
290
- output_hidden_states = (
291
- output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
292
- )
293
- return_dict = return_dict if return_dict is not None else self.config.return_dict
294
-
295
- if decoder_input_ids is None:
296
- raise ValueError(
297
- "Make sure to provide both `input_ids` and `decoder_input_ids`. `decoder_input_ids` is not passed here."
298
- )
299
-
300
- # prepare encoder inputs
301
- if attention_mask is None:
302
- attention_mask = jnp.ones_like(input_ids)
303
-
304
- # prepare decoder inputs
305
- if decoder_attention_mask is None:
306
- decoder_attention_mask = jnp.ones_like(decoder_input_ids)
307
-
308
- # Handle any PRNG if needed
309
- rngs = {"dropout": dropout_rng} if dropout_rng is not None else {}
310
-
311
- return self.module.apply(
312
- {"params": params or self.params},
313
- input_ids=jnp.array(input_ids, dtype="i4"),
314
- attention_mask=jnp.array(attention_mask, dtype="i4"),
315
- decoder_input_ids=jnp.array(decoder_input_ids, dtype="i4"),
316
- decoder_attention_mask=jnp.array(decoder_attention_mask, dtype="i4"),
317
- output_attentions=output_attentions,
318
- output_hidden_states=output_hidden_states,
319
- return_dict=return_dict,
320
- deterministic=not train,
321
- rngs=rngs,
322
- )
323
-
324
- def encode(
325
- self,
326
- input_ids: jnp.ndarray,
327
- attention_mask: Optional[jnp.ndarray] = None,
328
- output_attentions: Optional[bool] = None,
329
- output_hidden_states: Optional[bool] = None,
330
- return_dict: Optional[bool] = None,
331
- train: bool = False,
332
- params: dict = None,
333
- dropout_rng: PRNGKey = None,
334
- ):
335
- output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
336
- output_hidden_states = (
337
- output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
338
- )
339
- return_dict = return_dict if return_dict is not None else self.config.return_dict
340
-
341
- if attention_mask is None:
342
- attention_mask = jnp.ones_like(input_ids)
343
-
344
- # Handle any PRNG if needed
345
- rngs = {}
346
- if dropout_rng is not None:
347
- rngs["dropout"] = dropout_rng
348
-
349
- def _encoder_forward(module, input_ids, attention_mask, **kwargs):
350
- encode_module = module._get_encoder_module()
351
- vae_encoder_module = module._get_vae_encoder_module()
352
- return vae_encoder_module(encode_module(input_ids, attention_mask, **kwargs)[0])
353
-
354
- return self.module.apply(
355
- {"params": params or self.params},
356
- input_ids=jnp.array(input_ids, dtype="i4"),
357
- attention_mask=jnp.array(attention_mask, dtype="i4"),
358
- output_attentions=output_attentions,
359
- output_hidden_states=output_hidden_states,
360
- return_dict=return_dict,
361
- deterministic=not train,
362
- rngs=rngs,
363
- method=_encoder_forward,
364
- )
365
-
366
- def decode(
367
- self,
368
- decoder_input_ids,
369
- latent_codes,
370
- encoder_attention_mask: Optional[jnp.ndarray] = None,
371
- decoder_attention_mask: Optional[jnp.ndarray] = None,
372
- past_key_values: dict = None,
373
- output_attentions: Optional[bool] = None,
374
- output_hidden_states: Optional[bool] = None,
375
- return_dict: Optional[bool] = None,
376
- train: bool = False,
377
- params: dict = None,
378
- dropout_rng: PRNGKey = None,
379
- ):
380
- r"""
381
- Returns:
382
-
383
- Example::
384
-
385
- >>> model = FlaxT5VaeForAutoencoding.from_pretrained('t5-small')
386
- >>> tokenizer = T5Tokenizer.from_pretrained('t5-small')
387
-
388
- >>> text = "My friends are cool but they eat too many carbs."
389
- >>> inputs = tokenizer(text, max_length=512, return_tensors='jax')
390
- >>> latent_codes = model.encode(**inputs)
391
-
392
- >>> decoder_start_token_id = model.config.decoder_start_token_id
393
- >>> decoder_input_ids = jnp.ones((inputs.input_ids.shape[0], 1), dtype="i4") * decoder_start_token_id
394
-
395
- >>> outputs = model.decode(decoder_input_ids, latent_codes)
396
- >>> last_decoder_hidden_states = outputs.last_hidden_state
397
- """
398
- output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
399
- output_hidden_states = (
400
- output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
401
- )
402
- return_dict = return_dict if return_dict is not None else self.config.return_dict
403
-
404
- if encoder_attention_mask is None:
405
- batch_size, sequence_length = latent_codes.shape[:2]
406
- encoder_attention_mask = jnp.ones((batch_size, sequence_length))
407
-
408
- batch_size, sequence_length = decoder_input_ids.shape
409
- if decoder_attention_mask is None:
410
- decoder_attention_mask = jnp.ones((batch_size, sequence_length))
411
-
412
- # Handle any PRNG if needed
413
- rngs = {}
414
- if dropout_rng is not None:
415
- rngs["dropout"] = dropout_rng
416
-
417
- inputs = {"params": params or self.params}
418
-
419
- # if past_key_values are passed then cache is already initialized a private flag init_cache has to be
420
- # passed down to ensure cache is used. It has to be made sure that cache is marked as mutable so that
421
- # it can be changed by FlaxT5Attention module
422
- if past_key_values:
423
- inputs["cache"] = past_key_values
424
- mutable = ["cache"]
425
- else:
426
- mutable = False
427
-
428
- def _decoder_forward(module, decoder_input_ids, latent_codes, decoder_attention_mask, **kwargs):
429
- vae_decoder_module = module._get_vae_decoder_module()
430
- decoder_module = module._get_decoder_module()
431
- decoder_outputs = decoder_module(
432
- decoder_input_ids,
433
- decoder_attention_mask,
434
- encoder_hidden_states=vae_decoder_module(latent_codes),
435
- **kwargs,
436
- )
437
- sequence_output = decoder_outputs[0]
438
-
439
- if self.config.tie_word_embeddings:
440
- # Rescale output before projecting on vocab
441
- # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/transformer.py#L586
442
- sequence_output = sequence_output * (self.config.d_model ** -0.5)
443
-
444
- if self.config.tie_word_embeddings:
445
- shared_embedding = module.t5.shared.variables["params"]["embedding"]
446
- lm_logits = module.t5.lm_head.apply({"params": {"kernel": shared_embedding.T}}, sequence_output)
447
- else:
448
- lm_logits = module.t5.lm_head(sequence_output)
449
-
450
- return lm_logits, decoder_outputs
451
-
452
- outputs = self.module.apply(
453
- inputs,
454
- decoder_input_ids=jnp.array(decoder_input_ids, dtype="i4"),
455
- latent_codes=latent_codes,
456
- decoder_attention_mask=jnp.array(decoder_attention_mask, dtype="i4"),
457
- encoder_attention_mask=jnp.array(encoder_attention_mask, dtype="i4"),
458
- output_attentions=output_attentions,
459
- output_hidden_states=output_hidden_states,
460
- return_dict=return_dict,
461
- deterministic=not train,
462
- rngs=rngs,
463
- mutable=mutable,
464
- method=_decoder_forward,
465
- )
466
-
467
- if past_key_values is None:
468
- lm_logits, decoder_outputs = outputs
469
- else:
470
- (lm_logits, decoder_outputs), past = outputs
471
-
472
- if return_dict:
473
- outputs = FlaxCausalLMOutputWithCrossAttentions(
474
- logits=lm_logits,
475
- hidden_states=decoder_outputs.hidden_states,
476
- attentions=decoder_outputs.attentions,
477
- cross_attentions=decoder_outputs.cross_attentions,
478
- )
479
- else:
480
- outputs = (lm_logits,) + decoder_outputs[1:]
481
-
482
- # add updated cache to model output
483
- if past_key_values is not None and return_dict:
484
- outputs["past_key_values"] = unfreeze(past["cache"])
485
- return outputs
486
- elif past_key_values is not None and not return_dict:
487
- outputs = outputs[:1] + (unfreeze(past["cache"]),) + outputs[1:]
488
-
489
- return outputs
490
-
491
- def prepare_inputs_for_generation(
492
- self,
493
- decoder_input_ids,
494
- max_length,
495
- attention_mask: Optional[jnp.DeviceArray] = None,
496
- decoder_attention_mask: Optional[jnp.DeviceArray] = None,
497
- latent_codes=None,
498
- **kwargs
499
- ):
500
- # initializing the cache
501
- batch_size, seq_length = decoder_input_ids.shape
502
-
503
- past_key_values = self.init_cache(batch_size, max_length, latent_codes)
504
- # Note that usually one would have to put 0's in the attention_mask for x > input_ids.shape[-1] and x < cache_length.
505
- # But since the decoder uses a causal mask, those positions are masked anyways.
506
- # Thus we can create a single static attention_mask here, which is more efficient for compilation
507
- extended_attention_mask = jnp.ones((batch_size, max_length), dtype="i4")
508
- if decoder_attention_mask is not None:
509
- extended_attention_mask = jax.lax.dynamic_update_slice(
510
- extended_attention_mask, decoder_attention_mask, (0, 0)
511
- )
512
-
513
- return {
514
- "past_key_values": past_key_values,
515
- "latent_codes": latent_codes,
516
- "encoder_attention_mask": attention_mask,
517
- "decoder_attention_mask": extended_attention_mask,
518
- }
519
-
520
- def update_inputs_for_generation(self, model_outputs, model_kwargs):
521
- model_kwargs["past_key_values"] = model_outputs.past_key_values
522
- return model_kwargs
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
model/utils.py DELETED
@@ -1,24 +0,0 @@
1
- from typing import Sequence
2
-
3
- import flax.linen as nn
4
-
5
-
6
- class MLP(nn.Module):
7
- features: Sequence[int]
8
-
9
- @nn.compact
10
- def __call__(self, x):
11
- for feat in self.features[:-1]:
12
- x = nn.relu(nn.Dense(feat)(x))
13
- x = nn.Dense(self.features[-1])(x)
14
- return x
15
-
16
-
17
- def assertEqual(actual, expected, msg, first="Got", second="Expected"):
18
- if actual != expected:
19
- raise ValueError(msg + f' {first}: "{actual}" {second}: "{expected}"')
20
-
21
-
22
- def assertIn(actual, expected, msg, first="Got", second="Expected one of"):
23
- if actual not in expected:
24
- raise ValueError(msg + f' {first}: "{actual}" {second}: {expected}')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
model/vae.py DELETED
@@ -1,30 +0,0 @@
1
- import jax.numpy as jnp
2
- import flax.linen as nn
3
-
4
- from model.encoders import VAE_ENCODER_MODELS
5
- from model.decoders import VAE_DECODER_MODELS
6
- from model.config import T5VaeConfig
7
-
8
-
9
- class VAE(nn.Module):
10
- # see https://github.com/google/flax#what-does-flax-look-like
11
- """
12
- An MMD-VAE used with encoder-decoder models.
13
- Encodes all token encodings into a single latent & spits them back out.
14
- """
15
- config: T5VaeConfig
16
- dtype: jnp.dtype = jnp.float32 # the dtype of the computation
17
-
18
- def setup(self):
19
- self.encoder = VAE_ENCODER_MODELS[self.config.vae_encoder_model](self.config.latent_token_size, self.config.n_latent_tokens)
20
- self.decoder = VAE_DECODER_MODELS[self.config.vae_decoder_model](self.config.t5.d_model, self.config.n_latent_tokens)
21
-
22
- def __call__(self, encoding=None, latent_codes=None):
23
- latent_codes = self.encode(encoding)
24
- return self.decode(latent_codes), latent_codes
25
-
26
- def encode(self, encoding):
27
- return self.encoder(encoding)
28
-
29
- def decode(self, latent):
30
- return self.decoder(latent)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
t5-vae-flax DELETED
@@ -1 +0,0 @@
1
- Subproject commit 0a7735b81b50995c0d1901501c5e6928ce62c0ef
 
 
t5_vae_flax CHANGED
@@ -1 +1 @@
1
- Subproject commit 78562617b5fac81e1798f5dbde27c8ff9d4e378b
 
1
+ Subproject commit 0c030dca4751e6def730968a2f33fe093a608cdb
train.py CHANGED
@@ -46,8 +46,8 @@ from transformers import (
46
  from transformers.models.t5.modeling_flax_t5 import shift_tokens_right
47
  from transformers.testing_utils import CaptureLogger
48
 
49
- from model.t5_vae import FlaxT5VaeForAutoencoding
50
- from model.config import T5VaeConfig
51
 
52
 
53
  logger = logging.getLogger(__name__)
 
46
  from transformers.models.t5.modeling_flax_t5 import shift_tokens_right
47
  from transformers.testing_utils import CaptureLogger
48
 
49
+ from t5_vae_flax.src.t5_vae import FlaxT5VaeForAutoencoding
50
+ from t5_vae_flax.src.config import T5VaeConfig
51
 
52
 
53
  logger = logging.getLogger(__name__)
train.sh CHANGED
@@ -1,4 +1,4 @@
1
- export RUN_NAME=single_latent
2
 
3
  ./venv/bin/python train.py \
4
  --t5_model_name_or_path="t5-base" \
@@ -6,8 +6,8 @@ export RUN_NAME=single_latent
6
  --overwrite_output_dir \
7
  --dataset_name="Fraser/python-lines" \
8
  --do_train --do_eval \
9
- --n_latent_tokens 1 \
10
- --latent_token_size 32 \
11
  --save_steps="2500" \
12
  --eval_steps="2500" \
13
  --block_size="32" \
 
1
+ export RUN_NAME=two_latent
2
 
3
  ./venv/bin/python train.py \
4
  --t5_model_name_or_path="t5-base" \
 
6
  --overwrite_output_dir \
7
  --dataset_name="Fraser/python-lines" \
8
  --do_train --do_eval \
9
+ --n_latent_tokens 2 \
10
+ --latent_token_size 16 \
11
  --save_steps="2500" \
12
  --eval_steps="2500" \
13
  --block_size="32" \