Spaces:
Sleeping
Sleeping
File size: 6,548 Bytes
9bdaa77 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 |
# Copyright 2022 DeepMind Technologies Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Modified transformer to learn a linear compression of the residual stream.
CompressedTransformer adds three arguments compared to Transformer:
- embedding_size: the size of the compressed residual stream.
- unembed_at_every_layer: whether to apply the unembedding before applying
attention and MLP layers
- return_activations: whether to return all model activations rather than just
the outputs
"""
import collections
import dataclasses
from typing import Optional
import haiku as hk
import jax
import numpy as np
from tracr.transformer import attention
from tracr.transformer import model
@dataclasses.dataclass
class CompressedTransformer(hk.Module):
"""A transformer stack with linearly compressed residual stream."""
config: model.TransformerConfig
name: Optional[str] = None
def __call__(
self,
embeddings: jax.Array, # [B, T, D]
mask: jax.Array, # [B, T]
*,
use_dropout: bool = True,
embedding_size: Optional[int] = None,
unembed_at_every_layer: bool = False,
) -> model.TransformerOutput: # [B, T, D]
"""Transforms input embedding sequences to output embedding sequences.
Args:
embeddings: Input embeddings to pass through the model.
mask: Boolean mask to restrict the inputs the model uses.
use_dropout: Turns dropout on/off.
embedding_size: Dimension to compress the residual stream to.
unembed_at_every_layer: Whether to unembed the residual stream when
reading the input for every layer (keeping the layer input sizes) or to
only unembed before the model output (compressing the layer inputs).
Returns:
The outputs of the forward pass through the transformer.
"""
def layer_norm(x: jax.Array) -> jax.Array:
"""Applies a unique LayerNorm to x with default settings."""
if self.config.layer_norm:
return hk.LayerNorm(axis=-1, create_scale=True, create_offset=True)(x)
return x
initializer = hk.initializers.VarianceScaling(2 / self.config.num_layers)
dropout_rate = self.config.dropout_rate if use_dropout else 0.
_, seq_len, model_size = embeddings.shape
# To compress the model, we multiply with a matrix W when reading from
# the residual stream, and with W^T when writing to the residual stream.
if embedding_size is not None:
# [to_size, from_size]
w_emb = hk.get_parameter(
"w_emb", (embedding_size, model_size),
init=hk.initializers.RandomNormal())
write_to_residual = lambda x: x @ w_emb.T
read_from_residual = lambda x: x @ w_emb
if not unembed_at_every_layer:
model_size = embedding_size
else:
write_to_residual = lambda x: x
read_from_residual = lambda x: x
# Compute causal mask for autoregressive sequence modelling.
mask = mask[:, None, None, :] # [B, H=1, T'=1, T]
mask = mask.repeat(seq_len, axis=2) # [B, H=1, T, T]
if self.config.causal:
causal_mask = np.ones((1, 1, seq_len, seq_len)) # [B=1, H=1, T, T]
causal_mask = np.tril(causal_mask)
mask = mask * causal_mask # [B, H=1, T, T]
# Set up activation collection.
collected = collections.defaultdict(list)
def collect(**kwargs):
for k, v in kwargs.items():
collected[k].append(v)
residual = write_to_residual(embeddings)
for layer in range(self.config.num_layers):
with hk.experimental.name_scope(f"layer_{layer}"):
# First the attention block.
attn_block = attention.MultiHeadAttention(
num_heads=self.config.num_heads,
key_size=self.config.key_size,
model_size=model_size,
w_init=initializer,
name="attn")
attn_in = residual
if unembed_at_every_layer:
attn_in = read_from_residual(attn_in)
attn_in = layer_norm(attn_in)
attn_out = attn_block(attn_in, attn_in, attn_in, mask=mask)
attn_out, attn_logits = attn_out.out, attn_out.logits
if dropout_rate > 0:
attn_out = hk.dropout(hk.next_rng_key(), dropout_rate, attn_out)
if unembed_at_every_layer:
collect(layer_outputs=attn_out, attn_logits=attn_logits)
else:
collect(
layer_outputs=read_from_residual(attn_out),
attn_logits=attn_logits,
)
if unembed_at_every_layer:
attn_out = write_to_residual(attn_out)
residual = residual + attn_out
collect(residuals=residual)
# Then the dense block.
with hk.experimental.name_scope("mlp"):
dense_block = hk.Sequential([
hk.Linear(
self.config.mlp_hidden_size,
w_init=initializer,
name="linear_1"),
self.config.activation_function,
hk.Linear(model_size, w_init=initializer, name="linear_2"),
])
dense_in = residual
if unembed_at_every_layer:
dense_in = read_from_residual(dense_in)
dense_in = layer_norm(dense_in)
dense_out = dense_block(dense_in)
if dropout_rate > 0:
dense_out = hk.dropout(hk.next_rng_key(), dropout_rate, dense_out)
if unembed_at_every_layer:
collect(layer_outputs=dense_out)
else:
collect(layer_outputs=read_from_residual(dense_out))
if unembed_at_every_layer:
dense_out = write_to_residual(dense_out)
residual = residual + dense_out
collect(residuals=residual)
output = read_from_residual(residual)
output = layer_norm(output)
return model.TransformerOutput(
layer_outputs=collected["layer_outputs"],
residuals=collected["residuals"],
attn_logits=collected["attn_logits"],
output=output,
input_embeddings=embeddings,
)
|