RASP-Synthesis / tracr /compiler /expr_to_craft_graph.py
Vladimir Mikulik
add typing_extensions to list of deps.
d4d39d0
# Copyright 2022 DeepMind Technologies Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Add craft model blocks to graph of RASPExpr."""
from typing import Any, Callable, Optional
import networkx as nx
from tracr.compiler import nodes
from tracr.craft import bases
from tracr.craft.chamber import categorical_attn
from tracr.craft.chamber import categorical_mlp
from tracr.craft.chamber import numerical_mlp
from tracr.craft.chamber import selector_width
from tracr.rasp import rasp
def _transform_fun_to_basis_fun(
fun: Callable[..., Any],
output_direction_name: Optional[str] = None) -> Callable[..., Any]:
"""Transforms a function acting on values into one acting on directions."""
def bases_fun(*args):
values = [d.value for d in args]
result = fun(*values)
if output_direction_name:
return bases.BasisDirection(output_direction_name, result)
return result
return bases_fun
def _check_selector_expression(expr, graph):
"""Check graph structure and encodings for an aggregate or selector width."""
sel_expr = expr.selector
# Check graph structure
assert sel_expr.label in graph.predecessors(expr.label)
assert sel_expr.keys.label in graph.predecessors(sel_expr.label)
assert sel_expr.queries.label in graph.predecessors(sel_expr.label)
if (not rasp.is_categorical(sel_expr.queries) or
not rasp.is_categorical(sel_expr.keys)):
raise ValueError("Selector keys and queries must be categorical.")
def add_craft_components_to_rasp_graph(
graph: nx.DiGraph,
bos_dir: bases.BasisDirection = bases.BasisDirection("tokens", "bos"),
one_dir: bases.BasisDirection = bases.BasisDirection("one"),
causal: bool = False,
mlp_exactness: float = 100,
) -> None:
"""Translates expressions to craft blocks and attaches them to the graph.
Sets the `MODEL_BLOCK` attribute for all nodes in `graph`.
Args:
graph: RASP graph with `VALUE_SET` but not `MODEL_BLOCK` attributes.
bos_dir: Basis direction representing beginning of sequence (bos) token.
one_dir: Auxiliary basis direction that must contain 1.
causal: If True, marks attention blocks as causal.
mlp_exactness: Controls the approximation of the MLP layers.
Raises:
ValueError: On invalid input (if `MODEL_BLOCK` is set already, or
`VALUE_SET` is not set already)
NotImplementedError: If the graph contains an unsupported expression.
"""
one_space = bases.VectorSpaceWithBasis([one_dir])
for node_id, node in graph.nodes.items():
expr = node[nodes.EXPR]
if not isinstance(expr, rasp.SOp):
continue
if nodes.MODEL_BLOCK in node and node[nodes.MODEL_BLOCK]:
raise ValueError("Input graph cannot have model blocks set already.")
if nodes.VALUE_SET not in node:
raise ValueError(
"Craft components can only be added after basis inference.")
if expr is rasp.tokens or expr is rasp.indices:
block = None
elif isinstance(expr, rasp.Map):
inner_expr, inner_node = expr.inner, graph.nodes[expr.inner.label]
assert inner_expr.label in graph.predecessors(node_id)
input_space = bases.VectorSpaceWithBasis(inner_node[nodes.OUTPUT_BASIS])
output_space = bases.VectorSpaceWithBasis(node[nodes.OUTPUT_BASIS])
if rasp.is_categorical(inner_expr) and rasp.is_categorical(expr):
basis_fun = _transform_fun_to_basis_fun(expr.f, expr.label)
block = categorical_mlp.map_categorical_mlp(
input_space=input_space,
output_space=output_space,
operation=basis_fun)
elif rasp.is_categorical(inner_expr) and rasp.is_numerical(expr):
block = categorical_mlp.map_categorical_to_numerical_mlp(
input_space=input_space,
output_space=output_space,
operation=expr.f,
)
elif rasp.is_numerical(inner_expr) and rasp.is_categorical(expr):
block = numerical_mlp.map_numerical_to_categorical_mlp(
f=expr.f,
input_space=input_space,
output_space=output_space,
input_value_set=inner_node[nodes.VALUE_SET],
one_space=one_space,
hidden_name=f"_hidden_{expr.label}_",
large_number=mlp_exactness)
elif rasp.is_numerical(inner_expr) and rasp.is_numerical(expr):
block = numerical_mlp.map_numerical_mlp(
f=expr.f,
input_space=input_space,
output_space=output_space,
input_value_set=inner_node[nodes.VALUE_SET],
one_space=one_space,
hidden_name=f"_hidden_{expr.label}_",
large_number=mlp_exactness)
else:
raise NotImplementedError("Map does no support "
f"in_type '{inner_expr.type}' and"
f" out_type '{expr.type}'!")
elif isinstance(expr, rasp.SequenceMap):
fst_expr, fst_node = expr.fst, graph.nodes[expr.fst.label]
snd_expr, snd_node = expr.snd, graph.nodes[expr.snd.label]
# Check graph structure
assert fst_expr.label in graph.predecessors(node_id)
assert snd_expr.label in graph.predecessors(node_id)
fst_space = bases.VectorSpaceWithBasis(fst_node[nodes.OUTPUT_BASIS])
snd_space = bases.VectorSpaceWithBasis(snd_node[nodes.OUTPUT_BASIS])
out_space = bases.VectorSpaceWithBasis(node[nodes.OUTPUT_BASIS])
if (isinstance(expr, rasp.LinearSequenceMap) and
not all(rasp.is_numerical(x) for x in (fst_expr, snd_expr, expr))):
raise NotImplementedError("Linear SequenceMap only supports numerical "
"inputs/outputs.")
elif (
not isinstance(expr, rasp.LinearSequenceMap) and
not all(rasp.is_categorical(x) for x in (fst_expr, snd_expr, expr))):
raise NotImplementedError("(Non-linear) SequenceMap only supports "
"categorical inputs/outputs.")
if isinstance(expr, rasp.LinearSequenceMap):
assert len(fst_space.basis) == 1
assert len(snd_space.basis) == 1
assert len(out_space.basis) == 1
block = numerical_mlp.linear_sequence_map_numerical_mlp(
input1_basis_direction=fst_space.basis[0],
input2_basis_direction=snd_space.basis[0],
output_basis_direction=out_space.basis[0],
input1_factor=expr.fst_fac,
input2_factor=expr.snd_fac,
hidden_name=f"_hidden_{expr.label}_")
elif fst_space == snd_space:
# It's okay to use the local variable expr.f because it is
# only used within the same loop iteration to create the MLP.
# pylint: disable=cell-var-from-loop
basis_fun = _transform_fun_to_basis_fun(lambda x: expr.f(x, x),
expr.label)
block = categorical_mlp.map_categorical_mlp(
input_space=fst_space, output_space=out_space, operation=basis_fun)
else:
basis_fun = _transform_fun_to_basis_fun(expr.f, expr.label)
block = categorical_mlp.sequence_map_categorical_mlp(
input1_space=fst_space,
input2_space=snd_space,
output_space=out_space,
operation=basis_fun,
one_space=one_space,
hidden_name=f"_hidden_{expr.label}_")
elif isinstance(expr, rasp.Aggregate):
sel_expr: rasp.Select = expr.selector
agg_expr: rasp.Aggregate = expr
if not isinstance(sel_expr, rasp.Select):
raise TypeError("Compiling composite Selectors is not supported. "
f"Got a {sel_expr}.")
queries = graph.nodes[sel_expr.queries.label]
keys = graph.nodes[sel_expr.keys.label]
sop = graph.nodes[agg_expr.sop.label]
_check_selector_expression(expr, graph)
assert agg_expr.sop.label in graph.predecessors(node_id)
if rasp.get_encoding(agg_expr.sop) != rasp.get_encoding(agg_expr):
raise ValueError(
"sop encoding must match output encoding of the aggregate.")
if rasp.is_categorical(agg_expr) and agg_expr.default is not None:
raise ValueError("Default for a categorical aggregate must be None. "
f"Got {agg_expr.default}")
if rasp.is_numerical(agg_expr) and agg_expr.default != 0:
raise ValueError("Default for a numerical aggregate must be 0. "
f"Got {agg_expr.default}")
bos_space = bases.VectorSpaceWithBasis([bos_dir])
one_space = bases.VectorSpaceWithBasis([one_dir])
query_space = bases.VectorSpaceWithBasis(queries[nodes.OUTPUT_BASIS])
key_space = bases.VectorSpaceWithBasis(keys[nodes.OUTPUT_BASIS])
value_space = bases.VectorSpaceWithBasis(sop[nodes.OUTPUT_BASIS])
output_space = bases.VectorSpaceWithBasis(node[nodes.OUTPUT_BASIS])
# Argument order is different in craft / transformers than RASP selectors
def attn_basis_fn(query: bases.BasisDirection,
key: bases.BasisDirection) -> bool:
# It's okay to use the local variable sel_expr because this function is
# only used within the same loop iteration to create an attention head.
# pylint: disable=cell-var-from-loop
selector_basis_fn = _transform_fun_to_basis_fun(sel_expr.predicate)
return selector_basis_fn(key, query)
block = categorical_attn.categorical_attn(
query_space=query_space,
key_space=key_space,
value_space=value_space,
output_space=output_space,
bos_space=bos_space,
one_space=one_space,
attn_fn=attn_basis_fn,
default_output=output_space.null_vector(),
causal=causal,
always_attend_to_bos=False,
use_bos_for_default_output=True,
softmax_coldness=100)
elif isinstance(expr, rasp.SelectorWidth):
sel_expr = expr.selector
queries = graph.nodes[sel_expr.queries.label]
keys = graph.nodes[sel_expr.keys.label]
_check_selector_expression(expr, graph)
bos_space = bases.VectorSpaceWithBasis([bos_dir])
query_space = bases.VectorSpaceWithBasis(queries[nodes.OUTPUT_BASIS])
key_space = bases.VectorSpaceWithBasis(keys[nodes.OUTPUT_BASIS])
output_space = bases.VectorSpaceWithBasis(node[nodes.OUTPUT_BASIS])
# Argument order is different in craft / transformers than RASP selectors
def attn_basis_fn(query: bases.BasisDirection,
key: bases.BasisDirection) -> bool:
# It's okay to use the local variable sel_expr because this function is
# only used within the same loop iteration to create an attention head.
selector_basis_fn = _transform_fun_to_basis_fun(sel_expr.predicate) # pylint: disable=cell-var-from-loop
return selector_basis_fn(key, query)
block = selector_width.selector_width(
query_space=query_space,
key_space=key_space,
output_space=output_space,
bos_space=bos_space,
one_space=one_space,
attn_fn=attn_basis_fn,
out_value_set=node[nodes.VALUE_SET],
categorical_output=rasp.is_categorical(expr),
causal=False,
softmax_coldness=100,
mlp_large_number=mlp_exactness,
label=expr.label)
else:
raise NotImplementedError(f"Expression {expr} cannot be translated to "
"a model component.")
graph.nodes[node_id][nodes.MODEL_BLOCK] = block