Spaces:
Runtime error
Runtime error
feat: add sinkformer + custom final ln + pre-ln (#151)
Browse files- README.md +11 -1
- src/dalle_mini/model/configuration.py +22 -12
- src/dalle_mini/model/modeling.py +75 -17
README.md
CHANGED
@@ -124,8 +124,9 @@ Sequence to sequence model based on "[BART: Denoising Sequence-to-Sequence Pre-t
|
|
124 |
- "[Deepnet: Scaling Transformers to 1,000 Layers](https://arxiv.org/abs/2203.00555)"
|
125 |
- "[NormFormer: Improved Transformer Pretraining with Extra Normalization](https://arxiv.org/abs/2110.09456)"
|
126 |
- "[Swin Transformer: Hierarchical Vision Transformer using Shifted Windows](https://arxiv.org/abs/2103.14030)"
|
127 |
-
- "[CogView: Mastering Text-to-Image Generation via Transformers](https://arxiv.org/abs/2105.13290v2)
|
128 |
- "[Root Mean Square Layer Normalization](https://arxiv.org/abs/1910.07467)"
|
|
|
129 |
|
130 |
Main optimizer (Distributed Shampoo) from "[Scalable Second Order Optimization for Deep Learning](https://arxiv.org/abs/2002.09018)".
|
131 |
|
@@ -247,3 +248,12 @@ Main optimizer (Distributed Shampoo) from "[Scalable Second Order Optimization f
|
|
247 |
primaryClass = {cs.LG}
|
248 |
}
|
249 |
```
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
124 |
- "[Deepnet: Scaling Transformers to 1,000 Layers](https://arxiv.org/abs/2203.00555)"
|
125 |
- "[NormFormer: Improved Transformer Pretraining with Extra Normalization](https://arxiv.org/abs/2110.09456)"
|
126 |
- "[Swin Transformer: Hierarchical Vision Transformer using Shifted Windows](https://arxiv.org/abs/2103.14030)"
|
127 |
+
- "[CogView: Mastering Text-to-Image Generation via Transformers](https://arxiv.org/abs/2105.13290v2)"
|
128 |
- "[Root Mean Square Layer Normalization](https://arxiv.org/abs/1910.07467)"
|
129 |
+
- "[Sinkformers: Transformers with Doubly Stochastic Attention](https://arxiv.org/abs/2110.11773)"
|
130 |
|
131 |
Main optimizer (Distributed Shampoo) from "[Scalable Second Order Optimization for Deep Learning](https://arxiv.org/abs/2002.09018)".
|
132 |
|
|
|
248 |
primaryClass = {cs.LG}
|
249 |
}
|
250 |
```
|
251 |
+
|
252 |
+
```text
|
253 |
+
@misc{title = {Sinkformers: Transformers with Doubly Stochastic Attention},
|
254 |
+
url = {https://arxiv.org/abs/2110.11773},
|
255 |
+
author = {Sander, Michael E. and Ablin, Pierre and Blondel, Mathieu and Peyré, Gabriel},
|
256 |
+
publisher = {arXiv},
|
257 |
+
year = {2021},
|
258 |
+
}
|
259 |
+
```
|
src/dalle_mini/model/configuration.py
CHANGED
@@ -59,37 +59,39 @@ class DalleBartConfig(PretrainedFromWandbMixin, PretrainedConfig):
|
|
59 |
do_sample=True,
|
60 |
# transformer variants
|
61 |
ln_type="layernorm", # layer normalization type, "rmsnorm", "layernorm"
|
62 |
-
ln_positions="normformer", # layer normalization positions, "normformer", "swinv2", "cogview", "postln", "deepnet" (same as postln)
|
63 |
-
|
64 |
use_cosine_attention=False, # used in Swin v2
|
65 |
tau_init=0.05, # used only in cosine attention (Swin v2)
|
66 |
use_deepnet_scaling=False, # used in Deepnet
|
67 |
use_glu=False, # "GLU Variants Improve Transformer"
|
68 |
use_alibi=False, # from "Train Short, Test Long: Attention with Linear Biases Enables Input Length Extrapolation"
|
69 |
-
|
|
|
|
|
70 |
# parameters that should not be necessary but could affect results
|
71 |
-
force_ln_scale=
|
72 |
-
force_final_ln_encoder=False, # force layer normalization in encoder final layer even when followed by dense layers
|
73 |
**kwargs,
|
74 |
):
|
75 |
# text normalizer
|
76 |
self.normalize_text = normalize_text
|
77 |
|
78 |
# transformer variants
|
79 |
-
self.
|
80 |
assert ln_type in [
|
81 |
"rmsnorm",
|
82 |
"layernorm",
|
83 |
], "ln_type must be 'rmsnorm' or 'layernorm'"
|
84 |
self.ln_type = ln_type
|
|
|
|
|
85 |
assert ln_positions in [
|
86 |
"normformer",
|
87 |
"swinv2",
|
88 |
"cogview",
|
89 |
-
"
|
90 |
-
|
91 |
-
|
92 |
-
ln_positions = "postln"
|
93 |
assert use_alibi is False, "use_alibi is not supported yet"
|
94 |
self.ln_positions = ln_positions
|
95 |
self.use_cosine_attention = use_cosine_attention
|
@@ -97,9 +99,17 @@ class DalleBartConfig(PretrainedFromWandbMixin, PretrainedConfig):
|
|
97 |
self.use_deepnet_scaling = use_deepnet_scaling
|
98 |
self.use_glu = use_glu
|
99 |
self.use_alibi = use_alibi
|
100 |
-
self.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
101 |
self.force_ln_scale = force_ln_scale
|
102 |
-
self.force_final_ln_encoder = force_final_ln_encoder
|
103 |
|
104 |
# common parameters
|
105 |
self.encoder_vocab_size = encoder_vocab_size
|
|
|
59 |
do_sample=True,
|
60 |
# transformer variants
|
61 |
ln_type="layernorm", # layer normalization type, "rmsnorm", "layernorm"
|
62 |
+
ln_positions="normformer", # layer normalization positions, "normformer", "swinv2", "cogview", "postln", "preln", "deepnet" (same as postln)
|
63 |
+
use_head_scale=False, # used in NormFormer
|
64 |
use_cosine_attention=False, # used in Swin v2
|
65 |
tau_init=0.05, # used only in cosine attention (Swin v2)
|
66 |
use_deepnet_scaling=False, # used in Deepnet
|
67 |
use_glu=False, # "GLU Variants Improve Transformer"
|
68 |
use_alibi=False, # from "Train Short, Test Long: Attention with Linear Biases Enables Input Length Extrapolation"
|
69 |
+
sinkhorn_iters=1, # used in SinkFormers
|
70 |
+
use_final_ln_encoder=False, # final layer normalization in encoder
|
71 |
+
use_final_ln_decoder=False, # final layer normalization in decoder
|
72 |
# parameters that should not be necessary but could affect results
|
73 |
+
force_ln_scale=False, # force scale in layernorm even when followed by dense layers
|
|
|
74 |
**kwargs,
|
75 |
):
|
76 |
# text normalizer
|
77 |
self.normalize_text = normalize_text
|
78 |
|
79 |
# transformer variants
|
80 |
+
self.use_head_scale = use_head_scale # per Normformer
|
81 |
assert ln_type in [
|
82 |
"rmsnorm",
|
83 |
"layernorm",
|
84 |
], "ln_type must be 'rmsnorm' or 'layernorm'"
|
85 |
self.ln_type = ln_type
|
86 |
+
if ln_positions == "deepnet":
|
87 |
+
ln_positions = "postln"
|
88 |
assert ln_positions in [
|
89 |
"normformer",
|
90 |
"swinv2",
|
91 |
"cogview",
|
92 |
+
"postln",
|
93 |
+
"preln",
|
94 |
+
], "ln_positions must be 'normformer', 'swinv2', 'cogview', 'postln', 'preln'"
|
|
|
95 |
assert use_alibi is False, "use_alibi is not supported yet"
|
96 |
self.ln_positions = ln_positions
|
97 |
self.use_cosine_attention = use_cosine_attention
|
|
|
99 |
self.use_deepnet_scaling = use_deepnet_scaling
|
100 |
self.use_glu = use_glu
|
101 |
self.use_alibi = use_alibi
|
102 |
+
self.sinkhorn_iters = sinkhorn_iters
|
103 |
+
if ln_positions == "postln":
|
104 |
+
assert (
|
105 |
+
use_final_ln_encoder
|
106 |
+
), "use_final_ln_encoder must be True when ln_positions is 'postln'"
|
107 |
+
assert (
|
108 |
+
use_final_ln_decoder
|
109 |
+
), "use_final_ln_decoder must be True when ln_positions is 'postln'"
|
110 |
+
self.use_final_ln_encoder = use_final_ln_encoder
|
111 |
+
self.use_final_ln_decoder = use_final_ln_decoder
|
112 |
self.force_ln_scale = force_ln_scale
|
|
|
113 |
|
114 |
# common parameters
|
115 |
self.encoder_vocab_size = encoder_vocab_size
|
src/dalle_mini/model/modeling.py
CHANGED
@@ -28,7 +28,7 @@ import msgpack.exceptions
|
|
28 |
from flax.core.frozen_dict import unfreeze
|
29 |
from flax.linen import combine_masks, make_causal_mask
|
30 |
from flax.linen import partitioning as nn_partitioning
|
31 |
-
from flax.linen.
|
32 |
from flax.serialization import from_bytes
|
33 |
from flax.traverse_util import flatten_dict, unflatten_dict
|
34 |
from jax import lax
|
@@ -175,6 +175,66 @@ def norm(type, *args, **kwargs):
|
|
175 |
raise ValueError(f"Unknown norm type {type}")
|
176 |
|
177 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
178 |
class FlaxBartAttention(FlaxBartAttention):
|
179 |
"""
|
180 |
Edits:
|
@@ -225,7 +285,7 @@ class FlaxBartAttention(FlaxBartAttention):
|
|
225 |
)
|
226 |
self.dropout_layer = nn.Dropout(rate=self.dropout)
|
227 |
|
228 |
-
if self.config.
|
229 |
self.head_scale = self.param(
|
230 |
"head_scale", jax.nn.initializers.ones, (1, 1, self.num_heads, 1)
|
231 |
)
|
@@ -342,13 +402,14 @@ class FlaxBartAttention(FlaxBartAttention):
|
|
342 |
deterministic=deterministic,
|
343 |
dtype=self.dtype,
|
344 |
precision=None,
|
|
|
345 |
)
|
346 |
if self.config.use_cosine_attention:
|
347 |
# divide by tau
|
348 |
attn_weights = attn_weights / jnp.maximum(self.tau, 0.01)
|
349 |
|
350 |
attn_output = jnp.einsum("...hqk,...khd->...qhd", attn_weights, value_states)
|
351 |
-
if self.config.
|
352 |
# per Normformer
|
353 |
attn_output = attn_output * self.head_scale
|
354 |
attn_output = self._merge_heads(attn_output)
|
@@ -373,7 +434,7 @@ class GLU(nn.Module):
|
|
373 |
self.config
|
374 |
)
|
375 |
|
376 |
-
if self.config.ln_positions in ["normformer", "cogview"]:
|
377 |
x = norm(
|
378 |
self.config.ln_type,
|
379 |
dtype=self.dtype,
|
@@ -438,7 +499,7 @@ class FFN(nn.Module):
|
|
438 |
gain = deepnet_gain["encoder" if self.is_encoder else "decoder"]["beta"](
|
439 |
self.config
|
440 |
)
|
441 |
-
if self.config.ln_positions in ["normformer", "cogview"]:
|
442 |
x = norm(
|
443 |
self.config.ln_type,
|
444 |
dtype=self.dtype,
|
@@ -507,7 +568,7 @@ class FlaxBartEncoderLayer(nn.Module):
|
|
507 |
|
508 |
embed_dim = self.config.d_model
|
509 |
residual = hidden_states
|
510 |
-
if self.config.ln_positions in ["normformer", "cogview"]:
|
511 |
hidden_states = norm(
|
512 |
self.config.ln_type,
|
513 |
dtype=self.dtype,
|
@@ -612,7 +673,7 @@ class FlaxBartDecoderLayer(nn.Module):
|
|
612 |
residual = hidden_states
|
613 |
|
614 |
# Self Attention
|
615 |
-
if self.config.ln_positions in ["normformer", "cogview"]:
|
616 |
hidden_states = norm(
|
617 |
self.config.ln_type,
|
618 |
dtype=self.dtype,
|
@@ -651,7 +712,7 @@ class FlaxBartDecoderLayer(nn.Module):
|
|
651 |
cross_attn_weights = None
|
652 |
if encoder_hidden_states is not None:
|
653 |
residual = hidden_states
|
654 |
-
if self.config.ln_positions in ["normformer", "cogview"]:
|
655 |
hidden_states = norm(
|
656 |
self.config.ln_type,
|
657 |
dtype=self.dtype,
|
@@ -759,12 +820,9 @@ class FlaxBartEncoderLayerCollection(nn.Module):
|
|
759 |
all_hidden_states += (hidden_states,)
|
760 |
# final layernorm on the output of the last layer
|
761 |
# or every 6 layers for Swin v2
|
762 |
-
|
763 |
-
|
764 |
-
|
765 |
-
self.config.ln_positions == "swinv2"
|
766 |
-
and ((i == n_layers - 1) or ((i + 1) % 6 == 0))
|
767 |
-
)
|
768 |
# we don't need to scale the norm for the last layer
|
769 |
use_scale = i != n_layers - 1
|
770 |
layer_outputs = layer(
|
@@ -839,9 +897,9 @@ class FlaxBartDecoderLayerCollection(nn.Module):
|
|
839 |
all_hidden_states += (hidden_states,)
|
840 |
# final layernorm on the output of the last layer
|
841 |
# or every 6 layers for Swin v2
|
842 |
-
add_norm = (
|
843 |
-
|
844 |
-
)
|
845 |
# we don't need to scale the norm for the last layer
|
846 |
use_scale = i != n_layers - 1
|
847 |
layer_outputs = layer(
|
|
|
28 |
from flax.core.frozen_dict import unfreeze
|
29 |
from flax.linen import combine_masks, make_causal_mask
|
30 |
from flax.linen import partitioning as nn_partitioning
|
31 |
+
from flax.linen.linear import PrecisionLike
|
32 |
from flax.serialization import from_bytes
|
33 |
from flax.traverse_util import flatten_dict, unflatten_dict
|
34 |
from jax import lax
|
|
|
175 |
raise ValueError(f"Unknown norm type {type}")
|
176 |
|
177 |
|
178 |
+
def dot_product_attention_weights(
|
179 |
+
query: Any,
|
180 |
+
key: Any,
|
181 |
+
bias: Optional[Any] = None,
|
182 |
+
mask: Optional[Any] = None,
|
183 |
+
broadcast_dropout: bool = True,
|
184 |
+
dropout_rng: Optional[PRNGKey] = None,
|
185 |
+
dropout_rate: float = 0.0,
|
186 |
+
deterministic: bool = False,
|
187 |
+
dtype: Any = jnp.float32,
|
188 |
+
precision: PrecisionLike = None,
|
189 |
+
sinkhorn_iters: int = 1,
|
190 |
+
):
|
191 |
+
"""
|
192 |
+
Computes dot-product attention weights given query and key.
|
193 |
+
|
194 |
+
Adapted from flax.linen.attention.dot_product_attention_weights"
|
195 |
+
"""
|
196 |
+
assert query.ndim == key.ndim, "q, k must have same rank."
|
197 |
+
assert query.shape[:-3] == key.shape[:-3], "q, k batch dims must match."
|
198 |
+
assert query.shape[-2] == key.shape[-2], "q, k num_heads must match."
|
199 |
+
assert query.shape[-1] == key.shape[-1], "q, k depths must match."
|
200 |
+
|
201 |
+
# calculate attention matrix
|
202 |
+
depth = query.shape[-1]
|
203 |
+
query = query / jnp.sqrt(depth).astype(dtype)
|
204 |
+
# attn weight shape is (batch..., num_heads, q_length, kv_length)
|
205 |
+
attn_weights = jnp.einsum("...qhd,...khd->...hqk", query, key, precision=precision)
|
206 |
+
|
207 |
+
# apply attention bias: masking, dropout, proximity bias, etc.
|
208 |
+
if bias is not None:
|
209 |
+
attn_weights = attn_weights + bias
|
210 |
+
# apply attention mask
|
211 |
+
if mask is not None:
|
212 |
+
big_neg = jnp.finfo(dtype).min
|
213 |
+
attn_weights = jnp.where(mask, attn_weights, big_neg)
|
214 |
+
|
215 |
+
# normalize the attention weights
|
216 |
+
attn_weights = jax.nn.softmax(attn_weights).astype(dtype)
|
217 |
+
for i in range(sinkhorn_iters - 1):
|
218 |
+
axis = -2 if i % 2 == 0 else -1
|
219 |
+
attn_weights /= 1e-8 + jnp.sum(attn_weights, axis=axis, keepdims=True)
|
220 |
+
|
221 |
+
# apply attention dropout
|
222 |
+
if not deterministic and dropout_rate > 0.0:
|
223 |
+
keep_prob = 1.0 - dropout_rate
|
224 |
+
if broadcast_dropout:
|
225 |
+
# dropout is broadcast across the batch + head dimensions
|
226 |
+
dropout_shape = tuple([1] * (key.ndim - 2)) + attn_weights.shape[-2:]
|
227 |
+
keep = jax.random.bernoulli(dropout_rng, keep_prob, dropout_shape)
|
228 |
+
else:
|
229 |
+
keep = jax.random.bernoulli(dropout_rng, keep_prob, attn_weights.shape)
|
230 |
+
multiplier = keep.astype(attn_weights.dtype) / jnp.asarray(
|
231 |
+
keep_prob, dtype=dtype
|
232 |
+
)
|
233 |
+
attn_weights = attn_weights * multiplier
|
234 |
+
|
235 |
+
return attn_weights
|
236 |
+
|
237 |
+
|
238 |
class FlaxBartAttention(FlaxBartAttention):
|
239 |
"""
|
240 |
Edits:
|
|
|
285 |
)
|
286 |
self.dropout_layer = nn.Dropout(rate=self.dropout)
|
287 |
|
288 |
+
if self.config.use_head_scale:
|
289 |
self.head_scale = self.param(
|
290 |
"head_scale", jax.nn.initializers.ones, (1, 1, self.num_heads, 1)
|
291 |
)
|
|
|
402 |
deterministic=deterministic,
|
403 |
dtype=self.dtype,
|
404 |
precision=None,
|
405 |
+
sinkhorn_iters=self.config.sinkhorn_iters,
|
406 |
)
|
407 |
if self.config.use_cosine_attention:
|
408 |
# divide by tau
|
409 |
attn_weights = attn_weights / jnp.maximum(self.tau, 0.01)
|
410 |
|
411 |
attn_output = jnp.einsum("...hqk,...khd->...qhd", attn_weights, value_states)
|
412 |
+
if self.config.use_head_scale:
|
413 |
# per Normformer
|
414 |
attn_output = attn_output * self.head_scale
|
415 |
attn_output = self._merge_heads(attn_output)
|
|
|
434 |
self.config
|
435 |
)
|
436 |
|
437 |
+
if self.config.ln_positions in ["normformer", "cogview", "preln"]:
|
438 |
x = norm(
|
439 |
self.config.ln_type,
|
440 |
dtype=self.dtype,
|
|
|
499 |
gain = deepnet_gain["encoder" if self.is_encoder else "decoder"]["beta"](
|
500 |
self.config
|
501 |
)
|
502 |
+
if self.config.ln_positions in ["normformer", "cogview", "preln"]:
|
503 |
x = norm(
|
504 |
self.config.ln_type,
|
505 |
dtype=self.dtype,
|
|
|
568 |
|
569 |
embed_dim = self.config.d_model
|
570 |
residual = hidden_states
|
571 |
+
if self.config.ln_positions in ["normformer", "cogview", "preln"]:
|
572 |
hidden_states = norm(
|
573 |
self.config.ln_type,
|
574 |
dtype=self.dtype,
|
|
|
673 |
residual = hidden_states
|
674 |
|
675 |
# Self Attention
|
676 |
+
if self.config.ln_positions in ["normformer", "cogview", "preln"]:
|
677 |
hidden_states = norm(
|
678 |
self.config.ln_type,
|
679 |
dtype=self.dtype,
|
|
|
712 |
cross_attn_weights = None
|
713 |
if encoder_hidden_states is not None:
|
714 |
residual = hidden_states
|
715 |
+
if self.config.ln_positions in ["normformer", "cogview", "preln"]:
|
716 |
hidden_states = norm(
|
717 |
self.config.ln_type,
|
718 |
dtype=self.dtype,
|
|
|
820 |
all_hidden_states += (hidden_states,)
|
821 |
# final layernorm on the output of the last layer
|
822 |
# or every 6 layers for Swin v2
|
823 |
+
add_norm = (
|
824 |
+
self.config.ln_positions == "swinv2" and ((i + 1) % 6 == 0)
|
825 |
+
) or (self.config.use_final_ln_encoder and (i == n_layers - 1))
|
|
|
|
|
|
|
826 |
# we don't need to scale the norm for the last layer
|
827 |
use_scale = i != n_layers - 1
|
828 |
layer_outputs = layer(
|
|
|
897 |
all_hidden_states += (hidden_states,)
|
898 |
# final layernorm on the output of the last layer
|
899 |
# or every 6 layers for Swin v2
|
900 |
+
add_norm = (
|
901 |
+
self.config.ln_positions == "swinv2" and ((i + 1) % 6 == 0)
|
902 |
+
) or (self.config.use_final_ln_decoder and (i == n_layers - 1))
|
903 |
# we don't need to scale the norm for the last layer
|
904 |
use_scale = i != n_layers - 1
|
905 |
layer_outputs = layer(
|