File size: 9,011 Bytes
9bdaa77
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
# Copyright 2022 DeepMind Technologies Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Create a craft model from a computational graph."""

import collections
from typing import Sequence

import networkx as nx
from tracr.compiler import nodes
from tracr.craft import bases
from tracr.craft import transformers
from tracr.rasp import rasp

Node = nodes.Node
NodeID = nodes.NodeID


def _get_longest_path_length_to_node(graph: nx.DiGraph, sources: Sequence[Node],
                                     node: Node) -> int:
  """Returns the lengths of the longest path from sources to node.

  Only SOps count towards the length of a path.

  Args:
    graph: DAG to compute longest path in.
    sources: List of starting nodes, longest path will be a maximum over all.
    node: Target node.

  Returns:
    Number of steps needed for the longest path from the source to the node, or
    -1 if there is no path from any of the sources to the target node.
  """
  if node in sources:
    return 0

  def num_sops(path: Sequence[NodeID]) -> int:
    num = 0
    for node_id in path:
      if isinstance(graph.nodes[node_id][nodes.EXPR], rasp.SOp):
        num += 1
    return num

  result = -1
  for source in sources:
    all_paths = nx.all_simple_paths(graph, source[nodes.ID], node[nodes.ID])
    longest_path_len = max(map(num_sops, all_paths), default=-1) - 1
    if longest_path_len > result:
      result = longest_path_len
  return result


def _node_is_attn(node: Node) -> bool:
  """Returns True if node is an attention layer."""
  return nodes.MODEL_BLOCK in node and isinstance(
      node[nodes.MODEL_BLOCK],
      (transformers.AttentionHead, transformers.MultiAttentionHead))


def _node_is_mlp(node: Node) -> bool:
  """Returns True if node is an MLP layer."""
  return nodes.MODEL_BLOCK in node and isinstance(node[nodes.MODEL_BLOCK],
                                                  transformers.MLP)


def _node_is_residual_block(node: Node) -> bool:
  """Returns True if node is a valid residual block (Attn followed by MLP)."""
  block = node[nodes.MODEL_BLOCK] if nodes.MODEL_BLOCK in node else None
  if block and isinstance(block, transformers.SeriesWithResiduals):
    if len(block.blocks) == 2:
      attn, mlp = block.blocks
      if (isinstance(
          attn,
          (transformers.AttentionHead, transformers.MultiAttentionHead)) and
          isinstance(mlp, transformers.MLP)):
        return True
  return False


def _all_attn_nodes(node_list: Sequence[Node]) -> bool:
  """Returns True iff all nodes are attention layers (or nodes is empty)."""
  for node in node_list:
    if not _node_is_attn(node):
      return False
  return True


def _all_mlp_nodes(node_list: Sequence[Node]) -> bool:
  """Returns True iff all nodes are MLP layers (or nodes is empty)."""
  for node in node_list:
    if not _node_is_mlp(node):
      return False
  return True


def _allocate_modules_to_layers(graph: nx.DiGraph,
                                sources: Sequence[Node]) -> dict[int, int]:
  """Allocate all nodes in compute graph to layers.

  First, computes the longest path from the input to each node that is a model
  component (not input and output nodes). The longest path to a model component
  (its "depth") determines a layer in which we can place it while ensuring that
  all necessary previous computations have already happened.

  This assumes layers are arranged as [Attention, MLP, Attention, MLP, ...]

  In the special case where there are only Attention layers at one depth level
  and only MLP layers in the next depth layer, they are treated as if there
  are at the same depth because attention layers always come before MLP layers
  for the same depth.

  Args:
    graph: RASP graph with craft blocks.
    sources: List of input nodes

  Returns:
    A dict mapping from node ids to layer indices, where 0, 1, 2, 3, ...
    are in the order attention, mlp, attention, mlp, ...
  """
  layer_allocation: dict[int, int] = collections.defaultdict(lambda: -1)
  depth_by_node_id: dict[int, int] = dict()
  nodes_by_depth: dict[int, list[Node]] = collections.defaultdict(list)

  # Compute depth of all model components (longest path from source to node)
  for node_id, node in graph.nodes.items():
    if (_node_is_mlp(node) or _node_is_attn(node)
        or _node_is_residual_block(node)):
      # Node is a model component
      longest_path_len = _get_longest_path_length_to_node(graph, sources, node)
      depth_by_node_id[node_id] = longest_path_len
      nodes_by_depth[longest_path_len].append(node)

  # If at level `depth` there are only attention heads and at level `depths + 1`
  # there are only MLPs, we can condense them into one level
  # TODO(b/255936816): Think about improving this heuristic. The heuristic is
  # not optimal, and only catches very basic opportunities for optimization. It
  # is easy to come up with opportunities for optimization that it does not
  # catch.
  min_depth, max_depth = min(nodes_by_depth.keys()), max(nodes_by_depth.keys())
  depth = min_depth
  while depth < max_depth:
    if _all_attn_nodes(nodes_by_depth[depth]) and _all_mlp_nodes(
        nodes_by_depth[depth + 1]):
      # Condense by decrementing the depth of all nodes starting from depth+1
      for update_depth in range(depth + 1, max_depth + 1):
        for node in nodes_by_depth[update_depth]:
          node_id = node[nodes.ID]
          depth_by_node_id[node_id] = update_depth - 1
        nodes_by_depth[update_depth - 1].extend(nodes_by_depth[update_depth])
        nodes_by_depth[update_depth] = []
      max_depth -= 1
    depth += 1

  # Allocate nodes to layers by depth, ensuring attn -> mlp -> attn -> mlp ...
  current_layer = 0
  current_depth = 1
  for node_id, depth in sorted(depth_by_node_id.items(), key=lambda x: x[1]):
    while depth > current_depth:
      current_depth += 1
      current_layer += 2
    if depth == current_depth:
      if _node_is_residual_block(graph.nodes[node_id]):
        layer_allocation[node_id] = current_layer
      else:
        is_mlp = _node_is_mlp(graph.nodes[node_id])
        layer_allocation[node_id] = current_layer + int(is_mlp)

  return layer_allocation


def craft_graph_to_model(
    graph: nx.DiGraph,
    sources: Sequence[Node]) -> transformers.SeriesWithResiduals:
  """Translates a RASP graph with craft blocks into a full craft model.

  1. Allocate modules to layers, assuming layers in the order
  2. Creates subspaces for all inputs and outputs, and builds residual stream.
  3. Assembles everything into a craft model and returns it.

  Args:
    graph: RASP graph with craft blocks.
    sources: List of input nodes

  Returns:
    A craft model that can be compiled to model weights.

  Raises:
    ValueError: On invalid input (if the craft_graph does not have craft blocks
      already specified)
  """
  layer_allocation = _allocate_modules_to_layers(graph, sources)
  blocks_by_layer = collections.defaultdict(list)
  model_blocks = []

  residual_space = bases.VectorSpaceWithBasis([])

  for node_id, layer_no in layer_allocation.items():
    node = graph.nodes[node_id]
    block = node[nodes.MODEL_BLOCK] if nodes.MODEL_BLOCK in node else None

    if _node_is_residual_block(node):
      assert isinstance(block, transformers.SeriesWithResiduals)
      assert len(block.blocks) == 2
      residual_space = bases.join_vector_spaces(residual_space,
                                                block.blocks[0].residual_space,
                                                block.blocks[1].residual_space)
      blocks_by_layer[layer_no].append(block.blocks[0])
      blocks_by_layer[layer_no + 1].append(block.blocks[1])
    elif block:
      residual_space = bases.join_vector_spaces(
          residual_space, node[nodes.MODEL_BLOCK].residual_space)
      blocks_by_layer[layer_no].append(block)

  for layer_no, layer_blocks in sorted(
      blocks_by_layer.items(), key=lambda x: x[0]):
    for block in layer_blocks:
      block.residual_space = residual_space

    if layer_blocks:
      if layer_no % 2 == 0:  # Attention Layer
        multi_head_attn = transformers.MultiAttentionHead(layer_blocks)
        model_blocks.append(multi_head_attn)
      else:  # MLP Layer
        parallel_mlp = transformers.MLP.combine_in_parallel(layer_blocks)
        model_blocks.append(parallel_mlp)

  return transformers.SeriesWithResiduals(model_blocks)