Spaces:
Sleeping
Sleeping
# Copyright 2022 DeepMind Technologies Limited. All Rights Reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
# ============================================================================== | |
"""Tests for selector_width.""" | |
from absl.testing import absltest | |
from absl.testing import parameterized | |
from tracr.craft import bases | |
from tracr.craft import tests_common | |
from tracr.craft.chamber import selector_width | |
class SelectorWidthTest(tests_common.VectorFnTestCase): | |
def test_selector_width_of_select_all_is_length(self, causal, | |
categorical_output, | |
input_seq): | |
vocab = range(-20, 20) | |
input_space = bases.VectorSpaceWithBasis.from_values("input", vocab) | |
if categorical_output: | |
output_space = bases.VectorSpaceWithBasis.from_values("output", range(10)) | |
else: | |
output_space = bases.VectorSpaceWithBasis( | |
[bases.BasisDirection("output")]) | |
bos_dir = bases.BasisDirection("bos_dimension") | |
bos_space = bases.VectorSpaceWithBasis([bos_dir]) | |
one_dir = bases.BasisDirection("one_dimension") | |
one_space = bases.VectorSpaceWithBasis([one_dir]) | |
input_space = bases.join_vector_spaces(input_space, bos_space, one_space) | |
residual_space = bases.join_vector_spaces(input_space, output_space) | |
bos_vec = residual_space.vector_from_basis_direction(bos_dir) | |
one_vec = residual_space.vector_from_basis_direction(one_dir) | |
block = selector_width.selector_width( | |
query_space=input_space, | |
key_space=input_space, | |
output_space=output_space, | |
bos_space=bos_space, | |
one_space=one_space, | |
attn_fn=lambda x, y: True, | |
out_value_set=set(range(len(input_seq) + 1)), | |
categorical_output=categorical_output, | |
causal=causal, | |
label="select_all") | |
test_inputs = [bos_vec + one_vec] | |
for x in input_seq: | |
test_inputs.append( | |
residual_space.vector_from_basis_direction( | |
bases.BasisDirection("input", x)) + one_vec) | |
test_inputs = bases.VectorInBasis.stack(test_inputs) | |
# Expect length of the input sequence | |
if causal: | |
expected_results = list(range(1, len(input_seq) + 1)) | |
else: | |
expected_results = [len(input_seq) for _ in input_seq] | |
if categorical_output: | |
expected_results = [ | |
output_space.vector_from_basis_direction( | |
bases.BasisDirection("output", x)) for x in expected_results | |
] | |
else: | |
output_vec = output_space.vector_from_basis_direction( | |
bases.BasisDirection("output")) | |
expected_results = [x * output_vec for x in expected_results] | |
expected_results = bases.VectorInBasis.stack(expected_results) | |
test_outputs = block.apply(test_inputs).project(output_space) | |
self.assertVectorAllClose( | |
tests_common.strip_bos_token(test_outputs), expected_results) | |
def test_selector_width_of_select_none_is_zero(self, causal, | |
categorical_output, input_seq): | |
vocab = range(-20, 20) | |
input_space = bases.VectorSpaceWithBasis.from_values("input", vocab) | |
if categorical_output: | |
output_space = bases.VectorSpaceWithBasis.from_values("output", range(10)) | |
else: | |
output_space = bases.VectorSpaceWithBasis( | |
[bases.BasisDirection("output")]) | |
bos_dir = bases.BasisDirection("bos_dimension") | |
bos_space = bases.VectorSpaceWithBasis([bos_dir]) | |
one_dir = bases.BasisDirection("one_dimension") | |
one_space = bases.VectorSpaceWithBasis([one_dir]) | |
input_space = bases.join_vector_spaces(input_space, bos_space, one_space) | |
residual_space = bases.join_vector_spaces(input_space, output_space) | |
bos_vec = residual_space.vector_from_basis_direction(bos_dir) | |
one_vec = residual_space.vector_from_basis_direction(one_dir) | |
block = selector_width.selector_width( | |
query_space=input_space, | |
key_space=input_space, | |
output_space=output_space, | |
bos_space=bos_space, | |
one_space=one_space, | |
attn_fn=lambda x, y: False, | |
out_value_set=set(range(len(input_seq) + 1)), | |
categorical_output=categorical_output, | |
causal=causal, | |
label="select_all") | |
test_inputs = [bos_vec + one_vec] | |
for x in input_seq: | |
test_inputs.append( | |
residual_space.vector_from_basis_direction( | |
bases.BasisDirection("input", x)) + one_vec) | |
test_inputs = bases.VectorInBasis.stack(test_inputs) | |
# Expect zero output | |
if categorical_output: | |
expected_results = [ | |
output_space.vector_from_basis_direction( | |
bases.BasisDirection("output", 0)) for _ in input_seq | |
] | |
else: | |
expected_results = [output_space.null_vector() for _ in input_seq] | |
expected_results = bases.VectorInBasis.stack(expected_results) | |
test_outputs = block.apply(test_inputs).project(output_space) | |
self.assertVectorAllClose( | |
tests_common.strip_bos_token(test_outputs), expected_results) | |
if __name__ == "__main__": | |
absltest.main() | |