Spaces:
Sleeping
Sleeping
# Copyright 2022 DeepMind Technologies Limited. All Rights Reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
# ============================================================================== | |
"""Tests for compiler.expr_to_craft_graph.""" | |
from absl.testing import absltest | |
from absl.testing import parameterized | |
from tracr.compiler import basis_inference | |
from tracr.compiler import expr_to_craft_graph | |
from tracr.compiler import lib | |
from tracr.compiler import nodes | |
from tracr.compiler import rasp_to_graph | |
from tracr.craft import bases | |
from tracr.craft import transformers | |
from tracr.rasp import rasp | |
class ExprToCraftGraphTest(parameterized.TestCase): | |
def _check_block_types_are_correct(self, graph): | |
for _, node in graph.nodes.items(): | |
expr = node[nodes.EXPR] | |
if isinstance(expr, rasp.SOp): | |
block = node[nodes.MODEL_BLOCK] | |
if isinstance(expr, (rasp.Map, rasp.SequenceMap)): | |
self.assertIsInstance(block, transformers.MLP) | |
elif isinstance(expr, rasp.Aggregate): | |
self.assertIsInstance(block, transformers.AttentionHead) | |
def _get_input_space_from_node(self, node): | |
block = node[nodes.MODEL_BLOCK] | |
if isinstance(block, transformers.MLP): | |
return block.fst.input_space | |
elif isinstance(block, transformers.AttentionHead): | |
return bases.join_vector_spaces(block.w_qk.left_space, | |
block.w_qk.right_space, | |
block.w_ov.input_space) | |
else: | |
return None | |
def _check_spaces_are_consistent(self, graph): | |
"""Check that for each edge the output is a subspace of the input.""" | |
for u, v in graph.edges: | |
u_node, v_node = graph.nodes[u], graph.nodes[v] | |
if isinstance(u_node[nodes.EXPR], rasp.SOp) and isinstance( | |
v_node[nodes.EXPR], rasp.SOp): | |
u_out_basis = u_node[nodes.OUTPUT_BASIS] | |
u_out_space = bases.VectorSpaceWithBasis(u_out_basis) | |
v_in_space = self._get_input_space_from_node(v_node) | |
self.assertTrue(u_out_space.issubspace(v_in_space)) | |
def test_compiling_rasp_programs(self, program): | |
vocab = {0, 1, 2} | |
extracted = rasp_to_graph.extract_rasp_graph(program) | |
basis_inference.infer_bases( | |
extracted.graph, | |
extracted.sink, | |
vocab, | |
max_seq_len=3, | |
) | |
expr_to_craft_graph.add_craft_components_to_rasp_graph(extracted.graph) | |
self._check_block_types_are_correct(extracted.graph) | |
self._check_spaces_are_consistent(extracted.graph) | |
def test_add_craft_components_raises_value_error_if_called_before_basis_inference( | |
self): | |
program = rasp.categorical(rasp.Map(lambda x: x + 1, rasp.tokens)) | |
extracted = rasp_to_graph.extract_rasp_graph(program) | |
with self.assertRaisesRegex( | |
ValueError, | |
r"^.*Craft components can only be added after basis inference.*$"): | |
expr_to_craft_graph.add_craft_components_to_rasp_graph(extracted.graph) | |
def test_add_craft_components_raises_value_error_if_called_twice(self): | |
vocab = {0, 1, 2} | |
program = rasp.categorical(rasp.Map(lambda x: x + 1, rasp.tokens)) | |
extracted = rasp_to_graph.extract_rasp_graph(program) | |
basis_inference.infer_bases( | |
extracted.graph, | |
extracted.sink, | |
vocab, | |
max_seq_len=1, | |
) | |
expr_to_craft_graph.add_craft_components_to_rasp_graph(extracted.graph) | |
with self.assertRaisesRegex( | |
ValueError, r"^.*Input graph cannot have model blocks set already.*$"): | |
expr_to_craft_graph.add_craft_components_to_rasp_graph(extracted.graph) | |
if __name__ == "__main__": | |
absltest.main() | |