Spaces:
Sleeping
Sleeping
# Copyright 2022 DeepMind Technologies Limited. All Rights Reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
# ============================================================================== | |
"""Create a craft model from a computational graph.""" | |
import collections | |
from typing import Dict, List, Sequence | |
import networkx as nx | |
from tracr.compiler import nodes | |
from tracr.craft import bases | |
from tracr.craft import transformers | |
from tracr.rasp import rasp | |
Node = nodes.Node | |
NodeID = nodes.NodeID | |
def _get_longest_path_length_to_node(graph: nx.DiGraph, sources: Sequence[Node], | |
node: Node) -> int: | |
"""Returns the lengths of the longest path from sources to node. | |
Only SOps count towards the length of a path. | |
Args: | |
graph: DAG to compute longest path in. | |
sources: List of starting nodes, longest path will be a maximum over all. | |
node: Target node. | |
Returns: | |
Number of steps needed for the longest path from the source to the node, or | |
-1 if there is no path from any of the sources to the target node. | |
""" | |
if node in sources: | |
return 0 | |
def num_sops(path: Sequence[NodeID]) -> int: | |
num = 0 | |
for node_id in path: | |
if isinstance(graph.nodes[node_id][nodes.EXPR], rasp.SOp): | |
num += 1 | |
return num | |
result = -1 | |
for source in sources: | |
all_paths = nx.all_simple_paths(graph, source[nodes.ID], node[nodes.ID]) | |
longest_path_len = max(map(num_sops, all_paths), default=-1) - 1 | |
if longest_path_len > result: | |
result = longest_path_len | |
return result | |
def _node_is_attn(node: Node) -> bool: | |
"""Returns True if node is an attention layer.""" | |
return nodes.MODEL_BLOCK in node and isinstance( | |
node[nodes.MODEL_BLOCK], | |
(transformers.AttentionHead, transformers.MultiAttentionHead)) | |
def _node_is_mlp(node: Node) -> bool: | |
"""Returns True if node is an MLP layer.""" | |
return nodes.MODEL_BLOCK in node and isinstance(node[nodes.MODEL_BLOCK], | |
transformers.MLP) | |
def _node_is_residual_block(node: Node) -> bool: | |
"""Returns True if node is a valid residual block (Attn followed by MLP).""" | |
block = node[nodes.MODEL_BLOCK] if nodes.MODEL_BLOCK in node else None | |
if block and isinstance(block, transformers.SeriesWithResiduals): | |
if len(block.blocks) == 2: | |
attn, mlp = block.blocks | |
if (isinstance( | |
attn, | |
(transformers.AttentionHead, transformers.MultiAttentionHead)) and | |
isinstance(mlp, transformers.MLP)): | |
return True | |
return False | |
def _all_attn_nodes(node_list: Sequence[Node]) -> bool: | |
"""Returns True iff all nodes are attention layers (or nodes is empty).""" | |
for node in node_list: | |
if not _node_is_attn(node): | |
return False | |
return True | |
def _all_mlp_nodes(node_list: Sequence[Node]) -> bool: | |
"""Returns True iff all nodes are MLP layers (or nodes is empty).""" | |
for node in node_list: | |
if not _node_is_mlp(node): | |
return False | |
return True | |
def _allocate_modules_to_layers(graph: nx.DiGraph, | |
sources: Sequence[Node]) -> Dict[int, int]: | |
"""Allocate all nodes in compute graph to layers. | |
First, computes the longest path from the input to each node that is a model | |
component (not input and output nodes). The longest path to a model component | |
(its "depth") determines a layer in which we can place it while ensuring that | |
all necessary previous computations have already happened. | |
This assumes layers are arranged as [Attention, MLP, Attention, MLP, ...] | |
In the special case where there are only Attention layers at one depth level | |
and only MLP layers in the next depth layer, they are treated as if there | |
are at the same depth because attention layers always come before MLP layers | |
for the same depth. | |
Args: | |
graph: RASP graph with craft blocks. | |
sources: List of input nodes | |
Returns: | |
A dict mapping from node ids to layer indices, where 0, 1, 2, 3, ... | |
are in the order attention, mlp, attention, mlp, ... | |
""" | |
layer_allocation: Dict[int, int] = collections.defaultdict(lambda: -1) | |
depth_by_node_id: Dict[int, int] = dict() | |
nodes_by_depth: Dict[int, List[Node]] = collections.defaultdict(list) | |
# Compute depth of all model components (longest path from source to node) | |
for node_id, node in graph.nodes.items(): | |
if (_node_is_mlp(node) or _node_is_attn(node) | |
or _node_is_residual_block(node)): | |
# Node is a model component | |
longest_path_len = _get_longest_path_length_to_node(graph, sources, node) | |
depth_by_node_id[node_id] = longest_path_len | |
nodes_by_depth[longest_path_len].append(node) | |
# If at level `depth` there are only attention heads and at level `depths + 1` | |
# there are only MLPs, we can condense them into one level | |
# TODO(b/255936816): Think about improving this heuristic. The heuristic is | |
# not optimal, and only catches very basic opportunities for optimization. It | |
# is easy to come up with opportunities for optimization that it does not | |
# catch. | |
min_depth, max_depth = min(nodes_by_depth.keys()), max(nodes_by_depth.keys()) | |
depth = min_depth | |
while depth < max_depth: | |
if _all_attn_nodes(nodes_by_depth[depth]) and _all_mlp_nodes( | |
nodes_by_depth[depth + 1]): | |
# Condense by decrementing the depth of all nodes starting from depth+1 | |
for update_depth in range(depth + 1, max_depth + 1): | |
for node in nodes_by_depth[update_depth]: | |
node_id = node[nodes.ID] | |
depth_by_node_id[node_id] = update_depth - 1 | |
nodes_by_depth[update_depth - 1].extend(nodes_by_depth[update_depth]) | |
nodes_by_depth[update_depth] = [] | |
max_depth -= 1 | |
depth += 1 | |
# Allocate nodes to layers by depth, ensuring attn -> mlp -> attn -> mlp ... | |
current_layer = 0 | |
current_depth = 1 | |
for node_id, depth in sorted(depth_by_node_id.items(), key=lambda x: x[1]): | |
while depth > current_depth: | |
current_depth += 1 | |
current_layer += 2 | |
if depth == current_depth: | |
if _node_is_residual_block(graph.nodes[node_id]): | |
layer_allocation[node_id] = current_layer | |
else: | |
is_mlp = _node_is_mlp(graph.nodes[node_id]) | |
layer_allocation[node_id] = current_layer + int(is_mlp) | |
return layer_allocation | |
def craft_graph_to_model( | |
graph: nx.DiGraph, | |
sources: Sequence[Node]) -> transformers.SeriesWithResiduals: | |
"""Translates a RASP graph with craft blocks into a full craft model. | |
1. Allocate modules to layers, assuming layers in the order | |
2. Creates subspaces for all inputs and outputs, and builds residual stream. | |
3. Assembles everything into a craft model and returns it. | |
Args: | |
graph: RASP graph with craft blocks. | |
sources: List of input nodes | |
Returns: | |
A craft model that can be compiled to model weights. | |
Raises: | |
ValueError: On invalid input (if the craft_graph does not have craft blocks | |
already specified) | |
""" | |
layer_allocation = _allocate_modules_to_layers(graph, sources) | |
blocks_by_layer = collections.defaultdict(list) | |
model_blocks = [] | |
residual_space = bases.VectorSpaceWithBasis([]) | |
for node_id, layer_no in layer_allocation.items(): | |
node = graph.nodes[node_id] | |
block = node[nodes.MODEL_BLOCK] if nodes.MODEL_BLOCK in node else None | |
if _node_is_residual_block(node): | |
assert isinstance(block, transformers.SeriesWithResiduals) | |
assert len(block.blocks) == 2 | |
residual_space = bases.join_vector_spaces(residual_space, | |
block.blocks[0].residual_space, | |
block.blocks[1].residual_space) | |
blocks_by_layer[layer_no].append(block.blocks[0]) | |
blocks_by_layer[layer_no + 1].append(block.blocks[1]) | |
elif block: | |
residual_space = bases.join_vector_spaces( | |
residual_space, node[nodes.MODEL_BLOCK].residual_space) | |
blocks_by_layer[layer_no].append(block) | |
for layer_no, layer_blocks in sorted( | |
blocks_by_layer.items(), key=lambda x: x[0]): | |
for block in layer_blocks: | |
block.residual_space = residual_space | |
if layer_blocks: | |
if layer_no % 2 == 0: # Attention Layer | |
multi_head_attn = transformers.MultiAttentionHead(layer_blocks) | |
model_blocks.append(multi_head_attn) | |
else: # MLP Layer | |
parallel_mlp = transformers.MLP.combine_in_parallel(layer_blocks) | |
model_blocks.append(parallel_mlp) | |
return transformers.SeriesWithResiduals(model_blocks) | |