Spaces:
Sleeping
Sleeping
# Copyright 2022 DeepMind Technologies Limited. All Rights Reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
# ============================================================================== | |
"""A set of RASP programs and input/output pairs used in integration tests.""" | |
from tracr.compiler import lib | |
from tracr.rasp import rasp | |
UNIVERSAL_TEST_CASES = [ | |
dict( | |
testcase_name="frac_prevs_1", | |
program=lib.make_frac_prevs(rasp.tokens == "l"), | |
vocab={"h", "e", "l", "o"}, | |
test_input=list("hello"), | |
expected_output=[0.0, 0.0, 1 / 3, 1 / 2, 2 / 5], | |
max_seq_len=5), | |
dict( | |
testcase_name="frac_prevs_2", | |
program=lib.make_frac_prevs(rasp.tokens == "("), | |
vocab={"a", "b", "c", "(", ")"}, | |
test_input=list("a()b(c))"), | |
expected_output=[0.0, 1 / 2, 1 / 3, 1 / 4, 2 / 5, 2 / 6, 2 / 7, 2 / 8], | |
max_seq_len=10), | |
dict( | |
testcase_name="frac_prevs_3", | |
program=lib.make_frac_prevs(rasp.tokens == ")"), | |
vocab={"a", "b", "c", "(", ")"}, | |
test_input=list("a()b(c))"), | |
expected_output=[0.0, 0.0, 1 / 3, 1 / 4, 1 / 5, 1 / 6, 2 / 7, 3 / 8], | |
max_seq_len=10, | |
), | |
dict( | |
testcase_name="shift_by_one", | |
program=lib.shift_by(1, rasp.tokens), | |
vocab={"a", "b", "c", "d"}, | |
test_input=list("abcd"), | |
expected_output=[None, "a", "b", "c"], | |
max_seq_len=5, | |
), | |
dict( | |
testcase_name="shift_by_two", | |
program=lib.shift_by(2, rasp.tokens), | |
vocab={"a", "b", "c", "d"}, | |
test_input=list("abcd"), | |
expected_output=[None, None, "a", "b"], | |
max_seq_len=5, | |
), | |
dict( | |
testcase_name="detect_pattern_a", | |
program=lib.detect_pattern(rasp.tokens, "a"), | |
vocab={"a", "b", "c", "d"}, | |
test_input=list("bacd"), | |
expected_output=[False, True, False, False], | |
max_seq_len=5, | |
), | |
dict( | |
testcase_name="detect_pattern_ab", | |
program=lib.detect_pattern(rasp.tokens, "ab"), | |
vocab={"a", "b"}, | |
test_input=list("aaba"), | |
expected_output=[None, False, True, False], | |
max_seq_len=5, | |
), | |
dict( | |
testcase_name="detect_pattern_ab_2", | |
program=lib.detect_pattern(rasp.tokens, "ab"), | |
vocab={"a", "b"}, | |
test_input=list("abaa"), | |
expected_output=[None, True, False, False], | |
max_seq_len=5, | |
), | |
dict( | |
testcase_name="detect_pattern_ab_3", | |
program=lib.detect_pattern(rasp.tokens, "ab"), | |
vocab={"a", "b"}, | |
test_input=list("aaaa"), | |
expected_output=[None, False, False, False], | |
max_seq_len=5, | |
), | |
dict( | |
testcase_name="detect_pattern_abc", | |
program=lib.detect_pattern(rasp.tokens, "abc"), | |
vocab={"a", "b", "c"}, | |
test_input=list("abcabc"), | |
expected_output=[None, None, True, False, False, True], | |
max_seq_len=6, | |
), | |
] | |
TEST_CASES = UNIVERSAL_TEST_CASES + [ | |
dict( | |
testcase_name="reverse_1", | |
program=lib.make_reverse(rasp.tokens), | |
vocab={"a", "b", "c", "d"}, | |
test_input=list("abcd"), | |
expected_output=list("dcba"), | |
max_seq_len=5), | |
dict( | |
testcase_name="reverse_2", | |
program=lib.make_reverse(rasp.tokens), | |
vocab={"a", "b", "c", "d"}, | |
test_input=list("abc"), | |
expected_output=list("cba"), | |
max_seq_len=5), | |
dict( | |
testcase_name="reverse_3", | |
program=lib.make_reverse(rasp.tokens), | |
vocab={"a", "b", "c", "d"}, | |
test_input=list("ad"), | |
expected_output=list("da"), | |
max_seq_len=5), | |
dict( | |
testcase_name="reverse_4", | |
program=lib.make_reverse(rasp.tokens), | |
vocab={"a", "b", "c", "d"}, | |
test_input=["c"], | |
expected_output=["c"], | |
max_seq_len=5), | |
dict( | |
testcase_name="length_categorical_1", | |
program=rasp.categorical(lib.make_length()), | |
vocab={"a", "b", "c", "d"}, | |
test_input=list("abc"), | |
expected_output=[3, 3, 3], | |
max_seq_len=3), | |
dict( | |
testcase_name="length_categorical_2", | |
program=rasp.categorical(lib.make_length()), | |
vocab={"a", "b", "c", "d"}, | |
test_input=list("ad"), | |
expected_output=[2, 2], | |
max_seq_len=3), | |
dict( | |
testcase_name="length_categorical_3", | |
program=rasp.categorical(lib.make_length()), | |
vocab={"a", "b", "c", "d"}, | |
test_input=["c"], | |
expected_output=[1], | |
max_seq_len=3), | |
dict( | |
testcase_name="length_numerical_1", | |
program=rasp.numerical(lib.make_length()), | |
vocab={"a", "b", "c", "d"}, | |
test_input=list("abc"), | |
expected_output=[3, 3, 3], | |
max_seq_len=3), | |
dict( | |
testcase_name="length_numerical_2", | |
program=rasp.numerical(lib.make_length()), | |
vocab={"a", "b", "c", "d"}, | |
test_input=list("ad"), | |
expected_output=[2, 2], | |
max_seq_len=3), | |
dict( | |
testcase_name="length_numerical_3", | |
program=rasp.numerical(lib.make_length()), | |
vocab={"a", "b", "c", "d"}, | |
test_input=["c"], | |
expected_output=[1], | |
max_seq_len=3), | |
dict( | |
testcase_name="pair_balance_1", | |
program=lib.make_pair_balance(rasp.tokens, "(", ")"), | |
vocab={"a", "b", "c", "(", ")"}, | |
test_input=list("a()b(c))"), | |
expected_output=[0.0, 1 / 2, 0.0, 0.0, 1 / 5, 1 / 6, 0.0, -1 / 8], | |
max_seq_len=10), | |
dict( | |
testcase_name="shuffle_dyck2_1", | |
program=lib.make_shuffle_dyck(pairs=["()", "{}"]), | |
vocab={"(", ")", "{", "}"}, | |
test_input=list("({)}"), | |
expected_output=[1, 1, 1, 1], | |
max_seq_len=5), | |
dict( | |
testcase_name="shuffle_dyck2_2", | |
program=lib.make_shuffle_dyck(pairs=["()", "{}"]), | |
vocab={"(", ")", "{", "}"}, | |
test_input=list("(){)}"), | |
expected_output=[0, 0, 0, 0, 0], | |
max_seq_len=5), | |
dict( | |
testcase_name="shuffle_dyck2_3", | |
program=lib.make_shuffle_dyck(pairs=["()", "{}"]), | |
vocab={"(", ")", "{", "}"}, | |
test_input=list("{}("), | |
expected_output=[0, 0, 0], | |
max_seq_len=5), | |
dict( | |
testcase_name="shuffle_dyck3_1", | |
program=lib.make_shuffle_dyck(pairs=["()", "{}", "[]"]), | |
vocab={"(", ")", "{", "}", "[", "]"}, | |
test_input=list("({)[}]"), | |
expected_output=[1, 1, 1, 1, 1, 1], | |
max_seq_len=6), | |
dict( | |
testcase_name="shuffle_dyck3_2", | |
program=lib.make_shuffle_dyck(pairs=["()", "{}", "[]"]), | |
vocab={"(", ")", "{", "}", "[", "]"}, | |
test_input=list("(){)}"), | |
expected_output=[0, 0, 0, 0, 0], | |
max_seq_len=6), | |
dict( | |
testcase_name="shuffle_dyck3_3", | |
program=lib.make_shuffle_dyck(pairs=["()", "{}", "[]"]), | |
vocab={"(", ")", "{", "}", "[", "]"}, | |
test_input=list("{}[(]"), | |
expected_output=[0, 0, 0, 0, 0], | |
max_seq_len=6), | |
dict( | |
testcase_name="hist", | |
program=lib.make_hist(), | |
vocab={"a", "b", "c", "d"}, | |
test_input=list("abac"), | |
expected_output=[2, 1, 2, 1], | |
max_seq_len=5, | |
), | |
dict( | |
testcase_name="sort_unique_1", | |
program=lib.make_sort_unique(vals=rasp.tokens, keys=rasp.tokens), | |
vocab={1, 2, 3, 4}, | |
test_input=[2, 4, 3, 1], | |
expected_output=[1, 2, 3, 4], | |
max_seq_len=5), | |
dict( | |
testcase_name="sort_unique_2", | |
program=lib.make_sort_unique(vals=rasp.tokens, keys=1 - rasp.indices), | |
vocab={"a", "b", "c", "d"}, | |
test_input=list("abcd"), | |
expected_output=["d", "c", "b", "a"], | |
max_seq_len=5), | |
dict( | |
testcase_name="sort_1", | |
program=lib.make_sort( | |
vals=rasp.tokens, keys=rasp.tokens, max_seq_len=5, min_key=1), | |
vocab={1, 2, 3, 4}, | |
test_input=[2, 4, 3, 1], | |
expected_output=[1, 2, 3, 4], | |
max_seq_len=5), | |
dict( | |
testcase_name="sort_2", | |
program=lib.make_sort( | |
vals=rasp.tokens, keys=1 - rasp.indices, max_seq_len=5, min_key=1), | |
vocab={"a", "b", "c", "d"}, | |
test_input=list("abcd"), | |
expected_output=["d", "c", "b", "a"], | |
max_seq_len=5), | |
dict( | |
testcase_name="sort_3", | |
program=lib.make_sort( | |
vals=rasp.tokens, keys=rasp.tokens, max_seq_len=5, min_key=1), | |
vocab={1, 2, 3, 4}, | |
test_input=[2, 4, 1, 2], | |
expected_output=[1, 2, 2, 4], | |
max_seq_len=5), | |
dict( | |
testcase_name="sort_freq", | |
program=lib.make_sort_freq(max_seq_len=5), | |
vocab={1, 2, 3, 4}, | |
test_input=[2, 4, 2, 1], | |
expected_output=[2, 2, 4, 1], | |
max_seq_len=5), | |
dict( | |
testcase_name="make_count_less_freq_categorical_1", | |
program=lib.make_count_less_freq(n=2), | |
vocab={"a", "b", "c", "d"}, | |
test_input=["a", "a", "a", "b", "b", "c"], | |
expected_output=[3, 3, 3, 3, 3, 3], | |
max_seq_len=6), | |
dict( | |
testcase_name="make_count_less_freq_categorical_2", | |
program=lib.make_count_less_freq(n=2), | |
vocab={"a", "b", "c", "d"}, | |
test_input=["a", "a", "c", "b", "b", "c"], | |
expected_output=[6, 6, 6, 6, 6, 6], | |
max_seq_len=6), | |
dict( | |
testcase_name="make_count_less_freq_numerical_1", | |
program=rasp.numerical(lib.make_count_less_freq(n=2)), | |
vocab={"a", "b", "c", "d"}, | |
test_input=["a", "a", "a", "b", "b", "c"], | |
expected_output=[3, 3, 3, 3, 3, 3], | |
max_seq_len=6), | |
dict( | |
testcase_name="make_count_less_freq_numerical_2", | |
program=rasp.numerical(lib.make_count_less_freq(n=2)), | |
vocab={"a", "b", "c", "d"}, | |
test_input=["a", "a", "c", "b", "b", "c"], | |
expected_output=[6, 6, 6, 6, 6, 6], | |
max_seq_len=6), | |
dict( | |
testcase_name="make_count_1", | |
program=lib.make_count(rasp.tokens, "a"), | |
vocab={"a", "b", "c"}, | |
test_input=["a", "a", "a", "b", "b", "c"], | |
expected_output=[3, 3, 3, 3, 3, 3], | |
max_seq_len=8, | |
), | |
dict( | |
testcase_name="make_count_2", | |
program=lib.make_count(rasp.tokens, "a"), | |
vocab={"a", "b", "c"}, | |
test_input=["c", "a", "b", "c"], | |
expected_output=[1, 1, 1, 1], | |
max_seq_len=8, | |
), | |
dict( | |
testcase_name="make_count_3", | |
program=lib.make_count(rasp.tokens, "a"), | |
vocab={"a", "b", "c"}, | |
test_input=["b", "b", "c"], | |
expected_output=[0, 0, 0], | |
max_seq_len=8, | |
), | |
dict( | |
testcase_name="make_nary_sequencemap_1", | |
program=lib.make_nary_sequencemap( | |
lambda x, y, z: x + y - z, rasp.tokens, rasp.tokens, rasp.indices), | |
vocab={1, 2, 3}, | |
test_input=[1, 2, 3], | |
expected_output=[2, 3, 4], | |
max_seq_len=5, | |
), | |
dict( | |
testcase_name="make_nary_sequencemap_2", | |
program=lib.make_nary_sequencemap( | |
lambda x, y, z: x * y / z, rasp.indices, rasp.indices, rasp.tokens), | |
vocab={1, 2, 3}, | |
test_input=[1, 2, 3], | |
expected_output=[0, 1 / 2, 4 / 3], | |
max_seq_len=3, | |
) | |
] | |
# make_nary_sequencemap(f, *sops) | |
CAUSAL_TEST_CASES = UNIVERSAL_TEST_CASES + [ | |
dict( | |
testcase_name="selector_width", | |
program=rasp.SelectorWidth( | |
rasp.Select(rasp.tokens, rasp.tokens, rasp.Comparison.TRUE)), | |
vocab={"a", "b", "c", "d"}, | |
test_input=list("abcd"), | |
expected_output=[1, 2, 3, 4], | |
max_seq_len=5), | |
] | |