Spaces:
Sleeping
Sleeping
# Copyright 2022 DeepMind Technologies Limited. All Rights Reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
# ============================================================================== | |
"""Tests for rasp.rasp.""" | |
import itertools | |
from absl.testing import absltest | |
from absl.testing import parameterized | |
import numpy as np | |
from tracr.rasp import rasp | |
# Note that the example text labels must match their default names. | |
_SOP_PRIMITIVE_EXAMPLES = lambda: [ # pylint: disable=g-long-lambda | |
("tokens", rasp.tokens), | |
("length", rasp.length), | |
("indices", rasp.indices), | |
] | |
_NONPRIMITIVE_SOP_EXAMPLES = lambda: [ # pylint: disable=g-long-lambda | |
("map", rasp.Map(lambda x: x, rasp.tokens)), | |
( | |
"sequence_map", | |
rasp.SequenceMap(lambda x, y: x + y, rasp.tokens, rasp.tokens), | |
), | |
( | |
"linear_sequence_map", | |
rasp.LinearSequenceMap(rasp.tokens, rasp.tokens, 0.1, 0.2), | |
), | |
( | |
"aggregate", | |
rasp.Aggregate( | |
rasp.Select(rasp.tokens, rasp.indices, rasp.Comparison.EQ), | |
rasp.tokens, | |
), | |
), | |
( | |
"selector_width", | |
rasp.SelectorWidth( | |
rasp.Select(rasp.tokens, rasp.indices, rasp.Comparison.EQ)), | |
), | |
] | |
_SOP_EXAMPLES = lambda: _SOP_PRIMITIVE_EXAMPLES() + _NONPRIMITIVE_SOP_EXAMPLES() | |
_SELECTOR_EXAMPLES = lambda: [ # pylint: disable=g-long-lambda | |
("select", rasp.Select(rasp.tokens, rasp.indices, rasp.Comparison.EQ)), | |
("selector_and", | |
rasp.SelectorAnd( | |
rasp.Select(rasp.tokens, rasp.indices, rasp.Comparison.EQ), | |
rasp.Select(rasp.indices, rasp.tokens, rasp.Comparison.LEQ), | |
)), | |
("selector_or", | |
rasp.SelectorOr( | |
rasp.Select(rasp.tokens, rasp.indices, rasp.Comparison.EQ), | |
rasp.Select(rasp.indices, rasp.tokens, rasp.Comparison.LEQ), | |
)), | |
("selector_not", | |
rasp.SelectorNot( | |
rasp.Select(rasp.tokens, rasp.indices, rasp.Comparison.EQ),)), | |
] | |
_ALL_EXAMPLES = lambda: _SOP_EXAMPLES() + _SELECTOR_EXAMPLES() | |
class LabelTest(parameterized.TestCase): | |
def test_primitive_labels(self): | |
self.assertEqual(rasp.tokens.label, "tokens") | |
self.assertEqual(rasp.indices.label, "indices") | |
self.assertEqual(rasp.length.label, "length") | |
def test_default_names(self, default_name: str, expr: rasp.RASPExpr): | |
self.assertEqual(expr.name, default_name) | |
class SOpTest(parameterized.TestCase): | |
"""Tests for S-Ops.""" | |
def test_tokens(self, input_sequence, expected): | |
self.assertEqual(rasp.tokens(input_sequence), expected) | |
def test_indices(self, input_sequence, expected): | |
self.assertEqual(rasp.indices(input_sequence), expected) | |
def test_length(self, input_sequence, expected): | |
self.assertEqual(rasp.length(input_sequence), expected) | |
def test_prims_are_sops(self): | |
self.assertIsInstance(rasp.tokens, rasp.SOp) | |
self.assertIsInstance(rasp.indices, rasp.SOp) | |
self.assertIsInstance(rasp.length, rasp.SOp) | |
def test_prims_are_raspexprs(self): | |
self.assertIsInstance(rasp.tokens, rasp.RASPExpr) | |
self.assertIsInstance(rasp.indices, rasp.RASPExpr) | |
self.assertIsInstance(rasp.length, rasp.RASPExpr) | |
def test_map(self, f, input_sequence, expected): | |
self.assertEqual(rasp.Map(f, rasp.tokens)(input_sequence), expected) | |
def test_nested_elementwise_ops_results_in_only_one_map_object(self): | |
map_sop = ((rasp.tokens * 2) + 2) / 2 | |
self.assertEqual(map_sop.inner, rasp.tokens) | |
self.assertEqual(map_sop([1]), [2]) | |
def test_sequence_map(self, f, input_sequence, expected): | |
self.assertEqual( | |
rasp.SequenceMap(f, rasp.tokens, rasp.tokens)(input_sequence), expected) | |
def test_sequence_map_with_same_inputs_logs_warning(self): | |
with self.assertLogs(level="WARNING"): | |
rasp.SequenceMap(lambda x, y: x + y, rasp.tokens, rasp.tokens) | |
def test_linear_sequence_map(self, fst_fac, snd_fac, input_sequence, | |
expected): | |
self.assertEqual( | |
rasp.LinearSequenceMap(rasp.tokens, rasp.tokens, fst_fac, | |
snd_fac)(input_sequence), expected) | |
def test_constant(self, const, input_sequence, expected): | |
self.assertEqual(rasp.ConstantSOp(const)(input_sequence), expected) | |
def test_constant_complains_if_sizes_dont_match(self): | |
with self.assertRaisesRegex( | |
ValueError, | |
r"^.*Constant len .* doesn't match input len .*$",): | |
rasp.ConstantSOp([1, 2, 3])("longer string") | |
def test_can_turn_off_constant_complaints(self): | |
rasp.ConstantSOp([1, 2, 3], check_length=False)("longer string") | |
def test_numeric_dunders(self): | |
# We don't check all the cases here -- only a few representative ones. | |
self.assertEqual( | |
(rasp.tokens > 1)([0, 1, 2]), | |
[0, 0, 1], | |
) | |
self.assertEqual( | |
(1 < rasp.tokens)([0, 1, 2]), | |
[0, 0, 1], | |
) | |
self.assertEqual( | |
(rasp.tokens < 1)([0, 1, 2]), | |
[1, 0, 0], | |
) | |
self.assertEqual( | |
(1 > rasp.tokens)([0, 1, 2]), | |
[1, 0, 0], | |
) | |
self.assertEqual( | |
(rasp.tokens == 1)([0, 1, 2]), | |
[0, 1, 0], | |
) | |
self.assertEqual( | |
(rasp.tokens + 1)([0, 1, 2]), | |
[1, 2, 3], | |
) | |
self.assertEqual( | |
(1 + rasp.tokens)([0, 1, 2]), | |
[1, 2, 3], | |
) | |
def test_dunders_with_sop(self): | |
self.assertEqual( | |
(rasp.tokens + rasp.indices)([0, 1, 2]), | |
[0, 2, 4], | |
) | |
self.assertEqual( | |
(rasp.length - 1 - rasp.indices)([0, 1, 2]), | |
[2, 1, 0], | |
) | |
self.assertEqual( | |
(rasp.length * rasp.length)([0, 1, 2]), | |
[9, 9, 9], | |
) | |
def test_logical_dunders(self): | |
self.assertEqual( | |
(rasp.tokens & True)([True, False]), | |
[True, False], | |
) | |
self.assertEqual( | |
(rasp.tokens & False)([True, False]), | |
[False, False], | |
) | |
self.assertEqual( | |
(rasp.tokens | True)([True, False]), | |
[True, True], | |
) | |
self.assertEqual( | |
(rasp.tokens | False)([True, False]), | |
[True, False], | |
) | |
self.assertEqual( | |
(True & rasp.tokens)([True, False]), | |
[True, False], | |
) | |
self.assertEqual( | |
(False & rasp.tokens)([True, False]), | |
[False, False], | |
) | |
self.assertEqual( | |
(True | rasp.tokens)([True, False]), | |
[True, True], | |
) | |
self.assertEqual( | |
(False | rasp.tokens)([True, False]), | |
[True, False], | |
) | |
self.assertEqual( | |
(~rasp.tokens)([True, False]), | |
[False, True], | |
) | |
self.assertEqual( | |
(rasp.ConstantSOp([True, True, False, False]) | |
& rasp.ConstantSOp([True, False, True, False]))([1, 1, 1, 1]), | |
[True, False, False, False], | |
) | |
self.assertEqual( | |
(rasp.ConstantSOp([True, True, False, False]) | |
| rasp.ConstantSOp([True, False, True, False]))([1, 1, 1, 1]), | |
[True, True, True, False], | |
) | |
class EncodingTest(parameterized.TestCase): | |
"""Tests for SOp encodings.""" | |
def test_all_sops_are_categorical_by_default(self, sop: rasp.SOp): | |
self.assertTrue(rasp.is_categorical(sop)) | |
def test_is_numerical(self, sop: rasp.SOp): | |
self.assertTrue(rasp.is_numerical(rasp.numerical(sop))) | |
self.assertFalse(rasp.is_numerical(rasp.categorical(sop))) | |
def test_is_categorical(self, sop: rasp.SOp): | |
self.assertTrue(rasp.is_categorical(rasp.categorical(sop))) | |
self.assertFalse(rasp.is_categorical(rasp.numerical(sop))) | |
def test_double_encoding_annotations_overwrites_encoding(self, sop: rasp.SOp): | |
num_sop = rasp.numerical(sop) | |
cat_num_sop = rasp.categorical(num_sop) | |
self.assertTrue(rasp.is_numerical(num_sop)) | |
self.assertTrue(rasp.is_categorical(cat_num_sop)) | |
class SelectorTest(parameterized.TestCase): | |
"""Tests for Selectors.""" | |
def test_select_eq_has_correct_value(self): | |
selector = rasp.Select(rasp.tokens, rasp.tokens, rasp.Comparison.EQ) | |
self.assertEqual( | |
selector("hey"), [ | |
[True, False, False], | |
[False, True, False], | |
[False, False, True], | |
]) | |
def test_select_lt_has_correct_value(self): | |
selector = rasp.Select(rasp.tokens, rasp.tokens, rasp.Comparison.LT) | |
self.assertEqual(selector([0, 1]), [ | |
[False, False], | |
[True, False], | |
]) | |
def test_select_leq_has_correct_value(self): | |
selector = rasp.Select(rasp.tokens, rasp.tokens, rasp.Comparison.LEQ) | |
self.assertEqual(selector([0, 1]), [ | |
[True, False], | |
[True, True], | |
]) | |
def test_select_gt_has_correct_value(self): | |
selector = rasp.Select(rasp.tokens, rasp.tokens, rasp.Comparison.GT) | |
self.assertEqual(selector([0, 1]), [ | |
[False, True], | |
[False, False], | |
]) | |
def test_select_geq_has_correct_value(self): | |
selector = rasp.Select(rasp.tokens, rasp.tokens, rasp.Comparison.GEQ) | |
self.assertEqual(selector([0, 1]), [ | |
[True, True], | |
[False, True], | |
]) | |
def test_select_neq_has_correct_value(self): | |
selector = rasp.Select(rasp.tokens, rasp.tokens, rasp.Comparison.NEQ) | |
self.assertEqual(selector([0, 1]), [ | |
[False, True], | |
[True, False], | |
]) | |
def test_select_true_has_correct_value(self): | |
selector = rasp.Select(rasp.tokens, rasp.tokens, rasp.Comparison.TRUE) | |
self.assertEqual(selector([0, 1]), [ | |
[True, True], | |
[True, True], | |
]) | |
def test_select_false_has_correct_value(self): | |
selector = rasp.Select(rasp.tokens, rasp.tokens, rasp.Comparison.FALSE) | |
self.assertEqual(selector([0, 1]), [ | |
[False, False], | |
[False, False], | |
]) | |
def test_selector_and_gets_simplified_when_keys_and_queries_match(self): | |
selector = rasp.selector_and( | |
rasp.Select(rasp.tokens, rasp.indices, rasp.Comparison.GEQ), | |
rasp.Select(rasp.tokens, rasp.indices, rasp.Comparison.LEQ), | |
) | |
self.assertIsInstance(selector, rasp.Select) | |
self.assertIs(selector.keys, rasp.tokens) | |
self.assertIs(selector.queries, rasp.indices) | |
def test_selector_and_doesnt_get_simplified_when_keys_queries_different(self): | |
selector = rasp.selector_and( | |
rasp.Select(rasp.tokens, rasp.indices, rasp.Comparison.GEQ), | |
rasp.Select(rasp.indices, rasp.tokens, rasp.Comparison.LEQ), | |
) | |
self.assertIsInstance(selector, rasp.SelectorAnd) | |
def test_selector_and_gets_simplified_when_keys_are_full(self): | |
selector = rasp.selector_and( | |
rasp.Select(rasp.Full(1), rasp.indices, rasp.Comparison.GEQ), | |
rasp.Select(rasp.tokens, rasp.indices, rasp.Comparison.LEQ), | |
) | |
self.assertIsInstance(selector, rasp.Select) | |
self.assertIs(selector.keys, rasp.tokens) | |
self.assertIs(selector.queries, rasp.indices) | |
def test_selector_and_gets_simplified_when_queries_are_full(self): | |
selector = rasp.selector_and( | |
rasp.Select(rasp.tokens, rasp.indices, rasp.Comparison.GEQ), | |
rasp.Select(rasp.tokens, rasp.Full(1), rasp.Comparison.LEQ), | |
) | |
self.assertIsInstance(selector, rasp.Select) | |
self.assertIs(selector.keys, rasp.tokens) | |
self.assertIs(selector.queries, rasp.indices) | |
def test_simplified_selector_and_works_the_same_way_as_not( | |
self, fst_k, fst_q, fst_p, snd_k, snd_q, snd_p): | |
fst = rasp.Select(fst_k, fst_q, fst_p) | |
snd = rasp.Select(snd_k, snd_q, snd_p) | |
simplified = rasp.selector_and(fst, snd)([0, 1, 2, 3]) | |
not_simplified = rasp.selector_and(fst, snd, simplify=False)([0, 1, 2, 3]) | |
np.testing.assert_array_equal( | |
np.array(simplified), | |
np.array(not_simplified), | |
) | |
def test_select_is_selector(self): | |
self.assertIsInstance( | |
rasp.Select(rasp.tokens, rasp.tokens, rasp.Comparison.EQ), | |
rasp.Selector, | |
) | |
def test_select_is_raspexpr(self): | |
self.assertIsInstance( | |
rasp.Select(rasp.tokens, rasp.tokens, rasp.Comparison.EQ), | |
rasp.RASPExpr, | |
) | |
def test_constant_selector(self): | |
self.assertEqual( | |
rasp.ConstantSelector([[True, True], [False, False]])([1, 2]), | |
[[True, True], [False, False]], | |
) | |
class CopyTest(parameterized.TestCase): | |
def test_copy_preserves_name(self, expr: rasp.RASPExpr): | |
expr = expr.named("foo") | |
self.assertEqual(expr.copy().name, expr.name) | |
def test_renaming_copy_doesnt_rename_original(self, expr: rasp.RASPExpr): | |
expr = expr.named("foo") | |
expr.copy().named("bar") | |
self.assertEqual(expr.name, "foo") | |
def test_renaming_original_doesnt_rename_copy(self, expr: rasp.RASPExpr): | |
expr = expr.named("foo") | |
copy = expr.copy() | |
expr.named("bar") | |
self.assertEqual(copy.name, "foo") | |
def test_copy_changes_id(self, expr: rasp.RASPExpr): | |
self.assertNotEqual(expr.copy().unique_id, expr.unique_id) | |
def test_copy_preserves_child_ids(self, expr: rasp.RASPExpr): | |
copy_child_ids = [c.unique_id for c in expr.copy().children] | |
child_ids = [c.unique_id for c in expr.children] | |
for child_id, copy_child_id in zip(child_ids, copy_child_ids): | |
self.assertEqual(child_id, copy_child_id) | |
class AggregateTest(parameterized.TestCase): | |
"""Tests for Aggregate.""" | |
def test_aggregate_on_size_2_inputs(self, selector, sop, default, | |
expected_value): | |
# The 0, 0 input is ignored as it's overridden by the constant SOps. | |
self.assertEqual( | |
rasp.Aggregate(selector, sop, default)([0, 0]), | |
expected_value, | |
) | |
class RaspProgramTest(parameterized.TestCase): | |
"""Each testcase implements and tests a RASP program.""" | |
def test_has_prev(self): | |
def has_prev(seq: rasp.SOp) -> rasp.SOp: | |
prev_copy = rasp.SelectorAnd( | |
rasp.Select(seq, seq, rasp.Comparison.EQ), | |
rasp.Select(rasp.indices, rasp.indices, rasp.Comparison.LT), | |
) | |
return rasp.Aggregate(prev_copy, rasp.Full(1), default=0) > 0 | |
self.assertEqual( | |
has_prev(rasp.tokens)("hello"), | |
[0, 0, 0, 1, 0], | |
) | |
self.assertEqual( | |
has_prev(rasp.tokens)("helllo"), | |
[0, 0, 0, 1, 1, 0], | |
) | |
self.assertEqual( | |
has_prev(rasp.tokens)([0, 2, 3, 2, 1, 0, 2]), | |
[0, 0, 0, 1, 0, 1, 1], | |
) | |
if __name__ == "__main__": | |
absltest.main() | |