Spaces:
Sleeping
Sleeping
# Copyright 2022 DeepMind Technologies Limited. All Rights Reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
# ============================================================================== | |
"""Convert craft model into transformer with the correct input/output spaces.""" | |
import networkx as nx | |
from tracr.compiler import assemble | |
from tracr.compiler import nodes | |
from tracr.craft import bases | |
from tracr.craft import transformers | |
from tracr.rasp import rasp | |
from tracr.transformer import encoder | |
def craft_model_to_transformer( | |
craft_model: transformers.SeriesWithResiduals, | |
graph: nx.DiGraph, | |
sink: nodes.Node, | |
max_seq_len: int, | |
compiler_bos: str, | |
compiler_pad: str, | |
causal: bool = False, | |
) -> assemble.AssembledTransformerModel: | |
"""Turn a craft model into a transformer model.""" | |
# Add the compiler BOS token. | |
tokens_value_set = ( | |
graph.nodes[rasp.tokens.label][nodes.VALUE_SET].union( | |
{compiler_bos, compiler_pad})) | |
tokens_space = bases.VectorSpaceWithBasis.from_values(rasp.tokens.label, | |
tokens_value_set) | |
indices_space = bases.VectorSpaceWithBasis.from_values( | |
rasp.indices.label, range(max_seq_len)) | |
categorical_output = rasp.is_categorical(sink[nodes.EXPR]) | |
output_space = bases.VectorSpaceWithBasis(sink[nodes.OUTPUT_BASIS]) | |
assembled_model = assemble.assemble_craft_model( | |
craft_model=craft_model, | |
tokens_space=tokens_space, | |
indices_space=indices_space, | |
output_space=output_space, | |
categorical_output=categorical_output, | |
causal=causal, | |
) | |
assembled_model.input_encoder = encoder.CategoricalEncoder( | |
basis=tokens_space.basis, | |
enforce_bos=compiler_bos is not None, | |
bos_token=compiler_bos, | |
pad_token=compiler_pad, | |
max_seq_len=max_seq_len + 1 if compiler_bos is not None else max_seq_len, | |
) | |
if categorical_output: | |
assembled_model.output_encoder = encoder.CategoricalEncoder( | |
basis=output_space.basis, | |
enforce_bos=False, | |
bos_token=None, | |
pad_token=None) | |
else: | |
assembled_model.output_encoder = encoder.NumericalEncoder() | |
return assembled_model | |