NeelNanda commited on
Commit
c46567d
·
1 Parent(s): 4d24b96

Made compatible with Python 3.8

Browse files
tracr/compiler/assemble.py CHANGED
@@ -15,7 +15,8 @@
15
  """Assemble weights of a transformer model from a craft residual stack."""
16
 
17
  import dataclasses
18
- from typing import Any, Callable, Optional, Protocol
 
19
 
20
  import chex
21
  import einops
@@ -32,11 +33,11 @@ from tracr.transformer import model
32
 
33
  @chex.dataclass
34
  class AssembledTransformerModelOutput:
35
- decoded: list[Any] # length T.
36
  unembedded: jax.Array # [B, T] B = 1 always.
37
- layer_outputs: list[jax.Array] # [B, T, D]
38
- residuals: list[jax.Array] # [B, T, D]
39
- attn_logits: list[jax.Array] # [B, T, T, H]
40
  transformer_output: jax.Array # [B, T, D]
41
  input_embeddings: jax.Array
42
 
@@ -58,11 +59,11 @@ class AssembledTransformerModel:
58
  get_compiled_model: Callable[[], model.CompiledTransformerModel]
59
  params: hk.Params
60
  model_config: model.TransformerConfig
61
- residual_labels: list[str]
62
  input_encoder: Optional[encoder.Encoder] = None
63
  output_encoder: Optional[encoder.Encoder] = None
64
 
65
- def apply(self, tokens: list[bases.Value]) -> AssembledTransformerModelOutput:
66
  """Returns output from running the model on a set of input tokens."""
67
  if self.input_encoder:
68
  tokens = self.input_encoder.encode(tokens)
@@ -97,12 +98,12 @@ class EmbeddingModules:
97
 
98
  def _get_model_config_and_module_names(
99
  craft_model: transformers.SeriesWithResiduals
100
- ) -> tuple[model.TransformerConfig, list[str]]:
101
  """Returns model config and locations (in params) for halflayers."""
102
 
103
- multi_attn_heads: list[list[transformers.AttentionHead]] = []
104
- mlps: list[transformers.MLP] = []
105
- module_names: list[str] = []
106
 
107
  candidate_module_names = []
108
  for layer in range(len(craft_model.blocks)):
 
15
  """Assemble weights of a transformer model from a craft residual stack."""
16
 
17
  import dataclasses
18
+ from typing import Any, Callable, Optional, List, Tuple
19
+ from typing_extensions import Protocol
20
 
21
  import chex
22
  import einops
 
33
 
34
  @chex.dataclass
35
  class AssembledTransformerModelOutput:
36
+ decoded: List[Any] # length T.
37
  unembedded: jax.Array # [B, T] B = 1 always.
38
+ layer_outputs: List[jax.Array] # [B, T, D]
39
+ residuals: List[jax.Array] # [B, T, D]
40
+ attn_logits: List[jax.Array] # [B, T, T, H]
41
  transformer_output: jax.Array # [B, T, D]
42
  input_embeddings: jax.Array
43
 
 
59
  get_compiled_model: Callable[[], model.CompiledTransformerModel]
60
  params: hk.Params
61
  model_config: model.TransformerConfig
62
+ residual_labels: List[str]
63
  input_encoder: Optional[encoder.Encoder] = None
64
  output_encoder: Optional[encoder.Encoder] = None
65
 
66
+ def apply(self, tokens: List[bases.Value]) -> AssembledTransformerModelOutput:
67
  """Returns output from running the model on a set of input tokens."""
68
  if self.input_encoder:
69
  tokens = self.input_encoder.encode(tokens)
 
98
 
99
  def _get_model_config_and_module_names(
100
  craft_model: transformers.SeriesWithResiduals
101
+ ) -> Tuple[model.TransformerConfig, List[str]]:
102
  """Returns model config and locations (in params) for halflayers."""
103
 
104
+ multi_attn_heads: List[List[transformers.AttentionHead]] = []
105
+ mlps: List[transformers.MLP] = []
106
+ module_names: List[str] = []
107
 
108
  candidate_module_names = []
109
  for layer in range(len(craft_model.blocks)):
tracr/compiler/basis_inference.py CHANGED
@@ -16,6 +16,7 @@
16
 
17
  import dataclasses
18
  import itertools
 
19
 
20
  import networkx as nx
21
  from tracr.compiler import nodes
@@ -34,12 +35,12 @@ class InferBasesOutput:
34
  def infer_bases(
35
  graph: nx.DiGraph,
36
  sink: Node,
37
- vocab: set[rasp.Value],
38
  max_seq_len: int,
39
  ) -> None:
40
  """Infers in-place the possible output values and vector bases of the SOps."""
41
 
42
- def compute_value_set(sop: rasp.SOp) -> set[rasp.Value]:
43
  """Computes value set using already-computed predecessor value sets."""
44
  if sop is rasp.tokens:
45
  return vocab
 
16
 
17
  import dataclasses
18
  import itertools
19
+ from typing import Set
20
 
21
  import networkx as nx
22
  from tracr.compiler import nodes
 
35
  def infer_bases(
36
  graph: nx.DiGraph,
37
  sink: Node,
38
+ vocab: Set[rasp.Value],
39
  max_seq_len: int,
40
  ) -> None:
41
  """Infers in-place the possible output values and vector bases of the SOps."""
42
 
43
+ def compute_value_set(sop: rasp.SOp) -> Set[rasp.Value]:
44
  """Computes value set using already-computed predecessor value sets."""
45
  if sop is rasp.tokens:
46
  return vocab
tracr/compiler/compiling.py CHANGED
@@ -13,6 +13,7 @@
13
  # limitations under the License.
14
  # ==============================================================================
15
  """Combines all steps of compiling a RASP program."""
 
16
 
17
  from tracr.compiler import assemble
18
  from tracr.compiler import basis_inference
@@ -29,7 +30,7 @@ COMPILER_PAD = "compiler_pad"
29
 
30
  def compile_rasp_to_model(
31
  program: rasp.SOp,
32
- vocab: set[rasp.Value],
33
  max_seq_len: int,
34
  causal: bool = False,
35
  compiler_bos: str = COMPILER_BOS,
 
13
  # limitations under the License.
14
  # ==============================================================================
15
  """Combines all steps of compiling a RASP program."""
16
+ from typing import Set
17
 
18
  from tracr.compiler import assemble
19
  from tracr.compiler import basis_inference
 
30
 
31
  def compile_rasp_to_model(
32
  program: rasp.SOp,
33
+ vocab: Set[rasp.Value],
34
  max_seq_len: int,
35
  causal: bool = False,
36
  compiler_bos: str = COMPILER_BOS,
tracr/compiler/craft_graph_to_model.py CHANGED
@@ -15,7 +15,7 @@
15
  """Create a craft model from a computational graph."""
16
 
17
  import collections
18
- from typing import Sequence
19
 
20
  import networkx as nx
21
  from tracr.compiler import nodes
@@ -105,7 +105,7 @@ def _all_mlp_nodes(node_list: Sequence[Node]) -> bool:
105
 
106
 
107
  def _allocate_modules_to_layers(graph: nx.DiGraph,
108
- sources: Sequence[Node]) -> dict[int, int]:
109
  """Allocate all nodes in compute graph to layers.
110
 
111
  First, computes the longest path from the input to each node that is a model
@@ -128,9 +128,9 @@ def _allocate_modules_to_layers(graph: nx.DiGraph,
128
  A dict mapping from node ids to layer indices, where 0, 1, 2, 3, ...
129
  are in the order attention, mlp, attention, mlp, ...
130
  """
131
- layer_allocation: dict[int, int] = collections.defaultdict(lambda: -1)
132
- depth_by_node_id: dict[int, int] = dict()
133
- nodes_by_depth: dict[int, list[Node]] = collections.defaultdict(list)
134
 
135
  # Compute depth of all model components (longest path from source to node)
136
  for node_id, node in graph.nodes.items():
 
15
  """Create a craft model from a computational graph."""
16
 
17
  import collections
18
+ from typing import Sequence, List, Dict
19
 
20
  import networkx as nx
21
  from tracr.compiler import nodes
 
105
 
106
 
107
  def _allocate_modules_to_layers(graph: nx.DiGraph,
108
+ sources: Sequence[Node]) -> Dict[int, int]:
109
  """Allocate all nodes in compute graph to layers.
110
 
111
  First, computes the longest path from the input to each node that is a model
 
128
  A dict mapping from node ids to layer indices, where 0, 1, 2, 3, ...
129
  are in the order attention, mlp, attention, mlp, ...
130
  """
131
+ layer_allocation: Dict[int, int] = collections.defaultdict(lambda: -1)
132
+ depth_by_node_id: Dict[int, int] = dict()
133
+ nodes_by_depth: Dict[int, List[Node]] = collections.defaultdict(list)
134
 
135
  # Compute depth of all model components (longest path from source to node)
136
  for node_id, node in graph.nodes.items():
tracr/compiler/lib.py CHANGED
@@ -14,7 +14,7 @@
14
  # ==============================================================================
15
  """RASP programs only using the subset of RASP supported by the compiler."""
16
 
17
- from typing import Sequence
18
 
19
  from tracr.rasp import rasp
20
 
@@ -95,7 +95,7 @@ def make_pair_balance(sop: rasp.SOp, open_token: str,
95
  return pair_balance.named("pair_balance")
96
 
97
 
98
- def make_shuffle_dyck(pairs: list[str]) -> rasp.SOp:
99
  """Returns 1 if a set of parentheses are balanced, 0 else.
100
 
101
  (As implemented in the RASP paper.)
 
14
  # ==============================================================================
15
  """RASP programs only using the subset of RASP supported by the compiler."""
16
 
17
+ from typing import Sequence, List
18
 
19
  from tracr.rasp import rasp
20
 
 
95
  return pair_balance.named("pair_balance")
96
 
97
 
98
+ def make_shuffle_dyck(pairs: List[str]) -> rasp.SOp:
99
  """Returns 1 if a set of parentheses are balanced, 0 else.
100
 
101
  (As implemented in the RASP paper.)
tracr/compiler/nodes.py CHANGED
@@ -14,9 +14,9 @@
14
  # ==============================================================================
15
  """Documents the data stored in nodes after each compiler pass."""
16
 
17
- from typing import Any
18
 
19
- Node = dict[str, Any]
20
  NodeID = str
21
 
22
  # RASP -> Graph
 
14
  # ==============================================================================
15
  """Documents the data stored in nodes after each compiler pass."""
16
 
17
+ from typing import Any, Dict
18
 
19
+ Node = Dict[str, Any]
20
  NodeID = str
21
 
22
  # RASP -> Graph
tracr/compiler/rasp_to_graph.py CHANGED
@@ -16,6 +16,7 @@
16
 
17
  import dataclasses
18
  import queue
 
19
 
20
  import networkx as nx
21
  from tracr.compiler import nodes
@@ -29,14 +30,14 @@ NodeID = nodes.NodeID
29
  class ExtractRaspGraphOutput:
30
  graph: nx.DiGraph
31
  sink: Node # the program's output.
32
- sources: list[Node] # the primitive S-Ops.
33
 
34
 
35
  def extract_rasp_graph(tip: rasp.SOp) -> ExtractRaspGraphOutput:
36
  """Converts a RASP program into a graph representation."""
37
  expr_queue = queue.Queue()
38
  graph = nx.DiGraph()
39
- sources: list[NodeID] = []
40
 
41
  def ensure_node(expr: rasp.RASPExpr) -> NodeID:
42
  """Finds or creates a graph node corresponding to expr; returns its ID."""
 
16
 
17
  import dataclasses
18
  import queue
19
+ from typing import List
20
 
21
  import networkx as nx
22
  from tracr.compiler import nodes
 
30
  class ExtractRaspGraphOutput:
31
  graph: nx.DiGraph
32
  sink: Node # the program's output.
33
+ sources: List[Node] # the primitive S-Ops.
34
 
35
 
36
  def extract_rasp_graph(tip: rasp.SOp) -> ExtractRaspGraphOutput:
37
  """Converts a RASP program into a graph representation."""
38
  expr_queue = queue.Queue()
39
  graph = nx.DiGraph()
40
+ sources: List[NodeID] = []
41
 
42
  def ensure_node(expr: rasp.RASPExpr) -> NodeID:
43
  """Finds or creates a graph node corresponding to expr; returns its ID."""
tracr/compiler/rasp_to_transformer_integration_test.py CHANGED
@@ -38,7 +38,7 @@ class CompilerIntegrationTest(tests_common.VectorFnTestCase):
38
  for actual, expected in zip(actual_seq, expected_seq):
39
  if expected is not None and actual != expected:
40
  self.fail(f"{actual_seq} does not match (ignoring Nones) "
41
- f"{expected_seq=}")
42
 
43
  @parameterized.named_parameters(
44
  dict(
 
38
  for actual, expected in zip(actual_seq, expected_seq):
39
  if expected is not None and actual != expected:
40
  self.fail(f"{actual_seq} does not match (ignoring Nones) "
41
+ f"expected_seq={expected_seq}")
42
 
43
  @parameterized.named_parameters(
44
  dict(
tracr/craft/bases.py CHANGED
@@ -243,5 +243,5 @@ def ensure_dims(
243
  ) -> None:
244
  """Raises ValueError if vs has the wrong number of dimensions."""
245
  if vs.num_dims != num_dims:
246
- raise ValueError(f"{name} must have {num_dims=}, "
247
  f"but got {vs.num_dims}: {vs.basis}")
 
243
  ) -> None:
244
  """Raises ValueError if vs has the wrong number of dimensions."""
245
  if vs.num_dims != num_dims:
246
+ raise ValueError(f"{name} must have num_dims={num_dims}, "
247
  f"but got {vs.num_dims}: {vs.basis}")
tracr/craft/chamber/categorical_attn.py CHANGED
@@ -14,7 +14,8 @@
14
  # ==============================================================================
15
  """Attention head for categorical inputs."""
16
 
17
- from typing import Optional, Protocol
 
18
 
19
  from tracr.craft import bases
20
  from tracr.craft import transformers
 
14
  # ==============================================================================
15
  """Attention head for categorical inputs."""
16
 
17
+ from typing import Optional
18
+ from typing_extensions import Protocol
19
 
20
  from tracr.craft import bases
21
  from tracr.craft import transformers
tracr/craft/chamber/numerical_mlp.py CHANGED
@@ -16,7 +16,7 @@
16
 
17
  import dataclasses
18
 
19
- from typing import Callable, Iterable
20
 
21
  from tracr.craft import bases
22
  from tracr.craft import transformers
@@ -35,7 +35,7 @@ class DiscretisingLayerMaterials:
35
  """
36
  action: Callable[[bases.BasisDirection], bases.VectorInBasis]
37
  hidden_space: bases.VectorSpaceWithBasis
38
- output_values: list[float]
39
 
40
 
41
  def _get_discretising_layer(input_value_set: Iterable[float],
 
16
 
17
  import dataclasses
18
 
19
+ from typing import Callable, Iterable, List
20
 
21
  from tracr.craft import bases
22
  from tracr.craft import transformers
 
35
  """
36
  action: Callable[[bases.BasisDirection], bases.VectorInBasis]
37
  hidden_space: bases.VectorSpaceWithBasis
38
+ output_values: List[float]
39
 
40
 
41
  def _get_discretising_layer(input_value_set: Iterable[float],
tracr/craft/transformers.py CHANGED
@@ -16,7 +16,7 @@
16
 
17
  import abc
18
  import dataclasses
19
- from typing import Iterable, Optional, Sequence, Union
20
 
21
  import numpy as np
22
 
@@ -111,7 +111,7 @@ class AttentionHead(Block):
111
  @dataclasses.dataclass
112
  class MultiAttentionHead(Block):
113
  """Applies attention heads in parallel."""
114
- sub_blocks: list[Union[AttentionHead, "MultiAttentionHead"]]
115
 
116
  def __post_init__(self):
117
  spaces = [block.residual_space for block in self.sub_blocks]
@@ -182,7 +182,7 @@ HalfLayerBlock = Union[MLP, AttentionHead, MultiAttentionHead]
182
  @dataclasses.dataclass
183
  class SeriesWithResiduals(Block):
184
  """A series of blocks with residual connections."""
185
- blocks: list[HalfLayerBlock]
186
 
187
  def __post_init__(self):
188
  spaces = [block.residual_space for block in self.blocks]
 
16
 
17
  import abc
18
  import dataclasses
19
+ from typing import Iterable, Optional, Sequence, Union, List
20
 
21
  import numpy as np
22
 
 
111
  @dataclasses.dataclass
112
  class MultiAttentionHead(Block):
113
  """Applies attention heads in parallel."""
114
+ sub_blocks: List[Union[AttentionHead, "MultiAttentionHead"]]
115
 
116
  def __post_init__(self):
117
  spaces = [block.residual_space for block in self.sub_blocks]
 
182
  @dataclasses.dataclass
183
  class SeriesWithResiduals(Block):
184
  """A series of blocks with residual connections."""
185
+ blocks: List[HalfLayerBlock]
186
 
187
  def __post_init__(self):
188
  spaces = [block.residual_space for block in self.blocks]
tracr/craft/vectorspace_fns.py CHANGED
@@ -65,7 +65,7 @@ class Linear(VectorFunction):
65
 
66
  def __call__(self, x: VectorInBasis) -> VectorInBasis:
67
  if x not in self.input_space:
68
- raise TypeError(f"{x=} not in {self.input_space=}.")
69
  return VectorInBasis(
70
  basis_directions=sorted(self.output_space.basis),
71
  magnitudes=x.magnitudes @ self.matrix,
@@ -84,8 +84,8 @@ class Linear(VectorFunction):
84
  for i, direction in enumerate(input_space.basis):
85
  out_vector = action(direction)
86
  if out_vector not in output_space:
87
- raise TypeError(f"image of {direction} from {input_space=} "
88
- f"is not in {output_space=}")
89
  matrix[i, :] = out_vector.magnitudes
90
 
91
  return Linear(input_space, output_space, matrix)
@@ -140,9 +140,9 @@ class ScalarBilinear:
140
  def __call__(self, x: VectorInBasis, y: VectorInBasis) -> float:
141
  """Describes the action of the operator on vectors."""
142
  if x not in self.left_space:
143
- raise TypeError(f"{x=} not in {self.left_space=}.")
144
  if y not in self.right_space:
145
- raise TypeError(f"{y=} not in {self.right_space=}.")
146
  return (x.magnitudes.T @ self.matrix @ y.magnitudes).item()
147
 
148
  @classmethod
 
65
 
66
  def __call__(self, x: VectorInBasis) -> VectorInBasis:
67
  if x not in self.input_space:
68
+ raise TypeError(f"x={x} not in self.input_space={self.input_space}.")
69
  return VectorInBasis(
70
  basis_directions=sorted(self.output_space.basis),
71
  magnitudes=x.magnitudes @ self.matrix,
 
84
  for i, direction in enumerate(input_space.basis):
85
  out_vector = action(direction)
86
  if out_vector not in output_space:
87
+ raise TypeError(f"image of {direction} from input_space={input_space} "
88
+ f"is not in output_space={output_space}")
89
  matrix[i, :] = out_vector.magnitudes
90
 
91
  return Linear(input_space, output_space, matrix)
 
140
  def __call__(self, x: VectorInBasis, y: VectorInBasis) -> float:
141
  """Describes the action of the operator on vectors."""
142
  if x not in self.left_space:
143
+ raise TypeError(f"x={x} not in self.left_space={self.left_space}.")
144
  if y not in self.right_space:
145
+ raise TypeError(f"y={y} not in self.right_space={self.right_space}.")
146
  return (x.magnitudes.T @ self.matrix @ y.magnitudes).item()
147
 
148
  @classmethod
tracr/rasp/rasp.py CHANGED
@@ -16,7 +16,7 @@
16
 
17
  Every object in the RASP language is a function.
18
 
19
- The most important type is S-Op, which is a function list[Value] -> list[Value].
20
 
21
  An S-Op represents a state inside the residual stream of the transformer.
22
  Therefore, any RASP program that represents a transformer computation must
@@ -26,11 +26,12 @@ end of the computation. In particular, given an S-Op `x`,
26
  at location `x` when the transformer is fed [1, 2, 3] as input.
27
 
28
  A secondary (but still important) type is Selector, which is a function
29
- list[Value] -> list[list[bool]]. Given a Selector `sel`, sel([1, 2, 3])
30
  represents something like an attention matrix in the transformer.
31
 
32
  For a full reference on RASP, see https://arxiv.org/abs/2106.06981.
33
  """
 
34
 
35
  import abc
36
  import collections.abc
@@ -38,13 +39,14 @@ import copy
38
  import enum
39
  import functools
40
  import itertools
41
- from typing import (Any, Callable, Generic, Mapping, Optional, Protocol,
42
  Sequence, TypeVar, Union)
 
43
  from absl import logging
44
 
45
  import numpy as np
46
 
47
- SelectorValue = list[list[bool]]
48
  NumericValue = Union[int, float]
49
  Value = Union[None, int, float, str, bool]
50
  VT = TypeVar("VT", bound=Value)
@@ -63,7 +65,7 @@ _ENCODING_KEY = "encoding"
63
  # that key is accessed.
64
  #
65
  # See the `default_name` annotator for a full example.
66
- DEFAULT_ANNOTATORS: dict[str, "Annotator"] = {}
67
 
68
 
69
  class Annotator(Protocol):
@@ -81,7 +83,7 @@ class _Annotations(collections.abc.Mapping):
81
 
82
  def __init__(self, expr, **kwargs: Any):
83
  self._expr = expr
84
- self._inner_dict: dict[str, Any] = {**kwargs}
85
 
86
  def __getitem__(self, key: str) -> Any:
87
  if key not in self._inner_dict:
@@ -758,7 +760,7 @@ _default_name_by_class = {
758
  }
759
 
760
 
761
- def default_name(expr: RASPExpr) -> dict[str, str]:
762
  for cls, name in _default_name_by_class.items():
763
  if isinstance(expr, cls):
764
  return name
@@ -905,7 +907,7 @@ class DefaultRASPEvaluator(abc.ABC):
905
 
906
 
907
  def _get_selected(
908
- selector_row: list[bool],
909
  values: Sequence[VT],
910
  ) -> Sequence[VT]:
911
  """Helper for aggregate. [T T F], [a b c] -> [a b]."""
 
16
 
17
  Every object in the RASP language is a function.
18
 
19
+ The most important type is S-Op, which is a function List[Value] -> List[Value].
20
 
21
  An S-Op represents a state inside the residual stream of the transformer.
22
  Therefore, any RASP program that represents a transformer computation must
 
26
  at location `x` when the transformer is fed [1, 2, 3] as input.
27
 
28
  A secondary (but still important) type is Selector, which is a function
29
+ List[Value] -> List[List[bool]]. Given a Selector `sel`, sel([1, 2, 3])
30
  represents something like an attention matrix in the transformer.
31
 
32
  For a full reference on RASP, see https://arxiv.org/abs/2106.06981.
33
  """
34
+ import pdb
35
 
36
  import abc
37
  import collections.abc
 
39
  import enum
40
  import functools
41
  import itertools
42
+ from typing import (Any, Callable, Generic, Mapping, Optional, List, Dict,
43
  Sequence, TypeVar, Union)
44
+ from typing_extensions import Protocol
45
  from absl import logging
46
 
47
  import numpy as np
48
 
49
+ SelectorValue = List[List[bool]]
50
  NumericValue = Union[int, float]
51
  Value = Union[None, int, float, str, bool]
52
  VT = TypeVar("VT", bound=Value)
 
65
  # that key is accessed.
66
  #
67
  # See the `default_name` annotator for a full example.
68
+ DEFAULT_ANNOTATORS: Dict[str, "Annotator"] = {}
69
 
70
 
71
  class Annotator(Protocol):
 
83
 
84
  def __init__(self, expr, **kwargs: Any):
85
  self._expr = expr
86
+ self._inner_dict: Dict[str, Any] = {**kwargs}
87
 
88
  def __getitem__(self, key: str) -> Any:
89
  if key not in self._inner_dict:
 
760
  }
761
 
762
 
763
+ def default_name(expr: RASPExpr) -> Dict[str, str]:
764
  for cls, name in _default_name_by_class.items():
765
  if isinstance(expr, cls):
766
  return name
 
907
 
908
 
909
  def _get_selected(
910
+ selector_row: List[bool],
911
  values: Sequence[VT],
912
  ) -> Sequence[VT]:
913
  """Helper for aggregate. [T T F], [a b c] -> [a b]."""
tracr/transformer/encoder.py CHANGED
@@ -15,7 +15,7 @@
15
  """Basic encoder for inputs with a fixed vocabulary."""
16
 
17
  import abc
18
- from typing import Any, Sequence, Optional
19
 
20
  from tracr.craft import bases
21
 
@@ -28,11 +28,11 @@ class Encoder(abc.ABC):
28
  """
29
 
30
  @abc.abstractmethod
31
- def encode(self, inputs: list[Any]) -> list[Any]:
32
  return list()
33
 
34
  @abc.abstractmethod
35
- def decode(self, encodings: list[Any]) -> list[Any]:
36
  return list()
37
 
38
  @property
@@ -55,10 +55,10 @@ class Encoder(abc.ABC):
55
  class NumericalEncoder(Encoder):
56
  """Encodes numerical variables (simply using the identity mapping)."""
57
 
58
- def encode(self, inputs: list[float]) -> list[float]:
59
  return inputs
60
 
61
- def decode(self, encodings: list[float]) -> list[float]:
62
  return encodings
63
 
64
 
@@ -93,7 +93,7 @@ class CategoricalEncoder(Encoder):
93
  self._pad_token = pad_token
94
  self._max_seq_len = max_seq_len
95
 
96
- def encode(self, inputs: list[bases.Value]) -> list[int]:
97
  if self.enforce_bos and inputs[0] != self.bos_token:
98
  raise ValueError("First input token must be BOS token. "
99
  f"Should be '{self.bos_token}', but was '{inputs[0]}'.")
@@ -101,12 +101,12 @@ class CategoricalEncoder(Encoder):
101
  raise ValueError(f"Inputs {missing} not found in encoding ",
102
  self.encoding_map.keys())
103
  if self._max_seq_len is not None and len(inputs) > self._max_seq_len:
104
- raise ValueError(f"{inputs=} are longer than the maximum "
105
  f"sequence length {self._max_seq_len}")
106
 
107
  return [self.encoding_map[x] for x in inputs]
108
 
109
- def decode(self, encodings: list[int]) -> list[bases.Value]:
110
  """Recover the tokens that corresponds to `ids`. Inverse of __call__."""
111
  decoding_map = {val: key for key, val in self.encoding_map.items()}
112
  if missing := set(encodings) - set(decoding_map.keys()):
 
15
  """Basic encoder for inputs with a fixed vocabulary."""
16
 
17
  import abc
18
+ from typing import Any, Sequence, Optional, List
19
 
20
  from tracr.craft import bases
21
 
 
28
  """
29
 
30
  @abc.abstractmethod
31
+ def encode(self, inputs: List[Any]) -> List[Any]:
32
  return list()
33
 
34
  @abc.abstractmethod
35
+ def decode(self, encodings: List[Any]) -> List[Any]:
36
  return list()
37
 
38
  @property
 
55
  class NumericalEncoder(Encoder):
56
  """Encodes numerical variables (simply using the identity mapping)."""
57
 
58
+ def encode(self, inputs: List[float]) -> List[float]:
59
  return inputs
60
 
61
+ def decode(self, encodings: List[float]) -> List[float]:
62
  return encodings
63
 
64
 
 
93
  self._pad_token = pad_token
94
  self._max_seq_len = max_seq_len
95
 
96
+ def encode(self, inputs: List[bases.Value]) -> List[int]:
97
  if self.enforce_bos and inputs[0] != self.bos_token:
98
  raise ValueError("First input token must be BOS token. "
99
  f"Should be '{self.bos_token}', but was '{inputs[0]}'.")
 
101
  raise ValueError(f"Inputs {missing} not found in encoding ",
102
  self.encoding_map.keys())
103
  if self._max_seq_len is not None and len(inputs) > self._max_seq_len:
104
+ raise ValueError(f"inputs={inputs} are longer than the maximum "
105
  f"sequence length {self._max_seq_len}")
106
 
107
  return [self.encoding_map[x] for x in inputs]
108
 
109
+ def decode(self, encodings: List[int]) -> List[bases.Value]:
110
  """Recover the tokens that corresponds to `ids`. Inverse of __call__."""
111
  decoding_map = {val: key for key, val in self.encoding_map.items()}
112
  if missing := set(encodings) - set(decoding_map.keys()):
tracr/transformer/model.py CHANGED
@@ -26,7 +26,7 @@ Forked from: haiku.examples.transformer.model
26
 
27
  import collections
28
  import dataclasses
29
- from typing import Callable, Optional
30
 
31
  import chex
32
  import haiku as hk
@@ -44,9 +44,9 @@ CallableHaikuModule = Callable[..., jax.Array]
44
 
45
  @chex.dataclass
46
  class TransformerOutput:
47
- layer_outputs: list[jax.Array] # [B, T, D]
48
- residuals: list[jax.Array] # [B, T, D]
49
- attn_logits: list[jax.Array] # [B, H, T, T]
50
  output: jax.Array # [B, T, D]
51
  input_embeddings: jax.Array # [B, T, D]
52
 
 
26
 
27
  import collections
28
  import dataclasses
29
+ from typing import Callable, Optional, List
30
 
31
  import chex
32
  import haiku as hk
 
44
 
45
  @chex.dataclass
46
  class TransformerOutput:
47
+ layer_outputs: List[jax.Array] # [B, T, D]
48
+ residuals: List[jax.Array] # [B, T, D]
49
+ attn_logits: List[jax.Array] # [B, H, T, T]
50
  output: jax.Array # [B, T, D]
51
  input_embeddings: jax.Array # [B, T, D]
52