Spaces:
Sleeping
Sleeping
# Copyright 2022 DeepMind Technologies Limited. All Rights Reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
# ============================================================================== | |
"""Tests for vectorspace_fns.""" | |
from absl.testing import absltest | |
from absl.testing import parameterized | |
import numpy as np | |
from tracr.craft import bases | |
from tracr.craft import tests_common | |
from tracr.craft import vectorspace_fns as vs_fns | |
class LinearTest(tests_common.VectorFnTestCase): | |
def test_identity_from_matrix(self): | |
vs = bases.VectorSpaceWithBasis.from_names(["a", "b", "c"]) | |
f = vs_fns.Linear(vs, vs, np.eye(3)) | |
for v in vs.basis_vectors(): | |
self.assertEqual(f(v), v) | |
def test_identity_from_action(self): | |
vs = bases.VectorSpaceWithBasis.from_names(["a", "b", "c"]) | |
f = vs_fns.Linear.from_action(vs, vs, vs.vector_from_basis_direction) | |
for v in vs.basis_vectors(): | |
self.assertEqual(f(v), v) | |
def test_nonidentiy(self): | |
vs = bases.VectorSpaceWithBasis.from_names(["a", "b"]) | |
a = vs.vector_from_basis_direction(bases.BasisDirection("a")) | |
b = vs.vector_from_basis_direction(bases.BasisDirection("b")) | |
f = vs_fns.Linear(vs, vs, np.array([[0.3, 0.7], [0.2, 0.1]])) | |
self.assertEqual( | |
f(a), bases.VectorInBasis(vs.basis, np.array([0.3, 0.7]))) | |
self.assertEqual( | |
f(b), bases.VectorInBasis(vs.basis, np.array([0.2, 0.1]))) | |
def test_different_vector_spaces(self): | |
vs1 = bases.VectorSpaceWithBasis.from_names(["a", "b"]) | |
vs2 = bases.VectorSpaceWithBasis.from_names(["c", "d"]) | |
a, b = vs1.basis_vectors() | |
c, d = vs2.basis_vectors() | |
f = vs_fns.Linear(vs1, vs2, np.eye(2)) | |
self.assertEqual(f(a), c) | |
self.assertEqual(f(b), d) | |
def test_combining_linear_functions_with_different_input(self): | |
vs1 = bases.VectorSpaceWithBasis.from_names(["a", "b"]) | |
vs2 = bases.VectorSpaceWithBasis.from_names(["c", "d"]) | |
vs = bases.direct_sum(vs1, vs2) | |
a = vs.vector_from_basis_direction(bases.BasisDirection("a")) | |
b = vs.vector_from_basis_direction(bases.BasisDirection("b")) | |
c = vs.vector_from_basis_direction(bases.BasisDirection("c")) | |
d = vs.vector_from_basis_direction(bases.BasisDirection("d")) | |
f1 = vs_fns.Linear(vs1, vs1, np.array([[0, 1], [1, 0]])) | |
f2 = vs_fns.Linear(vs2, vs2, np.array([[1, 0], [0, 0]])) | |
f3 = vs_fns.Linear.combine_in_parallel([f1, f2]) | |
self.assertEqual( | |
f3(a), bases.VectorInBasis(vs.basis, np.array([0, 1, 0, 0]))) | |
self.assertEqual( | |
f3(b), bases.VectorInBasis(vs.basis, np.array([1, 0, 0, 0]))) | |
self.assertEqual( | |
f3(c), bases.VectorInBasis(vs.basis, np.array([0, 0, 1, 0]))) | |
self.assertEqual( | |
f3(d), bases.VectorInBasis(vs.basis, np.array([0, 0, 0, 0]))) | |
def test_combining_linear_functions_with_same_input(self): | |
vs = bases.VectorSpaceWithBasis.from_names(["a", "b"]) | |
a = vs.vector_from_basis_direction(bases.BasisDirection("a")) | |
b = vs.vector_from_basis_direction(bases.BasisDirection("b")) | |
f1 = vs_fns.Linear(vs, vs, np.array([[0, 1], [1, 0]])) | |
f2 = vs_fns.Linear(vs, vs, np.array([[1, 0], [0, 0]])) | |
f3 = vs_fns.Linear.combine_in_parallel([f1, f2]) | |
self.assertEqual( | |
f3(a), bases.VectorInBasis(vs.basis, np.array([1, 1]))) | |
self.assertEqual( | |
f3(b), bases.VectorInBasis(vs.basis, np.array([1, 0]))) | |
self.assertEqual(f3(a), f1(a) + f2(a)) | |
self.assertEqual(f3(b), f1(b) + f2(b)) | |
class ProjectionTest(tests_common.VectorFnTestCase): | |
def test_projection_to_larger_space(self): | |
vs1 = bases.VectorSpaceWithBasis.from_names(["a", "b"]) | |
vs2 = bases.VectorSpaceWithBasis.from_names(["a", "b", "c", "d"]) | |
a1, b1 = vs1.basis_vectors() | |
a2, b2, _, _ = vs2.basis_vectors() | |
f = vs_fns.project(vs1, vs2) | |
self.assertEqual(f(a1), a2) | |
self.assertEqual(f(b1), b2) | |
def test_projection_to_smaller_space(self): | |
vs1 = bases.VectorSpaceWithBasis.from_names(["a", "b", "c", "d"]) | |
vs2 = bases.VectorSpaceWithBasis.from_names(["a", "b"]) | |
a1, b1, c1, d1 = vs1.basis_vectors() | |
a2, b2 = vs2.basis_vectors() | |
f = vs_fns.project(vs1, vs2) | |
self.assertEqual(f(a1), a2) | |
self.assertEqual(f(b1), b2) | |
self.assertEqual(f(c1), vs2.null_vector()) | |
self.assertEqual(f(d1), vs2.null_vector()) | |
class ScalarBilinearTest(parameterized.TestCase): | |
def test_identity_matrix(self): | |
vs = bases.VectorSpaceWithBasis.from_names(["a", "b"]) | |
a, b = vs.basis_vectors() | |
f = vs_fns.ScalarBilinear(vs, vs, np.eye(2)) | |
self.assertEqual(f(a, a), 1) | |
self.assertEqual(f(a, b), 0) | |
self.assertEqual(f(b, a), 0) | |
self.assertEqual(f(b, b), 1) | |
def test_identity_from_action(self): | |
vs = bases.VectorSpaceWithBasis.from_names(["a", "b"]) | |
a, b = vs.basis_vectors() | |
f = vs_fns.ScalarBilinear.from_action(vs, vs, lambda x, y: int(x == y)) | |
self.assertEqual(f(a, a), 1) | |
self.assertEqual(f(a, b), 0) | |
self.assertEqual(f(b, a), 0) | |
self.assertEqual(f(b, b), 1) | |
def test_non_identity(self): | |
vs = bases.VectorSpaceWithBasis.from_names(["a", "b"]) | |
a, b = vs.basis_vectors() | |
f = vs_fns.ScalarBilinear.from_action(vs, vs, | |
lambda x, y: int(x.name == "a")) | |
self.assertEqual(f(a, a), 1) | |
self.assertEqual(f(a, b), 1) | |
self.assertEqual(f(b, a), 0) | |
self.assertEqual(f(b, b), 0) | |
if __name__ == "__main__": | |
absltest.main() | |