boris commited on
Commit
f139b0b
·
unverified ·
1 Parent(s): 69bcbeb

feat: add sinkformer + custom final ln + pre-ln (#151)

Browse files
README.md CHANGED
@@ -124,8 +124,9 @@ Sequence to sequence model based on "[BART: Denoising Sequence-to-Sequence Pre-t
124
  - "[Deepnet: Scaling Transformers to 1,000 Layers](https://arxiv.org/abs/2203.00555)"
125
  - "[NormFormer: Improved Transformer Pretraining with Extra Normalization](https://arxiv.org/abs/2110.09456)"
126
  - "[Swin Transformer: Hierarchical Vision Transformer using Shifted Windows](https://arxiv.org/abs/2103.14030)"
127
- - "[CogView: Mastering Text-to-Image Generation via Transformers](https://arxiv.org/abs/2105.13290v2)
128
  - "[Root Mean Square Layer Normalization](https://arxiv.org/abs/1910.07467)"
 
129
 
130
  Main optimizer (Distributed Shampoo) from "[Scalable Second Order Optimization for Deep Learning](https://arxiv.org/abs/2002.09018)".
131
 
@@ -247,3 +248,12 @@ Main optimizer (Distributed Shampoo) from "[Scalable Second Order Optimization f
247
  primaryClass = {cs.LG}
248
  }
249
  ```
 
 
 
 
 
 
 
 
 
 
124
  - "[Deepnet: Scaling Transformers to 1,000 Layers](https://arxiv.org/abs/2203.00555)"
125
  - "[NormFormer: Improved Transformer Pretraining with Extra Normalization](https://arxiv.org/abs/2110.09456)"
126
  - "[Swin Transformer: Hierarchical Vision Transformer using Shifted Windows](https://arxiv.org/abs/2103.14030)"
127
+ - "[CogView: Mastering Text-to-Image Generation via Transformers](https://arxiv.org/abs/2105.13290v2)"
128
  - "[Root Mean Square Layer Normalization](https://arxiv.org/abs/1910.07467)"
129
+ - "[Sinkformers: Transformers with Doubly Stochastic Attention](https://arxiv.org/abs/2110.11773)"
130
 
131
  Main optimizer (Distributed Shampoo) from "[Scalable Second Order Optimization for Deep Learning](https://arxiv.org/abs/2002.09018)".
132
 
 
248
  primaryClass = {cs.LG}
249
  }
250
  ```
251
+
252
+ ```text
253
+ @misc{title = {Sinkformers: Transformers with Doubly Stochastic Attention},
254
+ url = {https://arxiv.org/abs/2110.11773},
255
+ author = {Sander, Michael E. and Ablin, Pierre and Blondel, Mathieu and Peyré, Gabriel},
256
+ publisher = {arXiv},
257
+ year = {2021},
258
+ }
259
+ ```
src/dalle_mini/model/configuration.py CHANGED
@@ -59,37 +59,39 @@ class DalleBartConfig(PretrainedFromWandbMixin, PretrainedConfig):
59
  do_sample=True,
60
  # transformer variants
61
  ln_type="layernorm", # layer normalization type, "rmsnorm", "layernorm"
62
- ln_positions="normformer", # layer normalization positions, "normformer", "swinv2", "cogview", "postln", "deepnet" (same as postln)
63
- head_scale=False, # used in NormFormer
64
  use_cosine_attention=False, # used in Swin v2
65
  tau_init=0.05, # used only in cosine attention (Swin v2)
66
  use_deepnet_scaling=False, # used in Deepnet
67
  use_glu=False, # "GLU Variants Improve Transformer"
68
  use_alibi=False, # from "Train Short, Test Long: Attention with Linear Biases Enables Input Length Extrapolation"
69
- sink_iters=1, # used in SinkFormers
 
 
70
  # parameters that should not be necessary but could affect results
71
- force_ln_scale=True, # force scale in layernorm even when followed by dense layers
72
- force_final_ln_encoder=False, # force layer normalization in encoder final layer even when followed by dense layers
73
  **kwargs,
74
  ):
75
  # text normalizer
76
  self.normalize_text = normalize_text
77
 
78
  # transformer variants
79
- self.head_scale = head_scale # per Normformer
80
  assert ln_type in [
81
  "rmsnorm",
82
  "layernorm",
83
  ], "ln_type must be 'rmsnorm' or 'layernorm'"
84
  self.ln_type = ln_type
 
 
85
  assert ln_positions in [
86
  "normformer",
87
  "swinv2",
88
  "cogview",
89
- "deepnet",
90
- ], "ln_positions must be 'normformer', 'swinv2' or 'deepnet'"
91
- if ln_positions == "deepnet":
92
- ln_positions = "postln"
93
  assert use_alibi is False, "use_alibi is not supported yet"
94
  self.ln_positions = ln_positions
95
  self.use_cosine_attention = use_cosine_attention
@@ -97,9 +99,17 @@ class DalleBartConfig(PretrainedFromWandbMixin, PretrainedConfig):
97
  self.use_deepnet_scaling = use_deepnet_scaling
98
  self.use_glu = use_glu
99
  self.use_alibi = use_alibi
100
- self.sink_iters = sink_iters
 
 
 
 
 
 
 
 
 
101
  self.force_ln_scale = force_ln_scale
102
- self.force_final_ln_encoder = force_final_ln_encoder
103
 
104
  # common parameters
105
  self.encoder_vocab_size = encoder_vocab_size
 
59
  do_sample=True,
60
  # transformer variants
61
  ln_type="layernorm", # layer normalization type, "rmsnorm", "layernorm"
62
+ ln_positions="normformer", # layer normalization positions, "normformer", "swinv2", "cogview", "postln", "preln", "deepnet" (same as postln)
63
+ use_head_scale=False, # used in NormFormer
64
  use_cosine_attention=False, # used in Swin v2
65
  tau_init=0.05, # used only in cosine attention (Swin v2)
66
  use_deepnet_scaling=False, # used in Deepnet
67
  use_glu=False, # "GLU Variants Improve Transformer"
68
  use_alibi=False, # from "Train Short, Test Long: Attention with Linear Biases Enables Input Length Extrapolation"
69
+ sinkhorn_iters=1, # used in SinkFormers
70
+ use_final_ln_encoder=False, # final layer normalization in encoder
71
+ use_final_ln_decoder=False, # final layer normalization in decoder
72
  # parameters that should not be necessary but could affect results
73
+ force_ln_scale=False, # force scale in layernorm even when followed by dense layers
 
74
  **kwargs,
75
  ):
76
  # text normalizer
77
  self.normalize_text = normalize_text
78
 
79
  # transformer variants
80
+ self.use_head_scale = use_head_scale # per Normformer
81
  assert ln_type in [
82
  "rmsnorm",
83
  "layernorm",
84
  ], "ln_type must be 'rmsnorm' or 'layernorm'"
85
  self.ln_type = ln_type
86
+ if ln_positions == "deepnet":
87
+ ln_positions = "postln"
88
  assert ln_positions in [
89
  "normformer",
90
  "swinv2",
91
  "cogview",
92
+ "postln",
93
+ "preln",
94
+ ], "ln_positions must be 'normformer', 'swinv2', 'cogview', 'postln', 'preln'"
 
95
  assert use_alibi is False, "use_alibi is not supported yet"
96
  self.ln_positions = ln_positions
97
  self.use_cosine_attention = use_cosine_attention
 
99
  self.use_deepnet_scaling = use_deepnet_scaling
100
  self.use_glu = use_glu
101
  self.use_alibi = use_alibi
102
+ self.sinkhorn_iters = sinkhorn_iters
103
+ if ln_positions == "postln":
104
+ assert (
105
+ use_final_ln_encoder
106
+ ), "use_final_ln_encoder must be True when ln_positions is 'postln'"
107
+ assert (
108
+ use_final_ln_decoder
109
+ ), "use_final_ln_decoder must be True when ln_positions is 'postln'"
110
+ self.use_final_ln_encoder = use_final_ln_encoder
111
+ self.use_final_ln_decoder = use_final_ln_decoder
112
  self.force_ln_scale = force_ln_scale
 
113
 
114
  # common parameters
115
  self.encoder_vocab_size = encoder_vocab_size
src/dalle_mini/model/modeling.py CHANGED
@@ -28,7 +28,7 @@ import msgpack.exceptions
28
  from flax.core.frozen_dict import unfreeze
29
  from flax.linen import combine_masks, make_causal_mask
30
  from flax.linen import partitioning as nn_partitioning
31
- from flax.linen.attention import dot_product_attention_weights
32
  from flax.serialization import from_bytes
33
  from flax.traverse_util import flatten_dict, unflatten_dict
34
  from jax import lax
@@ -175,6 +175,66 @@ def norm(type, *args, **kwargs):
175
  raise ValueError(f"Unknown norm type {type}")
176
 
177
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
178
  class FlaxBartAttention(FlaxBartAttention):
179
  """
180
  Edits:
@@ -225,7 +285,7 @@ class FlaxBartAttention(FlaxBartAttention):
225
  )
226
  self.dropout_layer = nn.Dropout(rate=self.dropout)
227
 
228
- if self.config.head_scale:
229
  self.head_scale = self.param(
230
  "head_scale", jax.nn.initializers.ones, (1, 1, self.num_heads, 1)
231
  )
@@ -342,13 +402,14 @@ class FlaxBartAttention(FlaxBartAttention):
342
  deterministic=deterministic,
343
  dtype=self.dtype,
344
  precision=None,
 
345
  )
346
  if self.config.use_cosine_attention:
347
  # divide by tau
348
  attn_weights = attn_weights / jnp.maximum(self.tau, 0.01)
349
 
350
  attn_output = jnp.einsum("...hqk,...khd->...qhd", attn_weights, value_states)
351
- if self.config.head_scale:
352
  # per Normformer
353
  attn_output = attn_output * self.head_scale
354
  attn_output = self._merge_heads(attn_output)
@@ -373,7 +434,7 @@ class GLU(nn.Module):
373
  self.config
374
  )
375
 
376
- if self.config.ln_positions in ["normformer", "cogview"]:
377
  x = norm(
378
  self.config.ln_type,
379
  dtype=self.dtype,
@@ -438,7 +499,7 @@ class FFN(nn.Module):
438
  gain = deepnet_gain["encoder" if self.is_encoder else "decoder"]["beta"](
439
  self.config
440
  )
441
- if self.config.ln_positions in ["normformer", "cogview"]:
442
  x = norm(
443
  self.config.ln_type,
444
  dtype=self.dtype,
@@ -507,7 +568,7 @@ class FlaxBartEncoderLayer(nn.Module):
507
 
508
  embed_dim = self.config.d_model
509
  residual = hidden_states
510
- if self.config.ln_positions in ["normformer", "cogview"]:
511
  hidden_states = norm(
512
  self.config.ln_type,
513
  dtype=self.dtype,
@@ -612,7 +673,7 @@ class FlaxBartDecoderLayer(nn.Module):
612
  residual = hidden_states
613
 
614
  # Self Attention
615
- if self.config.ln_positions in ["normformer", "cogview"]:
616
  hidden_states = norm(
617
  self.config.ln_type,
618
  dtype=self.dtype,
@@ -651,7 +712,7 @@ class FlaxBartDecoderLayer(nn.Module):
651
  cross_attn_weights = None
652
  if encoder_hidden_states is not None:
653
  residual = hidden_states
654
- if self.config.ln_positions in ["normformer", "cogview"]:
655
  hidden_states = norm(
656
  self.config.ln_type,
657
  dtype=self.dtype,
@@ -759,12 +820,9 @@ class FlaxBartEncoderLayerCollection(nn.Module):
759
  all_hidden_states += (hidden_states,)
760
  # final layernorm on the output of the last layer
761
  # or every 6 layers for Swin v2
762
- # not needed for other models which use layernorm before x-attention
763
- # ignored args for deepnet which always add a norm with scale
764
- add_norm = self.config.force_final_ln_encoder or (
765
- self.config.ln_positions == "swinv2"
766
- and ((i == n_layers - 1) or ((i + 1) % 6 == 0))
767
- )
768
  # we don't need to scale the norm for the last layer
769
  use_scale = i != n_layers - 1
770
  layer_outputs = layer(
@@ -839,9 +897,9 @@ class FlaxBartDecoderLayerCollection(nn.Module):
839
  all_hidden_states += (hidden_states,)
840
  # final layernorm on the output of the last layer
841
  # or every 6 layers for Swin v2
842
- add_norm = (i == n_layers - 1) or (
843
- (self.config.ln_positions == "swinv2") and ((i + 1) % 6 == 0)
844
- )
845
  # we don't need to scale the norm for the last layer
846
  use_scale = i != n_layers - 1
847
  layer_outputs = layer(
 
28
  from flax.core.frozen_dict import unfreeze
29
  from flax.linen import combine_masks, make_causal_mask
30
  from flax.linen import partitioning as nn_partitioning
31
+ from flax.linen.linear import PrecisionLike
32
  from flax.serialization import from_bytes
33
  from flax.traverse_util import flatten_dict, unflatten_dict
34
  from jax import lax
 
175
  raise ValueError(f"Unknown norm type {type}")
176
 
177
 
178
+ def dot_product_attention_weights(
179
+ query: Any,
180
+ key: Any,
181
+ bias: Optional[Any] = None,
182
+ mask: Optional[Any] = None,
183
+ broadcast_dropout: bool = True,
184
+ dropout_rng: Optional[PRNGKey] = None,
185
+ dropout_rate: float = 0.0,
186
+ deterministic: bool = False,
187
+ dtype: Any = jnp.float32,
188
+ precision: PrecisionLike = None,
189
+ sinkhorn_iters: int = 1,
190
+ ):
191
+ """
192
+ Computes dot-product attention weights given query and key.
193
+
194
+ Adapted from flax.linen.attention.dot_product_attention_weights"
195
+ """
196
+ assert query.ndim == key.ndim, "q, k must have same rank."
197
+ assert query.shape[:-3] == key.shape[:-3], "q, k batch dims must match."
198
+ assert query.shape[-2] == key.shape[-2], "q, k num_heads must match."
199
+ assert query.shape[-1] == key.shape[-1], "q, k depths must match."
200
+
201
+ # calculate attention matrix
202
+ depth = query.shape[-1]
203
+ query = query / jnp.sqrt(depth).astype(dtype)
204
+ # attn weight shape is (batch..., num_heads, q_length, kv_length)
205
+ attn_weights = jnp.einsum("...qhd,...khd->...hqk", query, key, precision=precision)
206
+
207
+ # apply attention bias: masking, dropout, proximity bias, etc.
208
+ if bias is not None:
209
+ attn_weights = attn_weights + bias
210
+ # apply attention mask
211
+ if mask is not None:
212
+ big_neg = jnp.finfo(dtype).min
213
+ attn_weights = jnp.where(mask, attn_weights, big_neg)
214
+
215
+ # normalize the attention weights
216
+ attn_weights = jax.nn.softmax(attn_weights).astype(dtype)
217
+ for i in range(sinkhorn_iters - 1):
218
+ axis = -2 if i % 2 == 0 else -1
219
+ attn_weights /= 1e-8 + jnp.sum(attn_weights, axis=axis, keepdims=True)
220
+
221
+ # apply attention dropout
222
+ if not deterministic and dropout_rate > 0.0:
223
+ keep_prob = 1.0 - dropout_rate
224
+ if broadcast_dropout:
225
+ # dropout is broadcast across the batch + head dimensions
226
+ dropout_shape = tuple([1] * (key.ndim - 2)) + attn_weights.shape[-2:]
227
+ keep = jax.random.bernoulli(dropout_rng, keep_prob, dropout_shape)
228
+ else:
229
+ keep = jax.random.bernoulli(dropout_rng, keep_prob, attn_weights.shape)
230
+ multiplier = keep.astype(attn_weights.dtype) / jnp.asarray(
231
+ keep_prob, dtype=dtype
232
+ )
233
+ attn_weights = attn_weights * multiplier
234
+
235
+ return attn_weights
236
+
237
+
238
  class FlaxBartAttention(FlaxBartAttention):
239
  """
240
  Edits:
 
285
  )
286
  self.dropout_layer = nn.Dropout(rate=self.dropout)
287
 
288
+ if self.config.use_head_scale:
289
  self.head_scale = self.param(
290
  "head_scale", jax.nn.initializers.ones, (1, 1, self.num_heads, 1)
291
  )
 
402
  deterministic=deterministic,
403
  dtype=self.dtype,
404
  precision=None,
405
+ sinkhorn_iters=self.config.sinkhorn_iters,
406
  )
407
  if self.config.use_cosine_attention:
408
  # divide by tau
409
  attn_weights = attn_weights / jnp.maximum(self.tau, 0.01)
410
 
411
  attn_output = jnp.einsum("...hqk,...khd->...qhd", attn_weights, value_states)
412
+ if self.config.use_head_scale:
413
  # per Normformer
414
  attn_output = attn_output * self.head_scale
415
  attn_output = self._merge_heads(attn_output)
 
434
  self.config
435
  )
436
 
437
+ if self.config.ln_positions in ["normformer", "cogview", "preln"]:
438
  x = norm(
439
  self.config.ln_type,
440
  dtype=self.dtype,
 
499
  gain = deepnet_gain["encoder" if self.is_encoder else "decoder"]["beta"](
500
  self.config
501
  )
502
+ if self.config.ln_positions in ["normformer", "cogview", "preln"]:
503
  x = norm(
504
  self.config.ln_type,
505
  dtype=self.dtype,
 
568
 
569
  embed_dim = self.config.d_model
570
  residual = hidden_states
571
+ if self.config.ln_positions in ["normformer", "cogview", "preln"]:
572
  hidden_states = norm(
573
  self.config.ln_type,
574
  dtype=self.dtype,
 
673
  residual = hidden_states
674
 
675
  # Self Attention
676
+ if self.config.ln_positions in ["normformer", "cogview", "preln"]:
677
  hidden_states = norm(
678
  self.config.ln_type,
679
  dtype=self.dtype,
 
712
  cross_attn_weights = None
713
  if encoder_hidden_states is not None:
714
  residual = hidden_states
715
+ if self.config.ln_positions in ["normformer", "cogview", "preln"]:
716
  hidden_states = norm(
717
  self.config.ln_type,
718
  dtype=self.dtype,
 
820
  all_hidden_states += (hidden_states,)
821
  # final layernorm on the output of the last layer
822
  # or every 6 layers for Swin v2
823
+ add_norm = (
824
+ self.config.ln_positions == "swinv2" and ((i + 1) % 6 == 0)
825
+ ) or (self.config.use_final_ln_encoder and (i == n_layers - 1))
 
 
 
826
  # we don't need to scale the norm for the last layer
827
  use_scale = i != n_layers - 1
828
  layer_outputs = layer(
 
897
  all_hidden_states += (hidden_states,)
898
  # final layernorm on the output of the last layer
899
  # or every 6 layers for Swin v2
900
+ add_norm = (
901
+ self.config.ln_positions == "swinv2" and ((i + 1) % 6 == 0)
902
+ ) or (self.config.use_final_ln_decoder and (i == n_layers - 1))
903
  # we don't need to scale the norm for the last layer
904
  use_scale = i != n_layers - 1
905
  layer_outputs = layer(