RASP-Synthesis / README.md
DeepMind LMI Team
Internal change
9bdaa77
|
raw
history blame
7.25 kB
# Tracr: TRAnsformer Compiler for RASP.
Tracr is a compiler for converting RASP programs
([Weiss et al. 2021](https://arxiv.org/abs/2106.06981))
into transformer weights.
Directory structure:
* `rasp` contains an implementation of RASP embedded in Python.
* `compiler` contains the compiler itself.
* `transformer` contains the implementation of the transformer.
* `craft` contains the intermediate representation used by the compiler:
essentially a small linear algebra-based library with named dimensions.
This is not an officially supported Google product.
## Installation
Installation is currently a bit manual. First, install dependencies:
```
pip3 install chex einops dm-haiku networkx
```
Second, clone the repo:
```
git clone https://github.com/deepmind/tracr
```
Third, put the resulting folder somewhere in your `PYTHONPATH`
(eg by placing the `tracr` checkout in the root of your project folder).
This will be made easier in the future.
## Usage example: RASP `reverse` program
Consider the RASP `reverse` program:
```
opp_index = length - indices - 1;
flip = select(indices, opp_index, ==);
reverse = aggregate(flip, tokens);
```
To compile this with Tracr, we would first implement the program using Tracr's
RASP library:
```python
from tracr.rasp import rasp
length = make_length() # `length` is not a primitive in our implementation.
opp_index = length - rasp.indices - 1
flip = rasp.Select(rasp.indices, opp_index, rasp.Comparison.EQ)
reverse = rasp.Aggregate(flip, rasp.tokens)
```
Where:
```python
def make_length():
all_true_selector = rasp.Select(rasp.tokens, rasp.tokens, rasp.Comparison.TRUE)
return rasp.SelectorWidth(all_true_selector)
```
We can then compile the RASP program to a transformer with:
```python
from tracr.compiler import compiling
bos = "BOS"
model = compiling.compile_rasp_to_model(
reverse,
vocab={1, 2, 3},
max_seq_len=5,
compiler_bos=bos,
)
```
This yields a transformer as a [Haiku](https://github.com/deepmind/dm-haiku) model.
This model isn't intended to provide _everything_ you might need, but rather serves
as a kind of "documentation-in-code" for the semantics of the generated parameters.
The expectation is that the user can then write or contribute an adapter that converts
parameters from this reference model to another transformer implementation.
Using this model we can perform a forward pass:
```python
>>> out = model.apply([bos, 1, 2, 3])
>>> out.decoded
["BOS", 3, 2, 1]
```
Success! We have a transformer that reverses its input tokens.
Note: compiled models always expect a BOS token in order to support
selectors which don't attend to any of the input tokens. This is necessary to
preserve intuitive RASP semantics; the alternative would have been to treat
all-False selector rows as equivalent to all-True (which is what softmax in an
attention layer would naturally do). For more details, see our paper.
You can also inspect some of the intermediate activations of the model, using
`out.residuals`, `out.layer_outputs`, and `out.attn_logits`.
For more examples of RASP programs we can compile, check out
[compiler/lib.py](compiler/lib.py).
For an interactive example of compiling a model and visualizing its computation,
check out the notebook at
[examples/Visualize\_Tracr\_Models.ipynb](examples/Visualize_Tracr_Models.ipynb).
## Developer README
If you'd like to extend Tracr to fit your purposes, here's some information on
how Tracr works under the hood.
### How Tracr works conceptually
To compile a program, Tracr does the following.
1. **Trace RASP program into a graph representation.** This involves creating
a graph node for each RASP expression and inferring dependencies between
these graph nodes.
2. **Infer bases.** Tracr is designed to have each node output to a separate
subspace of the residual stream. To do this, we first infer the set of all
possible token values that each node can take, then using that information,
decide on a subspace for each node, and augment each node in the graph
with the basis vectors for that node's subspace.
3. **Convert nodes to Craft components.** Craft is the name of our internal
intermediate representation that does linear algebra on named subspaces. In
this stage, each expression node is converted to a Craft component that
actually performs the linear algebra operations necessary to implement the
expression. This includes converting _sequence operators_ to MLP weights,
and _selectors_ to weights of attention heads. (We compute the appropriate
weights directly using the theory of universal approximation for MLPs - no
gradient descent required!)
4. **Convert Craft graph to Craft model.** In this stage, we convert from
a graph representation to a layout that looks more like an actual
transformer. At this stage, we essentially have a working model, but
with the linear algebra done using Craft rather than JAX + Haiku.
5. **Convert Craft model to Haiku model.** Finally, we convert our
intermediate representation of the model to a full Haiku model.
Two details worth expanding on here are subspaces and corresponding bases.
Each node writes to a separate subspace of the residual stream,
where each subspace is simply a unique chunk of the residual stream vector.
For example, the first node might write to the first 5 components of
the residual stream; the second node the next 5; and so on. In terms of what
the embeddings actually associated with each node, Tracr employs two
different kinds of bases:
* **Categorical representation** - in which each unique token value is
represented as a unique one-hot vector in that node's subspace. This
is the representation used by default.
* **Numerical representation** - in which each unique token value is
mapped to a unique scalar value. This is necessary for some uses
of the `aggregate` operation - essentially, ones which involve taking
a mean - and some other operations are represented more efficiently
with this representation.
A final detail is BOS tokens. The compiler relies on beginning-of-sequence
tokens to in order to implement a number of operations. This is why token
sequences fed into the final model _must_ start with a BOS token.
### How Tracr works in practice
The flow of compilation execution begins in
[`compiler/compiling.py`](compiler/compiling.py), in the
`compile_rasp_to_model` function. This function is fairly short and maps
directly to the stages outlined above, so don't be afraid to read the source!
## Running tests
We use [`absltest`](https://abseil.io/docs/python/guides/testing), which is
`unittest`-compatible, and is therefore in turn `pytest`-compatible.
First, install test dependencies:
```
pip3 install absl-py pytest
```
```
# We use `python3 -m pytest` instead of just `pytest` so that the working directory is
# added to PYTHONPATH.
# -ra: Report names of tests that failed, were skipped, etc.
python3 -m pytest -ra
```
This should take about 60 seconds. If you install `pytest-xdist`, you can run them in
parallel with:
```
python3 -m pytest -ra -n auto
```
However, currently this only shaves off about 10 seconds, since it's bottlenecked by a
single long-running test.