Spaces:
Sleeping
Sleeping
Made compatible with Python 3.8
Browse files- tracr/compiler/assemble.py +12 -11
- tracr/compiler/basis_inference.py +3 -2
- tracr/compiler/compiling.py +2 -1
- tracr/compiler/craft_graph_to_model.py +5 -5
- tracr/compiler/lib.py +2 -2
- tracr/compiler/nodes.py +2 -2
- tracr/compiler/rasp_to_graph.py +3 -2
- tracr/compiler/rasp_to_transformer_integration_test.py +1 -1
- tracr/craft/bases.py +1 -1
- tracr/craft/chamber/categorical_attn.py +2 -1
- tracr/craft/chamber/numerical_mlp.py +2 -2
- tracr/craft/transformers.py +3 -3
- tracr/craft/vectorspace_fns.py +5 -5
- tracr/rasp/rasp.py +10 -8
- tracr/transformer/encoder.py +8 -8
- tracr/transformer/model.py +4 -4
tracr/compiler/assemble.py
CHANGED
@@ -15,7 +15,8 @@
|
|
15 |
"""Assemble weights of a transformer model from a craft residual stack."""
|
16 |
|
17 |
import dataclasses
|
18 |
-
from typing import Any, Callable, Optional,
|
|
|
19 |
|
20 |
import chex
|
21 |
import einops
|
@@ -32,11 +33,11 @@ from tracr.transformer import model
|
|
32 |
|
33 |
@chex.dataclass
|
34 |
class AssembledTransformerModelOutput:
|
35 |
-
decoded:
|
36 |
unembedded: jax.Array # [B, T] B = 1 always.
|
37 |
-
layer_outputs:
|
38 |
-
residuals:
|
39 |
-
attn_logits:
|
40 |
transformer_output: jax.Array # [B, T, D]
|
41 |
input_embeddings: jax.Array
|
42 |
|
@@ -58,11 +59,11 @@ class AssembledTransformerModel:
|
|
58 |
get_compiled_model: Callable[[], model.CompiledTransformerModel]
|
59 |
params: hk.Params
|
60 |
model_config: model.TransformerConfig
|
61 |
-
residual_labels:
|
62 |
input_encoder: Optional[encoder.Encoder] = None
|
63 |
output_encoder: Optional[encoder.Encoder] = None
|
64 |
|
65 |
-
def apply(self, tokens:
|
66 |
"""Returns output from running the model on a set of input tokens."""
|
67 |
if self.input_encoder:
|
68 |
tokens = self.input_encoder.encode(tokens)
|
@@ -97,12 +98,12 @@ class EmbeddingModules:
|
|
97 |
|
98 |
def _get_model_config_and_module_names(
|
99 |
craft_model: transformers.SeriesWithResiduals
|
100 |
-
) ->
|
101 |
"""Returns model config and locations (in params) for halflayers."""
|
102 |
|
103 |
-
multi_attn_heads:
|
104 |
-
mlps:
|
105 |
-
module_names:
|
106 |
|
107 |
candidate_module_names = []
|
108 |
for layer in range(len(craft_model.blocks)):
|
|
|
15 |
"""Assemble weights of a transformer model from a craft residual stack."""
|
16 |
|
17 |
import dataclasses
|
18 |
+
from typing import Any, Callable, Optional, List, Tuple
|
19 |
+
from typing_extensions import Protocol
|
20 |
|
21 |
import chex
|
22 |
import einops
|
|
|
33 |
|
34 |
@chex.dataclass
|
35 |
class AssembledTransformerModelOutput:
|
36 |
+
decoded: List[Any] # length T.
|
37 |
unembedded: jax.Array # [B, T] B = 1 always.
|
38 |
+
layer_outputs: List[jax.Array] # [B, T, D]
|
39 |
+
residuals: List[jax.Array] # [B, T, D]
|
40 |
+
attn_logits: List[jax.Array] # [B, T, T, H]
|
41 |
transformer_output: jax.Array # [B, T, D]
|
42 |
input_embeddings: jax.Array
|
43 |
|
|
|
59 |
get_compiled_model: Callable[[], model.CompiledTransformerModel]
|
60 |
params: hk.Params
|
61 |
model_config: model.TransformerConfig
|
62 |
+
residual_labels: List[str]
|
63 |
input_encoder: Optional[encoder.Encoder] = None
|
64 |
output_encoder: Optional[encoder.Encoder] = None
|
65 |
|
66 |
+
def apply(self, tokens: List[bases.Value]) -> AssembledTransformerModelOutput:
|
67 |
"""Returns output from running the model on a set of input tokens."""
|
68 |
if self.input_encoder:
|
69 |
tokens = self.input_encoder.encode(tokens)
|
|
|
98 |
|
99 |
def _get_model_config_and_module_names(
|
100 |
craft_model: transformers.SeriesWithResiduals
|
101 |
+
) -> Tuple[model.TransformerConfig, List[str]]:
|
102 |
"""Returns model config and locations (in params) for halflayers."""
|
103 |
|
104 |
+
multi_attn_heads: List[List[transformers.AttentionHead]] = []
|
105 |
+
mlps: List[transformers.MLP] = []
|
106 |
+
module_names: List[str] = []
|
107 |
|
108 |
candidate_module_names = []
|
109 |
for layer in range(len(craft_model.blocks)):
|
tracr/compiler/basis_inference.py
CHANGED
@@ -16,6 +16,7 @@
|
|
16 |
|
17 |
import dataclasses
|
18 |
import itertools
|
|
|
19 |
|
20 |
import networkx as nx
|
21 |
from tracr.compiler import nodes
|
@@ -34,12 +35,12 @@ class InferBasesOutput:
|
|
34 |
def infer_bases(
|
35 |
graph: nx.DiGraph,
|
36 |
sink: Node,
|
37 |
-
vocab:
|
38 |
max_seq_len: int,
|
39 |
) -> None:
|
40 |
"""Infers in-place the possible output values and vector bases of the SOps."""
|
41 |
|
42 |
-
def compute_value_set(sop: rasp.SOp) ->
|
43 |
"""Computes value set using already-computed predecessor value sets."""
|
44 |
if sop is rasp.tokens:
|
45 |
return vocab
|
|
|
16 |
|
17 |
import dataclasses
|
18 |
import itertools
|
19 |
+
from typing import Set
|
20 |
|
21 |
import networkx as nx
|
22 |
from tracr.compiler import nodes
|
|
|
35 |
def infer_bases(
|
36 |
graph: nx.DiGraph,
|
37 |
sink: Node,
|
38 |
+
vocab: Set[rasp.Value],
|
39 |
max_seq_len: int,
|
40 |
) -> None:
|
41 |
"""Infers in-place the possible output values and vector bases of the SOps."""
|
42 |
|
43 |
+
def compute_value_set(sop: rasp.SOp) -> Set[rasp.Value]:
|
44 |
"""Computes value set using already-computed predecessor value sets."""
|
45 |
if sop is rasp.tokens:
|
46 |
return vocab
|
tracr/compiler/compiling.py
CHANGED
@@ -13,6 +13,7 @@
|
|
13 |
# limitations under the License.
|
14 |
# ==============================================================================
|
15 |
"""Combines all steps of compiling a RASP program."""
|
|
|
16 |
|
17 |
from tracr.compiler import assemble
|
18 |
from tracr.compiler import basis_inference
|
@@ -29,7 +30,7 @@ COMPILER_PAD = "compiler_pad"
|
|
29 |
|
30 |
def compile_rasp_to_model(
|
31 |
program: rasp.SOp,
|
32 |
-
vocab:
|
33 |
max_seq_len: int,
|
34 |
causal: bool = False,
|
35 |
compiler_bos: str = COMPILER_BOS,
|
|
|
13 |
# limitations under the License.
|
14 |
# ==============================================================================
|
15 |
"""Combines all steps of compiling a RASP program."""
|
16 |
+
from typing import Set
|
17 |
|
18 |
from tracr.compiler import assemble
|
19 |
from tracr.compiler import basis_inference
|
|
|
30 |
|
31 |
def compile_rasp_to_model(
|
32 |
program: rasp.SOp,
|
33 |
+
vocab: Set[rasp.Value],
|
34 |
max_seq_len: int,
|
35 |
causal: bool = False,
|
36 |
compiler_bos: str = COMPILER_BOS,
|
tracr/compiler/craft_graph_to_model.py
CHANGED
@@ -15,7 +15,7 @@
|
|
15 |
"""Create a craft model from a computational graph."""
|
16 |
|
17 |
import collections
|
18 |
-
from typing import Sequence
|
19 |
|
20 |
import networkx as nx
|
21 |
from tracr.compiler import nodes
|
@@ -105,7 +105,7 @@ def _all_mlp_nodes(node_list: Sequence[Node]) -> bool:
|
|
105 |
|
106 |
|
107 |
def _allocate_modules_to_layers(graph: nx.DiGraph,
|
108 |
-
sources: Sequence[Node]) ->
|
109 |
"""Allocate all nodes in compute graph to layers.
|
110 |
|
111 |
First, computes the longest path from the input to each node that is a model
|
@@ -128,9 +128,9 @@ def _allocate_modules_to_layers(graph: nx.DiGraph,
|
|
128 |
A dict mapping from node ids to layer indices, where 0, 1, 2, 3, ...
|
129 |
are in the order attention, mlp, attention, mlp, ...
|
130 |
"""
|
131 |
-
layer_allocation:
|
132 |
-
depth_by_node_id:
|
133 |
-
nodes_by_depth:
|
134 |
|
135 |
# Compute depth of all model components (longest path from source to node)
|
136 |
for node_id, node in graph.nodes.items():
|
|
|
15 |
"""Create a craft model from a computational graph."""
|
16 |
|
17 |
import collections
|
18 |
+
from typing import Sequence, List, Dict
|
19 |
|
20 |
import networkx as nx
|
21 |
from tracr.compiler import nodes
|
|
|
105 |
|
106 |
|
107 |
def _allocate_modules_to_layers(graph: nx.DiGraph,
|
108 |
+
sources: Sequence[Node]) -> Dict[int, int]:
|
109 |
"""Allocate all nodes in compute graph to layers.
|
110 |
|
111 |
First, computes the longest path from the input to each node that is a model
|
|
|
128 |
A dict mapping from node ids to layer indices, where 0, 1, 2, 3, ...
|
129 |
are in the order attention, mlp, attention, mlp, ...
|
130 |
"""
|
131 |
+
layer_allocation: Dict[int, int] = collections.defaultdict(lambda: -1)
|
132 |
+
depth_by_node_id: Dict[int, int] = dict()
|
133 |
+
nodes_by_depth: Dict[int, List[Node]] = collections.defaultdict(list)
|
134 |
|
135 |
# Compute depth of all model components (longest path from source to node)
|
136 |
for node_id, node in graph.nodes.items():
|
tracr/compiler/lib.py
CHANGED
@@ -14,7 +14,7 @@
|
|
14 |
# ==============================================================================
|
15 |
"""RASP programs only using the subset of RASP supported by the compiler."""
|
16 |
|
17 |
-
from typing import Sequence
|
18 |
|
19 |
from tracr.rasp import rasp
|
20 |
|
@@ -95,7 +95,7 @@ def make_pair_balance(sop: rasp.SOp, open_token: str,
|
|
95 |
return pair_balance.named("pair_balance")
|
96 |
|
97 |
|
98 |
-
def make_shuffle_dyck(pairs:
|
99 |
"""Returns 1 if a set of parentheses are balanced, 0 else.
|
100 |
|
101 |
(As implemented in the RASP paper.)
|
|
|
14 |
# ==============================================================================
|
15 |
"""RASP programs only using the subset of RASP supported by the compiler."""
|
16 |
|
17 |
+
from typing import Sequence, List
|
18 |
|
19 |
from tracr.rasp import rasp
|
20 |
|
|
|
95 |
return pair_balance.named("pair_balance")
|
96 |
|
97 |
|
98 |
+
def make_shuffle_dyck(pairs: List[str]) -> rasp.SOp:
|
99 |
"""Returns 1 if a set of parentheses are balanced, 0 else.
|
100 |
|
101 |
(As implemented in the RASP paper.)
|
tracr/compiler/nodes.py
CHANGED
@@ -14,9 +14,9 @@
|
|
14 |
# ==============================================================================
|
15 |
"""Documents the data stored in nodes after each compiler pass."""
|
16 |
|
17 |
-
from typing import Any
|
18 |
|
19 |
-
Node =
|
20 |
NodeID = str
|
21 |
|
22 |
# RASP -> Graph
|
|
|
14 |
# ==============================================================================
|
15 |
"""Documents the data stored in nodes after each compiler pass."""
|
16 |
|
17 |
+
from typing import Any, Dict
|
18 |
|
19 |
+
Node = Dict[str, Any]
|
20 |
NodeID = str
|
21 |
|
22 |
# RASP -> Graph
|
tracr/compiler/rasp_to_graph.py
CHANGED
@@ -16,6 +16,7 @@
|
|
16 |
|
17 |
import dataclasses
|
18 |
import queue
|
|
|
19 |
|
20 |
import networkx as nx
|
21 |
from tracr.compiler import nodes
|
@@ -29,14 +30,14 @@ NodeID = nodes.NodeID
|
|
29 |
class ExtractRaspGraphOutput:
|
30 |
graph: nx.DiGraph
|
31 |
sink: Node # the program's output.
|
32 |
-
sources:
|
33 |
|
34 |
|
35 |
def extract_rasp_graph(tip: rasp.SOp) -> ExtractRaspGraphOutput:
|
36 |
"""Converts a RASP program into a graph representation."""
|
37 |
expr_queue = queue.Queue()
|
38 |
graph = nx.DiGraph()
|
39 |
-
sources:
|
40 |
|
41 |
def ensure_node(expr: rasp.RASPExpr) -> NodeID:
|
42 |
"""Finds or creates a graph node corresponding to expr; returns its ID."""
|
|
|
16 |
|
17 |
import dataclasses
|
18 |
import queue
|
19 |
+
from typing import List
|
20 |
|
21 |
import networkx as nx
|
22 |
from tracr.compiler import nodes
|
|
|
30 |
class ExtractRaspGraphOutput:
|
31 |
graph: nx.DiGraph
|
32 |
sink: Node # the program's output.
|
33 |
+
sources: List[Node] # the primitive S-Ops.
|
34 |
|
35 |
|
36 |
def extract_rasp_graph(tip: rasp.SOp) -> ExtractRaspGraphOutput:
|
37 |
"""Converts a RASP program into a graph representation."""
|
38 |
expr_queue = queue.Queue()
|
39 |
graph = nx.DiGraph()
|
40 |
+
sources: List[NodeID] = []
|
41 |
|
42 |
def ensure_node(expr: rasp.RASPExpr) -> NodeID:
|
43 |
"""Finds or creates a graph node corresponding to expr; returns its ID."""
|
tracr/compiler/rasp_to_transformer_integration_test.py
CHANGED
@@ -38,7 +38,7 @@ class CompilerIntegrationTest(tests_common.VectorFnTestCase):
|
|
38 |
for actual, expected in zip(actual_seq, expected_seq):
|
39 |
if expected is not None and actual != expected:
|
40 |
self.fail(f"{actual_seq} does not match (ignoring Nones) "
|
41 |
-
f"{expected_seq
|
42 |
|
43 |
@parameterized.named_parameters(
|
44 |
dict(
|
|
|
38 |
for actual, expected in zip(actual_seq, expected_seq):
|
39 |
if expected is not None and actual != expected:
|
40 |
self.fail(f"{actual_seq} does not match (ignoring Nones) "
|
41 |
+
f"expected_seq={expected_seq}")
|
42 |
|
43 |
@parameterized.named_parameters(
|
44 |
dict(
|
tracr/craft/bases.py
CHANGED
@@ -243,5 +243,5 @@ def ensure_dims(
|
|
243 |
) -> None:
|
244 |
"""Raises ValueError if vs has the wrong number of dimensions."""
|
245 |
if vs.num_dims != num_dims:
|
246 |
-
raise ValueError(f"{name} must have {num_dims
|
247 |
f"but got {vs.num_dims}: {vs.basis}")
|
|
|
243 |
) -> None:
|
244 |
"""Raises ValueError if vs has the wrong number of dimensions."""
|
245 |
if vs.num_dims != num_dims:
|
246 |
+
raise ValueError(f"{name} must have num_dims={num_dims}, "
|
247 |
f"but got {vs.num_dims}: {vs.basis}")
|
tracr/craft/chamber/categorical_attn.py
CHANGED
@@ -14,7 +14,8 @@
|
|
14 |
# ==============================================================================
|
15 |
"""Attention head for categorical inputs."""
|
16 |
|
17 |
-
from typing import Optional
|
|
|
18 |
|
19 |
from tracr.craft import bases
|
20 |
from tracr.craft import transformers
|
|
|
14 |
# ==============================================================================
|
15 |
"""Attention head for categorical inputs."""
|
16 |
|
17 |
+
from typing import Optional
|
18 |
+
from typing_extensions import Protocol
|
19 |
|
20 |
from tracr.craft import bases
|
21 |
from tracr.craft import transformers
|
tracr/craft/chamber/numerical_mlp.py
CHANGED
@@ -16,7 +16,7 @@
|
|
16 |
|
17 |
import dataclasses
|
18 |
|
19 |
-
from typing import Callable, Iterable
|
20 |
|
21 |
from tracr.craft import bases
|
22 |
from tracr.craft import transformers
|
@@ -35,7 +35,7 @@ class DiscretisingLayerMaterials:
|
|
35 |
"""
|
36 |
action: Callable[[bases.BasisDirection], bases.VectorInBasis]
|
37 |
hidden_space: bases.VectorSpaceWithBasis
|
38 |
-
output_values:
|
39 |
|
40 |
|
41 |
def _get_discretising_layer(input_value_set: Iterable[float],
|
|
|
16 |
|
17 |
import dataclasses
|
18 |
|
19 |
+
from typing import Callable, Iterable, List
|
20 |
|
21 |
from tracr.craft import bases
|
22 |
from tracr.craft import transformers
|
|
|
35 |
"""
|
36 |
action: Callable[[bases.BasisDirection], bases.VectorInBasis]
|
37 |
hidden_space: bases.VectorSpaceWithBasis
|
38 |
+
output_values: List[float]
|
39 |
|
40 |
|
41 |
def _get_discretising_layer(input_value_set: Iterable[float],
|
tracr/craft/transformers.py
CHANGED
@@ -16,7 +16,7 @@
|
|
16 |
|
17 |
import abc
|
18 |
import dataclasses
|
19 |
-
from typing import Iterable, Optional, Sequence, Union
|
20 |
|
21 |
import numpy as np
|
22 |
|
@@ -111,7 +111,7 @@ class AttentionHead(Block):
|
|
111 |
@dataclasses.dataclass
|
112 |
class MultiAttentionHead(Block):
|
113 |
"""Applies attention heads in parallel."""
|
114 |
-
sub_blocks:
|
115 |
|
116 |
def __post_init__(self):
|
117 |
spaces = [block.residual_space for block in self.sub_blocks]
|
@@ -182,7 +182,7 @@ HalfLayerBlock = Union[MLP, AttentionHead, MultiAttentionHead]
|
|
182 |
@dataclasses.dataclass
|
183 |
class SeriesWithResiduals(Block):
|
184 |
"""A series of blocks with residual connections."""
|
185 |
-
blocks:
|
186 |
|
187 |
def __post_init__(self):
|
188 |
spaces = [block.residual_space for block in self.blocks]
|
|
|
16 |
|
17 |
import abc
|
18 |
import dataclasses
|
19 |
+
from typing import Iterable, Optional, Sequence, Union, List
|
20 |
|
21 |
import numpy as np
|
22 |
|
|
|
111 |
@dataclasses.dataclass
|
112 |
class MultiAttentionHead(Block):
|
113 |
"""Applies attention heads in parallel."""
|
114 |
+
sub_blocks: List[Union[AttentionHead, "MultiAttentionHead"]]
|
115 |
|
116 |
def __post_init__(self):
|
117 |
spaces = [block.residual_space for block in self.sub_blocks]
|
|
|
182 |
@dataclasses.dataclass
|
183 |
class SeriesWithResiduals(Block):
|
184 |
"""A series of blocks with residual connections."""
|
185 |
+
blocks: List[HalfLayerBlock]
|
186 |
|
187 |
def __post_init__(self):
|
188 |
spaces = [block.residual_space for block in self.blocks]
|
tracr/craft/vectorspace_fns.py
CHANGED
@@ -65,7 +65,7 @@ class Linear(VectorFunction):
|
|
65 |
|
66 |
def __call__(self, x: VectorInBasis) -> VectorInBasis:
|
67 |
if x not in self.input_space:
|
68 |
-
raise TypeError(f"{x
|
69 |
return VectorInBasis(
|
70 |
basis_directions=sorted(self.output_space.basis),
|
71 |
magnitudes=x.magnitudes @ self.matrix,
|
@@ -84,8 +84,8 @@ class Linear(VectorFunction):
|
|
84 |
for i, direction in enumerate(input_space.basis):
|
85 |
out_vector = action(direction)
|
86 |
if out_vector not in output_space:
|
87 |
-
raise TypeError(f"image of {direction} from {input_space
|
88 |
-
f"is not in {output_space
|
89 |
matrix[i, :] = out_vector.magnitudes
|
90 |
|
91 |
return Linear(input_space, output_space, matrix)
|
@@ -140,9 +140,9 @@ class ScalarBilinear:
|
|
140 |
def __call__(self, x: VectorInBasis, y: VectorInBasis) -> float:
|
141 |
"""Describes the action of the operator on vectors."""
|
142 |
if x not in self.left_space:
|
143 |
-
raise TypeError(f"{x
|
144 |
if y not in self.right_space:
|
145 |
-
raise TypeError(f"{y
|
146 |
return (x.magnitudes.T @ self.matrix @ y.magnitudes).item()
|
147 |
|
148 |
@classmethod
|
|
|
65 |
|
66 |
def __call__(self, x: VectorInBasis) -> VectorInBasis:
|
67 |
if x not in self.input_space:
|
68 |
+
raise TypeError(f"x={x} not in self.input_space={self.input_space}.")
|
69 |
return VectorInBasis(
|
70 |
basis_directions=sorted(self.output_space.basis),
|
71 |
magnitudes=x.magnitudes @ self.matrix,
|
|
|
84 |
for i, direction in enumerate(input_space.basis):
|
85 |
out_vector = action(direction)
|
86 |
if out_vector not in output_space:
|
87 |
+
raise TypeError(f"image of {direction} from input_space={input_space} "
|
88 |
+
f"is not in output_space={output_space}")
|
89 |
matrix[i, :] = out_vector.magnitudes
|
90 |
|
91 |
return Linear(input_space, output_space, matrix)
|
|
|
140 |
def __call__(self, x: VectorInBasis, y: VectorInBasis) -> float:
|
141 |
"""Describes the action of the operator on vectors."""
|
142 |
if x not in self.left_space:
|
143 |
+
raise TypeError(f"x={x} not in self.left_space={self.left_space}.")
|
144 |
if y not in self.right_space:
|
145 |
+
raise TypeError(f"y={y} not in self.right_space={self.right_space}.")
|
146 |
return (x.magnitudes.T @ self.matrix @ y.magnitudes).item()
|
147 |
|
148 |
@classmethod
|
tracr/rasp/rasp.py
CHANGED
@@ -16,7 +16,7 @@
|
|
16 |
|
17 |
Every object in the RASP language is a function.
|
18 |
|
19 |
-
The most important type is S-Op, which is a function
|
20 |
|
21 |
An S-Op represents a state inside the residual stream of the transformer.
|
22 |
Therefore, any RASP program that represents a transformer computation must
|
@@ -26,11 +26,12 @@ end of the computation. In particular, given an S-Op `x`,
|
|
26 |
at location `x` when the transformer is fed [1, 2, 3] as input.
|
27 |
|
28 |
A secondary (but still important) type is Selector, which is a function
|
29 |
-
|
30 |
represents something like an attention matrix in the transformer.
|
31 |
|
32 |
For a full reference on RASP, see https://arxiv.org/abs/2106.06981.
|
33 |
"""
|
|
|
34 |
|
35 |
import abc
|
36 |
import collections.abc
|
@@ -38,13 +39,14 @@ import copy
|
|
38 |
import enum
|
39 |
import functools
|
40 |
import itertools
|
41 |
-
from typing import (Any, Callable, Generic, Mapping, Optional,
|
42 |
Sequence, TypeVar, Union)
|
|
|
43 |
from absl import logging
|
44 |
|
45 |
import numpy as np
|
46 |
|
47 |
-
SelectorValue =
|
48 |
NumericValue = Union[int, float]
|
49 |
Value = Union[None, int, float, str, bool]
|
50 |
VT = TypeVar("VT", bound=Value)
|
@@ -63,7 +65,7 @@ _ENCODING_KEY = "encoding"
|
|
63 |
# that key is accessed.
|
64 |
#
|
65 |
# See the `default_name` annotator for a full example.
|
66 |
-
DEFAULT_ANNOTATORS:
|
67 |
|
68 |
|
69 |
class Annotator(Protocol):
|
@@ -81,7 +83,7 @@ class _Annotations(collections.abc.Mapping):
|
|
81 |
|
82 |
def __init__(self, expr, **kwargs: Any):
|
83 |
self._expr = expr
|
84 |
-
self._inner_dict:
|
85 |
|
86 |
def __getitem__(self, key: str) -> Any:
|
87 |
if key not in self._inner_dict:
|
@@ -758,7 +760,7 @@ _default_name_by_class = {
|
|
758 |
}
|
759 |
|
760 |
|
761 |
-
def default_name(expr: RASPExpr) ->
|
762 |
for cls, name in _default_name_by_class.items():
|
763 |
if isinstance(expr, cls):
|
764 |
return name
|
@@ -905,7 +907,7 @@ class DefaultRASPEvaluator(abc.ABC):
|
|
905 |
|
906 |
|
907 |
def _get_selected(
|
908 |
-
selector_row:
|
909 |
values: Sequence[VT],
|
910 |
) -> Sequence[VT]:
|
911 |
"""Helper for aggregate. [T T F], [a b c] -> [a b]."""
|
|
|
16 |
|
17 |
Every object in the RASP language is a function.
|
18 |
|
19 |
+
The most important type is S-Op, which is a function List[Value] -> List[Value].
|
20 |
|
21 |
An S-Op represents a state inside the residual stream of the transformer.
|
22 |
Therefore, any RASP program that represents a transformer computation must
|
|
|
26 |
at location `x` when the transformer is fed [1, 2, 3] as input.
|
27 |
|
28 |
A secondary (but still important) type is Selector, which is a function
|
29 |
+
List[Value] -> List[List[bool]]. Given a Selector `sel`, sel([1, 2, 3])
|
30 |
represents something like an attention matrix in the transformer.
|
31 |
|
32 |
For a full reference on RASP, see https://arxiv.org/abs/2106.06981.
|
33 |
"""
|
34 |
+
import pdb
|
35 |
|
36 |
import abc
|
37 |
import collections.abc
|
|
|
39 |
import enum
|
40 |
import functools
|
41 |
import itertools
|
42 |
+
from typing import (Any, Callable, Generic, Mapping, Optional, List, Dict,
|
43 |
Sequence, TypeVar, Union)
|
44 |
+
from typing_extensions import Protocol
|
45 |
from absl import logging
|
46 |
|
47 |
import numpy as np
|
48 |
|
49 |
+
SelectorValue = List[List[bool]]
|
50 |
NumericValue = Union[int, float]
|
51 |
Value = Union[None, int, float, str, bool]
|
52 |
VT = TypeVar("VT", bound=Value)
|
|
|
65 |
# that key is accessed.
|
66 |
#
|
67 |
# See the `default_name` annotator for a full example.
|
68 |
+
DEFAULT_ANNOTATORS: Dict[str, "Annotator"] = {}
|
69 |
|
70 |
|
71 |
class Annotator(Protocol):
|
|
|
83 |
|
84 |
def __init__(self, expr, **kwargs: Any):
|
85 |
self._expr = expr
|
86 |
+
self._inner_dict: Dict[str, Any] = {**kwargs}
|
87 |
|
88 |
def __getitem__(self, key: str) -> Any:
|
89 |
if key not in self._inner_dict:
|
|
|
760 |
}
|
761 |
|
762 |
|
763 |
+
def default_name(expr: RASPExpr) -> Dict[str, str]:
|
764 |
for cls, name in _default_name_by_class.items():
|
765 |
if isinstance(expr, cls):
|
766 |
return name
|
|
|
907 |
|
908 |
|
909 |
def _get_selected(
|
910 |
+
selector_row: List[bool],
|
911 |
values: Sequence[VT],
|
912 |
) -> Sequence[VT]:
|
913 |
"""Helper for aggregate. [T T F], [a b c] -> [a b]."""
|
tracr/transformer/encoder.py
CHANGED
@@ -15,7 +15,7 @@
|
|
15 |
"""Basic encoder for inputs with a fixed vocabulary."""
|
16 |
|
17 |
import abc
|
18 |
-
from typing import Any, Sequence, Optional
|
19 |
|
20 |
from tracr.craft import bases
|
21 |
|
@@ -28,11 +28,11 @@ class Encoder(abc.ABC):
|
|
28 |
"""
|
29 |
|
30 |
@abc.abstractmethod
|
31 |
-
def encode(self, inputs:
|
32 |
return list()
|
33 |
|
34 |
@abc.abstractmethod
|
35 |
-
def decode(self, encodings:
|
36 |
return list()
|
37 |
|
38 |
@property
|
@@ -55,10 +55,10 @@ class Encoder(abc.ABC):
|
|
55 |
class NumericalEncoder(Encoder):
|
56 |
"""Encodes numerical variables (simply using the identity mapping)."""
|
57 |
|
58 |
-
def encode(self, inputs:
|
59 |
return inputs
|
60 |
|
61 |
-
def decode(self, encodings:
|
62 |
return encodings
|
63 |
|
64 |
|
@@ -93,7 +93,7 @@ class CategoricalEncoder(Encoder):
|
|
93 |
self._pad_token = pad_token
|
94 |
self._max_seq_len = max_seq_len
|
95 |
|
96 |
-
def encode(self, inputs:
|
97 |
if self.enforce_bos and inputs[0] != self.bos_token:
|
98 |
raise ValueError("First input token must be BOS token. "
|
99 |
f"Should be '{self.bos_token}', but was '{inputs[0]}'.")
|
@@ -101,12 +101,12 @@ class CategoricalEncoder(Encoder):
|
|
101 |
raise ValueError(f"Inputs {missing} not found in encoding ",
|
102 |
self.encoding_map.keys())
|
103 |
if self._max_seq_len is not None and len(inputs) > self._max_seq_len:
|
104 |
-
raise ValueError(f"{inputs
|
105 |
f"sequence length {self._max_seq_len}")
|
106 |
|
107 |
return [self.encoding_map[x] for x in inputs]
|
108 |
|
109 |
-
def decode(self, encodings:
|
110 |
"""Recover the tokens that corresponds to `ids`. Inverse of __call__."""
|
111 |
decoding_map = {val: key for key, val in self.encoding_map.items()}
|
112 |
if missing := set(encodings) - set(decoding_map.keys()):
|
|
|
15 |
"""Basic encoder for inputs with a fixed vocabulary."""
|
16 |
|
17 |
import abc
|
18 |
+
from typing import Any, Sequence, Optional, List
|
19 |
|
20 |
from tracr.craft import bases
|
21 |
|
|
|
28 |
"""
|
29 |
|
30 |
@abc.abstractmethod
|
31 |
+
def encode(self, inputs: List[Any]) -> List[Any]:
|
32 |
return list()
|
33 |
|
34 |
@abc.abstractmethod
|
35 |
+
def decode(self, encodings: List[Any]) -> List[Any]:
|
36 |
return list()
|
37 |
|
38 |
@property
|
|
|
55 |
class NumericalEncoder(Encoder):
|
56 |
"""Encodes numerical variables (simply using the identity mapping)."""
|
57 |
|
58 |
+
def encode(self, inputs: List[float]) -> List[float]:
|
59 |
return inputs
|
60 |
|
61 |
+
def decode(self, encodings: List[float]) -> List[float]:
|
62 |
return encodings
|
63 |
|
64 |
|
|
|
93 |
self._pad_token = pad_token
|
94 |
self._max_seq_len = max_seq_len
|
95 |
|
96 |
+
def encode(self, inputs: List[bases.Value]) -> List[int]:
|
97 |
if self.enforce_bos and inputs[0] != self.bos_token:
|
98 |
raise ValueError("First input token must be BOS token. "
|
99 |
f"Should be '{self.bos_token}', but was '{inputs[0]}'.")
|
|
|
101 |
raise ValueError(f"Inputs {missing} not found in encoding ",
|
102 |
self.encoding_map.keys())
|
103 |
if self._max_seq_len is not None and len(inputs) > self._max_seq_len:
|
104 |
+
raise ValueError(f"inputs={inputs} are longer than the maximum "
|
105 |
f"sequence length {self._max_seq_len}")
|
106 |
|
107 |
return [self.encoding_map[x] for x in inputs]
|
108 |
|
109 |
+
def decode(self, encodings: List[int]) -> List[bases.Value]:
|
110 |
"""Recover the tokens that corresponds to `ids`. Inverse of __call__."""
|
111 |
decoding_map = {val: key for key, val in self.encoding_map.items()}
|
112 |
if missing := set(encodings) - set(decoding_map.keys()):
|
tracr/transformer/model.py
CHANGED
@@ -26,7 +26,7 @@ Forked from: haiku.examples.transformer.model
|
|
26 |
|
27 |
import collections
|
28 |
import dataclasses
|
29 |
-
from typing import Callable, Optional
|
30 |
|
31 |
import chex
|
32 |
import haiku as hk
|
@@ -44,9 +44,9 @@ CallableHaikuModule = Callable[..., jax.Array]
|
|
44 |
|
45 |
@chex.dataclass
|
46 |
class TransformerOutput:
|
47 |
-
layer_outputs:
|
48 |
-
residuals:
|
49 |
-
attn_logits:
|
50 |
output: jax.Array # [B, T, D]
|
51 |
input_embeddings: jax.Array # [B, T, D]
|
52 |
|
|
|
26 |
|
27 |
import collections
|
28 |
import dataclasses
|
29 |
+
from typing import Callable, Optional, List
|
30 |
|
31 |
import chex
|
32 |
import haiku as hk
|
|
|
44 |
|
45 |
@chex.dataclass
|
46 |
class TransformerOutput:
|
47 |
+
layer_outputs: List[jax.Array] # [B, T, D]
|
48 |
+
residuals: List[jax.Array] # [B, T, D]
|
49 |
+
attn_logits: List[jax.Array] # [B, H, T, T]
|
50 |
output: jax.Array # [B, T, D]
|
51 |
input_embeddings: jax.Array # [B, T, D]
|
52 |
|