RASP-Synthesis / tracr /compiler /craft_graph_to_model_test.py
Vladimir Mikulik
add typing_extensions to list of deps.
d4d39d0
raw
history blame
6.71 kB
# Copyright 2022 DeepMind Technologies Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for compiler.craft_graph_to_model."""
from absl.testing import absltest
from absl.testing import parameterized
import networkx as nx
from tracr.compiler import craft_graph_to_model
from tracr.compiler import nodes
from tracr.compiler import rasp_to_graph
from tracr.craft import bases
from tracr.craft.chamber import categorical_attn
from tracr.craft.chamber import categorical_mlp
from tracr.rasp import rasp
class CraftAllocateModulesToLayersTest(parameterized.TestCase):
def _get_dummy_block(self, block_type):
if block_type == "ATTN":
return categorical_attn.categorical_attn(
query_space=bases.VectorSpaceWithBasis.from_names(["query"]),
key_space=bases.VectorSpaceWithBasis.from_names(["bos", "key"]),
value_space=bases.VectorSpaceWithBasis.from_names(["bos", "value"]),
output_space=bases.VectorSpaceWithBasis.from_names(["output"]),
bos_space=bases.VectorSpaceWithBasis.from_names(["bos"]),
one_space=bases.VectorSpaceWithBasis.from_names(["one"]),
attn_fn=lambda x, y: True,
)
elif block_type == "MLP":
return categorical_mlp.map_categorical_mlp(
input_space=bases.VectorSpaceWithBasis.from_names(["input"]),
output_space=bases.VectorSpaceWithBasis.from_names(["output"]),
operation=lambda x: x,
)
else:
return None
def test_get_longest_path_length_to_node_returns_expected_result(self):
"""Creates a graph and checks the longest path for each node."""
# Node IDs:
# 0 -- 1 -- 2 -- 3 ------------ 4
# / /
# 5 -- 6 ---------- 7 -- 8 -- 9
#
# 10
# Expected return values:
# 0 -- 1 -- 2 -- 3 ------------ 5
# / /
# 0 -- 1 ---------- 2 -- 3 -- 4
#
# -1
graph = nx.DiGraph()
node_ids = list(range(11))
expected_results = [0, 1, 2, 3, 5, 0, 1, 2, 3, 4, -1]
for node_id, res in zip(node_ids, expected_results):
graph.add_node(
node_id, **{
nodes.ID: node_id,
nodes.EXPR: rasp.ConstantSOp(1),
"expected_result": res
})
graph.add_edge(0, 1)
graph.add_edge(1, 2)
graph.add_edge(2, 3)
graph.add_edge(3, 4)
graph.add_edge(5, 6)
graph.add_edge(6, 7)
graph.add_edge(7, 8)
graph.add_edge(8, 9)
graph.add_edge(6, 3)
graph.add_edge(9, 4)
sources = [graph.nodes[0], graph.nodes[5]]
for node_id, node in graph.nodes.items():
result = craft_graph_to_model._get_longest_path_length_to_node(
graph, sources, node)
self.assertEqual(result, node["expected_result"])
def test_allocate_modules_to_layers_returns_expected_result(self):
"""Creates a graph and checks if the correct layer assignment is returned."""
# Computation Graph:
# INPUT -- ATTN -- MLP -- ATTN ------ MLP -- OUTPUT
# / / /
# INPUT -- MLP --- MLP ATTN
# \ /
# ATTN
# Node IDs:
# 0 -- 1 -- 2 -- 3 -- 4 -- 5
# / / /
# 6 -- 7 ---- 8 9
# \ /
# 10
# Expected layer allocation:
# -1 -- 0 -- 3 -- 4 -- 7 -- -1
# / / /
# -1 -- 1 --- 3 6
# \ /
# 4
graph = nx.DiGraph()
node_ids = list(range(11))
types = [
"INPUT", "ATTN", "MLP", "ATTN", "MLP", "OUTPUT", "INPUT", "MLP", "MLP",
"ATTN", "ATTN"
]
expected_results = [-1, 0, 3, 4, 7, -1, -1, 1, 3, 6, 4]
for node_id, node_type, res in zip(node_ids, types, expected_results):
graph.add_node(
node_id, **{
nodes.ID: node_id,
nodes.EXPR: rasp.ConstantSOp(1),
nodes.MODEL_BLOCK: self._get_dummy_block(node_type),
"expected_result": res
})
graph.add_edge(0, 1)
graph.add_edge(1, 2)
graph.add_edge(2, 3)
graph.add_edge(3, 4)
graph.add_edge(4, 5)
graph.add_edge(6, 7)
graph.add_edge(7, 2)
graph.add_edge(7, 8)
graph.add_edge(8, 3)
graph.add_edge(8, 10)
graph.add_edge(9, 4)
graph.add_edge(10, 9)
craft_graph = rasp_to_graph.ExtractRaspGraphOutput(
graph=graph,
sink=graph.nodes[10],
sources=[graph.nodes[0], graph.nodes[6]])
layer_allocation = craft_graph_to_model._allocate_modules_to_layers(
craft_graph.graph, craft_graph.sources)
for node_id, node in graph.nodes.items():
self.assertEqual(layer_allocation[node_id], node["expected_result"])
def test_allocate_modules_to_layers_returns_expected_result_for_chain(self):
"""Tests a chain of alternating attention layers and MLPs."""
# Computation Graph:
# INPUT -- ATTN -- MLP -- ATTN -- MLP -- OUTPUT
# Node IDs:
# 0 -- 1 -- 2 -- 3 -- 4 -- 5
# Expected layer allocation:
# -1 -- 0 -- 1 -- 2 -- 3 -- -1
graph = nx.DiGraph()
node_ids = list(range(11))
types = ["INPUT", "ATTN", "MLP", "ATTN", "MLP", "OUTPUT"]
expected_results = [-1, 0, 1, 2, 3, -1]
for node_id, node_type, res in zip(node_ids, types, expected_results):
graph.add_node(
node_id, **{
nodes.ID: node_id,
nodes.EXPR: rasp.ConstantSOp(1),
nodes.MODEL_BLOCK: self._get_dummy_block(node_type),
"expected_result": res
})
graph.add_edge(0, 1)
graph.add_edge(1, 2)
graph.add_edge(2, 3)
graph.add_edge(3, 4)
graph.add_edge(4, 5)
craft_graph = rasp_to_graph.ExtractRaspGraphOutput(
graph=graph, sink=graph.nodes[5], sources=[graph.nodes[0]])
layer_allocation = craft_graph_to_model._allocate_modules_to_layers(
craft_graph.graph, craft_graph.sources)
for node_id, node in graph.nodes.items():
self.assertEqual(layer_allocation[node_id], node["expected_result"])
if __name__ == "__main__":
absltest.main()