RASP-Synthesis / tracr /craft /chamber /numerical_mlp.py
NeelNanda's picture
Made compatible with Python 3.8
c46567d
# Copyright 2022 DeepMind Technologies Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""MLPs to compute arbitrary numerical functions by discretising."""
import dataclasses
from typing import Callable, Iterable, List
from tracr.craft import bases
from tracr.craft import transformers
from tracr.craft import vectorspace_fns
from tracr.utils import errors
@dataclasses.dataclass
class DiscretisingLayerMaterials:
"""Provides components for a hidden layer that discretises the input.
Attributes:
action: Function acting on basis directions that defines the computation.
hidden_space: Vector space of the hidden representation of the layer.
output_values: Set of output values that correspond to the discretisation.
"""
action: Callable[[bases.BasisDirection], bases.VectorInBasis]
hidden_space: bases.VectorSpaceWithBasis
output_values: List[float]
def _get_discretising_layer(input_value_set: Iterable[float],
f: Callable[[float],
float], hidden_name: bases.Name,
one_direction: bases.BasisDirection,
large_number: float) -> DiscretisingLayerMaterials:
"""Creates a hidden layer that discretises the input of f(x) into a value set.
The input is split up into a distinct region around each value in
`input_value_set`:
elements of value set: v0 | v1 | v2 | v3 | v4 | ...
thresholds: t0 t1 t2 t3 t4
The hidden layer has two activations per threshold:
hidden_k_1 = ReLU(L * (x - threshold[k]) + 1)
hidden_k_2 = ReLU(L * (x - threshold[k]))
Note that hidden_k_1 - hidden_k_2 is:
1 if x >= threshold[k] + 1/L
0 if x <= threshold[k]
between 0 and 1 if threshold[k] < x < threshold[k] + 1/L
So as long as we choose L a big enough number, we have
hidden_k_1 - hidden_k_2 = 1 if x >= threshold[k].
i.e. we know in which region the input value is.
Args:
input_value_set: Set of discrete input values.
f: Function to approximate.
hidden_name: Name for hidden dimensions.
one_direction: Auxiliary dimension that must contain 1 in the input.
large_number: Large number L that determines accuracy of the computation.
Returns:
DiscretisingLayerMaterials containing all components for the layer.
"""
output_values, sorted_values = [], []
for x in sorted(input_value_set):
res = errors.ignoring_arithmetic_errors(f)(x)
if res is not None:
output_values.append(res)
sorted_values.append(x)
num_vals = len(sorted_values)
value_thresholds = [
(sorted_values[i] + sorted_values[i + 1]) / 2 for i in range(num_vals - 1)
]
hidden_directions = [bases.BasisDirection(f"{hidden_name}start")]
for k in range(1, num_vals):
dir0 = bases.BasisDirection(hidden_name, (k, 0))
dir1 = bases.BasisDirection(hidden_name, (k, 1))
hidden_directions.extend([dir0, dir1])
hidden_space = bases.VectorSpaceWithBasis(hidden_directions)
def action(direction: bases.BasisDirection) -> bases.VectorInBasis:
# hidden_k_0 = ReLU(L * (x - threshold[k]) + 1)
# hidden_k_1 = ReLU(L * (x - threshold[k]))
if direction == one_direction:
hidden = hidden_space.vector_from_basis_direction(
bases.BasisDirection(f"{hidden_name}start"))
else:
hidden = hidden_space.null_vector()
for k in range(1, num_vals):
vec0 = hidden_space.vector_from_basis_direction(
bases.BasisDirection(hidden_name, (k, 0)))
vec1 = hidden_space.vector_from_basis_direction(
bases.BasisDirection(hidden_name, (k, 1)))
if direction == one_direction:
hidden += (1 - large_number * value_thresholds[k - 1]) * vec0
hidden -= large_number * value_thresholds[k - 1] * vec1
else:
hidden += large_number * vec0 + large_number * vec1
return hidden
return DiscretisingLayerMaterials(
action=action, hidden_space=hidden_space, output_values=output_values)
def map_numerical_mlp(
f: Callable[[float], float],
input_space: bases.VectorSpaceWithBasis,
output_space: bases.VectorSpaceWithBasis,
input_value_set: Iterable[float],
one_space: bases.VectorSpaceWithBasis,
large_number: float = 100,
hidden_name: bases.Name = "__hidden__",
) -> transformers.MLP:
"""Returns an MLP that encodes any function of a single variable f(x).
This is implemented by discretising the input according to input_value_set
and defining thresholds that determine which part of the input range will
is allocated to which value in input_value_set.
elements of value set: v0 | v1 | v2 | v3 | v4 | ...
thresholds: t0 t1 t2 t3 t4
The MLP computes two hidden activations per threshold:
hidden_k_0 = ReLU(L * (x - threshold[k]) + 1)
hidden_k_1 = ReLU(L * (x - threshold[k]))
Note that hidden_k_1 - hidden_k_2 is:
1 if x >= threshold[k] + 1/L
0 if x <= threshold[k]
between 0 and 1 if threshold[k] < x < threshold[k] + 1/L
So as long as we choose L a big enough number, we have
hidden_k_0 - hidden_k_1 = 1 if x >= threshold[k].
The MLP then computes the output as:
output = f(input[0]) +
sum((hidden_k_0 - hidden_k_1) * (f(input[k]) - f(input[k-1]))
for all k=0,1,...)
This sum will be (by a telescoping sums argument)
f(input[0]) if x <= threshold[0]
f(input[k]) if threshold[k-1] < x <= threshold[k] for some other k
f(input[-1]) if x > threshold[-1]
which approximates f() up to an accuracy given by input_value_set and L.
Args:
f: Function to approximate.
input_space: 1-d vector space that encodes the input x.
output_space: 1-d vector space to write the output to.
input_value_set: Set of values the input can take.
one_space: Auxiliary 1-d vector space that must contain 1 in the input.
large_number: Large number L that determines accuracy of the computation.
Note that too large values of L can lead to numerical issues, particularly
during inference on GPU/TPU.
hidden_name: Name for hidden dimensions.
"""
bases.ensure_dims(input_space, num_dims=1, name="input_space")
bases.ensure_dims(output_space, num_dims=1, name="output_space")
bases.ensure_dims(one_space, num_dims=1, name="one_space")
input_space = bases.join_vector_spaces(input_space, one_space)
out_vec = output_space.vector_from_basis_direction(output_space.basis[0])
discretising_layer = _get_discretising_layer(
input_value_set=input_value_set,
f=f,
hidden_name=hidden_name,
one_direction=one_space.basis[0],
large_number=large_number)
first_layer = vectorspace_fns.Linear.from_action(
input_space, discretising_layer.hidden_space, discretising_layer.action)
def second_layer_action(
direction: bases.BasisDirection) -> bases.VectorInBasis:
# output = sum(
# (hidden_k_0 - hidden_k_1) * (f(input[k]) - f(input[k-1]))
# for all k)
if direction.name == f"{hidden_name}start":
return discretising_layer.output_values[0] * out_vec
k, i = direction.value
# add hidden_k_0 and subtract hidden_k_1
sign = {0: 1, 1: -1}[i]
return sign * (discretising_layer.output_values[k] -
discretising_layer.output_values[k - 1]) * out_vec
second_layer = vectorspace_fns.Linear.from_action(
discretising_layer.hidden_space, output_space, second_layer_action)
return transformers.MLP(first_layer, second_layer)
def map_numerical_to_categorical_mlp(
f: Callable[[float], float],
input_space: bases.VectorSpaceWithBasis,
output_space: bases.VectorSpaceWithBasis,
input_value_set: Iterable[float],
one_space: bases.VectorSpaceWithBasis,
large_number: float = 100,
hidden_name: bases.Name = "__hidden__",
) -> transformers.MLP:
"""Returns an MLP to compute f(x) from a numerical to a categorical variable.
Uses a set of possible output values, and rounds f(x) to the closest value
in this set to create a categorical output variable.
The output is discretised the same way as in `map_numerical_mlp`.
Args:
f: Function to approximate.
input_space: 1-d vector space that encodes the input x.
output_space: n-d vector space to write categorical output to. The output
directions need to encode the possible output values.
input_value_set: Set of values the input can take.
one_space: Auxiliary 1-d space that must contain 1 in the input.
large_number: Large number L that determines accuracy of the computation.
hidden_name: Name for hidden dimensions.
"""
bases.ensure_dims(input_space, num_dims=1, name="input_space")
bases.ensure_dims(one_space, num_dims=1, name="one_space")
input_space = bases.join_vector_spaces(input_space, one_space)
vec_by_out_val = dict()
for d in output_space.basis:
# TODO(b/255937603): Do a similar assert in other places where we expect
# categorical basis directions to encode values.
assert d.value is not None, ("output directions need to encode "
"possible output values")
vec_by_out_val[d.value] = output_space.vector_from_basis_direction(d)
discretising_layer = _get_discretising_layer(
input_value_set=input_value_set,
f=f,
hidden_name=hidden_name,
one_direction=one_space.basis[0],
large_number=large_number)
assert set(discretising_layer.output_values).issubset(
set(vec_by_out_val.keys()))
first_layer = vectorspace_fns.Linear.from_action(
input_space, discretising_layer.hidden_space, discretising_layer.action)
def second_layer_action(
direction: bases.BasisDirection) -> bases.VectorInBasis:
"""Computes output value and returns corresponding output direction."""
if direction.name == f"{hidden_name}start":
return vec_by_out_val[discretising_layer.output_values[0]]
else:
k, i = direction.value
# add hidden_k_0 and subtract hidden_k_1
sign = {0: 1, 1: -1}[i]
out_k = discretising_layer.output_values[k]
out_k_m_1 = discretising_layer.output_values[k - 1]
return sign * (vec_by_out_val[out_k] - vec_by_out_val[out_k_m_1])
second_layer = vectorspace_fns.Linear.from_action(
discretising_layer.hidden_space, output_space, second_layer_action)
return transformers.MLP(first_layer, second_layer)
def linear_sequence_map_numerical_mlp(
input1_basis_direction: bases.BasisDirection,
input2_basis_direction: bases.BasisDirection,
output_basis_direction: bases.BasisDirection,
input1_factor: float,
input2_factor: float,
hidden_name: bases.Name = "__hidden__",
) -> transformers.MLP:
"""Returns an MLP that encodes a linear function f(x, y) = a*x + b*y.
Args:
input1_basis_direction: Basis direction that encodes the input x.
input2_basis_direction: Basis direction that encodes the input y.
output_basis_direction: Basis direction to write the output to.
input1_factor: Linear factor a for input x.
input2_factor: Linear factor a for input y.
hidden_name: Name for hidden dimensions.
"""
input_space = bases.VectorSpaceWithBasis(
[input1_basis_direction, input2_basis_direction])
output_space = bases.VectorSpaceWithBasis([output_basis_direction])
out_vec = output_space.vector_from_basis_direction(output_basis_direction)
hidden_directions = [
bases.BasisDirection(f"{hidden_name}x", 1),
bases.BasisDirection(f"{hidden_name}x", -1),
bases.BasisDirection(f"{hidden_name}y", 1),
bases.BasisDirection(f"{hidden_name}y", -1)
]
hidden_space = bases.VectorSpaceWithBasis(hidden_directions)
x_pos_vec, x_neg_vec, y_pos_vec, y_neg_vec = (
hidden_space.vector_from_basis_direction(d) for d in hidden_directions)
def first_layer_action(
direction: bases.BasisDirection) -> bases.VectorInBasis:
output = hidden_space.null_vector()
if direction == input1_basis_direction:
output += x_pos_vec - x_neg_vec
if direction == input2_basis_direction:
output += y_pos_vec - y_neg_vec
return output
first_layer = vectorspace_fns.Linear.from_action(input_space, hidden_space,
first_layer_action)
def second_layer_action(
direction: bases.BasisDirection) -> bases.VectorInBasis:
if direction.name == f"{hidden_name}x":
return input1_factor * direction.value * out_vec
if direction.name == f"{hidden_name}y":
return input2_factor * direction.value * out_vec
return output_space.null_vector()
second_layer = vectorspace_fns.Linear.from_action(hidden_space, output_space,
second_layer_action)
return transformers.MLP(first_layer, second_layer)