RASP-Synthesis / tracr /craft /chamber /selector_width_test.py
Vladimir Mikulik
add typing_extensions to list of deps.
d4d39d0
# Copyright 2022 DeepMind Technologies Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for selector_width."""
from absl.testing import absltest
from absl.testing import parameterized
from tracr.craft import bases
from tracr.craft import tests_common
from tracr.craft.chamber import selector_width
class SelectorWidthTest(tests_common.VectorFnTestCase):
@parameterized.product(
causal=[False, True],
categorical_output=[False, True],
input_seq=[[1, 2, 3, 4, 5], [-1, 0, 1], [10]])
def test_selector_width_of_select_all_is_length(self, causal,
categorical_output,
input_seq):
vocab = range(-20, 20)
input_space = bases.VectorSpaceWithBasis.from_values("input", vocab)
if categorical_output:
output_space = bases.VectorSpaceWithBasis.from_values("output", range(10))
else:
output_space = bases.VectorSpaceWithBasis(
[bases.BasisDirection("output")])
bos_dir = bases.BasisDirection("bos_dimension")
bos_space = bases.VectorSpaceWithBasis([bos_dir])
one_dir = bases.BasisDirection("one_dimension")
one_space = bases.VectorSpaceWithBasis([one_dir])
input_space = bases.join_vector_spaces(input_space, bos_space, one_space)
residual_space = bases.join_vector_spaces(input_space, output_space)
bos_vec = residual_space.vector_from_basis_direction(bos_dir)
one_vec = residual_space.vector_from_basis_direction(one_dir)
block = selector_width.selector_width(
query_space=input_space,
key_space=input_space,
output_space=output_space,
bos_space=bos_space,
one_space=one_space,
attn_fn=lambda x, y: True,
out_value_set=set(range(len(input_seq) + 1)),
categorical_output=categorical_output,
causal=causal,
label="select_all")
test_inputs = [bos_vec + one_vec]
for x in input_seq:
test_inputs.append(
residual_space.vector_from_basis_direction(
bases.BasisDirection("input", x)) + one_vec)
test_inputs = bases.VectorInBasis.stack(test_inputs)
# Expect length of the input sequence
if causal:
expected_results = list(range(1, len(input_seq) + 1))
else:
expected_results = [len(input_seq) for _ in input_seq]
if categorical_output:
expected_results = [
output_space.vector_from_basis_direction(
bases.BasisDirection("output", x)) for x in expected_results
]
else:
output_vec = output_space.vector_from_basis_direction(
bases.BasisDirection("output"))
expected_results = [x * output_vec for x in expected_results]
expected_results = bases.VectorInBasis.stack(expected_results)
test_outputs = block.apply(test_inputs).project(output_space)
self.assertVectorAllClose(
tests_common.strip_bos_token(test_outputs), expected_results)
@parameterized.product(
causal=[False, True],
categorical_output=[False, True],
input_seq=[[1, 2, 3, 4, 5], [-1, 0, 1], [10]])
def test_selector_width_of_select_none_is_zero(self, causal,
categorical_output, input_seq):
vocab = range(-20, 20)
input_space = bases.VectorSpaceWithBasis.from_values("input", vocab)
if categorical_output:
output_space = bases.VectorSpaceWithBasis.from_values("output", range(10))
else:
output_space = bases.VectorSpaceWithBasis(
[bases.BasisDirection("output")])
bos_dir = bases.BasisDirection("bos_dimension")
bos_space = bases.VectorSpaceWithBasis([bos_dir])
one_dir = bases.BasisDirection("one_dimension")
one_space = bases.VectorSpaceWithBasis([one_dir])
input_space = bases.join_vector_spaces(input_space, bos_space, one_space)
residual_space = bases.join_vector_spaces(input_space, output_space)
bos_vec = residual_space.vector_from_basis_direction(bos_dir)
one_vec = residual_space.vector_from_basis_direction(one_dir)
block = selector_width.selector_width(
query_space=input_space,
key_space=input_space,
output_space=output_space,
bos_space=bos_space,
one_space=one_space,
attn_fn=lambda x, y: False,
out_value_set=set(range(len(input_seq) + 1)),
categorical_output=categorical_output,
causal=causal,
label="select_all")
test_inputs = [bos_vec + one_vec]
for x in input_seq:
test_inputs.append(
residual_space.vector_from_basis_direction(
bases.BasisDirection("input", x)) + one_vec)
test_inputs = bases.VectorInBasis.stack(test_inputs)
# Expect zero output
if categorical_output:
expected_results = [
output_space.vector_from_basis_direction(
bases.BasisDirection("output", 0)) for _ in input_seq
]
else:
expected_results = [output_space.null_vector() for _ in input_seq]
expected_results = bases.VectorInBasis.stack(expected_results)
test_outputs = block.apply(test_inputs).project(output_space)
self.assertVectorAllClose(
tests_common.strip_bos_token(test_outputs), expected_results)
if __name__ == "__main__":
absltest.main()