RASP-Synthesis / tracr /transformer /compressed_model_test.py
Vladimir Mikulik
add typing_extensions to list of deps.
d4d39d0
# Copyright 2022 DeepMind Technologies Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for transformer.model."""
from absl.testing import absltest
from absl.testing import parameterized
import haiku as hk
import jax
import jax.numpy as jnp
import numpy as np
from tracr.transformer import compressed_model
from tracr.transformer import model
class CompressedTransformerTest(parameterized.TestCase):
def _check_layer_naming(self, params):
# Modules should be named for example
# For MLPs: "compressed_transformer/layer_{i}/mlp/linear_1"
# For Attention: "compressed_transformer/layer_{i}/attn/key"
# For Layer Norm: "compressed_transformer/layer_{i}/layer_norm"
for key in params.keys():
levels = key.split("/")
self.assertEqual(levels[0], "compressed_transformer")
if len(levels) == 1:
self.assertEqual(list(params[key].keys()), ["w_emb"])
continue
if levels[1].startswith("layer_norm"):
continue # output layer norm
self.assertStartsWith(levels[1], "layer")
if levels[2] == "mlp":
self.assertIn(levels[3], {"linear_1", "linear_2"})
elif levels[2] == "attn":
self.assertIn(levels[3], {"key", "query", "value", "linear"})
else:
self.assertStartsWith(levels[2], "layer_norm")
def _zero_mlps(self, params):
for module in params:
if "mlp" in module:
for param in params[module]:
params[module][param] = jnp.zeros_like(params[module][param])
return params
@parameterized.parameters(dict(layer_norm=True), dict(layer_norm=False))
def test_layer_norm(self, layer_norm):
# input = [1, 1, 1, 1]
# If layer norm is used, this should give all-0 output for a freshly
# initialized model because LN will subtract the mean after each layer.
# Else we expect non-zero outputs.
@hk.transform
def forward(emb, mask):
transformer = compressed_model.CompressedTransformer(
model.TransformerConfig(
num_heads=2,
num_layers=2,
key_size=5,
mlp_hidden_size=64,
dropout_rate=0.,
layer_norm=layer_norm))
return transformer(emb, mask).output
seq_len = 4
emb = jnp.ones((1, seq_len, 1))
mask = jnp.ones((1, seq_len))
rng = hk.PRNGSequence(1)
params = forward.init(next(rng), emb, mask)
out = forward.apply(params, next(rng), emb, mask)
self._check_layer_naming(params)
if layer_norm:
np.testing.assert_allclose(out, 0)
else:
self.assertFalse(np.allclose(out, 0))
@parameterized.parameters(dict(causal=True), dict(causal=False))
def test_causal_attention(self, causal):
# input = [0, random, random, random]
# mask = [1, 0, 1, 1]
# For causal attention the second token can only attend to the first one, so
# it should be the same. For non-causal attention all tokens should change.
@hk.transform
def forward(emb, mask):
transformer = compressed_model.CompressedTransformer(
model.TransformerConfig(
num_heads=2,
num_layers=2,
key_size=5,
mlp_hidden_size=64,
dropout_rate=0.,
layer_norm=False,
causal=causal))
return transformer(emb, mask).output
seq_len = 4
emb = np.random.random((1, seq_len, 1))
emb[:, 0, :] = 0
mask = np.array([[1, 0, 1, 1]])
emb, mask = jnp.array(emb), jnp.array(mask)
rng = hk.PRNGSequence(1)
params = forward.init(next(rng), emb, mask)
params = self._zero_mlps(params)
out = forward.apply(params, next(rng), emb, mask)
self._check_layer_naming(params)
if causal:
self.assertEqual(0, out[0, 0, 0])
self.assertEqual(emb[0, 1, 0], out[0, 1, 0])
else:
self.assertNotEqual(0, out[0, 0, 0])
self.assertNotEqual(emb[0, 1, 0], out[0, 1, 0])
self.assertNotEqual(emb[0, 2, 0], out[0, 2, 0])
self.assertNotEqual(emb[0, 3, 0], out[0, 3, 0])
def test_setting_activation_function_to_zero(self):
# An activation function that always returns zeros should result in the
# same model output as setting all MLP weights to zero.
@hk.transform
def forward_zero(emb, mask):
transformer = compressed_model.CompressedTransformer(
model.TransformerConfig(
num_heads=2,
num_layers=2,
key_size=5,
mlp_hidden_size=64,
dropout_rate=0.,
causal=False,
layer_norm=False,
activation_function=jnp.zeros_like))
return transformer(emb, mask).output
@hk.transform
def forward(emb, mask):
transformer = compressed_model.CompressedTransformer(
model.TransformerConfig(
num_heads=2,
num_layers=2,
key_size=5,
mlp_hidden_size=64,
dropout_rate=0.,
causal=False,
layer_norm=False,
activation_function=jax.nn.gelu))
return transformer(emb, mask).output
seq_len = 4
emb = np.random.random((1, seq_len, 1))
mask = np.ones((1, seq_len))
emb, mask = jnp.array(emb), jnp.array(mask)
rng = hk.PRNGSequence(1)
params = forward.init(next(rng), emb, mask)
params_no_mlps = self._zero_mlps(params)
out_zero_activation = forward_zero.apply(params, next(rng), emb, mask)
out_no_mlps = forward.apply(params_no_mlps, next(rng), emb, mask)
self._check_layer_naming(params)
np.testing.assert_allclose(out_zero_activation, out_no_mlps)
self.assertFalse(np.allclose(out_zero_activation, 0))
def test_not_setting_embedding_size_produces_same_output_as_default_model(
self):
config = model.TransformerConfig(
num_heads=2,
num_layers=2,
key_size=5,
mlp_hidden_size=64,
dropout_rate=0.,
causal=False,
layer_norm=False)
@hk.without_apply_rng
@hk.transform
def forward_model(emb, mask):
return model.Transformer(config)(emb, mask).output
@hk.without_apply_rng
@hk.transform
def forward_superposition(emb, mask):
return compressed_model.CompressedTransformer(config)(emb, mask).output
seq_len = 4
emb = np.random.random((1, seq_len, 1))
mask = np.ones((1, seq_len))
emb, mask = jnp.array(emb), jnp.array(mask)
rng = hk.PRNGSequence(1)
params = forward_model.init(next(rng), emb, mask)
params_superposition = {
k.replace("transformer", "compressed_transformer"): v
for k, v in params.items()
}
out_model = forward_model.apply(params, emb, mask)
out_superposition = forward_superposition.apply(params_superposition, emb,
mask)
self._check_layer_naming(params_superposition)
np.testing.assert_allclose(out_model, out_superposition)
@parameterized.parameters(
dict(embedding_size=2, unembed_at_every_layer=True),
dict(embedding_size=2, unembed_at_every_layer=False),
dict(embedding_size=6, unembed_at_every_layer=True),
dict(embedding_size=6, unembed_at_every_layer=False))
def test_embbeding_size_produces_correct_shape_of_residuals_and_layer_outputs(
self, embedding_size, unembed_at_every_layer):
@hk.transform
def forward(emb, mask):
transformer = compressed_model.CompressedTransformer(
model.TransformerConfig(
num_heads=2,
num_layers=2,
key_size=5,
mlp_hidden_size=64,
dropout_rate=0.,
causal=False,
layer_norm=False))
return transformer(
emb,
mask,
embedding_size=embedding_size,
unembed_at_every_layer=unembed_at_every_layer,
)
seq_len = 4
model_size = 16
emb = np.random.random((1, seq_len, model_size))
mask = np.ones((1, seq_len))
emb, mask = jnp.array(emb), jnp.array(mask)
rng = hk.PRNGSequence(1)
params = forward.init(next(rng), emb, mask)
activations = forward.apply(params, next(rng), emb, mask)
self._check_layer_naming(params)
for residual in activations.residuals:
self.assertEqual(residual.shape, (1, seq_len, embedding_size))
for layer_output in activations.layer_outputs:
self.assertEqual(layer_output.shape, (1, seq_len, model_size))
@parameterized.parameters(
dict(model_size=2, unembed_at_every_layer=True),
dict(model_size=2, unembed_at_every_layer=False),
dict(model_size=6, unembed_at_every_layer=True),
dict(model_size=6, unembed_at_every_layer=False))
def test_identity_embedding_produces_same_output_as_standard_model(
self, model_size, unembed_at_every_layer):
config = model.TransformerConfig(
num_heads=2,
num_layers=2,
key_size=5,
mlp_hidden_size=64,
dropout_rate=0.,
causal=False,
layer_norm=False)
@hk.without_apply_rng
@hk.transform
def forward_model(emb, mask):
return model.Transformer(config)(emb, mask).output
@hk.without_apply_rng
@hk.transform
def forward_superposition(emb, mask):
return compressed_model.CompressedTransformer(config)(
emb,
mask,
embedding_size=model_size,
unembed_at_every_layer=unembed_at_every_layer).output
seq_len = 4
emb = np.random.random((1, seq_len, model_size))
mask = np.ones((1, seq_len))
emb, mask = jnp.array(emb), jnp.array(mask)
rng = hk.PRNGSequence(1)
params = forward_model.init(next(rng), emb, mask)
params_superposition = {
k.replace("transformer", "compressed_transformer"): v
for k, v in params.items()
}
params_superposition["compressed_transformer"] = {
"w_emb": jnp.identity(model_size)
}
out_model = forward_model.apply(params, emb, mask)
out_superposition = forward_superposition.apply(params_superposition, emb,
mask)
self._check_layer_naming(params_superposition)
np.testing.assert_allclose(out_model, out_superposition)
if __name__ == "__main__":
absltest.main()