# Copyright 2022 DeepMind Technologies Limited. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """MLP to compute basic linear functions of one-hot encoded integers.""" from typing import Callable import numpy as np from tracr.craft import bases from tracr.craft import transformers from tracr.craft import vectorspace_fns _ONE_SPACE = bases.VectorSpaceWithBasis.from_names(["one"]) def map_categorical_mlp( input_space: bases.VectorSpaceWithBasis, output_space: bases.VectorSpaceWithBasis, operation: Callable[[bases.BasisDirection], bases.BasisDirection], ) -> transformers.MLP: """Returns an MLP that encodes any categorical function of a single variable f(x). The hidden layer is the identity and output combines this with a lookup table output_k = sum(f(i)*input_i for all i in input space) Args: input_space: space containing the input x. output_space: space containing possible outputs. operation: A function operating on basis directions. """ def operation_fn(direction): if direction in input_space: output_direction = operation(direction) if output_direction in output_space: return output_space.vector_from_basis_direction(output_direction) return output_space.null_vector() first_layer = vectorspace_fns.Linear.from_action(input_space, output_space, operation_fn) second_layer = vectorspace_fns.project(output_space, output_space) return transformers.MLP(first_layer, second_layer) def map_categorical_to_numerical_mlp( input_space: bases.VectorSpaceWithBasis, output_space: bases.VectorSpaceWithBasis, operation: Callable[[bases.Value], float], ) -> transformers.MLP: """Returns an MLP to compute f(x) from a categorical to a numerical variable. The hidden layer is the identity and output combines this with a lookup table output = sum(f(i)*input_i for all i in input space) Args: input_space: Vector space containing the input x. output_space: Vector space to write the numerical output to. operation: A function operating on basis directions. """ bases.ensure_dims(output_space, num_dims=1, name="output_space") out_vec = output_space.vector_from_basis_direction(output_space.basis[0]) def operation_fn(direction): if direction in input_space: return operation(direction.value) * out_vec return output_space.null_vector() first_layer = vectorspace_fns.Linear.from_action(input_space, output_space, operation_fn) second_layer = vectorspace_fns.project(output_space, output_space) return transformers.MLP(first_layer, second_layer) def sequence_map_categorical_mlp( input1_space: bases.VectorSpaceWithBasis, input2_space: bases.VectorSpaceWithBasis, output_space: bases.VectorSpaceWithBasis, operation: Callable[[bases.BasisDirection, bases.BasisDirection], bases.BasisDirection], one_space: bases.VectorSpaceWithBasis = _ONE_SPACE, hidden_name: bases.Name = "__hidden__", ) -> transformers.MLP: """Returns an MLP that encodes a categorical function of two variables f(x, y). The hidden layer of the MLP computes the logical and of all input directions hidden_i_j = ReLU(x_i+x_j-1) And the output combines this with a lookup table output_k = sum(f(i, j)*hidden_i_j for all i,j in input space) Args: input1_space: Vector space containing the input x. input2_space: Vector space containing the input y. output_space: Vector space to write outputs to. operation: A function operating on basis directions. one_space: a reserved 1-d space that always contains a 1. hidden_name: Name for hidden dimensions. """ bases.ensure_dims(one_space, num_dims=1, name="one_space") if not set(input1_space.basis).isdisjoint(input2_space.basis): raise ValueError("Input spaces to a SequenceMap must be disjoint. " "If input spaces are the same, use Map instead!") input_space = bases.direct_sum(input1_space, input2_space, one_space) def to_hidden(x, y): return bases.BasisDirection(hidden_name, (x.name, x.value, y.name, y.value)) def from_hidden(h): x_name, x_value, y_name, y_value = h.value x_dir = bases.BasisDirection(x_name, x_value) y_dir = bases.BasisDirection(y_name, y_value) return x_dir, y_dir hidden_dir = [] for dir1 in input1_space.basis: for dir2 in input2_space.basis: hidden_dir.append(to_hidden(dir1, dir2)) hidden_space = bases.VectorSpaceWithBasis(hidden_dir) def logical_and(direction): if direction in one_space: out = bases.VectorInBasis(hidden_space.basis, -np.ones(hidden_space.num_dims)) elif direction in input1_space: dir1 = direction out = hidden_space.null_vector() for dir2 in input2_space.basis: out += hidden_space.vector_from_basis_direction(to_hidden(dir1, dir2)) else: dir2 = direction out = hidden_space.null_vector() for dir1 in input1_space.basis: out += hidden_space.vector_from_basis_direction(to_hidden(dir1, dir2)) return out first_layer = vectorspace_fns.Linear.from_action(input_space, hidden_space, logical_and) def operation_fn(direction): dir1, dir2 = from_hidden(direction) output_direction = operation(dir1, dir2) if output_direction in output_space: return output_space.vector_from_basis_direction(output_direction) else: return output_space.null_vector() second_layer = vectorspace_fns.Linear.from_action(hidden_space, output_space, operation_fn) return transformers.MLP(first_layer, second_layer)