RASP-Synthesis / rasp /rasp_test.py
DeepMind LMI Team
Internal change
9bdaa77
raw
history blame
18.4 kB
# Copyright 2022 DeepMind Technologies Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for rasp.rasp."""
import itertools
from absl.testing import absltest
from absl.testing import parameterized
import numpy as np
from tracr.rasp import rasp
# Note that the example text labels must match their default names.
_SOP_PRIMITIVE_EXAMPLES = lambda: [ # pylint: disable=g-long-lambda
("tokens", rasp.tokens),
("length", rasp.length),
("indices", rasp.indices),
]
_NONPRIMITIVE_SOP_EXAMPLES = lambda: [ # pylint: disable=g-long-lambda
("map", rasp.Map(lambda x: x, rasp.tokens)),
(
"sequence_map",
rasp.SequenceMap(lambda x, y: x + y, rasp.tokens, rasp.tokens),
),
(
"linear_sequence_map",
rasp.LinearSequenceMap(rasp.tokens, rasp.tokens, 0.1, 0.2),
),
(
"aggregate",
rasp.Aggregate(
rasp.Select(rasp.tokens, rasp.indices, rasp.Comparison.EQ),
rasp.tokens,
),
),
(
"selector_width",
rasp.SelectorWidth(
rasp.Select(rasp.tokens, rasp.indices, rasp.Comparison.EQ)),
),
]
_SOP_EXAMPLES = lambda: _SOP_PRIMITIVE_EXAMPLES() + _NONPRIMITIVE_SOP_EXAMPLES()
_SELECTOR_EXAMPLES = lambda: [ # pylint: disable=g-long-lambda
("select", rasp.Select(rasp.tokens, rasp.indices, rasp.Comparison.EQ)),
("selector_and",
rasp.SelectorAnd(
rasp.Select(rasp.tokens, rasp.indices, rasp.Comparison.EQ),
rasp.Select(rasp.indices, rasp.tokens, rasp.Comparison.LEQ),
)),
("selector_or",
rasp.SelectorOr(
rasp.Select(rasp.tokens, rasp.indices, rasp.Comparison.EQ),
rasp.Select(rasp.indices, rasp.tokens, rasp.Comparison.LEQ),
)),
("selector_not",
rasp.SelectorNot(
rasp.Select(rasp.tokens, rasp.indices, rasp.Comparison.EQ),)),
]
_ALL_EXAMPLES = lambda: _SOP_EXAMPLES() + _SELECTOR_EXAMPLES()
class LabelTest(parameterized.TestCase):
def test_primitive_labels(self):
self.assertEqual(rasp.tokens.label, "tokens")
self.assertEqual(rasp.indices.label, "indices")
self.assertEqual(rasp.length.label, "length")
@parameterized.parameters(*_ALL_EXAMPLES())
def test_default_names(self, default_name: str, expr: rasp.RASPExpr):
self.assertEqual(expr.name, default_name)
class SOpTest(parameterized.TestCase):
"""Tests for S-Ops."""
@parameterized.parameters(
("hello", ["h", "e", "l", "l", "o"]),
("h", ["h"]),
(["h", "e", "l", "l", "o"], ["h", "e", "l", "l", "o"]),
(["h"], ["h"]),
([1, 2], [1, 2]),
([0.1, 0.2], [0.1, 0.2]),
)
def test_tokens(self, input_sequence, expected):
self.assertEqual(rasp.tokens(input_sequence), expected)
@parameterized.parameters(
("hello", [0, 1, 2, 3, 4]),
("h", [0]),
(["h", "e", "l", "l", "o"], [0, 1, 2, 3, 4]),
(["h"], [0]),
([1, 2], [0, 1]),
([0.1, 0.2], [0, 1]),
)
def test_indices(self, input_sequence, expected):
self.assertEqual(rasp.indices(input_sequence), expected)
@parameterized.parameters(
("hello", [5, 5, 5, 5, 5]),
("h", [1]),
(["h", "e", "l", "l", "o"], [5, 5, 5, 5, 5]),
(["h"], [1]),
([1, 2], [2, 2]),
([0.1, 0.2], [2, 2]),
)
def test_length(self, input_sequence, expected):
self.assertEqual(rasp.length(input_sequence), expected)
def test_prims_are_sops(self):
self.assertIsInstance(rasp.tokens, rasp.SOp)
self.assertIsInstance(rasp.indices, rasp.SOp)
self.assertIsInstance(rasp.length, rasp.SOp)
def test_prims_are_raspexprs(self):
self.assertIsInstance(rasp.tokens, rasp.RASPExpr)
self.assertIsInstance(rasp.indices, rasp.RASPExpr)
self.assertIsInstance(rasp.length, rasp.RASPExpr)
@parameterized.parameters(
(lambda x: x + "a", "hello", ["ha", "ea", "la", "la", "oa"]),
(lambda x: x + "t", "h", ["ht"]),
(lambda x: x + 1, [1, 2], [2, 3]),
(lambda x: x / 2, [0.1, 0.2], [0.05, 0.1]),
)
def test_map(self, f, input_sequence, expected):
self.assertEqual(rasp.Map(f, rasp.tokens)(input_sequence), expected)
def test_nested_elementwise_ops_results_in_only_one_map_object(self):
map_sop = ((rasp.tokens * 2) + 2) / 2
self.assertEqual(map_sop.inner, rasp.tokens)
self.assertEqual(map_sop([1]), [2])
@parameterized.parameters(
(lambda x, y: x + y, "hello", ["hh", "ee", "ll", "ll", "oo"]),
(lambda x, y: x + y, "h", ["hh"]),
(lambda x, y: x + y, [1, 2], [2, 4]),
(lambda x, y: x * y, [1, 2], [1, 4]),
)
def test_sequence_map(self, f, input_sequence, expected):
self.assertEqual(
rasp.SequenceMap(f, rasp.tokens, rasp.tokens)(input_sequence), expected)
def test_sequence_map_with_same_inputs_logs_warning(self):
with self.assertLogs(level="WARNING"):
rasp.SequenceMap(lambda x, y: x + y, rasp.tokens, rasp.tokens)
@parameterized.parameters(
(1, 1, [1, 2], [2, 4]),
(1, -1, [1, 2], [0, 0]),
(1, -2, [1, 2], [-1, -2]),
)
def test_linear_sequence_map(self, fst_fac, snd_fac, input_sequence,
expected):
self.assertEqual(
rasp.LinearSequenceMap(rasp.tokens, rasp.tokens, fst_fac,
snd_fac)(input_sequence), expected)
@parameterized.parameters(
([5, 5, 5, 5, 5], "hello", [5, 5, 5, 5, 5]),
(["e"], "h", ["e"]),
([1, 2, 3, 4, 5], ["h", "e", "l", "l", "o"], [1, 2, 3, 4, 5]),
([2, 2], [1, 2], [2, 2]),
)
def test_constant(self, const, input_sequence, expected):
self.assertEqual(rasp.ConstantSOp(const)(input_sequence), expected)
def test_constant_complains_if_sizes_dont_match(self):
with self.assertRaisesRegex(
ValueError,
r"^.*Constant len .* doesn't match input len .*$",):
rasp.ConstantSOp([1, 2, 3])("longer string")
def test_can_turn_off_constant_complaints(self):
rasp.ConstantSOp([1, 2, 3], check_length=False)("longer string")
def test_numeric_dunders(self):
# We don't check all the cases here -- only a few representative ones.
self.assertEqual(
(rasp.tokens > 1)([0, 1, 2]),
[0, 0, 1],
)
self.assertEqual(
(1 < rasp.tokens)([0, 1, 2]),
[0, 0, 1],
)
self.assertEqual(
(rasp.tokens < 1)([0, 1, 2]),
[1, 0, 0],
)
self.assertEqual(
(1 > rasp.tokens)([0, 1, 2]),
[1, 0, 0],
)
self.assertEqual(
(rasp.tokens == 1)([0, 1, 2]),
[0, 1, 0],
)
self.assertEqual(
(rasp.tokens + 1)([0, 1, 2]),
[1, 2, 3],
)
self.assertEqual(
(1 + rasp.tokens)([0, 1, 2]),
[1, 2, 3],
)
def test_dunders_with_sop(self):
self.assertEqual(
(rasp.tokens + rasp.indices)([0, 1, 2]),
[0, 2, 4],
)
self.assertEqual(
(rasp.length - 1 - rasp.indices)([0, 1, 2]),
[2, 1, 0],
)
self.assertEqual(
(rasp.length * rasp.length)([0, 1, 2]),
[9, 9, 9],
)
def test_logical_dunders(self):
self.assertEqual(
(rasp.tokens & True)([True, False]),
[True, False],
)
self.assertEqual(
(rasp.tokens & False)([True, False]),
[False, False],
)
self.assertEqual(
(rasp.tokens | True)([True, False]),
[True, True],
)
self.assertEqual(
(rasp.tokens | False)([True, False]),
[True, False],
)
self.assertEqual(
(True & rasp.tokens)([True, False]),
[True, False],
)
self.assertEqual(
(False & rasp.tokens)([True, False]),
[False, False],
)
self.assertEqual(
(True | rasp.tokens)([True, False]),
[True, True],
)
self.assertEqual(
(False | rasp.tokens)([True, False]),
[True, False],
)
self.assertEqual(
(~rasp.tokens)([True, False]),
[False, True],
)
self.assertEqual(
(rasp.ConstantSOp([True, True, False, False])
& rasp.ConstantSOp([True, False, True, False]))([1, 1, 1, 1]),
[True, False, False, False],
)
self.assertEqual(
(rasp.ConstantSOp([True, True, False, False])
| rasp.ConstantSOp([True, False, True, False]))([1, 1, 1, 1]),
[True, True, True, False],
)
class EncodingTest(parameterized.TestCase):
"""Tests for SOp encodings."""
@parameterized.named_parameters(*_SOP_EXAMPLES())
def test_all_sops_are_categorical_by_default(self, sop: rasp.SOp):
self.assertTrue(rasp.is_categorical(sop))
@parameterized.named_parameters(*_SOP_EXAMPLES())
def test_is_numerical(self, sop: rasp.SOp):
self.assertTrue(rasp.is_numerical(rasp.numerical(sop)))
self.assertFalse(rasp.is_numerical(rasp.categorical(sop)))
@parameterized.named_parameters(*_SOP_EXAMPLES())
def test_is_categorical(self, sop: rasp.SOp):
self.assertTrue(rasp.is_categorical(rasp.categorical(sop)))
self.assertFalse(rasp.is_categorical(rasp.numerical(sop)))
@parameterized.named_parameters(*_SOP_EXAMPLES())
def test_double_encoding_annotations_overwrites_encoding(self, sop: rasp.SOp):
num_sop = rasp.numerical(sop)
cat_num_sop = rasp.categorical(num_sop)
self.assertTrue(rasp.is_numerical(num_sop))
self.assertTrue(rasp.is_categorical(cat_num_sop))
class SelectorTest(parameterized.TestCase):
"""Tests for Selectors."""
def test_select_eq_has_correct_value(self):
selector = rasp.Select(rasp.tokens, rasp.tokens, rasp.Comparison.EQ)
self.assertEqual(
selector("hey"), [
[True, False, False],
[False, True, False],
[False, False, True],
])
def test_select_lt_has_correct_value(self):
selector = rasp.Select(rasp.tokens, rasp.tokens, rasp.Comparison.LT)
self.assertEqual(selector([0, 1]), [
[False, False],
[True, False],
])
def test_select_leq_has_correct_value(self):
selector = rasp.Select(rasp.tokens, rasp.tokens, rasp.Comparison.LEQ)
self.assertEqual(selector([0, 1]), [
[True, False],
[True, True],
])
def test_select_gt_has_correct_value(self):
selector = rasp.Select(rasp.tokens, rasp.tokens, rasp.Comparison.GT)
self.assertEqual(selector([0, 1]), [
[False, True],
[False, False],
])
def test_select_geq_has_correct_value(self):
selector = rasp.Select(rasp.tokens, rasp.tokens, rasp.Comparison.GEQ)
self.assertEqual(selector([0, 1]), [
[True, True],
[False, True],
])
def test_select_neq_has_correct_value(self):
selector = rasp.Select(rasp.tokens, rasp.tokens, rasp.Comparison.NEQ)
self.assertEqual(selector([0, 1]), [
[False, True],
[True, False],
])
def test_select_true_has_correct_value(self):
selector = rasp.Select(rasp.tokens, rasp.tokens, rasp.Comparison.TRUE)
self.assertEqual(selector([0, 1]), [
[True, True],
[True, True],
])
def test_select_false_has_correct_value(self):
selector = rasp.Select(rasp.tokens, rasp.tokens, rasp.Comparison.FALSE)
self.assertEqual(selector([0, 1]), [
[False, False],
[False, False],
])
def test_selector_and_gets_simplified_when_keys_and_queries_match(self):
selector = rasp.selector_and(
rasp.Select(rasp.tokens, rasp.indices, rasp.Comparison.GEQ),
rasp.Select(rasp.tokens, rasp.indices, rasp.Comparison.LEQ),
)
self.assertIsInstance(selector, rasp.Select)
self.assertIs(selector.keys, rasp.tokens)
self.assertIs(selector.queries, rasp.indices)
def test_selector_and_doesnt_get_simplified_when_keys_queries_different(self):
selector = rasp.selector_and(
rasp.Select(rasp.tokens, rasp.indices, rasp.Comparison.GEQ),
rasp.Select(rasp.indices, rasp.tokens, rasp.Comparison.LEQ),
)
self.assertIsInstance(selector, rasp.SelectorAnd)
def test_selector_and_gets_simplified_when_keys_are_full(self):
selector = rasp.selector_and(
rasp.Select(rasp.Full(1), rasp.indices, rasp.Comparison.GEQ),
rasp.Select(rasp.tokens, rasp.indices, rasp.Comparison.LEQ),
)
self.assertIsInstance(selector, rasp.Select)
self.assertIs(selector.keys, rasp.tokens)
self.assertIs(selector.queries, rasp.indices)
def test_selector_and_gets_simplified_when_queries_are_full(self):
selector = rasp.selector_and(
rasp.Select(rasp.tokens, rasp.indices, rasp.Comparison.GEQ),
rasp.Select(rasp.tokens, rasp.Full(1), rasp.Comparison.LEQ),
)
self.assertIsInstance(selector, rasp.Select)
self.assertIs(selector.keys, rasp.tokens)
self.assertIs(selector.queries, rasp.indices)
@parameterized.parameters(
itertools.product(
(rasp.tokens, rasp.indices, rasp.Full(1)),
(rasp.tokens, rasp.indices, rasp.Full(1)),
list(rasp.Comparison),
(rasp.tokens, rasp.indices, rasp.Full(1)),
(rasp.tokens, rasp.indices, rasp.Full(1)),
list(rasp.Comparison),
))
def test_simplified_selector_and_works_the_same_way_as_not(
self, fst_k, fst_q, fst_p, snd_k, snd_q, snd_p):
fst = rasp.Select(fst_k, fst_q, fst_p)
snd = rasp.Select(snd_k, snd_q, snd_p)
simplified = rasp.selector_and(fst, snd)([0, 1, 2, 3])
not_simplified = rasp.selector_and(fst, snd, simplify=False)([0, 1, 2, 3])
np.testing.assert_array_equal(
np.array(simplified),
np.array(not_simplified),
)
def test_select_is_selector(self):
self.assertIsInstance(
rasp.Select(rasp.tokens, rasp.tokens, rasp.Comparison.EQ),
rasp.Selector,
)
def test_select_is_raspexpr(self):
self.assertIsInstance(
rasp.Select(rasp.tokens, rasp.tokens, rasp.Comparison.EQ),
rasp.RASPExpr,
)
def test_constant_selector(self):
self.assertEqual(
rasp.ConstantSelector([[True, True], [False, False]])([1, 2]),
[[True, True], [False, False]],
)
class CopyTest(parameterized.TestCase):
@parameterized.named_parameters(*_ALL_EXAMPLES())
def test_copy_preserves_name(self, expr: rasp.RASPExpr):
expr = expr.named("foo")
self.assertEqual(expr.copy().name, expr.name)
@parameterized.named_parameters(*_ALL_EXAMPLES())
def test_renaming_copy_doesnt_rename_original(self, expr: rasp.RASPExpr):
expr = expr.named("foo")
expr.copy().named("bar")
self.assertEqual(expr.name, "foo")
@parameterized.named_parameters(*_ALL_EXAMPLES())
def test_renaming_original_doesnt_rename_copy(self, expr: rasp.RASPExpr):
expr = expr.named("foo")
copy = expr.copy()
expr.named("bar")
self.assertEqual(copy.name, "foo")
@parameterized.named_parameters(*_ALL_EXAMPLES())
def test_copy_changes_id(self, expr: rasp.RASPExpr):
self.assertNotEqual(expr.copy().unique_id, expr.unique_id)
@parameterized.named_parameters(*_ALL_EXAMPLES())
def test_copy_preserves_child_ids(self, expr: rasp.RASPExpr):
copy_child_ids = [c.unique_id for c in expr.copy().children]
child_ids = [c.unique_id for c in expr.children]
for child_id, copy_child_id in zip(child_ids, copy_child_ids):
self.assertEqual(child_id, copy_child_id)
class AggregateTest(parameterized.TestCase):
"""Tests for Aggregate."""
@parameterized.parameters(
dict(
selector=rasp.ConstantSelector([
[True, False],
[False, True],
]),
sop=rasp.ConstantSOp(["h", "e"]),
default=None,
expected_value=["h", "e"],
),
dict(
selector=rasp.ConstantSelector([
[False, True],
[False, False],
]),
sop=rasp.ConstantSOp(["h", "e"]),
default=None,
expected_value=["e", None],
),
dict(
selector=rasp.ConstantSelector([
[True, False],
[False, False],
]),
sop=rasp.ConstantSOp(["h", "e"]),
default=None,
expected_value=["h", None],
),
dict(
selector=rasp.ConstantSelector([
[True, True],
[False, True],
]),
sop=rasp.ConstantSOp([0, 1]),
default=0,
expected_value=[0.5, 1],
),
dict(
selector=rasp.ConstantSelector([
[False, False],
[True, True],
]),
sop=rasp.ConstantSOp([0, 1]),
default=0,
expected_value=[0, 0.5],
),
dict(
selector=rasp.ConstantSelector([
[False, False],
[True, True],
]),
sop=rasp.ConstantSOp([0, 1]),
default=None,
expected_value=[None, 0.5],
),
)
def test_aggregate_on_size_2_inputs(self, selector, sop, default,
expected_value):
# The 0, 0 input is ignored as it's overridden by the constant SOps.
self.assertEqual(
rasp.Aggregate(selector, sop, default)([0, 0]),
expected_value,
)
class RaspProgramTest(parameterized.TestCase):
"""Each testcase implements and tests a RASP program."""
def test_has_prev(self):
def has_prev(seq: rasp.SOp) -> rasp.SOp:
prev_copy = rasp.SelectorAnd(
rasp.Select(seq, seq, rasp.Comparison.EQ),
rasp.Select(rasp.indices, rasp.indices, rasp.Comparison.LT),
)
return rasp.Aggregate(prev_copy, rasp.Full(1), default=0) > 0
self.assertEqual(
has_prev(rasp.tokens)("hello"),
[0, 0, 0, 1, 0],
)
self.assertEqual(
has_prev(rasp.tokens)("helllo"),
[0, 0, 0, 1, 1, 0],
)
self.assertEqual(
has_prev(rasp.tokens)([0, 2, 3, 2, 1, 0, 2]),
[0, 0, 0, 1, 0, 1, 1],
)
if __name__ == "__main__":
absltest.main()