RASP-Synthesis / craft /transformers_test.py
DeepMind LMI Team
Internal change
9bdaa77
raw
history blame
4.97 kB
# Copyright 2022 DeepMind Technologies Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for transformers."""
from absl.testing import absltest
from absl.testing import parameterized
import numpy as np
from tracr.craft import bases
from tracr.craft import tests_common
from tracr.craft import transformers
from tracr.craft import vectorspace_fns as vs_fns
# This makes it easier to use comments to annotate dimensions in arrays
# pylint: disable=g-no-space-after-comment
class AttentionHeadTest(tests_common.VectorFnTestCase):
@parameterized.parameters([
dict(with_residual_stream=False),
dict(with_residual_stream=True),
])
def test_attention_head(self, with_residual_stream):
i = bases.VectorSpaceWithBasis.from_values("i", [1, 2])
o = bases.VectorSpaceWithBasis.from_values("o", [1, 2])
q = bases.VectorSpaceWithBasis.from_values("q", [1, 2])
k = bases.VectorSpaceWithBasis.from_values("p", [1, 2])
rs = bases.direct_sum(i, o, q, k)
seq = bases.VectorInBasis(
rs.basis,
np.array([
#i1 i2 o1 o2 q1 q2 p1 p2
[1, 0, 0, 0, 1, 0, 1, 0],
[0, 1, 0, 0, 0, 1, 0, 1],
]))
head = transformers.AttentionHead(
w_qk=vs_fns.ScalarBilinear(q, k,
np.eye(2) * 100),
w_ov=vs_fns.Linear(i, o, np.eye(2)),
residual_space=rs if with_residual_stream else None,
causal=False,
)
self.assertVectorAllClose(
head.apply(seq),
bases.VectorInBasis(
rs.basis,
np.array([
#i1 i2 o1 o2 q1 q2 p1 p2
[0, 0, 1, 0, 0, 0, 0, 0],
[0, 0, 0, 1, 0, 0, 0, 0],
])),
)
class MLPTest(tests_common.VectorFnTestCase):
@parameterized.parameters([
dict(with_residual_stream=False, same_in_out=False),
dict(with_residual_stream=False, same_in_out=True),
dict(with_residual_stream=True, same_in_out=False),
dict(with_residual_stream=True, same_in_out=True),
])
def test_mlp(self, with_residual_stream, same_in_out):
i = bases.VectorSpaceWithBasis.from_values("i", [1, 2])
if same_in_out:
o, rs = i, i
expected_result = np.array([
#o1 o2
[1, 0],
[0, 1],
])
else:
o = bases.VectorSpaceWithBasis.from_values("o", [1, 2])
rs = bases.direct_sum(i, o)
expected_result = np.array([
#i1 i2 o1 o2
[0, 0, 1, 0],
[0, 0, 0, 1],
])
h = bases.VectorSpaceWithBasis.from_values("p", [1, 2])
seq = bases.VectorInBasis(
i.basis,
np.array([
#i1 i2
[1, -1],
[-1, 1],
])).project(rs)
mlp = transformers.MLP(
fst=vs_fns.Linear(i, h, np.eye(2)),
snd=vs_fns.Linear(h, o, np.eye(2)),
residual_space=rs if with_residual_stream else None,
)
self.assertEqual(
mlp.apply(seq),
bases.VectorInBasis(rs.basis, expected_result),
)
def test_combining_mlps(self):
in12 = bases.VectorSpaceWithBasis.from_values("in", [1, 2])
in34 = bases.VectorSpaceWithBasis.from_values("in", [3, 4])
out12 = bases.VectorSpaceWithBasis.from_values("out", [1, 2])
residual_space = bases.join_vector_spaces(in12, in34, out12)
h1 = bases.VectorSpaceWithBasis.from_values("h", [1])
h2 = bases.VectorSpaceWithBasis.from_values("h", [2])
# MLP1 maps in2 -> h1 -> out1
mlp1 = transformers.MLP(
fst=vs_fns.Linear(in12, h1, np.array([[0], [1]])),
snd=vs_fns.Linear(h1, out12, np.array([[1, 0]])))
# MLP2 maps in3 -> h2 -> out2
mlp2 = transformers.MLP(
fst=vs_fns.Linear(in34, h2, np.array([[1], [0]])),
snd=vs_fns.Linear(h2, out12, np.array([[0, 1]])))
mlp = transformers.MLP.combine_in_parallel([mlp1, mlp2])
seq = bases.VectorInBasis(
bases.direct_sum(in12, in34).basis,
np.array([
#i1 i2 i3 i4
[1, 2, 0, 0],
[0, 2, 3, 4],
])).project(residual_space)
expected_result = bases.VectorInBasis(
out12.basis,
np.array([
#o1 o2
[2, 0],
[2, 3],
]))
self.assertEqual(
mlp.apply(seq).project(out12),
expected_result,
)
if __name__ == "__main__":
absltest.main()