File size: 7,061 Bytes
9bdaa77
 
 
 
4d24b96
 
 
9bdaa77
 
 
 
 
 
 
 
 
 
 
 
 
 
c0a9b47
9bdaa77
 
 
c0a9b47
 
9bdaa77
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4d24b96
9bdaa77
 
 
 
 
 
 
 
 
 
c0a9b47
 
4d24b96
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
# Tracr: TRAnsformer Compiler for RASP.

Tracr is a compiler for converting RASP programs
([Weiss et al. 2021](https://arxiv.org/abs/2106.06981))
into transformer weights. Please see our
[tech report](https://arxiv.org/abs/2301.05062) for a detailed description of
the compiler.

Directory structure:

* `rasp` contains an implementation of RASP embedded in Python.
* `compiler` contains the compiler itself.
* `transformer` contains the implementation of the transformer.
* `craft` contains the intermediate representation used by the compiler:
  essentially a small linear algebra-based library with named dimensions.

This is not an officially supported Google product.


## Installation

Just clone and pip install:

```
git clone https://github.com/deepmind/tracr
cd tracr
pip3 install .
```


## Usage example: RASP `reverse` program

Consider the RASP `reverse` program:

```
opp_index = length - indices - 1;
flip = select(indices, opp_index, ==);
reverse = aggregate(flip, tokens);
```

To compile this with Tracr, we would first implement the program using Tracr's
RASP library:

```python
from tracr.rasp import rasp

length = make_length()  # `length` is not a primitive in our implementation.
opp_index = length - rasp.indices - 1
flip = rasp.Select(rasp.indices, opp_index, rasp.Comparison.EQ)
reverse = rasp.Aggregate(flip, rasp.tokens)
```

Where:

```python
def make_length():
  all_true_selector = rasp.Select(rasp.tokens, rasp.tokens, rasp.Comparison.TRUE)
  return rasp.SelectorWidth(all_true_selector)
```

We can then compile the RASP program to a transformer with:

```python
from tracr.compiler import compiling

bos = "BOS"
model = compiling.compile_rasp_to_model(
    reverse,
    vocab={1, 2, 3},
    max_seq_len=5,
    compiler_bos=bos,
)
```

This yields a transformer as a [Haiku](https://github.com/deepmind/dm-haiku) model.
This model isn't intended to provide _everything_ you might need, but rather serves
as a kind of "documentation-in-code" for the semantics of the generated parameters.
The expectation is that the user can then write or contribute an adapter that converts
parameters from this reference model to another transformer implementation.

Using this model we can perform a forward pass:

```python
>>> out = model.apply([bos, 1, 2, 3])
>>> out.decoded
["BOS", 3, 2, 1]
```

Success! We have a transformer that reverses its input tokens.

Note: compiled models always expect a BOS token in order to support
selectors which don't attend to any of the input tokens. This is necessary to
preserve intuitive RASP semantics; the alternative would have been to treat
all-False selector rows as equivalent to all-True (which is what softmax in an
attention layer would naturally do). For more details, see our paper.

You can also inspect some of the intermediate activations of the model, using
`out.residuals`, `out.layer_outputs`, and `out.attn_logits`.

For more examples of RASP programs we can compile, check out
[compiler/lib.py](compiler/lib.py).

For an interactive example of compiling a model and visualizing its computation,
check out the notebook at
[examples/Visualize\_Tracr\_Models.ipynb](examples/Visualize_Tracr_Models.ipynb).


## Developer README

If you'd like to extend Tracr to fit your purposes, here's some information on
how Tracr works under the hood.


### How Tracr works conceptually

To compile a program, Tracr does the following.

1. **Trace RASP program into a graph representation.** This involves creating
   a graph node for each RASP expression and inferring dependencies between
   these graph nodes.

2. **Infer bases.** Tracr is designed to have each node output to a separate
   subspace of the residual stream. To do this, we first infer the set of all
   possible token values that each node can take, then using that information,
   decide on a subspace for each node, and augment each node in the graph
   with the basis vectors for that node's subspace.

3. **Convert nodes to Craft components.** Craft is the name of our internal
   intermediate representation that does linear algebra on named subspaces. In
   this stage, each expression node is converted to a Craft component that
   actually performs the linear algebra operations necessary to implement the
   expression. This includes converting _sequence operators_ to MLP weights,
   and _selectors_ to weights of attention heads. (We compute the appropriate
   weights directly using the theory of universal approximation for MLPs - no
   gradient descent required!)

4. **Convert Craft graph to Craft model.** In this stage, we convert from
   a graph representation to a layout that looks more like an actual
   transformer. At this stage, we essentially have a working model, but
   with the linear algebra done using Craft rather than JAX + Haiku.

5. **Convert Craft model to Haiku model.** Finally, we convert our
   intermediate representation of the model to a full Haiku model.

Two details worth expanding on here are subspaces and corresponding bases.
Each node writes to a separate subspace of the residual stream,
where each subspace is simply a unique chunk of the residual stream vector.
For example, the first node might write to the first 5 components of
the residual stream; the second node the next 5; and so on.  In terms of what
the embeddings actually associated with each node, Tracr employs two
different kinds of bases:

* **Categorical representation** - in which each unique token value is
  represented as a unique one-hot vector in that node's subspace. This
  is the representation used by default.
* **Numerical representation** - in which each unique token value is
  mapped to a unique scalar value. This is necessary for some uses
  of the `aggregate` operation - essentially, ones which involve taking
  a mean - and some other operations are represented more efficiently
  with this representation.

A final detail is BOS tokens. The compiler relies on beginning-of-sequence
tokens to in order to implement a number of operations. This is why token
sequences fed into the final model _must_ start with a BOS token.


### How Tracr works in practice

The flow of compilation execution begins in
[`compiler/compiling.py`](compiler/compiling.py), in the
`compile_rasp_to_model` function. This function is fairly short and maps
directly to the stages outlined above, so don't be afraid to read the source!


### Running tests

We use [`absltest`](https://abseil.io/docs/python/guides/testing), which is
`unittest`-compatible, and is therefore in turn `pytest`-compatible.

First, install test dependencies:

```
pip3 install absl-py pytest
```

Then, in the checkout directory, simply run `pytest`. This should take about 60
seconds.

## Citing Tracr

Please use the bibtex for our tech report:

```
@article{lindner2023tracr,
  title = {Tracr: Compiled Transformers as a Laboratory for Interpretability},
  author = {Lindner, David and Kramár, János and Rahtz, Matthew and McGrath, Thomas and Mikulik, Vladimir},
  journal={arXiv preprint arXiv:2301.05062},
  year={2023}
}
```