RASP-Synthesis / tracr /craft /vectorspace_fns_test.py
Vladimir Mikulik
add typing_extensions to list of deps.
d4d39d0
raw
history blame
5.93 kB
# Copyright 2022 DeepMind Technologies Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for vectorspace_fns."""
from absl.testing import absltest
from absl.testing import parameterized
import numpy as np
from tracr.craft import bases
from tracr.craft import tests_common
from tracr.craft import vectorspace_fns as vs_fns
class LinearTest(tests_common.VectorFnTestCase):
def test_identity_from_matrix(self):
vs = bases.VectorSpaceWithBasis.from_names(["a", "b", "c"])
f = vs_fns.Linear(vs, vs, np.eye(3))
for v in vs.basis_vectors():
self.assertEqual(f(v), v)
def test_identity_from_action(self):
vs = bases.VectorSpaceWithBasis.from_names(["a", "b", "c"])
f = vs_fns.Linear.from_action(vs, vs, vs.vector_from_basis_direction)
for v in vs.basis_vectors():
self.assertEqual(f(v), v)
def test_nonidentiy(self):
vs = bases.VectorSpaceWithBasis.from_names(["a", "b"])
a = vs.vector_from_basis_direction(bases.BasisDirection("a"))
b = vs.vector_from_basis_direction(bases.BasisDirection("b"))
f = vs_fns.Linear(vs, vs, np.array([[0.3, 0.7], [0.2, 0.1]]))
self.assertEqual(
f(a), bases.VectorInBasis(vs.basis, np.array([0.3, 0.7])))
self.assertEqual(
f(b), bases.VectorInBasis(vs.basis, np.array([0.2, 0.1])))
def test_different_vector_spaces(self):
vs1 = bases.VectorSpaceWithBasis.from_names(["a", "b"])
vs2 = bases.VectorSpaceWithBasis.from_names(["c", "d"])
a, b = vs1.basis_vectors()
c, d = vs2.basis_vectors()
f = vs_fns.Linear(vs1, vs2, np.eye(2))
self.assertEqual(f(a), c)
self.assertEqual(f(b), d)
def test_combining_linear_functions_with_different_input(self):
vs1 = bases.VectorSpaceWithBasis.from_names(["a", "b"])
vs2 = bases.VectorSpaceWithBasis.from_names(["c", "d"])
vs = bases.direct_sum(vs1, vs2)
a = vs.vector_from_basis_direction(bases.BasisDirection("a"))
b = vs.vector_from_basis_direction(bases.BasisDirection("b"))
c = vs.vector_from_basis_direction(bases.BasisDirection("c"))
d = vs.vector_from_basis_direction(bases.BasisDirection("d"))
f1 = vs_fns.Linear(vs1, vs1, np.array([[0, 1], [1, 0]]))
f2 = vs_fns.Linear(vs2, vs2, np.array([[1, 0], [0, 0]]))
f3 = vs_fns.Linear.combine_in_parallel([f1, f2])
self.assertEqual(
f3(a), bases.VectorInBasis(vs.basis, np.array([0, 1, 0, 0])))
self.assertEqual(
f3(b), bases.VectorInBasis(vs.basis, np.array([1, 0, 0, 0])))
self.assertEqual(
f3(c), bases.VectorInBasis(vs.basis, np.array([0, 0, 1, 0])))
self.assertEqual(
f3(d), bases.VectorInBasis(vs.basis, np.array([0, 0, 0, 0])))
def test_combining_linear_functions_with_same_input(self):
vs = bases.VectorSpaceWithBasis.from_names(["a", "b"])
a = vs.vector_from_basis_direction(bases.BasisDirection("a"))
b = vs.vector_from_basis_direction(bases.BasisDirection("b"))
f1 = vs_fns.Linear(vs, vs, np.array([[0, 1], [1, 0]]))
f2 = vs_fns.Linear(vs, vs, np.array([[1, 0], [0, 0]]))
f3 = vs_fns.Linear.combine_in_parallel([f1, f2])
self.assertEqual(
f3(a), bases.VectorInBasis(vs.basis, np.array([1, 1])))
self.assertEqual(
f3(b), bases.VectorInBasis(vs.basis, np.array([1, 0])))
self.assertEqual(f3(a), f1(a) + f2(a))
self.assertEqual(f3(b), f1(b) + f2(b))
class ProjectionTest(tests_common.VectorFnTestCase):
def test_projection_to_larger_space(self):
vs1 = bases.VectorSpaceWithBasis.from_names(["a", "b"])
vs2 = bases.VectorSpaceWithBasis.from_names(["a", "b", "c", "d"])
a1, b1 = vs1.basis_vectors()
a2, b2, _, _ = vs2.basis_vectors()
f = vs_fns.project(vs1, vs2)
self.assertEqual(f(a1), a2)
self.assertEqual(f(b1), b2)
def test_projection_to_smaller_space(self):
vs1 = bases.VectorSpaceWithBasis.from_names(["a", "b", "c", "d"])
vs2 = bases.VectorSpaceWithBasis.from_names(["a", "b"])
a1, b1, c1, d1 = vs1.basis_vectors()
a2, b2 = vs2.basis_vectors()
f = vs_fns.project(vs1, vs2)
self.assertEqual(f(a1), a2)
self.assertEqual(f(b1), b2)
self.assertEqual(f(c1), vs2.null_vector())
self.assertEqual(f(d1), vs2.null_vector())
class ScalarBilinearTest(parameterized.TestCase):
def test_identity_matrix(self):
vs = bases.VectorSpaceWithBasis.from_names(["a", "b"])
a, b = vs.basis_vectors()
f = vs_fns.ScalarBilinear(vs, vs, np.eye(2))
self.assertEqual(f(a, a), 1)
self.assertEqual(f(a, b), 0)
self.assertEqual(f(b, a), 0)
self.assertEqual(f(b, b), 1)
def test_identity_from_action(self):
vs = bases.VectorSpaceWithBasis.from_names(["a", "b"])
a, b = vs.basis_vectors()
f = vs_fns.ScalarBilinear.from_action(vs, vs, lambda x, y: int(x == y))
self.assertEqual(f(a, a), 1)
self.assertEqual(f(a, b), 0)
self.assertEqual(f(b, a), 0)
self.assertEqual(f(b, b), 1)
def test_non_identity(self):
vs = bases.VectorSpaceWithBasis.from_names(["a", "b"])
a, b = vs.basis_vectors()
f = vs_fns.ScalarBilinear.from_action(vs, vs,
lambda x, y: int(x.name == "a"))
self.assertEqual(f(a, a), 1)
self.assertEqual(f(a, b), 1)
self.assertEqual(f(b, a), 0)
self.assertEqual(f(b, b), 0)
if __name__ == "__main__":
absltest.main()