Spaces:
Sleeping
Sleeping
# Copyright 2022 DeepMind Technologies Limited. All Rights Reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
# ============================================================================== | |
"""MLPs to compute arbitrary numerical functions by discretising.""" | |
import dataclasses | |
from typing import Callable, Iterable, List | |
from tracr.craft import bases | |
from tracr.craft import transformers | |
from tracr.craft import vectorspace_fns | |
from tracr.utils import errors | |
class DiscretisingLayerMaterials: | |
"""Provides components for a hidden layer that discretises the input. | |
Attributes: | |
action: Function acting on basis directions that defines the computation. | |
hidden_space: Vector space of the hidden representation of the layer. | |
output_values: Set of output values that correspond to the discretisation. | |
""" | |
action: Callable[[bases.BasisDirection], bases.VectorInBasis] | |
hidden_space: bases.VectorSpaceWithBasis | |
output_values: List[float] | |
def _get_discretising_layer(input_value_set: Iterable[float], | |
f: Callable[[float], | |
float], hidden_name: bases.Name, | |
one_direction: bases.BasisDirection, | |
large_number: float) -> DiscretisingLayerMaterials: | |
"""Creates a hidden layer that discretises the input of f(x) into a value set. | |
The input is split up into a distinct region around each value in | |
`input_value_set`: | |
elements of value set: v0 | v1 | v2 | v3 | v4 | ... | |
thresholds: t0 t1 t2 t3 t4 | |
The hidden layer has two activations per threshold: | |
hidden_k_1 = ReLU(L * (x - threshold[k]) + 1) | |
hidden_k_2 = ReLU(L * (x - threshold[k])) | |
Note that hidden_k_1 - hidden_k_2 is: | |
1 if x >= threshold[k] + 1/L | |
0 if x <= threshold[k] | |
between 0 and 1 if threshold[k] < x < threshold[k] + 1/L | |
So as long as we choose L a big enough number, we have | |
hidden_k_1 - hidden_k_2 = 1 if x >= threshold[k]. | |
i.e. we know in which region the input value is. | |
Args: | |
input_value_set: Set of discrete input values. | |
f: Function to approximate. | |
hidden_name: Name for hidden dimensions. | |
one_direction: Auxiliary dimension that must contain 1 in the input. | |
large_number: Large number L that determines accuracy of the computation. | |
Returns: | |
DiscretisingLayerMaterials containing all components for the layer. | |
""" | |
output_values, sorted_values = [], [] | |
for x in sorted(input_value_set): | |
res = errors.ignoring_arithmetic_errors(f)(x) | |
if res is not None: | |
output_values.append(res) | |
sorted_values.append(x) | |
num_vals = len(sorted_values) | |
value_thresholds = [ | |
(sorted_values[i] + sorted_values[i + 1]) / 2 for i in range(num_vals - 1) | |
] | |
hidden_directions = [bases.BasisDirection(f"{hidden_name}start")] | |
for k in range(1, num_vals): | |
dir0 = bases.BasisDirection(hidden_name, (k, 0)) | |
dir1 = bases.BasisDirection(hidden_name, (k, 1)) | |
hidden_directions.extend([dir0, dir1]) | |
hidden_space = bases.VectorSpaceWithBasis(hidden_directions) | |
def action(direction: bases.BasisDirection) -> bases.VectorInBasis: | |
# hidden_k_0 = ReLU(L * (x - threshold[k]) + 1) | |
# hidden_k_1 = ReLU(L * (x - threshold[k])) | |
if direction == one_direction: | |
hidden = hidden_space.vector_from_basis_direction( | |
bases.BasisDirection(f"{hidden_name}start")) | |
else: | |
hidden = hidden_space.null_vector() | |
for k in range(1, num_vals): | |
vec0 = hidden_space.vector_from_basis_direction( | |
bases.BasisDirection(hidden_name, (k, 0))) | |
vec1 = hidden_space.vector_from_basis_direction( | |
bases.BasisDirection(hidden_name, (k, 1))) | |
if direction == one_direction: | |
hidden += (1 - large_number * value_thresholds[k - 1]) * vec0 | |
hidden -= large_number * value_thresholds[k - 1] * vec1 | |
else: | |
hidden += large_number * vec0 + large_number * vec1 | |
return hidden | |
return DiscretisingLayerMaterials( | |
action=action, hidden_space=hidden_space, output_values=output_values) | |
def map_numerical_mlp( | |
f: Callable[[float], float], | |
input_space: bases.VectorSpaceWithBasis, | |
output_space: bases.VectorSpaceWithBasis, | |
input_value_set: Iterable[float], | |
one_space: bases.VectorSpaceWithBasis, | |
large_number: float = 100, | |
hidden_name: bases.Name = "__hidden__", | |
) -> transformers.MLP: | |
"""Returns an MLP that encodes any function of a single variable f(x). | |
This is implemented by discretising the input according to input_value_set | |
and defining thresholds that determine which part of the input range will | |
is allocated to which value in input_value_set. | |
elements of value set: v0 | v1 | v2 | v3 | v4 | ... | |
thresholds: t0 t1 t2 t3 t4 | |
The MLP computes two hidden activations per threshold: | |
hidden_k_0 = ReLU(L * (x - threshold[k]) + 1) | |
hidden_k_1 = ReLU(L * (x - threshold[k])) | |
Note that hidden_k_1 - hidden_k_2 is: | |
1 if x >= threshold[k] + 1/L | |
0 if x <= threshold[k] | |
between 0 and 1 if threshold[k] < x < threshold[k] + 1/L | |
So as long as we choose L a big enough number, we have | |
hidden_k_0 - hidden_k_1 = 1 if x >= threshold[k]. | |
The MLP then computes the output as: | |
output = f(input[0]) + | |
sum((hidden_k_0 - hidden_k_1) * (f(input[k]) - f(input[k-1])) | |
for all k=0,1,...) | |
This sum will be (by a telescoping sums argument) | |
f(input[0]) if x <= threshold[0] | |
f(input[k]) if threshold[k-1] < x <= threshold[k] for some other k | |
f(input[-1]) if x > threshold[-1] | |
which approximates f() up to an accuracy given by input_value_set and L. | |
Args: | |
f: Function to approximate. | |
input_space: 1-d vector space that encodes the input x. | |
output_space: 1-d vector space to write the output to. | |
input_value_set: Set of values the input can take. | |
one_space: Auxiliary 1-d vector space that must contain 1 in the input. | |
large_number: Large number L that determines accuracy of the computation. | |
Note that too large values of L can lead to numerical issues, particularly | |
during inference on GPU/TPU. | |
hidden_name: Name for hidden dimensions. | |
""" | |
bases.ensure_dims(input_space, num_dims=1, name="input_space") | |
bases.ensure_dims(output_space, num_dims=1, name="output_space") | |
bases.ensure_dims(one_space, num_dims=1, name="one_space") | |
input_space = bases.join_vector_spaces(input_space, one_space) | |
out_vec = output_space.vector_from_basis_direction(output_space.basis[0]) | |
discretising_layer = _get_discretising_layer( | |
input_value_set=input_value_set, | |
f=f, | |
hidden_name=hidden_name, | |
one_direction=one_space.basis[0], | |
large_number=large_number) | |
first_layer = vectorspace_fns.Linear.from_action( | |
input_space, discretising_layer.hidden_space, discretising_layer.action) | |
def second_layer_action( | |
direction: bases.BasisDirection) -> bases.VectorInBasis: | |
# output = sum( | |
# (hidden_k_0 - hidden_k_1) * (f(input[k]) - f(input[k-1])) | |
# for all k) | |
if direction.name == f"{hidden_name}start": | |
return discretising_layer.output_values[0] * out_vec | |
k, i = direction.value | |
# add hidden_k_0 and subtract hidden_k_1 | |
sign = {0: 1, 1: -1}[i] | |
return sign * (discretising_layer.output_values[k] - | |
discretising_layer.output_values[k - 1]) * out_vec | |
second_layer = vectorspace_fns.Linear.from_action( | |
discretising_layer.hidden_space, output_space, second_layer_action) | |
return transformers.MLP(first_layer, second_layer) | |
def map_numerical_to_categorical_mlp( | |
f: Callable[[float], float], | |
input_space: bases.VectorSpaceWithBasis, | |
output_space: bases.VectorSpaceWithBasis, | |
input_value_set: Iterable[float], | |
one_space: bases.VectorSpaceWithBasis, | |
large_number: float = 100, | |
hidden_name: bases.Name = "__hidden__", | |
) -> transformers.MLP: | |
"""Returns an MLP to compute f(x) from a numerical to a categorical variable. | |
Uses a set of possible output values, and rounds f(x) to the closest value | |
in this set to create a categorical output variable. | |
The output is discretised the same way as in `map_numerical_mlp`. | |
Args: | |
f: Function to approximate. | |
input_space: 1-d vector space that encodes the input x. | |
output_space: n-d vector space to write categorical output to. The output | |
directions need to encode the possible output values. | |
input_value_set: Set of values the input can take. | |
one_space: Auxiliary 1-d space that must contain 1 in the input. | |
large_number: Large number L that determines accuracy of the computation. | |
hidden_name: Name for hidden dimensions. | |
""" | |
bases.ensure_dims(input_space, num_dims=1, name="input_space") | |
bases.ensure_dims(one_space, num_dims=1, name="one_space") | |
input_space = bases.join_vector_spaces(input_space, one_space) | |
vec_by_out_val = dict() | |
for d in output_space.basis: | |
# TODO(b/255937603): Do a similar assert in other places where we expect | |
# categorical basis directions to encode values. | |
assert d.value is not None, ("output directions need to encode " | |
"possible output values") | |
vec_by_out_val[d.value] = output_space.vector_from_basis_direction(d) | |
discretising_layer = _get_discretising_layer( | |
input_value_set=input_value_set, | |
f=f, | |
hidden_name=hidden_name, | |
one_direction=one_space.basis[0], | |
large_number=large_number) | |
assert set(discretising_layer.output_values).issubset( | |
set(vec_by_out_val.keys())) | |
first_layer = vectorspace_fns.Linear.from_action( | |
input_space, discretising_layer.hidden_space, discretising_layer.action) | |
def second_layer_action( | |
direction: bases.BasisDirection) -> bases.VectorInBasis: | |
"""Computes output value and returns corresponding output direction.""" | |
if direction.name == f"{hidden_name}start": | |
return vec_by_out_val[discretising_layer.output_values[0]] | |
else: | |
k, i = direction.value | |
# add hidden_k_0 and subtract hidden_k_1 | |
sign = {0: 1, 1: -1}[i] | |
out_k = discretising_layer.output_values[k] | |
out_k_m_1 = discretising_layer.output_values[k - 1] | |
return sign * (vec_by_out_val[out_k] - vec_by_out_val[out_k_m_1]) | |
second_layer = vectorspace_fns.Linear.from_action( | |
discretising_layer.hidden_space, output_space, second_layer_action) | |
return transformers.MLP(first_layer, second_layer) | |
def linear_sequence_map_numerical_mlp( | |
input1_basis_direction: bases.BasisDirection, | |
input2_basis_direction: bases.BasisDirection, | |
output_basis_direction: bases.BasisDirection, | |
input1_factor: float, | |
input2_factor: float, | |
hidden_name: bases.Name = "__hidden__", | |
) -> transformers.MLP: | |
"""Returns an MLP that encodes a linear function f(x, y) = a*x + b*y. | |
Args: | |
input1_basis_direction: Basis direction that encodes the input x. | |
input2_basis_direction: Basis direction that encodes the input y. | |
output_basis_direction: Basis direction to write the output to. | |
input1_factor: Linear factor a for input x. | |
input2_factor: Linear factor a for input y. | |
hidden_name: Name for hidden dimensions. | |
""" | |
input_space = bases.VectorSpaceWithBasis( | |
[input1_basis_direction, input2_basis_direction]) | |
output_space = bases.VectorSpaceWithBasis([output_basis_direction]) | |
out_vec = output_space.vector_from_basis_direction(output_basis_direction) | |
hidden_directions = [ | |
bases.BasisDirection(f"{hidden_name}x", 1), | |
bases.BasisDirection(f"{hidden_name}x", -1), | |
bases.BasisDirection(f"{hidden_name}y", 1), | |
bases.BasisDirection(f"{hidden_name}y", -1) | |
] | |
hidden_space = bases.VectorSpaceWithBasis(hidden_directions) | |
x_pos_vec, x_neg_vec, y_pos_vec, y_neg_vec = ( | |
hidden_space.vector_from_basis_direction(d) for d in hidden_directions) | |
def first_layer_action( | |
direction: bases.BasisDirection) -> bases.VectorInBasis: | |
output = hidden_space.null_vector() | |
if direction == input1_basis_direction: | |
output += x_pos_vec - x_neg_vec | |
if direction == input2_basis_direction: | |
output += y_pos_vec - y_neg_vec | |
return output | |
first_layer = vectorspace_fns.Linear.from_action(input_space, hidden_space, | |
first_layer_action) | |
def second_layer_action( | |
direction: bases.BasisDirection) -> bases.VectorInBasis: | |
if direction.name == f"{hidden_name}x": | |
return input1_factor * direction.value * out_vec | |
if direction.name == f"{hidden_name}y": | |
return input2_factor * direction.value * out_vec | |
return output_space.null_vector() | |
second_layer = vectorspace_fns.Linear.from_action(hidden_space, output_space, | |
second_layer_action) | |
return transformers.MLP(first_layer, second_layer) | |