RASP-Synthesis / tracr /compiler /craft_graph_to_model.py
mrahtz's picture
Make sure linter doesn't complain about import order
95943b4 unverified
# Copyright 2022 DeepMind Technologies Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Create a craft model from a computational graph."""
import collections
from typing import Dict, List, Sequence
import networkx as nx
from tracr.compiler import nodes
from tracr.craft import bases
from tracr.craft import transformers
from tracr.rasp import rasp
Node = nodes.Node
NodeID = nodes.NodeID
def _get_longest_path_length_to_node(graph: nx.DiGraph, sources: Sequence[Node],
node: Node) -> int:
"""Returns the lengths of the longest path from sources to node.
Only SOps count towards the length of a path.
Args:
graph: DAG to compute longest path in.
sources: List of starting nodes, longest path will be a maximum over all.
node: Target node.
Returns:
Number of steps needed for the longest path from the source to the node, or
-1 if there is no path from any of the sources to the target node.
"""
if node in sources:
return 0
def num_sops(path: Sequence[NodeID]) -> int:
num = 0
for node_id in path:
if isinstance(graph.nodes[node_id][nodes.EXPR], rasp.SOp):
num += 1
return num
result = -1
for source in sources:
all_paths = nx.all_simple_paths(graph, source[nodes.ID], node[nodes.ID])
longest_path_len = max(map(num_sops, all_paths), default=-1) - 1
if longest_path_len > result:
result = longest_path_len
return result
def _node_is_attn(node: Node) -> bool:
"""Returns True if node is an attention layer."""
return nodes.MODEL_BLOCK in node and isinstance(
node[nodes.MODEL_BLOCK],
(transformers.AttentionHead, transformers.MultiAttentionHead))
def _node_is_mlp(node: Node) -> bool:
"""Returns True if node is an MLP layer."""
return nodes.MODEL_BLOCK in node and isinstance(node[nodes.MODEL_BLOCK],
transformers.MLP)
def _node_is_residual_block(node: Node) -> bool:
"""Returns True if node is a valid residual block (Attn followed by MLP)."""
block = node[nodes.MODEL_BLOCK] if nodes.MODEL_BLOCK in node else None
if block and isinstance(block, transformers.SeriesWithResiduals):
if len(block.blocks) == 2:
attn, mlp = block.blocks
if (isinstance(
attn,
(transformers.AttentionHead, transformers.MultiAttentionHead)) and
isinstance(mlp, transformers.MLP)):
return True
return False
def _all_attn_nodes(node_list: Sequence[Node]) -> bool:
"""Returns True iff all nodes are attention layers (or nodes is empty)."""
for node in node_list:
if not _node_is_attn(node):
return False
return True
def _all_mlp_nodes(node_list: Sequence[Node]) -> bool:
"""Returns True iff all nodes are MLP layers (or nodes is empty)."""
for node in node_list:
if not _node_is_mlp(node):
return False
return True
def _allocate_modules_to_layers(graph: nx.DiGraph,
sources: Sequence[Node]) -> Dict[int, int]:
"""Allocate all nodes in compute graph to layers.
First, computes the longest path from the input to each node that is a model
component (not input and output nodes). The longest path to a model component
(its "depth") determines a layer in which we can place it while ensuring that
all necessary previous computations have already happened.
This assumes layers are arranged as [Attention, MLP, Attention, MLP, ...]
In the special case where there are only Attention layers at one depth level
and only MLP layers in the next depth layer, they are treated as if there
are at the same depth because attention layers always come before MLP layers
for the same depth.
Args:
graph: RASP graph with craft blocks.
sources: List of input nodes
Returns:
A dict mapping from node ids to layer indices, where 0, 1, 2, 3, ...
are in the order attention, mlp, attention, mlp, ...
"""
layer_allocation: Dict[int, int] = collections.defaultdict(lambda: -1)
depth_by_node_id: Dict[int, int] = dict()
nodes_by_depth: Dict[int, List[Node]] = collections.defaultdict(list)
# Compute depth of all model components (longest path from source to node)
for node_id, node in graph.nodes.items():
if (_node_is_mlp(node) or _node_is_attn(node)
or _node_is_residual_block(node)):
# Node is a model component
longest_path_len = _get_longest_path_length_to_node(graph, sources, node)
depth_by_node_id[node_id] = longest_path_len
nodes_by_depth[longest_path_len].append(node)
# If at level `depth` there are only attention heads and at level `depths + 1`
# there are only MLPs, we can condense them into one level
# TODO(b/255936816): Think about improving this heuristic. The heuristic is
# not optimal, and only catches very basic opportunities for optimization. It
# is easy to come up with opportunities for optimization that it does not
# catch.
min_depth, max_depth = min(nodes_by_depth.keys()), max(nodes_by_depth.keys())
depth = min_depth
while depth < max_depth:
if _all_attn_nodes(nodes_by_depth[depth]) and _all_mlp_nodes(
nodes_by_depth[depth + 1]):
# Condense by decrementing the depth of all nodes starting from depth+1
for update_depth in range(depth + 1, max_depth + 1):
for node in nodes_by_depth[update_depth]:
node_id = node[nodes.ID]
depth_by_node_id[node_id] = update_depth - 1
nodes_by_depth[update_depth - 1].extend(nodes_by_depth[update_depth])
nodes_by_depth[update_depth] = []
max_depth -= 1
depth += 1
# Allocate nodes to layers by depth, ensuring attn -> mlp -> attn -> mlp ...
current_layer = 0
current_depth = 1
for node_id, depth in sorted(depth_by_node_id.items(), key=lambda x: x[1]):
while depth > current_depth:
current_depth += 1
current_layer += 2
if depth == current_depth:
if _node_is_residual_block(graph.nodes[node_id]):
layer_allocation[node_id] = current_layer
else:
is_mlp = _node_is_mlp(graph.nodes[node_id])
layer_allocation[node_id] = current_layer + int(is_mlp)
return layer_allocation
def craft_graph_to_model(
graph: nx.DiGraph,
sources: Sequence[Node]) -> transformers.SeriesWithResiduals:
"""Translates a RASP graph with craft blocks into a full craft model.
1. Allocate modules to layers, assuming layers in the order
2. Creates subspaces for all inputs and outputs, and builds residual stream.
3. Assembles everything into a craft model and returns it.
Args:
graph: RASP graph with craft blocks.
sources: List of input nodes
Returns:
A craft model that can be compiled to model weights.
Raises:
ValueError: On invalid input (if the craft_graph does not have craft blocks
already specified)
"""
layer_allocation = _allocate_modules_to_layers(graph, sources)
blocks_by_layer = collections.defaultdict(list)
model_blocks = []
residual_space = bases.VectorSpaceWithBasis([])
for node_id, layer_no in layer_allocation.items():
node = graph.nodes[node_id]
block = node[nodes.MODEL_BLOCK] if nodes.MODEL_BLOCK in node else None
if _node_is_residual_block(node):
assert isinstance(block, transformers.SeriesWithResiduals)
assert len(block.blocks) == 2
residual_space = bases.join_vector_spaces(residual_space,
block.blocks[0].residual_space,
block.blocks[1].residual_space)
blocks_by_layer[layer_no].append(block.blocks[0])
blocks_by_layer[layer_no + 1].append(block.blocks[1])
elif block:
residual_space = bases.join_vector_spaces(
residual_space, node[nodes.MODEL_BLOCK].residual_space)
blocks_by_layer[layer_no].append(block)
for layer_no, layer_blocks in sorted(
blocks_by_layer.items(), key=lambda x: x[0]):
for block in layer_blocks:
block.residual_space = residual_space
if layer_blocks:
if layer_no % 2 == 0: # Attention Layer
multi_head_attn = transformers.MultiAttentionHead(layer_blocks)
model_blocks.append(multi_head_attn)
else: # MLP Layer
parallel_mlp = transformers.MLP.combine_in_parallel(layer_blocks)
model_blocks.append(parallel_mlp)
return transformers.SeriesWithResiduals(model_blocks)