Spaces:
Sleeping
Sleeping
# Copyright 2022 DeepMind Technologies Limited. All Rights Reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
# ============================================================================== | |
"""Combines all steps of compiling a RASP program.""" | |
from tracr.compiler import assemble | |
from tracr.compiler import basis_inference | |
from tracr.compiler import craft_graph_to_model | |
from tracr.compiler import craft_model_to_transformer | |
from tracr.compiler import expr_to_craft_graph | |
from tracr.compiler import rasp_to_graph | |
from tracr.craft import bases | |
from tracr.rasp import rasp | |
COMPILER_BOS = "compiler_bos" | |
COMPILER_PAD = "compiler_pad" | |
def compile_rasp_to_model( | |
program: rasp.SOp, | |
vocab: set[rasp.Value], | |
max_seq_len: int, | |
causal: bool = False, | |
compiler_bos: str = COMPILER_BOS, | |
compiler_pad: str = COMPILER_PAD, | |
mlp_exactness: int = 100) -> assemble.AssembledTransformerModel: | |
"""Compile a RASP program to transformer weights. | |
Args: | |
program: the RASP program to compile. | |
vocab: the set of vocab tokens expected by RASP. | |
max_seq_len: the maximum sequence length for the compiled model. | |
causal: if True, outputs a model with causal masking. | |
compiler_bos: the name of the special BOS token that will be added by the | |
compiler. Must not be present in the vocab. | |
compiler_pad: the name of the special PAD token that will be added by the | |
compiler. Must not be present in the vocab. | |
mlp_exactness: Controls the approximation of the MLP layers. In theory, | |
larger values yield a better approximation. But too large values can cause | |
numerical issues due to large parameter norms. Reasonable values are | |
between 1 and 100. | |
Returns: | |
The compiled model. | |
""" | |
if compiler_bos in vocab: | |
raise ValueError("Compiler BOS token must not be present in the vocab. " | |
f"Found '{compiler_bos}' in {vocab}") | |
if compiler_pad in vocab: | |
raise ValueError("Compiler PAD token must not be present in the vocab. " | |
f"Found '{compiler_pad}' in {vocab}") | |
extracted = rasp_to_graph.extract_rasp_graph(program) | |
graph, sources, sink = extracted.graph, extracted.sources, extracted.sink | |
basis_inference.infer_bases( | |
graph, | |
sink, | |
vocab, | |
max_seq_len, | |
) | |
expr_to_craft_graph.add_craft_components_to_rasp_graph( | |
graph, | |
bos_dir=bases.BasisDirection(rasp.tokens.label, compiler_bos), | |
mlp_exactness=mlp_exactness, | |
) | |
craft_model = craft_graph_to_model.craft_graph_to_model(graph, sources) | |
return craft_model_to_transformer.craft_model_to_transformer( | |
craft_model=craft_model, | |
graph=graph, | |
sink=sink, | |
max_seq_len=max_seq_len, | |
causal=causal, | |
compiler_bos=compiler_bos, | |
compiler_pad=compiler_pad, | |
) | |