Spaces:
Sleeping
Sleeping
# Copyright 2022 DeepMind Technologies Limited. All Rights Reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
# ============================================================================== | |
"""Tests for transformer.model.""" | |
from absl.testing import absltest | |
from absl.testing import parameterized | |
import haiku as hk | |
import jax | |
import jax.numpy as jnp | |
import numpy as np | |
from tracr.transformer import compressed_model | |
from tracr.transformer import model | |
class CompressedTransformerTest(parameterized.TestCase): | |
def _check_layer_naming(self, params): | |
# Modules should be named for example | |
# For MLPs: "compressed_transformer/layer_{i}/mlp/linear_1" | |
# For Attention: "compressed_transformer/layer_{i}/attn/key" | |
# For Layer Norm: "compressed_transformer/layer_{i}/layer_norm" | |
for key in params.keys(): | |
levels = key.split("/") | |
self.assertEqual(levels[0], "compressed_transformer") | |
if len(levels) == 1: | |
self.assertEqual(list(params[key].keys()), ["w_emb"]) | |
continue | |
if levels[1].startswith("layer_norm"): | |
continue # output layer norm | |
self.assertStartsWith(levels[1], "layer") | |
if levels[2] == "mlp": | |
self.assertIn(levels[3], {"linear_1", "linear_2"}) | |
elif levels[2] == "attn": | |
self.assertIn(levels[3], {"key", "query", "value", "linear"}) | |
else: | |
self.assertStartsWith(levels[2], "layer_norm") | |
def _zero_mlps(self, params): | |
for module in params: | |
if "mlp" in module: | |
for param in params[module]: | |
params[module][param] = jnp.zeros_like(params[module][param]) | |
return params | |
def test_layer_norm(self, layer_norm): | |
# input = [1, 1, 1, 1] | |
# If layer norm is used, this should give all-0 output for a freshly | |
# initialized model because LN will subtract the mean after each layer. | |
# Else we expect non-zero outputs. | |
def forward(emb, mask): | |
transformer = compressed_model.CompressedTransformer( | |
model.TransformerConfig( | |
num_heads=2, | |
num_layers=2, | |
key_size=5, | |
mlp_hidden_size=64, | |
dropout_rate=0., | |
layer_norm=layer_norm)) | |
return transformer(emb, mask).output | |
seq_len = 4 | |
emb = jnp.ones((1, seq_len, 1)) | |
mask = jnp.ones((1, seq_len)) | |
rng = hk.PRNGSequence(1) | |
params = forward.init(next(rng), emb, mask) | |
out = forward.apply(params, next(rng), emb, mask) | |
self._check_layer_naming(params) | |
if layer_norm: | |
np.testing.assert_allclose(out, 0) | |
else: | |
self.assertFalse(np.allclose(out, 0)) | |
def test_causal_attention(self, causal): | |
# input = [0, random, random, random] | |
# mask = [1, 0, 1, 1] | |
# For causal attention the second token can only attend to the first one, so | |
# it should be the same. For non-causal attention all tokens should change. | |
def forward(emb, mask): | |
transformer = compressed_model.CompressedTransformer( | |
model.TransformerConfig( | |
num_heads=2, | |
num_layers=2, | |
key_size=5, | |
mlp_hidden_size=64, | |
dropout_rate=0., | |
layer_norm=False, | |
causal=causal)) | |
return transformer(emb, mask).output | |
seq_len = 4 | |
emb = np.random.random((1, seq_len, 1)) | |
emb[:, 0, :] = 0 | |
mask = np.array([[1, 0, 1, 1]]) | |
emb, mask = jnp.array(emb), jnp.array(mask) | |
rng = hk.PRNGSequence(1) | |
params = forward.init(next(rng), emb, mask) | |
params = self._zero_mlps(params) | |
out = forward.apply(params, next(rng), emb, mask) | |
self._check_layer_naming(params) | |
if causal: | |
self.assertEqual(0, out[0, 0, 0]) | |
self.assertEqual(emb[0, 1, 0], out[0, 1, 0]) | |
else: | |
self.assertNotEqual(0, out[0, 0, 0]) | |
self.assertNotEqual(emb[0, 1, 0], out[0, 1, 0]) | |
self.assertNotEqual(emb[0, 2, 0], out[0, 2, 0]) | |
self.assertNotEqual(emb[0, 3, 0], out[0, 3, 0]) | |
def test_setting_activation_function_to_zero(self): | |
# An activation function that always returns zeros should result in the | |
# same model output as setting all MLP weights to zero. | |
def forward_zero(emb, mask): | |
transformer = compressed_model.CompressedTransformer( | |
model.TransformerConfig( | |
num_heads=2, | |
num_layers=2, | |
key_size=5, | |
mlp_hidden_size=64, | |
dropout_rate=0., | |
causal=False, | |
layer_norm=False, | |
activation_function=jnp.zeros_like)) | |
return transformer(emb, mask).output | |
def forward(emb, mask): | |
transformer = compressed_model.CompressedTransformer( | |
model.TransformerConfig( | |
num_heads=2, | |
num_layers=2, | |
key_size=5, | |
mlp_hidden_size=64, | |
dropout_rate=0., | |
causal=False, | |
layer_norm=False, | |
activation_function=jax.nn.gelu)) | |
return transformer(emb, mask).output | |
seq_len = 4 | |
emb = np.random.random((1, seq_len, 1)) | |
mask = np.ones((1, seq_len)) | |
emb, mask = jnp.array(emb), jnp.array(mask) | |
rng = hk.PRNGSequence(1) | |
params = forward.init(next(rng), emb, mask) | |
params_no_mlps = self._zero_mlps(params) | |
out_zero_activation = forward_zero.apply(params, next(rng), emb, mask) | |
out_no_mlps = forward.apply(params_no_mlps, next(rng), emb, mask) | |
self._check_layer_naming(params) | |
np.testing.assert_allclose(out_zero_activation, out_no_mlps) | |
self.assertFalse(np.allclose(out_zero_activation, 0)) | |
def test_not_setting_embedding_size_produces_same_output_as_default_model( | |
self): | |
config = model.TransformerConfig( | |
num_heads=2, | |
num_layers=2, | |
key_size=5, | |
mlp_hidden_size=64, | |
dropout_rate=0., | |
causal=False, | |
layer_norm=False) | |
def forward_model(emb, mask): | |
return model.Transformer(config)(emb, mask).output | |
def forward_superposition(emb, mask): | |
return compressed_model.CompressedTransformer(config)(emb, mask).output | |
seq_len = 4 | |
emb = np.random.random((1, seq_len, 1)) | |
mask = np.ones((1, seq_len)) | |
emb, mask = jnp.array(emb), jnp.array(mask) | |
rng = hk.PRNGSequence(1) | |
params = forward_model.init(next(rng), emb, mask) | |
params_superposition = { | |
k.replace("transformer", "compressed_transformer"): v | |
for k, v in params.items() | |
} | |
out_model = forward_model.apply(params, emb, mask) | |
out_superposition = forward_superposition.apply(params_superposition, emb, | |
mask) | |
self._check_layer_naming(params_superposition) | |
np.testing.assert_allclose(out_model, out_superposition) | |
def test_embbeding_size_produces_correct_shape_of_residuals_and_layer_outputs( | |
self, embedding_size, unembed_at_every_layer): | |
def forward(emb, mask): | |
transformer = compressed_model.CompressedTransformer( | |
model.TransformerConfig( | |
num_heads=2, | |
num_layers=2, | |
key_size=5, | |
mlp_hidden_size=64, | |
dropout_rate=0., | |
causal=False, | |
layer_norm=False)) | |
return transformer( | |
emb, | |
mask, | |
embedding_size=embedding_size, | |
unembed_at_every_layer=unembed_at_every_layer, | |
) | |
seq_len = 4 | |
model_size = 16 | |
emb = np.random.random((1, seq_len, model_size)) | |
mask = np.ones((1, seq_len)) | |
emb, mask = jnp.array(emb), jnp.array(mask) | |
rng = hk.PRNGSequence(1) | |
params = forward.init(next(rng), emb, mask) | |
activations = forward.apply(params, next(rng), emb, mask) | |
self._check_layer_naming(params) | |
for residual in activations.residuals: | |
self.assertEqual(residual.shape, (1, seq_len, embedding_size)) | |
for layer_output in activations.layer_outputs: | |
self.assertEqual(layer_output.shape, (1, seq_len, model_size)) | |
def test_identity_embedding_produces_same_output_as_standard_model( | |
self, model_size, unembed_at_every_layer): | |
config = model.TransformerConfig( | |
num_heads=2, | |
num_layers=2, | |
key_size=5, | |
mlp_hidden_size=64, | |
dropout_rate=0., | |
causal=False, | |
layer_norm=False) | |
def forward_model(emb, mask): | |
return model.Transformer(config)(emb, mask).output | |
def forward_superposition(emb, mask): | |
return compressed_model.CompressedTransformer(config)( | |
emb, | |
mask, | |
embedding_size=model_size, | |
unembed_at_every_layer=unembed_at_every_layer).output | |
seq_len = 4 | |
emb = np.random.random((1, seq_len, model_size)) | |
mask = np.ones((1, seq_len)) | |
emb, mask = jnp.array(emb), jnp.array(mask) | |
rng = hk.PRNGSequence(1) | |
params = forward_model.init(next(rng), emb, mask) | |
params_superposition = { | |
k.replace("transformer", "compressed_transformer"): v | |
for k, v in params.items() | |
} | |
params_superposition["compressed_transformer"] = { | |
"w_emb": jnp.identity(model_size) | |
} | |
out_model = forward_model.apply(params, emb, mask) | |
out_superposition = forward_superposition.apply(params_superposition, emb, | |
mask) | |
self._check_layer_naming(params_superposition) | |
np.testing.assert_allclose(out_model, out_superposition) | |
if __name__ == "__main__": | |
absltest.main() | |