Spaces:
Sleeping
Sleeping
DeepMind LMI Team
commited on
Commit
·
9bdaa77
1
Parent(s):
b4ed985
Internal change
Browse filesPiperOrigin-RevId: 495328606
This view is limited to 50 files because it contains too many changes.
See raw diff
- CONTRIBUTING.md +28 -0
- LICENSE +202 -0
- README.md +208 -1
- compiler/__init__.py +19 -0
- compiler/assemble.py +335 -0
- compiler/assemble_test.py +120 -0
- compiler/basis_inference.py +106 -0
- compiler/basis_inference_test.py +140 -0
- compiler/compiling.py +92 -0
- compiler/craft_graph_to_model.py +238 -0
- compiler/craft_graph_to_model_test.py +194 -0
- compiler/craft_model_to_transformer.py +76 -0
- compiler/expr_to_craft_graph.py +277 -0
- compiler/expr_to_craft_graph_test.py +121 -0
- compiler/lib.py +371 -0
- compiler/lib_test.py +40 -0
- compiler/nodes.py +32 -0
- compiler/rasp_to_craft_integration_test.py +254 -0
- compiler/rasp_to_graph.py +67 -0
- compiler/rasp_to_graph_test.py +71 -0
- compiler/rasp_to_transformer_integration_test.py +214 -0
- compiler/test_cases.py +357 -0
- craft/bases.py +247 -0
- craft/bases_test.py +158 -0
- craft/chamber/categorical_attn.py +167 -0
- craft/chamber/categorical_attn_test.py +229 -0
- craft/chamber/categorical_mlp.py +168 -0
- craft/chamber/categorical_mlp_test.py +164 -0
- craft/chamber/numerical_mlp.py +334 -0
- craft/chamber/numerical_mlp_test.py +233 -0
- craft/chamber/selector_width.py +144 -0
- craft/chamber/selector_width_test.py +155 -0
- craft/tests_common.py +33 -0
- craft/transformers.py +197 -0
- craft/transformers_test.py +160 -0
- craft/vectorspace_fns.py +162 -0
- craft/vectorspace_fns_test.py +166 -0
- examples/Visualize_Tracr_Models.ipynb +262 -0
- rasp/causal_eval.py +39 -0
- rasp/causal_eval_test.py +61 -0
- rasp/rasp.py +932 -0
- rasp/rasp_test.py +580 -0
- transformer/attention.py +160 -0
- transformer/compressed_model.py +185 -0
- transformer/compressed_model_test.py +318 -0
- transformer/encoder.py +135 -0
- transformer/encoder_test.py +123 -0
- transformer/model.py +199 -0
- transformer/model_test.py +275 -0
- utils/debugging.py +28 -0
CONTRIBUTING.md
ADDED
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# How to Contribute
|
2 |
+
|
3 |
+
We welcome your contributions to this project. Please read the guidance below
|
4 |
+
first.
|
5 |
+
|
6 |
+
## Contributor License Agreement
|
7 |
+
|
8 |
+
Contributions to this project must be accompanied by a Contributor License
|
9 |
+
Agreement. You (or your employer) retain the copyright to your contribution,
|
10 |
+
this simply gives us permission to use and redistribute your contributions as
|
11 |
+
part of the project. Head over to <https://cla.developers.google.com/> to see
|
12 |
+
your current agreements on file or to sign a new one.
|
13 |
+
|
14 |
+
You generally only need to submit a CLA once, so if you've already submitted one
|
15 |
+
(even if it was for a different project), you probably don't need to do it
|
16 |
+
again.
|
17 |
+
|
18 |
+
## Code reviews
|
19 |
+
|
20 |
+
All submissions, including submissions by project members, require review. We
|
21 |
+
use GitHub pull requests for this purpose. Consult
|
22 |
+
[GitHub Help](https://help.github.com/articles/about-pull-requests/) for more
|
23 |
+
information on using pull requests.
|
24 |
+
|
25 |
+
## Community Guidelines
|
26 |
+
|
27 |
+
This project follows [Google's Open Source Community
|
28 |
+
Guidelines](https://opensource.google/conduct/).
|
LICENSE
ADDED
@@ -0,0 +1,202 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
Apache License
|
3 |
+
Version 2.0, January 2004
|
4 |
+
http://www.apache.org/licenses/
|
5 |
+
|
6 |
+
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
7 |
+
|
8 |
+
1. Definitions.
|
9 |
+
|
10 |
+
"License" shall mean the terms and conditions for use, reproduction,
|
11 |
+
and distribution as defined by Sections 1 through 9 of this document.
|
12 |
+
|
13 |
+
"Licensor" shall mean the copyright owner or entity authorized by
|
14 |
+
the copyright owner that is granting the License.
|
15 |
+
|
16 |
+
"Legal Entity" shall mean the union of the acting entity and all
|
17 |
+
other entities that control, are controlled by, or are under common
|
18 |
+
control with that entity. For the purposes of this definition,
|
19 |
+
"control" means (i) the power, direct or indirect, to cause the
|
20 |
+
direction or management of such entity, whether by contract or
|
21 |
+
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
22 |
+
outstanding shares, or (iii) beneficial ownership of such entity.
|
23 |
+
|
24 |
+
"You" (or "Your") shall mean an individual or Legal Entity
|
25 |
+
exercising permissions granted by this License.
|
26 |
+
|
27 |
+
"Source" form shall mean the preferred form for making modifications,
|
28 |
+
including but not limited to software source code, documentation
|
29 |
+
source, and configuration files.
|
30 |
+
|
31 |
+
"Object" form shall mean any form resulting from mechanical
|
32 |
+
transformation or translation of a Source form, including but
|
33 |
+
not limited to compiled object code, generated documentation,
|
34 |
+
and conversions to other media types.
|
35 |
+
|
36 |
+
"Work" shall mean the work of authorship, whether in Source or
|
37 |
+
Object form, made available under the License, as indicated by a
|
38 |
+
copyright notice that is included in or attached to the work
|
39 |
+
(an example is provided in the Appendix below).
|
40 |
+
|
41 |
+
"Derivative Works" shall mean any work, whether in Source or Object
|
42 |
+
form, that is based on (or derived from) the Work and for which the
|
43 |
+
editorial revisions, annotations, elaborations, or other modifications
|
44 |
+
represent, as a whole, an original work of authorship. For the purposes
|
45 |
+
of this License, Derivative Works shall not include works that remain
|
46 |
+
separable from, or merely link (or bind by name) to the interfaces of,
|
47 |
+
the Work and Derivative Works thereof.
|
48 |
+
|
49 |
+
"Contribution" shall mean any work of authorship, including
|
50 |
+
the original version of the Work and any modifications or additions
|
51 |
+
to that Work or Derivative Works thereof, that is intentionally
|
52 |
+
submitted to Licensor for inclusion in the Work by the copyright owner
|
53 |
+
or by an individual or Legal Entity authorized to submit on behalf of
|
54 |
+
the copyright owner. For the purposes of this definition, "submitted"
|
55 |
+
means any form of electronic, verbal, or written communication sent
|
56 |
+
to the Licensor or its representatives, including but not limited to
|
57 |
+
communication on electronic mailing lists, source code control systems,
|
58 |
+
and issue tracking systems that are managed by, or on behalf of, the
|
59 |
+
Licensor for the purpose of discussing and improving the Work, but
|
60 |
+
excluding communication that is conspicuously marked or otherwise
|
61 |
+
designated in writing by the copyright owner as "Not a Contribution."
|
62 |
+
|
63 |
+
"Contributor" shall mean Licensor and any individual or Legal Entity
|
64 |
+
on behalf of whom a Contribution has been received by Licensor and
|
65 |
+
subsequently incorporated within the Work.
|
66 |
+
|
67 |
+
2. Grant of Copyright License. Subject to the terms and conditions of
|
68 |
+
this License, each Contributor hereby grants to You a perpetual,
|
69 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
70 |
+
copyright license to reproduce, prepare Derivative Works of,
|
71 |
+
publicly display, publicly perform, sublicense, and distribute the
|
72 |
+
Work and such Derivative Works in Source or Object form.
|
73 |
+
|
74 |
+
3. Grant of Patent License. Subject to the terms and conditions of
|
75 |
+
this License, each Contributor hereby grants to You a perpetual,
|
76 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
77 |
+
(except as stated in this section) patent license to make, have made,
|
78 |
+
use, offer to sell, sell, import, and otherwise transfer the Work,
|
79 |
+
where such license applies only to those patent claims licensable
|
80 |
+
by such Contributor that are necessarily infringed by their
|
81 |
+
Contribution(s) alone or by combination of their Contribution(s)
|
82 |
+
with the Work to which such Contribution(s) was submitted. If You
|
83 |
+
institute patent litigation against any entity (including a
|
84 |
+
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
85 |
+
or a Contribution incorporated within the Work constitutes direct
|
86 |
+
or contributory patent infringement, then any patent licenses
|
87 |
+
granted to You under this License for that Work shall terminate
|
88 |
+
as of the date such litigation is filed.
|
89 |
+
|
90 |
+
4. Redistribution. You may reproduce and distribute copies of the
|
91 |
+
Work or Derivative Works thereof in any medium, with or without
|
92 |
+
modifications, and in Source or Object form, provided that You
|
93 |
+
meet the following conditions:
|
94 |
+
|
95 |
+
(a) You must give any other recipients of the Work or
|
96 |
+
Derivative Works a copy of this License; and
|
97 |
+
|
98 |
+
(b) You must cause any modified files to carry prominent notices
|
99 |
+
stating that You changed the files; and
|
100 |
+
|
101 |
+
(c) You must retain, in the Source form of any Derivative Works
|
102 |
+
that You distribute, all copyright, patent, trademark, and
|
103 |
+
attribution notices from the Source form of the Work,
|
104 |
+
excluding those notices that do not pertain to any part of
|
105 |
+
the Derivative Works; and
|
106 |
+
|
107 |
+
(d) If the Work includes a "NOTICE" text file as part of its
|
108 |
+
distribution, then any Derivative Works that You distribute must
|
109 |
+
include a readable copy of the attribution notices contained
|
110 |
+
within such NOTICE file, excluding those notices that do not
|
111 |
+
pertain to any part of the Derivative Works, in at least one
|
112 |
+
of the following places: within a NOTICE text file distributed
|
113 |
+
as part of the Derivative Works; within the Source form or
|
114 |
+
documentation, if provided along with the Derivative Works; or,
|
115 |
+
within a display generated by the Derivative Works, if and
|
116 |
+
wherever such third-party notices normally appear. The contents
|
117 |
+
of the NOTICE file are for informational purposes only and
|
118 |
+
do not modify the License. You may add Your own attribution
|
119 |
+
notices within Derivative Works that You distribute, alongside
|
120 |
+
or as an addendum to the NOTICE text from the Work, provided
|
121 |
+
that such additional attribution notices cannot be construed
|
122 |
+
as modifying the License.
|
123 |
+
|
124 |
+
You may add Your own copyright statement to Your modifications and
|
125 |
+
may provide additional or different license terms and conditions
|
126 |
+
for use, reproduction, or distribution of Your modifications, or
|
127 |
+
for any such Derivative Works as a whole, provided Your use,
|
128 |
+
reproduction, and distribution of the Work otherwise complies with
|
129 |
+
the conditions stated in this License.
|
130 |
+
|
131 |
+
5. Submission of Contributions. Unless You explicitly state otherwise,
|
132 |
+
any Contribution intentionally submitted for inclusion in the Work
|
133 |
+
by You to the Licensor shall be under the terms and conditions of
|
134 |
+
this License, without any additional terms or conditions.
|
135 |
+
Notwithstanding the above, nothing herein shall supersede or modify
|
136 |
+
the terms of any separate license agreement you may have executed
|
137 |
+
with Licensor regarding such Contributions.
|
138 |
+
|
139 |
+
6. Trademarks. This License does not grant permission to use the trade
|
140 |
+
names, trademarks, service marks, or product names of the Licensor,
|
141 |
+
except as required for reasonable and customary use in describing the
|
142 |
+
origin of the Work and reproducing the content of the NOTICE file.
|
143 |
+
|
144 |
+
7. Disclaimer of Warranty. Unless required by applicable law or
|
145 |
+
agreed to in writing, Licensor provides the Work (and each
|
146 |
+
Contributor provides its Contributions) on an "AS IS" BASIS,
|
147 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
148 |
+
implied, including, without limitation, any warranties or conditions
|
149 |
+
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
150 |
+
PARTICULAR PURPOSE. You are solely responsible for determining the
|
151 |
+
appropriateness of using or redistributing the Work and assume any
|
152 |
+
risks associated with Your exercise of permissions under this License.
|
153 |
+
|
154 |
+
8. Limitation of Liability. In no event and under no legal theory,
|
155 |
+
whether in tort (including negligence), contract, or otherwise,
|
156 |
+
unless required by applicable law (such as deliberate and grossly
|
157 |
+
negligent acts) or agreed to in writing, shall any Contributor be
|
158 |
+
liable to You for damages, including any direct, indirect, special,
|
159 |
+
incidental, or consequential damages of any character arising as a
|
160 |
+
result of this License or out of the use or inability to use the
|
161 |
+
Work (including but not limited to damages for loss of goodwill,
|
162 |
+
work stoppage, computer failure or malfunction, or any and all
|
163 |
+
other commercial damages or losses), even if such Contributor
|
164 |
+
has been advised of the possibility of such damages.
|
165 |
+
|
166 |
+
9. Accepting Warranty or Additional Liability. While redistributing
|
167 |
+
the Work or Derivative Works thereof, You may choose to offer,
|
168 |
+
and charge a fee for, acceptance of support, warranty, indemnity,
|
169 |
+
or other liability obligations and/or rights consistent with this
|
170 |
+
License. However, in accepting such obligations, You may act only
|
171 |
+
on Your own behalf and on Your sole responsibility, not on behalf
|
172 |
+
of any other Contributor, and only if You agree to indemnify,
|
173 |
+
defend, and hold each Contributor harmless for any liability
|
174 |
+
incurred by, or claims asserted against, such Contributor by reason
|
175 |
+
of your accepting any such warranty or additional liability.
|
176 |
+
|
177 |
+
END OF TERMS AND CONDITIONS
|
178 |
+
|
179 |
+
APPENDIX: How to apply the Apache License to your work.
|
180 |
+
|
181 |
+
To apply the Apache License to your work, attach the following
|
182 |
+
boilerplate notice, with the fields enclosed by brackets "[]"
|
183 |
+
replaced with your own identifying information. (Don't include
|
184 |
+
the brackets!) The text should be enclosed in the appropriate
|
185 |
+
comment syntax for the file format. We also recommend that a
|
186 |
+
file or class name and description of purpose be included on the
|
187 |
+
same "printed page" as the copyright notice for easier
|
188 |
+
identification within third-party archives.
|
189 |
+
|
190 |
+
Copyright [yyyy] [name of copyright owner]
|
191 |
+
|
192 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
193 |
+
you may not use this file except in compliance with the License.
|
194 |
+
You may obtain a copy of the License at
|
195 |
+
|
196 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
197 |
+
|
198 |
+
Unless required by applicable law or agreed to in writing, software
|
199 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
200 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
201 |
+
See the License for the specific language governing permissions and
|
202 |
+
limitations under the License.
|
README.md
CHANGED
@@ -1 +1,208 @@
|
|
1 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Tracr: TRAnsformer Compiler for RASP.
|
2 |
+
|
3 |
+
Tracr is a compiler for converting RASP programs
|
4 |
+
([Weiss et al. 2021](https://arxiv.org/abs/2106.06981))
|
5 |
+
into transformer weights.
|
6 |
+
|
7 |
+
Directory structure:
|
8 |
+
|
9 |
+
* `rasp` contains an implementation of RASP embedded in Python.
|
10 |
+
* `compiler` contains the compiler itself.
|
11 |
+
* `transformer` contains the implementation of the transformer.
|
12 |
+
* `craft` contains the intermediate representation used by the compiler:
|
13 |
+
essentially a small linear algebra-based library with named dimensions.
|
14 |
+
|
15 |
+
This is not an officially supported Google product.
|
16 |
+
|
17 |
+
|
18 |
+
## Installation
|
19 |
+
|
20 |
+
Installation is currently a bit manual. First, install dependencies:
|
21 |
+
|
22 |
+
```
|
23 |
+
pip3 install chex einops dm-haiku networkx
|
24 |
+
```
|
25 |
+
|
26 |
+
Second, clone the repo:
|
27 |
+
|
28 |
+
```
|
29 |
+
git clone https://github.com/deepmind/tracr
|
30 |
+
```
|
31 |
+
|
32 |
+
Third, put the resulting folder somewhere in your `PYTHONPATH`
|
33 |
+
(eg by placing the `tracr` checkout in the root of your project folder).
|
34 |
+
|
35 |
+
This will be made easier in the future.
|
36 |
+
|
37 |
+
|
38 |
+
## Usage example: RASP `reverse` program
|
39 |
+
|
40 |
+
Consider the RASP `reverse` program:
|
41 |
+
|
42 |
+
```
|
43 |
+
opp_index = length - indices - 1;
|
44 |
+
flip = select(indices, opp_index, ==);
|
45 |
+
reverse = aggregate(flip, tokens);
|
46 |
+
```
|
47 |
+
|
48 |
+
To compile this with Tracr, we would first implement the program using Tracr's
|
49 |
+
RASP library:
|
50 |
+
|
51 |
+
```python
|
52 |
+
from tracr.rasp import rasp
|
53 |
+
|
54 |
+
length = make_length() # `length` is not a primitive in our implementation.
|
55 |
+
opp_index = length - rasp.indices - 1
|
56 |
+
flip = rasp.Select(rasp.indices, opp_index, rasp.Comparison.EQ)
|
57 |
+
reverse = rasp.Aggregate(flip, rasp.tokens)
|
58 |
+
```
|
59 |
+
|
60 |
+
Where:
|
61 |
+
|
62 |
+
```python
|
63 |
+
def make_length():
|
64 |
+
all_true_selector = rasp.Select(rasp.tokens, rasp.tokens, rasp.Comparison.TRUE)
|
65 |
+
return rasp.SelectorWidth(all_true_selector)
|
66 |
+
```
|
67 |
+
|
68 |
+
We can then compile the RASP program to a transformer with:
|
69 |
+
|
70 |
+
```python
|
71 |
+
from tracr.compiler import compiling
|
72 |
+
|
73 |
+
bos = "BOS"
|
74 |
+
model = compiling.compile_rasp_to_model(
|
75 |
+
reverse,
|
76 |
+
vocab={1, 2, 3},
|
77 |
+
max_seq_len=5,
|
78 |
+
compiler_bos=bos,
|
79 |
+
)
|
80 |
+
```
|
81 |
+
|
82 |
+
This yields a transformer as a [Haiku](https://github.com/deepmind/dm-haiku) model.
|
83 |
+
This model isn't intended to provide _everything_ you might need, but rather serves
|
84 |
+
as a kind of "documentation-in-code" for the semantics of the generated parameters.
|
85 |
+
The expectation is that the user can then write or contribute an adapter that converts
|
86 |
+
parameters from this reference model to another transformer implementation.
|
87 |
+
|
88 |
+
Using this model we can perform a forward pass:
|
89 |
+
|
90 |
+
```python
|
91 |
+
>>> out = model.apply([bos, 1, 2, 3])
|
92 |
+
>>> out.decoded
|
93 |
+
["BOS", 3, 2, 1]
|
94 |
+
```
|
95 |
+
|
96 |
+
Success! We have a transformer that reverses its input tokens.
|
97 |
+
|
98 |
+
Note: compiled models always expect a BOS token in order to support
|
99 |
+
selectors which don't attend to any of the input tokens. This is necessary to
|
100 |
+
preserve intuitive RASP semantics; the alternative would have been to treat
|
101 |
+
all-False selector rows as equivalent to all-True (which is what softmax in an
|
102 |
+
attention layer would naturally do). For more details, see our paper.
|
103 |
+
|
104 |
+
You can also inspect some of the intermediate activations of the model, using
|
105 |
+
`out.residuals`, `out.layer_outputs`, and `out.attn_logits`.
|
106 |
+
|
107 |
+
For more examples of RASP programs we can compile, check out
|
108 |
+
[compiler/lib.py](compiler/lib.py).
|
109 |
+
|
110 |
+
For an interactive example of compiling a model and visualizing its computation,
|
111 |
+
check out the notebook at
|
112 |
+
[examples/Visualize\_Tracr\_Models.ipynb](examples/Visualize_Tracr_Models.ipynb).
|
113 |
+
|
114 |
+
|
115 |
+
## Developer README
|
116 |
+
|
117 |
+
If you'd like to extend Tracr to fit your purposes, here's some information on
|
118 |
+
how Tracr works under the hood.
|
119 |
+
|
120 |
+
|
121 |
+
### How Tracr works conceptually
|
122 |
+
|
123 |
+
To compile a program, Tracr does the following.
|
124 |
+
|
125 |
+
1. **Trace RASP program into a graph representation.** This involves creating
|
126 |
+
a graph node for each RASP expression and inferring dependencies between
|
127 |
+
these graph nodes.
|
128 |
+
|
129 |
+
2. **Infer bases.** Tracr is designed to have each node output to a separate
|
130 |
+
subspace of the residual stream. To do this, we first infer the set of all
|
131 |
+
possible token values that each node can take, then using that information,
|
132 |
+
decide on a subspace for each node, and augment each node in the graph
|
133 |
+
with the basis vectors for that node's subspace.
|
134 |
+
|
135 |
+
3. **Convert nodes to Craft components.** Craft is the name of our internal
|
136 |
+
intermediate representation that does linear algebra on named subspaces. In
|
137 |
+
this stage, each expression node is converted to a Craft component that
|
138 |
+
actually performs the linear algebra operations necessary to implement the
|
139 |
+
expression. This includes converting _sequence operators_ to MLP weights,
|
140 |
+
and _selectors_ to weights of attention heads. (We compute the appropriate
|
141 |
+
weights directly using the theory of universal approximation for MLPs - no
|
142 |
+
gradient descent required!)
|
143 |
+
|
144 |
+
4. **Convert Craft graph to Craft model.** In this stage, we convert from
|
145 |
+
a graph representation to a layout that looks more like an actual
|
146 |
+
transformer. At this stage, we essentially have a working model, but
|
147 |
+
with the linear algebra done using Craft rather than JAX + Haiku.
|
148 |
+
|
149 |
+
5. **Convert Craft model to Haiku model.** Finally, we convert our
|
150 |
+
intermediate representation of the model to a full Haiku model.
|
151 |
+
|
152 |
+
Two details worth expanding on here are subspaces and corresponding bases.
|
153 |
+
Each node writes to a separate subspace of the residual stream,
|
154 |
+
where each subspace is simply a unique chunk of the residual stream vector.
|
155 |
+
For example, the first node might write to the first 5 components of
|
156 |
+
the residual stream; the second node the next 5; and so on. In terms of what
|
157 |
+
the embeddings actually associated with each node, Tracr employs two
|
158 |
+
different kinds of bases:
|
159 |
+
|
160 |
+
* **Categorical representation** - in which each unique token value is
|
161 |
+
represented as a unique one-hot vector in that node's subspace. This
|
162 |
+
is the representation used by default.
|
163 |
+
* **Numerical representation** - in which each unique token value is
|
164 |
+
mapped to a unique scalar value. This is necessary for some uses
|
165 |
+
of the `aggregate` operation - essentially, ones which involve taking
|
166 |
+
a mean - and some other operations are represented more efficiently
|
167 |
+
with this representation.
|
168 |
+
|
169 |
+
A final detail is BOS tokens. The compiler relies on beginning-of-sequence
|
170 |
+
tokens to in order to implement a number of operations. This is why token
|
171 |
+
sequences fed into the final model _must_ start with a BOS token.
|
172 |
+
|
173 |
+
|
174 |
+
### How Tracr works in practice
|
175 |
+
|
176 |
+
The flow of compilation execution begins in
|
177 |
+
[`compiler/compiling.py`](compiler/compiling.py), in the
|
178 |
+
`compile_rasp_to_model` function. This function is fairly short and maps
|
179 |
+
directly to the stages outlined above, so don't be afraid to read the source!
|
180 |
+
|
181 |
+
|
182 |
+
## Running tests
|
183 |
+
|
184 |
+
We use [`absltest`](https://abseil.io/docs/python/guides/testing), which is
|
185 |
+
`unittest`-compatible, and is therefore in turn `pytest`-compatible.
|
186 |
+
|
187 |
+
First, install test dependencies:
|
188 |
+
|
189 |
+
```
|
190 |
+
pip3 install absl-py pytest
|
191 |
+
```
|
192 |
+
|
193 |
+
```
|
194 |
+
# We use `python3 -m pytest` instead of just `pytest` so that the working directory is
|
195 |
+
# added to PYTHONPATH.
|
196 |
+
# -ra: Report names of tests that failed, were skipped, etc.
|
197 |
+
python3 -m pytest -ra
|
198 |
+
```
|
199 |
+
|
200 |
+
This should take about 60 seconds. If you install `pytest-xdist`, you can run them in
|
201 |
+
parallel with:
|
202 |
+
|
203 |
+
```
|
204 |
+
python3 -m pytest -ra -n auto
|
205 |
+
```
|
206 |
+
|
207 |
+
However, currently this only shaves off about 10 seconds, since it's bottlenecked by a
|
208 |
+
single long-running test.
|
compiler/__init__.py
ADDED
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2022 DeepMind Technologies Limited. All Rights Reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
# ==============================================================================
|
15 |
+
"""Provides the main compiler function as a public import."""
|
16 |
+
|
17 |
+
from tracr.compiler.compiling import compile_rasp_to_model
|
18 |
+
|
19 |
+
__all__ = ["compile_rasp_to_model"]
|
compiler/assemble.py
ADDED
@@ -0,0 +1,335 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2022 DeepMind Technologies Limited. All Rights Reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
# ==============================================================================
|
15 |
+
"""Assemble weights of a transformer model from a craft residual stack."""
|
16 |
+
|
17 |
+
import dataclasses
|
18 |
+
from typing import Any, Callable, Optional, Protocol
|
19 |
+
|
20 |
+
import chex
|
21 |
+
import einops
|
22 |
+
import haiku as hk
|
23 |
+
import jax
|
24 |
+
import jax.numpy as jnp
|
25 |
+
import numpy as np
|
26 |
+
from tracr.craft import bases
|
27 |
+
from tracr.craft import transformers
|
28 |
+
from tracr.craft import vectorspace_fns
|
29 |
+
from tracr.transformer import encoder
|
30 |
+
from tracr.transformer import model
|
31 |
+
|
32 |
+
|
33 |
+
@chex.dataclass
|
34 |
+
class AssembledTransformerModelOutput:
|
35 |
+
decoded: list[Any] # length T.
|
36 |
+
unembedded: jax.Array # [B, T] B = 1 always.
|
37 |
+
layer_outputs: list[jax.Array] # [B, T, D]
|
38 |
+
residuals: list[jax.Array] # [B, T, D]
|
39 |
+
attn_logits: list[jax.Array] # [B, T, T, H]
|
40 |
+
transformer_output: jax.Array # [B, T, D]
|
41 |
+
input_embeddings: jax.Array
|
42 |
+
|
43 |
+
|
44 |
+
class ModelForward(Protocol):
|
45 |
+
|
46 |
+
def __call__(
|
47 |
+
self,
|
48 |
+
params: hk.Params,
|
49 |
+
emb: jax.Array,
|
50 |
+
) -> model.CompiledTransformerModelOutput:
|
51 |
+
"""A hk-transformed forward pass through the compiled model."""
|
52 |
+
|
53 |
+
|
54 |
+
@dataclasses.dataclass
|
55 |
+
class AssembledTransformerModel:
|
56 |
+
"""Model architecture and parameters from assembling a model."""
|
57 |
+
forward: ModelForward
|
58 |
+
get_compiled_model: Callable[[], model.CompiledTransformerModel]
|
59 |
+
params: hk.Params
|
60 |
+
model_config: model.TransformerConfig
|
61 |
+
residual_labels: list[str]
|
62 |
+
input_encoder: Optional[encoder.Encoder] = None
|
63 |
+
output_encoder: Optional[encoder.Encoder] = None
|
64 |
+
|
65 |
+
def apply(self, tokens: list[bases.Value]) -> AssembledTransformerModelOutput:
|
66 |
+
"""Returns output from running the model on a set of input tokens."""
|
67 |
+
if self.input_encoder:
|
68 |
+
tokens = self.input_encoder.encode(tokens)
|
69 |
+
tokens = jnp.array([tokens])
|
70 |
+
output = self.forward(self.params, tokens)
|
71 |
+
decoded = output.unembedded_output[0].tolist()
|
72 |
+
if self.output_encoder:
|
73 |
+
decoded = self.output_encoder.decode(decoded)
|
74 |
+
|
75 |
+
if self.input_encoder.bos_token:
|
76 |
+
# Special case for decoding the bos token position, for which the output
|
77 |
+
# decoder might have unspecified behavior.
|
78 |
+
decoded = [self.input_encoder.bos_token] + decoded[1:]
|
79 |
+
|
80 |
+
return AssembledTransformerModelOutput(
|
81 |
+
decoded=decoded,
|
82 |
+
unembedded=output.unembedded_output,
|
83 |
+
layer_outputs=output.transformer_output.layer_outputs,
|
84 |
+
residuals=output.transformer_output.residuals,
|
85 |
+
attn_logits=output.transformer_output.attn_logits,
|
86 |
+
transformer_output=output.transformer_output.output,
|
87 |
+
input_embeddings=output.transformer_output.input_embeddings)
|
88 |
+
|
89 |
+
|
90 |
+
@dataclasses.dataclass
|
91 |
+
class EmbeddingModules:
|
92 |
+
"""Modules for embedding and tokens and positions and unembedding results."""
|
93 |
+
token_embed: model.CallableHaikuModule
|
94 |
+
pos_embed: model.CallableHaikuModule
|
95 |
+
unembed: model.CallableHaikuModule
|
96 |
+
|
97 |
+
|
98 |
+
def _get_model_config_and_module_names(
|
99 |
+
craft_model: transformers.SeriesWithResiduals
|
100 |
+
) -> tuple[model.TransformerConfig, list[str]]:
|
101 |
+
"""Returns model config and locations (in params) for halflayers."""
|
102 |
+
|
103 |
+
multi_attn_heads: list[list[transformers.AttentionHead]] = []
|
104 |
+
mlps: list[transformers.MLP] = []
|
105 |
+
module_names: list[str] = []
|
106 |
+
|
107 |
+
candidate_module_names = []
|
108 |
+
for layer in range(len(craft_model.blocks)):
|
109 |
+
candidate_module_names.append(f"transformer/layer_{layer}/attn")
|
110 |
+
candidate_module_names.append(f"transformer/layer_{layer}/mlp")
|
111 |
+
candidate_module_names = iter(candidate_module_names)
|
112 |
+
|
113 |
+
for module in craft_model.blocks:
|
114 |
+
if isinstance(module, transformers.MLP):
|
115 |
+
mlps.append(module)
|
116 |
+
layer_type = "mlp"
|
117 |
+
else:
|
118 |
+
multi_attn_heads.append(list(module.as_multi().heads()))
|
119 |
+
layer_type = "attn"
|
120 |
+
# Find next layer with the necessary type. Modules in-between, that are not
|
121 |
+
# added to module_names will be disabled later by setting all weights to 0.
|
122 |
+
module_name = next(candidate_module_names)
|
123 |
+
while layer_type not in module_name:
|
124 |
+
module_name = next(candidate_module_names)
|
125 |
+
module_names.append(module_name)
|
126 |
+
|
127 |
+
num_layers = int(module_names[-1].split("_")[1].split("/")[0]) + 1
|
128 |
+
heads = sum(multi_attn_heads, [])
|
129 |
+
|
130 |
+
if multi_attn_heads:
|
131 |
+
num_heads = max(len(heads) for heads in multi_attn_heads)
|
132 |
+
key_size = max(max(head.w_qk.matrix.shape) for head in heads)
|
133 |
+
else:
|
134 |
+
num_heads, key_size = 1, 1
|
135 |
+
|
136 |
+
if mlps:
|
137 |
+
mlp_hidden_size = max(mlp.fst.output_space.num_dims for mlp in mlps)
|
138 |
+
else:
|
139 |
+
mlp_hidden_size = 1
|
140 |
+
|
141 |
+
model_config = model.TransformerConfig(
|
142 |
+
num_heads=num_heads,
|
143 |
+
num_layers=num_layers,
|
144 |
+
key_size=key_size,
|
145 |
+
mlp_hidden_size=mlp_hidden_size,
|
146 |
+
dropout_rate=0.,
|
147 |
+
activation_function=jax.nn.relu,
|
148 |
+
layer_norm=False,
|
149 |
+
causal=False,
|
150 |
+
)
|
151 |
+
|
152 |
+
return model_config, module_names
|
153 |
+
|
154 |
+
|
155 |
+
def _make_embedding_modules(
|
156 |
+
residual_space: bases.VectorSpaceWithBasis,
|
157 |
+
tokens_space: bases.VectorSpaceWithBasis,
|
158 |
+
indices_space: bases.VectorSpaceWithBasis,
|
159 |
+
output_space: bases.VectorSpaceWithBasis) -> EmbeddingModules:
|
160 |
+
"""Creates embedding and unembedding modules from vector spaces.
|
161 |
+
|
162 |
+
Args:
|
163 |
+
residual_space: Full residual space of the model.
|
164 |
+
tokens_space: Subspace to embed tokens to.
|
165 |
+
indices_space: Subspace to embed indices/position embeddings to.
|
166 |
+
output_space: Subspace to unembed outputs from.
|
167 |
+
|
168 |
+
Returns:
|
169 |
+
EmbeddingModules containing modules for token embeddings, position
|
170 |
+
embeddings and unembeddings.
|
171 |
+
"""
|
172 |
+
tokens_to_res = vectorspace_fns.project(tokens_space, residual_space)
|
173 |
+
|
174 |
+
# If we use the 'one' direction, make sure all inputs have a 1 here
|
175 |
+
one_dir = bases.BasisDirection("one")
|
176 |
+
if one_dir in residual_space:
|
177 |
+
one_to_res = vectorspace_fns.Linear.from_action(
|
178 |
+
tokens_space, residual_space,
|
179 |
+
lambda x: residual_space.vector_from_basis_direction(one_dir))
|
180 |
+
tokens_to_res = vectorspace_fns.Linear.combine_in_parallel(
|
181 |
+
[tokens_to_res, one_to_res])
|
182 |
+
|
183 |
+
# Token embeddings.
|
184 |
+
res_to_out = vectorspace_fns.project(residual_space, output_space)
|
185 |
+
token_embed = hk.Embed(
|
186 |
+
embedding_matrix=tokens_to_res.matrix, name="token_embed")
|
187 |
+
|
188 |
+
# Positional embeddings.
|
189 |
+
index_to_res = vectorspace_fns.project(indices_space, residual_space)
|
190 |
+
# The zeroth position should not have any positional embeddings,
|
191 |
+
# so we add one line of padding at the zeroth position.
|
192 |
+
pos_matrix = np.concatenate(
|
193 |
+
[np.zeros((1, residual_space.num_dims)), index_to_res.matrix], axis=0)
|
194 |
+
pos_embed = hk.Embed(embedding_matrix=pos_matrix, name="pos_embed")
|
195 |
+
|
196 |
+
def unembed(x, use_unembed_argmax):
|
197 |
+
out = x @ res_to_out.matrix
|
198 |
+
if use_unembed_argmax:
|
199 |
+
return jnp.argmax(out, axis=-1)
|
200 |
+
elif out.shape[-1] == 1:
|
201 |
+
return out.squeeze(-1)
|
202 |
+
return out
|
203 |
+
|
204 |
+
unembed_mod = hk.to_module(unembed)()
|
205 |
+
return EmbeddingModules(
|
206 |
+
token_embed=token_embed, pos_embed=pos_embed, unembed=unembed_mod)
|
207 |
+
|
208 |
+
|
209 |
+
def assemble_craft_model(
|
210 |
+
craft_model: transformers.SeriesWithResiduals,
|
211 |
+
tokens_space: bases.VectorSpaceWithBasis,
|
212 |
+
indices_space: bases.VectorSpaceWithBasis,
|
213 |
+
output_space: bases.VectorSpaceWithBasis,
|
214 |
+
categorical_output: bool,
|
215 |
+
causal: bool = False,
|
216 |
+
) -> AssembledTransformerModel:
|
217 |
+
"""Assembles the given components into a Haiku model with parameters.
|
218 |
+
|
219 |
+
Args:
|
220 |
+
craft_model: Model to assemble weights for.
|
221 |
+
tokens_space: Vectorspace to embed the input tokens to.
|
222 |
+
indices_space: Vectorspace to embed the indices to (position encodings).
|
223 |
+
output_space: Vectorspace that the model will write outputs to that should
|
224 |
+
be unembedded.
|
225 |
+
categorical_output: Whether the output is categorical. If True, we take an
|
226 |
+
argmax when unembedding.
|
227 |
+
causal: Whether to output a causally-masked model.
|
228 |
+
|
229 |
+
Returns:
|
230 |
+
An AssembledTransformerModel that contains the model and parameters of the
|
231 |
+
assembled transformer.
|
232 |
+
"""
|
233 |
+
# TODO(b/255936413): Make embeddings only retain the tokens and indices that
|
234 |
+
# are actually used.
|
235 |
+
# TODO(b/255936496): Think about enabling layer norm and reversing it somehow
|
236 |
+
|
237 |
+
model_config, module_names = _get_model_config_and_module_names(craft_model)
|
238 |
+
model_config.causal = causal
|
239 |
+
|
240 |
+
residual_space = bases.join_vector_spaces(craft_model.residual_space,
|
241 |
+
tokens_space, indices_space,
|
242 |
+
output_space)
|
243 |
+
residual_labels = [str(basis_dir) for basis_dir in residual_space.basis]
|
244 |
+
|
245 |
+
# Build model with embedding and unembedding layers
|
246 |
+
def get_compiled_model():
|
247 |
+
transformer = model.Transformer(model_config)
|
248 |
+
embed_modules = _make_embedding_modules(
|
249 |
+
residual_space=residual_space,
|
250 |
+
tokens_space=tokens_space,
|
251 |
+
indices_space=indices_space,
|
252 |
+
output_space=output_space)
|
253 |
+
return model.CompiledTransformerModel(
|
254 |
+
transformer=transformer,
|
255 |
+
token_embed=embed_modules.token_embed,
|
256 |
+
position_embed=embed_modules.pos_embed,
|
257 |
+
unembed=embed_modules.unembed,
|
258 |
+
use_unembed_argmax=categorical_output)
|
259 |
+
|
260 |
+
@hk.without_apply_rng
|
261 |
+
@hk.transform
|
262 |
+
def forward(emb):
|
263 |
+
compiled_model = get_compiled_model()
|
264 |
+
return compiled_model(emb, use_dropout=False)
|
265 |
+
|
266 |
+
params = forward.init(jax.random.PRNGKey(0), jnp.array([[1, 2, 3]]))
|
267 |
+
|
268 |
+
for key in params:
|
269 |
+
if "transformer" in key:
|
270 |
+
for par in params[key]:
|
271 |
+
params[key][par] = np.zeros_like(params[key][par])
|
272 |
+
|
273 |
+
# Assemble attention and MLP weights.
|
274 |
+
project = lambda space: vectorspace_fns.project(residual_space, space).matrix
|
275 |
+
|
276 |
+
for module_name, module in zip(module_names, craft_model.blocks):
|
277 |
+
if isinstance(module, transformers.MLP):
|
278 |
+
hidden_size = module.fst.output_space.num_dims
|
279 |
+
residual_to_fst_input = project(module.fst.input_space)
|
280 |
+
snd_output_to_residual = project(module.snd.output_space).T
|
281 |
+
params[f"{module_name}/linear_1"]["w"][:, :hidden_size] = (
|
282 |
+
residual_to_fst_input @ module.fst.matrix)
|
283 |
+
params[f"{module_name}/linear_2"]["w"][:hidden_size, :] = (
|
284 |
+
module.snd.matrix @ snd_output_to_residual)
|
285 |
+
else: # Attention module
|
286 |
+
query, key, value, linear = [], [], [], []
|
287 |
+
for head in module.as_multi().heads():
|
288 |
+
key_size = head.w_qk.matrix.shape[1]
|
289 |
+
query_mat = np.zeros((residual_space.num_dims, model_config.key_size))
|
290 |
+
residual_to_query = project(head.w_qk.left_space)
|
291 |
+
query_mat[:, :key_size] = residual_to_query @ head.w_qk.matrix
|
292 |
+
query.append(query_mat)
|
293 |
+
|
294 |
+
key_mat = np.zeros((residual_space.num_dims, model_config.key_size))
|
295 |
+
key_mat[:, :key_size] = project(head.w_qk.right_space)
|
296 |
+
key.append(key_mat)
|
297 |
+
|
298 |
+
value_size = head.w_ov.matrix.shape[1]
|
299 |
+
value_mat = np.zeros((residual_space.num_dims, model_config.key_size))
|
300 |
+
residual_to_ov_input = project(head.w_ov.input_space)
|
301 |
+
value_mat[:, :value_size] = residual_to_ov_input @ head.w_ov.matrix
|
302 |
+
value.append(value_mat)
|
303 |
+
|
304 |
+
linear_mat = np.zeros((model_config.key_size, residual_space.num_dims))
|
305 |
+
linear_mat[:value_size, :] = project(head.w_ov.output_space).T
|
306 |
+
linear.append(linear_mat)
|
307 |
+
|
308 |
+
# Fill up heads that are not used with zero weights
|
309 |
+
for _ in range(model_config.num_heads - module.as_multi().num_heads):
|
310 |
+
query.append(np.zeros_like(query[0]))
|
311 |
+
key.append(np.zeros_like(key[0]))
|
312 |
+
value.append(np.zeros_like(value[0]))
|
313 |
+
linear.append(np.zeros_like(linear[0]))
|
314 |
+
|
315 |
+
query = einops.rearrange(query,
|
316 |
+
"heads input output -> input (heads output)")
|
317 |
+
key = einops.rearrange(key, "heads input output -> input (heads output)")
|
318 |
+
value = einops.rearrange(value,
|
319 |
+
"heads input output -> input (heads output)")
|
320 |
+
linear = einops.rearrange(linear,
|
321 |
+
"heads input output -> (heads input) output")
|
322 |
+
|
323 |
+
params[f"{module_name}/query"]["w"][:, :] = query
|
324 |
+
params[f"{module_name}/key"]["w"][:, :] = key
|
325 |
+
params[f"{module_name}/value"]["w"][:, :] = value
|
326 |
+
params[f"{module_name}/linear"]["w"][:, :] = linear
|
327 |
+
|
328 |
+
params = jax.tree_util.tree_map(jnp.array, params)
|
329 |
+
return AssembledTransformerModel(
|
330 |
+
forward=forward.apply,
|
331 |
+
get_compiled_model=get_compiled_model,
|
332 |
+
params=params,
|
333 |
+
model_config=model_config,
|
334 |
+
residual_labels=residual_labels,
|
335 |
+
)
|
compiler/assemble_test.py
ADDED
@@ -0,0 +1,120 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2022 DeepMind Technologies Limited. All Rights Reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
# ==============================================================================
|
15 |
+
"""Tests for transformer.assemble."""
|
16 |
+
|
17 |
+
from absl.testing import absltest
|
18 |
+
from absl.testing import parameterized
|
19 |
+
import haiku as hk
|
20 |
+
import jax
|
21 |
+
import jax.numpy as jnp
|
22 |
+
import numpy as np
|
23 |
+
from tracr.compiler import assemble
|
24 |
+
from tracr.craft import bases
|
25 |
+
|
26 |
+
|
27 |
+
class AssembleTest(parameterized.TestCase):
|
28 |
+
|
29 |
+
def test_token_embedding_produces_correct_embedding(self):
|
30 |
+
# Token embeddings should be one-hot embeddings of the input integers
|
31 |
+
# into the token subspace of residual_space
|
32 |
+
input_space = bases.VectorSpaceWithBasis.from_values("0inp", range(2))
|
33 |
+
indices_space = bases.VectorSpaceWithBasis.from_values("1ind", range(3))
|
34 |
+
output_space = bases.VectorSpaceWithBasis.from_values("2out", range(2))
|
35 |
+
residual_space = bases.join_vector_spaces(input_space, indices_space,
|
36 |
+
output_space)
|
37 |
+
|
38 |
+
@hk.without_apply_rng
|
39 |
+
@hk.transform
|
40 |
+
def token_pos_embed(tokens):
|
41 |
+
embed_modules = assemble._make_embedding_modules(
|
42 |
+
residual_space=residual_space,
|
43 |
+
tokens_space=input_space,
|
44 |
+
indices_space=indices_space,
|
45 |
+
output_space=output_space)
|
46 |
+
return embed_modules.token_embed(tokens)
|
47 |
+
|
48 |
+
tokens = jnp.array([0, 0, 1])
|
49 |
+
expected_token_embeddings = jnp.array([[1, 0, 0, 0, 0, 0, 0],
|
50 |
+
[1, 0, 0, 0, 0, 0, 0],
|
51 |
+
[0, 1, 0, 0, 0, 0, 0]])
|
52 |
+
|
53 |
+
params = token_pos_embed.init(jax.random.PRNGKey(0), tokens)
|
54 |
+
embeddings = token_pos_embed.apply(params, tokens)
|
55 |
+
np.testing.assert_allclose(embeddings, expected_token_embeddings)
|
56 |
+
|
57 |
+
def test_position_embedding_produces_correct_embedding(self):
|
58 |
+
# Position embeddings should be one-hot embeddings of the input integers
|
59 |
+
# (representing indices) into the indices subspace of residual_space
|
60 |
+
input_space = bases.VectorSpaceWithBasis.from_values("0inp", range(2))
|
61 |
+
indices_space = bases.VectorSpaceWithBasis.from_values("1ind", range(3))
|
62 |
+
output_space = bases.VectorSpaceWithBasis.from_values("2out", range(2))
|
63 |
+
residual_space = bases.join_vector_spaces(input_space, indices_space,
|
64 |
+
output_space)
|
65 |
+
|
66 |
+
@hk.without_apply_rng
|
67 |
+
@hk.transform
|
68 |
+
def token_pos_embed(tokens):
|
69 |
+
embed_modules = assemble._make_embedding_modules(
|
70 |
+
residual_space=residual_space,
|
71 |
+
tokens_space=input_space,
|
72 |
+
indices_space=indices_space,
|
73 |
+
output_space=output_space)
|
74 |
+
return embed_modules.pos_embed(jnp.indices(tokens.shape)[-1])
|
75 |
+
|
76 |
+
tokens = jnp.array([3, 0, 0, 1])
|
77 |
+
expected_pos_embeddings = jnp.array([[0, 0, 0, 0, 0, 0, 0],
|
78 |
+
[0, 0, 1, 0, 0, 0, 0],
|
79 |
+
[0, 0, 0, 1, 0, 0, 0],
|
80 |
+
[0, 0, 0, 0, 1, 0, 0]])
|
81 |
+
|
82 |
+
params = token_pos_embed.init(jax.random.PRNGKey(0), tokens)
|
83 |
+
embeddings = token_pos_embed.apply(params, tokens)
|
84 |
+
np.testing.assert_allclose(embeddings, expected_pos_embeddings)
|
85 |
+
|
86 |
+
def test_unembedding(self):
|
87 |
+
# Prepend numbers to preserve basis order [input, index, output]
|
88 |
+
input_space = bases.VectorSpaceWithBasis.from_values("0inp", range(2))
|
89 |
+
indices_space = bases.VectorSpaceWithBasis.from_values("1ind", range(3))
|
90 |
+
output_space = bases.VectorSpaceWithBasis.from_values("2out", range(2))
|
91 |
+
residual_space = bases.join_vector_spaces(input_space, indices_space,
|
92 |
+
output_space)
|
93 |
+
|
94 |
+
@hk.without_apply_rng
|
95 |
+
@hk.transform
|
96 |
+
def unembed(embeddings):
|
97 |
+
embed_modules = assemble._make_embedding_modules(
|
98 |
+
residual_space=residual_space,
|
99 |
+
tokens_space=input_space,
|
100 |
+
indices_space=indices_space,
|
101 |
+
output_space=output_space)
|
102 |
+
return embed_modules.unembed(embeddings, use_unembed_argmax=True)
|
103 |
+
|
104 |
+
embeddings = jnp.array([
|
105 |
+
# pylint: disable=g-no-space-after-comment
|
106 |
+
#inp| indices| out | < spaces
|
107 |
+
#0 1 0 1 2 0 1 < values in spaces
|
108 |
+
[0, 0, 0, 0, 0, 0, 1],
|
109 |
+
[0, 0, 0, 0, 0, 1, 0],
|
110 |
+
[0, 0, 0, 0, 0, 0, 1]
|
111 |
+
])
|
112 |
+
expected_tokens = jnp.array([1, 0, 1])
|
113 |
+
|
114 |
+
params = unembed.init(jax.random.PRNGKey(0), embeddings)
|
115 |
+
tokens = unembed.apply(params, embeddings)
|
116 |
+
np.testing.assert_allclose(tokens, expected_tokens)
|
117 |
+
|
118 |
+
|
119 |
+
if __name__ == "__main__":
|
120 |
+
absltest.main()
|
compiler/basis_inference.py
ADDED
@@ -0,0 +1,106 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2022 DeepMind Technologies Limited. All Rights Reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
# ==============================================================================
|
15 |
+
"""Inferring the vector spaces taken on by certain operations."""
|
16 |
+
|
17 |
+
import dataclasses
|
18 |
+
import itertools
|
19 |
+
|
20 |
+
import networkx as nx
|
21 |
+
from tracr.compiler import nodes
|
22 |
+
from tracr.craft import bases
|
23 |
+
from tracr.rasp import rasp
|
24 |
+
from tracr.utils import errors
|
25 |
+
|
26 |
+
Node = nodes.Node
|
27 |
+
|
28 |
+
|
29 |
+
@dataclasses.dataclass
|
30 |
+
class InferBasesOutput:
|
31 |
+
graph: nx.DiGraph
|
32 |
+
|
33 |
+
|
34 |
+
def infer_bases(
|
35 |
+
graph: nx.DiGraph,
|
36 |
+
sink: Node,
|
37 |
+
vocab: set[rasp.Value],
|
38 |
+
max_seq_len: int,
|
39 |
+
) -> None:
|
40 |
+
"""Infers in-place the possible output values and vector bases of the SOps."""
|
41 |
+
|
42 |
+
def compute_value_set(sop: rasp.SOp) -> set[rasp.Value]:
|
43 |
+
"""Computes value set using already-computed predecessor value sets."""
|
44 |
+
if sop is rasp.tokens:
|
45 |
+
return vocab
|
46 |
+
elif sop is rasp.indices:
|
47 |
+
return set(range(max_seq_len))
|
48 |
+
elif isinstance(sop, rasp.SelectorWidth):
|
49 |
+
return set(range(0, max_seq_len + 1))
|
50 |
+
elif isinstance(sop, rasp.Full):
|
51 |
+
return {sop.fill}
|
52 |
+
elif isinstance(sop, rasp.Map):
|
53 |
+
inner_value_set = graph.nodes[sop.inner.label][nodes.VALUE_SET]
|
54 |
+
out = set()
|
55 |
+
for x in inner_value_set:
|
56 |
+
res = errors.ignoring_arithmetic_errors(sop.f)(x)
|
57 |
+
if res is not None:
|
58 |
+
out.add(res)
|
59 |
+
return out
|
60 |
+
elif isinstance(sop, rasp.SequenceMap):
|
61 |
+
f_ignore_error = errors.ignoring_arithmetic_errors(sop.f)
|
62 |
+
fst_value_set = graph.nodes[sop.fst.label][nodes.VALUE_SET]
|
63 |
+
snd_value_set = graph.nodes[sop.snd.label][nodes.VALUE_SET]
|
64 |
+
out = set()
|
65 |
+
for l, r in itertools.product(fst_value_set, snd_value_set):
|
66 |
+
res = f_ignore_error(l, r)
|
67 |
+
if res is not None:
|
68 |
+
out.add(res)
|
69 |
+
return out
|
70 |
+
elif isinstance(sop, rasp.Aggregate):
|
71 |
+
if rasp.is_categorical(sop):
|
72 |
+
# Simply pass on the value set of the underlying S-Op.
|
73 |
+
return graph.nodes[sop.sop.label][nodes.VALUE_SET]
|
74 |
+
elif rasp.is_numerical(sop):
|
75 |
+
# TODO(b/255936408): This doesn't work if we average arbitrary values.
|
76 |
+
# But most examples only average binary variables.
|
77 |
+
sop_value_set = graph.nodes[sop.sop.label][nodes.VALUE_SET]
|
78 |
+
if {int(x) for x in sop_value_set} != {0, 1}:
|
79 |
+
raise NotImplementedError(
|
80 |
+
"Attention patterns can currently only "
|
81 |
+
"average binary variables. Not:", sop_value_set)
|
82 |
+
|
83 |
+
value_set = set()
|
84 |
+
for value in sop_value_set:
|
85 |
+
for length in range(1, max_seq_len + 1):
|
86 |
+
value_set.add(value / length)
|
87 |
+
return value_set
|
88 |
+
raise ValueError(f"Unsupported S-Op: {sop}")
|
89 |
+
|
90 |
+
for node_id in nx.dfs_postorder_nodes(graph.reverse(), sink[nodes.ID]):
|
91 |
+
expr = graph.nodes[node_id][nodes.EXPR]
|
92 |
+
|
93 |
+
if not isinstance(expr, rasp.SOp):
|
94 |
+
# Only S-Ops have output vector spaces.
|
95 |
+
continue
|
96 |
+
|
97 |
+
value_set = compute_value_set(expr)
|
98 |
+
graph.nodes[node_id][nodes.VALUE_SET] = value_set
|
99 |
+
|
100 |
+
if rasp.is_categorical(expr):
|
101 |
+
out_space = bases.VectorSpaceWithBasis.from_values(expr.label, value_set)
|
102 |
+
elif rasp.is_numerical(expr):
|
103 |
+
out_space = bases.VectorSpaceWithBasis.from_names([expr.label])
|
104 |
+
else:
|
105 |
+
raise ValueError(f"Unsupported S-Op type: {expr.type}")
|
106 |
+
graph.nodes[node_id][nodes.OUTPUT_BASIS] = out_space.basis
|
compiler/basis_inference_test.py
ADDED
@@ -0,0 +1,140 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2022 DeepMind Technologies Limited. All Rights Reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
# ==============================================================================
|
15 |
+
"""Tests for compiler.basis_inference."""
|
16 |
+
|
17 |
+
from absl.testing import absltest
|
18 |
+
from absl.testing import parameterized
|
19 |
+
from tracr.compiler import basis_inference
|
20 |
+
from tracr.compiler import nodes
|
21 |
+
from tracr.compiler import rasp_to_graph
|
22 |
+
from tracr.rasp import rasp
|
23 |
+
|
24 |
+
|
25 |
+
class InferBasesTest(parameterized.TestCase):
|
26 |
+
|
27 |
+
def test_arithmetic_error_logs_warning(self):
|
28 |
+
program = rasp.numerical(rasp.Map(lambda x: 1 / x, rasp.tokens))
|
29 |
+
extracted = rasp_to_graph.extract_rasp_graph(program)
|
30 |
+
vocab = {0, 1, 2}
|
31 |
+
with self.assertLogs(level="WARNING"):
|
32 |
+
basis_inference.infer_bases(
|
33 |
+
extracted.graph,
|
34 |
+
extracted.sink,
|
35 |
+
vocab,
|
36 |
+
max_seq_len=1,
|
37 |
+
)
|
38 |
+
|
39 |
+
@parameterized.parameters(({1, 2, 3}, {2, 3, 4}), ({0, 5}, {1, 6}))
|
40 |
+
def test_one_edge(self, vocab, expected_value_set):
|
41 |
+
program = rasp.categorical(rasp.Map(lambda x: x + 1, rasp.tokens))
|
42 |
+
extracted = rasp_to_graph.extract_rasp_graph(program)
|
43 |
+
|
44 |
+
basis_inference.infer_bases(
|
45 |
+
extracted.graph,
|
46 |
+
extracted.sink,
|
47 |
+
vocab,
|
48 |
+
max_seq_len=1,
|
49 |
+
)
|
50 |
+
|
51 |
+
self.assertSetEqual(
|
52 |
+
extracted.graph.nodes[program.label][nodes.VALUE_SET],
|
53 |
+
expected_value_set,
|
54 |
+
)
|
55 |
+
|
56 |
+
def test_primitive_close_to_tip(self):
|
57 |
+
intermediate = rasp.categorical(rasp.tokens + 1)
|
58 |
+
intermediate = rasp.categorical(intermediate + intermediate)
|
59 |
+
program = rasp.categorical(intermediate + rasp.indices)
|
60 |
+
extracted = rasp_to_graph.extract_rasp_graph(program)
|
61 |
+
|
62 |
+
basis_inference.infer_bases(
|
63 |
+
extracted.graph,
|
64 |
+
extracted.sink,
|
65 |
+
{0, 1},
|
66 |
+
max_seq_len=2,
|
67 |
+
)
|
68 |
+
|
69 |
+
self.assertSetEqual(
|
70 |
+
extracted.graph.nodes[program.label][nodes.VALUE_SET],
|
71 |
+
{2, 3, 4, 5},
|
72 |
+
)
|
73 |
+
self.assertSetEqual(
|
74 |
+
extracted.graph.nodes[intermediate.label][nodes.VALUE_SET],
|
75 |
+
{2, 3, 4},
|
76 |
+
)
|
77 |
+
|
78 |
+
def test_categorical_aggregate(self):
|
79 |
+
program = rasp.categorical(
|
80 |
+
rasp.Aggregate(
|
81 |
+
rasp.Select(rasp.tokens, rasp.indices, rasp.Comparison.EQ),
|
82 |
+
rasp.indices,
|
83 |
+
))
|
84 |
+
|
85 |
+
extracted = rasp_to_graph.extract_rasp_graph(program)
|
86 |
+
|
87 |
+
basis_inference.infer_bases(
|
88 |
+
extracted.graph,
|
89 |
+
extracted.sink,
|
90 |
+
{0, 1},
|
91 |
+
max_seq_len=3,
|
92 |
+
)
|
93 |
+
|
94 |
+
self.assertSetEqual(
|
95 |
+
extracted.graph.nodes[program.label][nodes.VALUE_SET],
|
96 |
+
{0, 1, 2},
|
97 |
+
)
|
98 |
+
|
99 |
+
def test_numerical_aggregate(self):
|
100 |
+
program = rasp.numerical(
|
101 |
+
rasp.Aggregate(
|
102 |
+
rasp.Select(rasp.tokens, rasp.indices, rasp.Comparison.EQ),
|
103 |
+
rasp.indices,
|
104 |
+
))
|
105 |
+
|
106 |
+
extracted = rasp_to_graph.extract_rasp_graph(program)
|
107 |
+
|
108 |
+
basis_inference.infer_bases(
|
109 |
+
extracted.graph,
|
110 |
+
extracted.sink,
|
111 |
+
{0, 1},
|
112 |
+
max_seq_len=2,
|
113 |
+
)
|
114 |
+
|
115 |
+
self.assertSetEqual(
|
116 |
+
extracted.graph.nodes[program.label][nodes.VALUE_SET],
|
117 |
+
{0, 1, 1 / 2},
|
118 |
+
)
|
119 |
+
|
120 |
+
def test_selector_width(self):
|
121 |
+
program = rasp.SelectorWidth(
|
122 |
+
rasp.Select(rasp.tokens, rasp.indices, rasp.Comparison.EQ))
|
123 |
+
|
124 |
+
extracted = rasp_to_graph.extract_rasp_graph(program)
|
125 |
+
|
126 |
+
basis_inference.infer_bases(
|
127 |
+
extracted.graph,
|
128 |
+
extracted.sink,
|
129 |
+
{0, 1},
|
130 |
+
max_seq_len=2,
|
131 |
+
)
|
132 |
+
|
133 |
+
self.assertSetEqual(
|
134 |
+
extracted.graph.nodes[program.label][nodes.VALUE_SET],
|
135 |
+
{0, 1, 2},
|
136 |
+
)
|
137 |
+
|
138 |
+
|
139 |
+
if __name__ == "__main__":
|
140 |
+
absltest.main()
|
compiler/compiling.py
ADDED
@@ -0,0 +1,92 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2022 DeepMind Technologies Limited. All Rights Reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
# ==============================================================================
|
15 |
+
"""Combines all steps of compiling a RASP program."""
|
16 |
+
|
17 |
+
from tracr.compiler import assemble
|
18 |
+
from tracr.compiler import basis_inference
|
19 |
+
from tracr.compiler import craft_graph_to_model
|
20 |
+
from tracr.compiler import craft_model_to_transformer
|
21 |
+
from tracr.compiler import expr_to_craft_graph
|
22 |
+
from tracr.compiler import rasp_to_graph
|
23 |
+
from tracr.craft import bases
|
24 |
+
from tracr.rasp import rasp
|
25 |
+
|
26 |
+
COMPILER_BOS = "compiler_bos"
|
27 |
+
COMPILER_PAD = "compiler_pad"
|
28 |
+
|
29 |
+
|
30 |
+
def compile_rasp_to_model(
|
31 |
+
program: rasp.SOp,
|
32 |
+
vocab: set[rasp.Value],
|
33 |
+
max_seq_len: int,
|
34 |
+
causal: bool = False,
|
35 |
+
compiler_bos: str = COMPILER_BOS,
|
36 |
+
compiler_pad: str = COMPILER_PAD,
|
37 |
+
mlp_exactness: int = 100) -> assemble.AssembledTransformerModel:
|
38 |
+
"""Compile a RASP program to transformer weights.
|
39 |
+
|
40 |
+
Args:
|
41 |
+
program: the RASP program to compile.
|
42 |
+
vocab: the set of vocab tokens expected by RASP.
|
43 |
+
max_seq_len: the maximum sequence length for the compiled model.
|
44 |
+
causal: if True, outputs a model with causal masking.
|
45 |
+
compiler_bos: the name of the special BOS token that will be added by the
|
46 |
+
compiler. Must not be present in the vocab.
|
47 |
+
compiler_pad: the name of the special PAD token that will be added by the
|
48 |
+
compiler. Must not be present in the vocab.
|
49 |
+
mlp_exactness: Controls the approximation of the MLP layers. In theory,
|
50 |
+
larger values yield a better approximation. But too large values can cause
|
51 |
+
numerical issues due to large parameter norms. Reasonable values are
|
52 |
+
between 1 and 100.
|
53 |
+
|
54 |
+
Returns:
|
55 |
+
The compiled model.
|
56 |
+
"""
|
57 |
+
|
58 |
+
if compiler_bos in vocab:
|
59 |
+
raise ValueError("Compiler BOS token must not be present in the vocab. "
|
60 |
+
f"Found '{compiler_bos}' in {vocab}")
|
61 |
+
|
62 |
+
if compiler_pad in vocab:
|
63 |
+
raise ValueError("Compiler PAD token must not be present in the vocab. "
|
64 |
+
f"Found '{compiler_pad}' in {vocab}")
|
65 |
+
|
66 |
+
extracted = rasp_to_graph.extract_rasp_graph(program)
|
67 |
+
graph, sources, sink = extracted.graph, extracted.sources, extracted.sink
|
68 |
+
|
69 |
+
basis_inference.infer_bases(
|
70 |
+
graph,
|
71 |
+
sink,
|
72 |
+
vocab,
|
73 |
+
max_seq_len,
|
74 |
+
)
|
75 |
+
|
76 |
+
expr_to_craft_graph.add_craft_components_to_rasp_graph(
|
77 |
+
graph,
|
78 |
+
bos_dir=bases.BasisDirection(rasp.tokens.label, compiler_bos),
|
79 |
+
mlp_exactness=mlp_exactness,
|
80 |
+
)
|
81 |
+
|
82 |
+
craft_model = craft_graph_to_model.craft_graph_to_model(graph, sources)
|
83 |
+
|
84 |
+
return craft_model_to_transformer.craft_model_to_transformer(
|
85 |
+
craft_model=craft_model,
|
86 |
+
graph=graph,
|
87 |
+
sink=sink,
|
88 |
+
max_seq_len=max_seq_len,
|
89 |
+
causal=causal,
|
90 |
+
compiler_bos=compiler_bos,
|
91 |
+
compiler_pad=compiler_pad,
|
92 |
+
)
|
compiler/craft_graph_to_model.py
ADDED
@@ -0,0 +1,238 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2022 DeepMind Technologies Limited. All Rights Reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
# ==============================================================================
|
15 |
+
"""Create a craft model from a computational graph."""
|
16 |
+
|
17 |
+
import collections
|
18 |
+
from typing import Sequence
|
19 |
+
|
20 |
+
import networkx as nx
|
21 |
+
from tracr.compiler import nodes
|
22 |
+
from tracr.craft import bases
|
23 |
+
from tracr.craft import transformers
|
24 |
+
from tracr.rasp import rasp
|
25 |
+
|
26 |
+
Node = nodes.Node
|
27 |
+
NodeID = nodes.NodeID
|
28 |
+
|
29 |
+
|
30 |
+
def _get_longest_path_length_to_node(graph: nx.DiGraph, sources: Sequence[Node],
|
31 |
+
node: Node) -> int:
|
32 |
+
"""Returns the lengths of the longest path from sources to node.
|
33 |
+
|
34 |
+
Only SOps count towards the length of a path.
|
35 |
+
|
36 |
+
Args:
|
37 |
+
graph: DAG to compute longest path in.
|
38 |
+
sources: List of starting nodes, longest path will be a maximum over all.
|
39 |
+
node: Target node.
|
40 |
+
|
41 |
+
Returns:
|
42 |
+
Number of steps needed for the longest path from the source to the node, or
|
43 |
+
-1 if there is no path from any of the sources to the target node.
|
44 |
+
"""
|
45 |
+
if node in sources:
|
46 |
+
return 0
|
47 |
+
|
48 |
+
def num_sops(path: Sequence[NodeID]) -> int:
|
49 |
+
num = 0
|
50 |
+
for node_id in path:
|
51 |
+
if isinstance(graph.nodes[node_id][nodes.EXPR], rasp.SOp):
|
52 |
+
num += 1
|
53 |
+
return num
|
54 |
+
|
55 |
+
result = -1
|
56 |
+
for source in sources:
|
57 |
+
all_paths = nx.all_simple_paths(graph, source[nodes.ID], node[nodes.ID])
|
58 |
+
longest_path_len = max(map(num_sops, all_paths), default=-1) - 1
|
59 |
+
if longest_path_len > result:
|
60 |
+
result = longest_path_len
|
61 |
+
return result
|
62 |
+
|
63 |
+
|
64 |
+
def _node_is_attn(node: Node) -> bool:
|
65 |
+
"""Returns True if node is an attention layer."""
|
66 |
+
return nodes.MODEL_BLOCK in node and isinstance(
|
67 |
+
node[nodes.MODEL_BLOCK],
|
68 |
+
(transformers.AttentionHead, transformers.MultiAttentionHead))
|
69 |
+
|
70 |
+
|
71 |
+
def _node_is_mlp(node: Node) -> bool:
|
72 |
+
"""Returns True if node is an MLP layer."""
|
73 |
+
return nodes.MODEL_BLOCK in node and isinstance(node[nodes.MODEL_BLOCK],
|
74 |
+
transformers.MLP)
|
75 |
+
|
76 |
+
|
77 |
+
def _node_is_residual_block(node: Node) -> bool:
|
78 |
+
"""Returns True if node is a valid residual block (Attn followed by MLP)."""
|
79 |
+
block = node[nodes.MODEL_BLOCK] if nodes.MODEL_BLOCK in node else None
|
80 |
+
if block and isinstance(block, transformers.SeriesWithResiduals):
|
81 |
+
if len(block.blocks) == 2:
|
82 |
+
attn, mlp = block.blocks
|
83 |
+
if (isinstance(
|
84 |
+
attn,
|
85 |
+
(transformers.AttentionHead, transformers.MultiAttentionHead)) and
|
86 |
+
isinstance(mlp, transformers.MLP)):
|
87 |
+
return True
|
88 |
+
return False
|
89 |
+
|
90 |
+
|
91 |
+
def _all_attn_nodes(node_list: Sequence[Node]) -> bool:
|
92 |
+
"""Returns True iff all nodes are attention layers (or nodes is empty)."""
|
93 |
+
for node in node_list:
|
94 |
+
if not _node_is_attn(node):
|
95 |
+
return False
|
96 |
+
return True
|
97 |
+
|
98 |
+
|
99 |
+
def _all_mlp_nodes(node_list: Sequence[Node]) -> bool:
|
100 |
+
"""Returns True iff all nodes are MLP layers (or nodes is empty)."""
|
101 |
+
for node in node_list:
|
102 |
+
if not _node_is_mlp(node):
|
103 |
+
return False
|
104 |
+
return True
|
105 |
+
|
106 |
+
|
107 |
+
def _allocate_modules_to_layers(graph: nx.DiGraph,
|
108 |
+
sources: Sequence[Node]) -> dict[int, int]:
|
109 |
+
"""Allocate all nodes in compute graph to layers.
|
110 |
+
|
111 |
+
First, computes the longest path from the input to each node that is a model
|
112 |
+
component (not input and output nodes). The longest path to a model component
|
113 |
+
(its "depth") determines a layer in which we can place it while ensuring that
|
114 |
+
all necessary previous computations have already happened.
|
115 |
+
|
116 |
+
This assumes layers are arranged as [Attention, MLP, Attention, MLP, ...]
|
117 |
+
|
118 |
+
In the special case where there are only Attention layers at one depth level
|
119 |
+
and only MLP layers in the next depth layer, they are treated as if there
|
120 |
+
are at the same depth because attention layers always come before MLP layers
|
121 |
+
for the same depth.
|
122 |
+
|
123 |
+
Args:
|
124 |
+
graph: RASP graph with craft blocks.
|
125 |
+
sources: List of input nodes
|
126 |
+
|
127 |
+
Returns:
|
128 |
+
A dict mapping from node ids to layer indices, where 0, 1, 2, 3, ...
|
129 |
+
are in the order attention, mlp, attention, mlp, ...
|
130 |
+
"""
|
131 |
+
layer_allocation: dict[int, int] = collections.defaultdict(lambda: -1)
|
132 |
+
depth_by_node_id: dict[int, int] = dict()
|
133 |
+
nodes_by_depth: dict[int, list[Node]] = collections.defaultdict(list)
|
134 |
+
|
135 |
+
# Compute depth of all model components (longest path from source to node)
|
136 |
+
for node_id, node in graph.nodes.items():
|
137 |
+
if (_node_is_mlp(node) or _node_is_attn(node)
|
138 |
+
or _node_is_residual_block(node)):
|
139 |
+
# Node is a model component
|
140 |
+
longest_path_len = _get_longest_path_length_to_node(graph, sources, node)
|
141 |
+
depth_by_node_id[node_id] = longest_path_len
|
142 |
+
nodes_by_depth[longest_path_len].append(node)
|
143 |
+
|
144 |
+
# If at level `depth` there are only attention heads and at level `depths + 1`
|
145 |
+
# there are only MLPs, we can condense them into one level
|
146 |
+
# TODO(b/255936816): Think about improving this heuristic. The heuristic is
|
147 |
+
# not optimal, and only catches very basic opportunities for optimization. It
|
148 |
+
# is easy to come up with opportunities for optimization that it does not
|
149 |
+
# catch.
|
150 |
+
min_depth, max_depth = min(nodes_by_depth.keys()), max(nodes_by_depth.keys())
|
151 |
+
depth = min_depth
|
152 |
+
while depth < max_depth:
|
153 |
+
if _all_attn_nodes(nodes_by_depth[depth]) and _all_mlp_nodes(
|
154 |
+
nodes_by_depth[depth + 1]):
|
155 |
+
# Condense by decrementing the depth of all nodes starting from depth+1
|
156 |
+
for update_depth in range(depth + 1, max_depth + 1):
|
157 |
+
for node in nodes_by_depth[update_depth]:
|
158 |
+
node_id = node[nodes.ID]
|
159 |
+
depth_by_node_id[node_id] = update_depth - 1
|
160 |
+
nodes_by_depth[update_depth - 1].extend(nodes_by_depth[update_depth])
|
161 |
+
nodes_by_depth[update_depth] = []
|
162 |
+
max_depth -= 1
|
163 |
+
depth += 1
|
164 |
+
|
165 |
+
# Allocate nodes to layers by depth, ensuring attn -> mlp -> attn -> mlp ...
|
166 |
+
current_layer = 0
|
167 |
+
current_depth = 1
|
168 |
+
for node_id, depth in sorted(depth_by_node_id.items(), key=lambda x: x[1]):
|
169 |
+
while depth > current_depth:
|
170 |
+
current_depth += 1
|
171 |
+
current_layer += 2
|
172 |
+
if depth == current_depth:
|
173 |
+
if _node_is_residual_block(graph.nodes[node_id]):
|
174 |
+
layer_allocation[node_id] = current_layer
|
175 |
+
else:
|
176 |
+
is_mlp = _node_is_mlp(graph.nodes[node_id])
|
177 |
+
layer_allocation[node_id] = current_layer + int(is_mlp)
|
178 |
+
|
179 |
+
return layer_allocation
|
180 |
+
|
181 |
+
|
182 |
+
def craft_graph_to_model(
|
183 |
+
graph: nx.DiGraph,
|
184 |
+
sources: Sequence[Node]) -> transformers.SeriesWithResiduals:
|
185 |
+
"""Translates a RASP graph with craft blocks into a full craft model.
|
186 |
+
|
187 |
+
1. Allocate modules to layers, assuming layers in the order
|
188 |
+
2. Creates subspaces for all inputs and outputs, and builds residual stream.
|
189 |
+
3. Assembles everything into a craft model and returns it.
|
190 |
+
|
191 |
+
Args:
|
192 |
+
graph: RASP graph with craft blocks.
|
193 |
+
sources: List of input nodes
|
194 |
+
|
195 |
+
Returns:
|
196 |
+
A craft model that can be compiled to model weights.
|
197 |
+
|
198 |
+
Raises:
|
199 |
+
ValueError: On invalid input (if the craft_graph does not have craft blocks
|
200 |
+
already specified)
|
201 |
+
"""
|
202 |
+
layer_allocation = _allocate_modules_to_layers(graph, sources)
|
203 |
+
blocks_by_layer = collections.defaultdict(list)
|
204 |
+
model_blocks = []
|
205 |
+
|
206 |
+
residual_space = bases.VectorSpaceWithBasis([])
|
207 |
+
|
208 |
+
for node_id, layer_no in layer_allocation.items():
|
209 |
+
node = graph.nodes[node_id]
|
210 |
+
block = node[nodes.MODEL_BLOCK] if nodes.MODEL_BLOCK in node else None
|
211 |
+
|
212 |
+
if _node_is_residual_block(node):
|
213 |
+
assert isinstance(block, transformers.SeriesWithResiduals)
|
214 |
+
assert len(block.blocks) == 2
|
215 |
+
residual_space = bases.join_vector_spaces(residual_space,
|
216 |
+
block.blocks[0].residual_space,
|
217 |
+
block.blocks[1].residual_space)
|
218 |
+
blocks_by_layer[layer_no].append(block.blocks[0])
|
219 |
+
blocks_by_layer[layer_no + 1].append(block.blocks[1])
|
220 |
+
elif block:
|
221 |
+
residual_space = bases.join_vector_spaces(
|
222 |
+
residual_space, node[nodes.MODEL_BLOCK].residual_space)
|
223 |
+
blocks_by_layer[layer_no].append(block)
|
224 |
+
|
225 |
+
for layer_no, layer_blocks in sorted(
|
226 |
+
blocks_by_layer.items(), key=lambda x: x[0]):
|
227 |
+
for block in layer_blocks:
|
228 |
+
block.residual_space = residual_space
|
229 |
+
|
230 |
+
if layer_blocks:
|
231 |
+
if layer_no % 2 == 0: # Attention Layer
|
232 |
+
multi_head_attn = transformers.MultiAttentionHead(layer_blocks)
|
233 |
+
model_blocks.append(multi_head_attn)
|
234 |
+
else: # MLP Layer
|
235 |
+
parallel_mlp = transformers.MLP.combine_in_parallel(layer_blocks)
|
236 |
+
model_blocks.append(parallel_mlp)
|
237 |
+
|
238 |
+
return transformers.SeriesWithResiduals(model_blocks)
|
compiler/craft_graph_to_model_test.py
ADDED
@@ -0,0 +1,194 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2022 DeepMind Technologies Limited. All Rights Reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
# ==============================================================================
|
15 |
+
"""Tests for compiler.craft_graph_to_model."""
|
16 |
+
|
17 |
+
from absl.testing import absltest
|
18 |
+
from absl.testing import parameterized
|
19 |
+
import networkx as nx
|
20 |
+
from tracr.compiler import craft_graph_to_model
|
21 |
+
from tracr.compiler import nodes
|
22 |
+
from tracr.compiler import rasp_to_graph
|
23 |
+
from tracr.craft import bases
|
24 |
+
from tracr.craft.chamber import categorical_attn
|
25 |
+
from tracr.craft.chamber import categorical_mlp
|
26 |
+
from tracr.rasp import rasp
|
27 |
+
|
28 |
+
|
29 |
+
class CraftAllocateModulesToLayersTest(parameterized.TestCase):
|
30 |
+
|
31 |
+
def _get_dummy_block(self, block_type):
|
32 |
+
if block_type == "ATTN":
|
33 |
+
return categorical_attn.categorical_attn(
|
34 |
+
query_space=bases.VectorSpaceWithBasis.from_names(["query"]),
|
35 |
+
key_space=bases.VectorSpaceWithBasis.from_names(["bos", "key"]),
|
36 |
+
value_space=bases.VectorSpaceWithBasis.from_names(["bos", "value"]),
|
37 |
+
output_space=bases.VectorSpaceWithBasis.from_names(["output"]),
|
38 |
+
bos_space=bases.VectorSpaceWithBasis.from_names(["bos"]),
|
39 |
+
one_space=bases.VectorSpaceWithBasis.from_names(["one"]),
|
40 |
+
attn_fn=lambda x, y: True,
|
41 |
+
)
|
42 |
+
elif block_type == "MLP":
|
43 |
+
return categorical_mlp.map_categorical_mlp(
|
44 |
+
input_space=bases.VectorSpaceWithBasis.from_names(["input"]),
|
45 |
+
output_space=bases.VectorSpaceWithBasis.from_names(["output"]),
|
46 |
+
operation=lambda x: x,
|
47 |
+
)
|
48 |
+
else:
|
49 |
+
return None
|
50 |
+
|
51 |
+
def test_get_longest_path_length_to_node_returns_expected_result(self):
|
52 |
+
"""Creates a graph and checks the longest path for each node."""
|
53 |
+
|
54 |
+
# Node IDs:
|
55 |
+
# 0 -- 1 -- 2 -- 3 ------------ 4
|
56 |
+
# / /
|
57 |
+
# 5 -- 6 ---------- 7 -- 8 -- 9
|
58 |
+
#
|
59 |
+
# 10
|
60 |
+
# Expected return values:
|
61 |
+
# 0 -- 1 -- 2 -- 3 ------------ 5
|
62 |
+
# / /
|
63 |
+
# 0 -- 1 ---------- 2 -- 3 -- 4
|
64 |
+
#
|
65 |
+
# -1
|
66 |
+
|
67 |
+
graph = nx.DiGraph()
|
68 |
+
node_ids = list(range(11))
|
69 |
+
expected_results = [0, 1, 2, 3, 5, 0, 1, 2, 3, 4, -1]
|
70 |
+
for node_id, res in zip(node_ids, expected_results):
|
71 |
+
graph.add_node(
|
72 |
+
node_id, **{
|
73 |
+
nodes.ID: node_id,
|
74 |
+
nodes.EXPR: rasp.ConstantSOp(1),
|
75 |
+
"expected_result": res
|
76 |
+
})
|
77 |
+
graph.add_edge(0, 1)
|
78 |
+
graph.add_edge(1, 2)
|
79 |
+
graph.add_edge(2, 3)
|
80 |
+
graph.add_edge(3, 4)
|
81 |
+
graph.add_edge(5, 6)
|
82 |
+
graph.add_edge(6, 7)
|
83 |
+
graph.add_edge(7, 8)
|
84 |
+
graph.add_edge(8, 9)
|
85 |
+
graph.add_edge(6, 3)
|
86 |
+
graph.add_edge(9, 4)
|
87 |
+
sources = [graph.nodes[0], graph.nodes[5]]
|
88 |
+
|
89 |
+
for node_id, node in graph.nodes.items():
|
90 |
+
result = craft_graph_to_model._get_longest_path_length_to_node(
|
91 |
+
graph, sources, node)
|
92 |
+
self.assertEqual(result, node["expected_result"])
|
93 |
+
|
94 |
+
def test_allocate_modules_to_layers_returns_expected_result(self):
|
95 |
+
"""Creates a graph and checks if the correct layer assignment is returned."""
|
96 |
+
|
97 |
+
# Computation Graph:
|
98 |
+
# INPUT -- ATTN -- MLP -- ATTN ------ MLP -- OUTPUT
|
99 |
+
# / / /
|
100 |
+
# INPUT -- MLP --- MLP ATTN
|
101 |
+
# \ /
|
102 |
+
# ATTN
|
103 |
+
# Node IDs:
|
104 |
+
# 0 -- 1 -- 2 -- 3 -- 4 -- 5
|
105 |
+
# / / /
|
106 |
+
# 6 -- 7 ---- 8 9
|
107 |
+
# \ /
|
108 |
+
# 10
|
109 |
+
# Expected layer allocation:
|
110 |
+
# -1 -- 0 -- 3 -- 4 -- 7 -- -1
|
111 |
+
# / / /
|
112 |
+
# -1 -- 1 --- 3 6
|
113 |
+
# \ /
|
114 |
+
# 4
|
115 |
+
|
116 |
+
graph = nx.DiGraph()
|
117 |
+
node_ids = list(range(11))
|
118 |
+
types = [
|
119 |
+
"INPUT", "ATTN", "MLP", "ATTN", "MLP", "OUTPUT", "INPUT", "MLP", "MLP",
|
120 |
+
"ATTN", "ATTN"
|
121 |
+
]
|
122 |
+
expected_results = [-1, 0, 3, 4, 7, -1, -1, 1, 3, 6, 4]
|
123 |
+
for node_id, node_type, res in zip(node_ids, types, expected_results):
|
124 |
+
graph.add_node(
|
125 |
+
node_id, **{
|
126 |
+
nodes.ID: node_id,
|
127 |
+
nodes.EXPR: rasp.ConstantSOp(1),
|
128 |
+
nodes.MODEL_BLOCK: self._get_dummy_block(node_type),
|
129 |
+
"expected_result": res
|
130 |
+
})
|
131 |
+
|
132 |
+
graph.add_edge(0, 1)
|
133 |
+
graph.add_edge(1, 2)
|
134 |
+
graph.add_edge(2, 3)
|
135 |
+
graph.add_edge(3, 4)
|
136 |
+
graph.add_edge(4, 5)
|
137 |
+
graph.add_edge(6, 7)
|
138 |
+
graph.add_edge(7, 2)
|
139 |
+
graph.add_edge(7, 8)
|
140 |
+
graph.add_edge(8, 3)
|
141 |
+
graph.add_edge(8, 10)
|
142 |
+
graph.add_edge(9, 4)
|
143 |
+
graph.add_edge(10, 9)
|
144 |
+
|
145 |
+
craft_graph = rasp_to_graph.ExtractRaspGraphOutput(
|
146 |
+
graph=graph,
|
147 |
+
sink=graph.nodes[10],
|
148 |
+
sources=[graph.nodes[0], graph.nodes[6]])
|
149 |
+
|
150 |
+
layer_allocation = craft_graph_to_model._allocate_modules_to_layers(
|
151 |
+
craft_graph.graph, craft_graph.sources)
|
152 |
+
for node_id, node in graph.nodes.items():
|
153 |
+
self.assertEqual(layer_allocation[node_id], node["expected_result"])
|
154 |
+
|
155 |
+
def test_allocate_modules_to_layers_returns_expected_result_for_chain(self):
|
156 |
+
"""Tests a chain of alternating attention layers and MLPs."""
|
157 |
+
|
158 |
+
# Computation Graph:
|
159 |
+
# INPUT -- ATTN -- MLP -- ATTN -- MLP -- OUTPUT
|
160 |
+
# Node IDs:
|
161 |
+
# 0 -- 1 -- 2 -- 3 -- 4 -- 5
|
162 |
+
# Expected layer allocation:
|
163 |
+
# -1 -- 0 -- 1 -- 2 -- 3 -- -1
|
164 |
+
|
165 |
+
graph = nx.DiGraph()
|
166 |
+
node_ids = list(range(11))
|
167 |
+
types = ["INPUT", "ATTN", "MLP", "ATTN", "MLP", "OUTPUT"]
|
168 |
+
expected_results = [-1, 0, 1, 2, 3, -1]
|
169 |
+
for node_id, node_type, res in zip(node_ids, types, expected_results):
|
170 |
+
graph.add_node(
|
171 |
+
node_id, **{
|
172 |
+
nodes.ID: node_id,
|
173 |
+
nodes.EXPR: rasp.ConstantSOp(1),
|
174 |
+
nodes.MODEL_BLOCK: self._get_dummy_block(node_type),
|
175 |
+
"expected_result": res
|
176 |
+
})
|
177 |
+
|
178 |
+
graph.add_edge(0, 1)
|
179 |
+
graph.add_edge(1, 2)
|
180 |
+
graph.add_edge(2, 3)
|
181 |
+
graph.add_edge(3, 4)
|
182 |
+
graph.add_edge(4, 5)
|
183 |
+
|
184 |
+
craft_graph = rasp_to_graph.ExtractRaspGraphOutput(
|
185 |
+
graph=graph, sink=graph.nodes[5], sources=[graph.nodes[0]])
|
186 |
+
|
187 |
+
layer_allocation = craft_graph_to_model._allocate_modules_to_layers(
|
188 |
+
craft_graph.graph, craft_graph.sources)
|
189 |
+
for node_id, node in graph.nodes.items():
|
190 |
+
self.assertEqual(layer_allocation[node_id], node["expected_result"])
|
191 |
+
|
192 |
+
|
193 |
+
if __name__ == "__main__":
|
194 |
+
absltest.main()
|
compiler/craft_model_to_transformer.py
ADDED
@@ -0,0 +1,76 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2022 DeepMind Technologies Limited. All Rights Reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
# ==============================================================================
|
15 |
+
"""Convert craft model into transformer with the correct input/output spaces."""
|
16 |
+
|
17 |
+
import networkx as nx
|
18 |
+
from tracr.compiler import assemble
|
19 |
+
from tracr.compiler import nodes
|
20 |
+
from tracr.craft import bases
|
21 |
+
from tracr.craft import transformers
|
22 |
+
from tracr.rasp import rasp
|
23 |
+
from tracr.transformer import encoder
|
24 |
+
|
25 |
+
|
26 |
+
def craft_model_to_transformer(
|
27 |
+
craft_model: transformers.SeriesWithResiduals,
|
28 |
+
graph: nx.DiGraph,
|
29 |
+
sink: nodes.Node,
|
30 |
+
max_seq_len: int,
|
31 |
+
compiler_bos: str,
|
32 |
+
compiler_pad: str,
|
33 |
+
causal: bool = False,
|
34 |
+
) -> assemble.AssembledTransformerModel:
|
35 |
+
"""Turn a craft model into a transformer model."""
|
36 |
+
|
37 |
+
# Add the compiler BOS token.
|
38 |
+
tokens_value_set = (
|
39 |
+
graph.nodes[rasp.tokens.label][nodes.VALUE_SET].union(
|
40 |
+
{compiler_bos, compiler_pad}))
|
41 |
+
tokens_space = bases.VectorSpaceWithBasis.from_values(rasp.tokens.label,
|
42 |
+
tokens_value_set)
|
43 |
+
|
44 |
+
indices_space = bases.VectorSpaceWithBasis.from_values(
|
45 |
+
rasp.indices.label, range(max_seq_len))
|
46 |
+
|
47 |
+
categorical_output = rasp.is_categorical(sink[nodes.EXPR])
|
48 |
+
output_space = bases.VectorSpaceWithBasis(sink[nodes.OUTPUT_BASIS])
|
49 |
+
|
50 |
+
assembled_model = assemble.assemble_craft_model(
|
51 |
+
craft_model=craft_model,
|
52 |
+
tokens_space=tokens_space,
|
53 |
+
indices_space=indices_space,
|
54 |
+
output_space=output_space,
|
55 |
+
categorical_output=categorical_output,
|
56 |
+
causal=causal,
|
57 |
+
)
|
58 |
+
|
59 |
+
assembled_model.input_encoder = encoder.CategoricalEncoder(
|
60 |
+
basis=tokens_space.basis,
|
61 |
+
enforce_bos=compiler_bos is not None,
|
62 |
+
bos_token=compiler_bos,
|
63 |
+
pad_token=compiler_pad,
|
64 |
+
max_seq_len=max_seq_len + 1 if compiler_bos is not None else max_seq_len,
|
65 |
+
)
|
66 |
+
|
67 |
+
if categorical_output:
|
68 |
+
assembled_model.output_encoder = encoder.CategoricalEncoder(
|
69 |
+
basis=output_space.basis,
|
70 |
+
enforce_bos=False,
|
71 |
+
bos_token=None,
|
72 |
+
pad_token=None)
|
73 |
+
else:
|
74 |
+
assembled_model.output_encoder = encoder.NumericalEncoder()
|
75 |
+
|
76 |
+
return assembled_model
|
compiler/expr_to_craft_graph.py
ADDED
@@ -0,0 +1,277 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2022 DeepMind Technologies Limited. All Rights Reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
# ==============================================================================
|
15 |
+
"""Add craft model blocks to graph of RASPExpr."""
|
16 |
+
|
17 |
+
from typing import Any, Callable, Optional
|
18 |
+
|
19 |
+
import networkx as nx
|
20 |
+
from tracr.compiler import nodes
|
21 |
+
from tracr.craft import bases
|
22 |
+
from tracr.craft.chamber import categorical_attn
|
23 |
+
from tracr.craft.chamber import categorical_mlp
|
24 |
+
from tracr.craft.chamber import numerical_mlp
|
25 |
+
from tracr.craft.chamber import selector_width
|
26 |
+
from tracr.rasp import rasp
|
27 |
+
|
28 |
+
|
29 |
+
def _transform_fun_to_basis_fun(
|
30 |
+
fun: Callable[..., Any],
|
31 |
+
output_direction_name: Optional[str] = None) -> Callable[..., Any]:
|
32 |
+
"""Transforms a function acting on values into one acting on directions."""
|
33 |
+
|
34 |
+
def bases_fun(*args):
|
35 |
+
values = [d.value for d in args]
|
36 |
+
result = fun(*values)
|
37 |
+
if output_direction_name:
|
38 |
+
return bases.BasisDirection(output_direction_name, result)
|
39 |
+
return result
|
40 |
+
|
41 |
+
return bases_fun
|
42 |
+
|
43 |
+
|
44 |
+
def _check_selector_expression(expr, graph):
|
45 |
+
"""Check graph structure and encodings for an aggregate or selector width."""
|
46 |
+
sel_expr = expr.selector
|
47 |
+
|
48 |
+
# Check graph structure
|
49 |
+
assert sel_expr.label in graph.predecessors(expr.label)
|
50 |
+
assert sel_expr.keys.label in graph.predecessors(sel_expr.label)
|
51 |
+
assert sel_expr.queries.label in graph.predecessors(sel_expr.label)
|
52 |
+
|
53 |
+
if (not rasp.is_categorical(sel_expr.queries) or
|
54 |
+
not rasp.is_categorical(sel_expr.keys)):
|
55 |
+
raise ValueError("Selector keys and queries must be categorical.")
|
56 |
+
|
57 |
+
|
58 |
+
def add_craft_components_to_rasp_graph(
|
59 |
+
graph: nx.DiGraph,
|
60 |
+
bos_dir: bases.BasisDirection = bases.BasisDirection("tokens", "bos"),
|
61 |
+
one_dir: bases.BasisDirection = bases.BasisDirection("one"),
|
62 |
+
causal: bool = False,
|
63 |
+
mlp_exactness: float = 100,
|
64 |
+
) -> None:
|
65 |
+
"""Translates expressions to craft blocks and attaches them to the graph.
|
66 |
+
|
67 |
+
Sets the `MODEL_BLOCK` attribute for all nodes in `graph`.
|
68 |
+
|
69 |
+
Args:
|
70 |
+
graph: RASP graph with `VALUE_SET` but not `MODEL_BLOCK` attributes.
|
71 |
+
bos_dir: Basis direction representing beginning of sequence (bos) token.
|
72 |
+
one_dir: Auxiliary basis direction that must contain 1.
|
73 |
+
causal: If True, marks attention blocks as causal.
|
74 |
+
mlp_exactness: Controls the approximation of the MLP layers.
|
75 |
+
|
76 |
+
Raises:
|
77 |
+
ValueError: On invalid input (if `MODEL_BLOCK` is set already, or
|
78 |
+
`VALUE_SET` is not set already)
|
79 |
+
NotImplementedError: If the graph contains an unsupported expression.
|
80 |
+
"""
|
81 |
+
one_space = bases.VectorSpaceWithBasis([one_dir])
|
82 |
+
|
83 |
+
for node_id, node in graph.nodes.items():
|
84 |
+
expr = node[nodes.EXPR]
|
85 |
+
|
86 |
+
if not isinstance(expr, rasp.SOp):
|
87 |
+
continue
|
88 |
+
|
89 |
+
if nodes.MODEL_BLOCK in node and node[nodes.MODEL_BLOCK]:
|
90 |
+
raise ValueError("Input graph cannot have model blocks set already.")
|
91 |
+
if nodes.VALUE_SET not in node:
|
92 |
+
raise ValueError(
|
93 |
+
"Craft components can only be added after basis inference.")
|
94 |
+
|
95 |
+
if expr is rasp.tokens or expr is rasp.indices:
|
96 |
+
block = None
|
97 |
+
elif isinstance(expr, rasp.Map):
|
98 |
+
inner_expr, inner_node = expr.inner, graph.nodes[expr.inner.label]
|
99 |
+
assert inner_expr.label in graph.predecessors(node_id)
|
100 |
+
input_space = bases.VectorSpaceWithBasis(inner_node[nodes.OUTPUT_BASIS])
|
101 |
+
output_space = bases.VectorSpaceWithBasis(node[nodes.OUTPUT_BASIS])
|
102 |
+
|
103 |
+
if rasp.is_categorical(inner_expr) and rasp.is_categorical(expr):
|
104 |
+
basis_fun = _transform_fun_to_basis_fun(expr.f, expr.label)
|
105 |
+
block = categorical_mlp.map_categorical_mlp(
|
106 |
+
input_space=input_space,
|
107 |
+
output_space=output_space,
|
108 |
+
operation=basis_fun)
|
109 |
+
elif rasp.is_categorical(inner_expr) and rasp.is_numerical(expr):
|
110 |
+
block = categorical_mlp.map_categorical_to_numerical_mlp(
|
111 |
+
input_space=input_space,
|
112 |
+
output_space=output_space,
|
113 |
+
operation=expr.f,
|
114 |
+
)
|
115 |
+
elif rasp.is_numerical(inner_expr) and rasp.is_categorical(expr):
|
116 |
+
block = numerical_mlp.map_numerical_to_categorical_mlp(
|
117 |
+
f=expr.f,
|
118 |
+
input_space=input_space,
|
119 |
+
output_space=output_space,
|
120 |
+
input_value_set=inner_node[nodes.VALUE_SET],
|
121 |
+
one_space=one_space,
|
122 |
+
hidden_name=f"_hidden_{expr.label}_",
|
123 |
+
large_number=mlp_exactness)
|
124 |
+
elif rasp.is_numerical(inner_expr) and rasp.is_numerical(expr):
|
125 |
+
block = numerical_mlp.map_numerical_mlp(
|
126 |
+
f=expr.f,
|
127 |
+
input_space=input_space,
|
128 |
+
output_space=output_space,
|
129 |
+
input_value_set=inner_node[nodes.VALUE_SET],
|
130 |
+
one_space=one_space,
|
131 |
+
hidden_name=f"_hidden_{expr.label}_",
|
132 |
+
large_number=mlp_exactness)
|
133 |
+
else:
|
134 |
+
raise NotImplementedError("Map does no support "
|
135 |
+
f"in_type '{inner_expr.type}' and"
|
136 |
+
f" out_type '{expr.type}'!")
|
137 |
+
|
138 |
+
elif isinstance(expr, rasp.SequenceMap):
|
139 |
+
fst_expr, fst_node = expr.fst, graph.nodes[expr.fst.label]
|
140 |
+
snd_expr, snd_node = expr.snd, graph.nodes[expr.snd.label]
|
141 |
+
|
142 |
+
# Check graph structure
|
143 |
+
assert fst_expr.label in graph.predecessors(node_id)
|
144 |
+
assert snd_expr.label in graph.predecessors(node_id)
|
145 |
+
|
146 |
+
fst_space = bases.VectorSpaceWithBasis(fst_node[nodes.OUTPUT_BASIS])
|
147 |
+
snd_space = bases.VectorSpaceWithBasis(snd_node[nodes.OUTPUT_BASIS])
|
148 |
+
out_space = bases.VectorSpaceWithBasis(node[nodes.OUTPUT_BASIS])
|
149 |
+
|
150 |
+
if (isinstance(expr, rasp.LinearSequenceMap) and
|
151 |
+
not all(rasp.is_numerical(x) for x in (fst_expr, snd_expr, expr))):
|
152 |
+
raise NotImplementedError("Linear SequenceMap only supports numerical "
|
153 |
+
"inputs/outputs.")
|
154 |
+
elif (
|
155 |
+
not isinstance(expr, rasp.LinearSequenceMap) and
|
156 |
+
not all(rasp.is_categorical(x) for x in (fst_expr, snd_expr, expr))):
|
157 |
+
raise NotImplementedError("(Non-linear) SequenceMap only supports "
|
158 |
+
"categorical inputs/outputs.")
|
159 |
+
|
160 |
+
if isinstance(expr, rasp.LinearSequenceMap):
|
161 |
+
assert len(fst_space.basis) == 1
|
162 |
+
assert len(snd_space.basis) == 1
|
163 |
+
assert len(out_space.basis) == 1
|
164 |
+
block = numerical_mlp.linear_sequence_map_numerical_mlp(
|
165 |
+
input1_basis_direction=fst_space.basis[0],
|
166 |
+
input2_basis_direction=snd_space.basis[0],
|
167 |
+
output_basis_direction=out_space.basis[0],
|
168 |
+
input1_factor=expr.fst_fac,
|
169 |
+
input2_factor=expr.snd_fac,
|
170 |
+
hidden_name=f"_hidden_{expr.label}_")
|
171 |
+
elif fst_space == snd_space:
|
172 |
+
# It's okay to use the local variable expr.f because it is
|
173 |
+
# only used within the same loop iteration to create the MLP.
|
174 |
+
# pylint: disable=cell-var-from-loop
|
175 |
+
basis_fun = _transform_fun_to_basis_fun(lambda x: expr.f(x, x),
|
176 |
+
expr.label)
|
177 |
+
block = categorical_mlp.map_categorical_mlp(
|
178 |
+
input_space=fst_space, output_space=out_space, operation=basis_fun)
|
179 |
+
else:
|
180 |
+
basis_fun = _transform_fun_to_basis_fun(expr.f, expr.label)
|
181 |
+
block = categorical_mlp.sequence_map_categorical_mlp(
|
182 |
+
input1_space=fst_space,
|
183 |
+
input2_space=snd_space,
|
184 |
+
output_space=out_space,
|
185 |
+
operation=basis_fun,
|
186 |
+
one_space=one_space,
|
187 |
+
hidden_name=f"_hidden_{expr.label}_")
|
188 |
+
elif isinstance(expr, rasp.Aggregate):
|
189 |
+
sel_expr: rasp.Select = expr.selector
|
190 |
+
agg_expr: rasp.Aggregate = expr
|
191 |
+
|
192 |
+
if not isinstance(sel_expr, rasp.Select):
|
193 |
+
raise TypeError("Compiling composite Selectors is not supported. "
|
194 |
+
f"Got a {sel_expr}.")
|
195 |
+
|
196 |
+
queries = graph.nodes[sel_expr.queries.label]
|
197 |
+
keys = graph.nodes[sel_expr.keys.label]
|
198 |
+
sop = graph.nodes[agg_expr.sop.label]
|
199 |
+
|
200 |
+
_check_selector_expression(expr, graph)
|
201 |
+
assert agg_expr.sop.label in graph.predecessors(node_id)
|
202 |
+
if rasp.get_encoding(agg_expr.sop) != rasp.get_encoding(agg_expr):
|
203 |
+
raise ValueError(
|
204 |
+
"sop encoding must match output encoding of the aggregate.")
|
205 |
+
if rasp.is_categorical(agg_expr) and agg_expr.default is not None:
|
206 |
+
raise ValueError("Default for a categorical aggregate must be None. "
|
207 |
+
f"Got {agg_expr.default}")
|
208 |
+
if rasp.is_numerical(agg_expr) and agg_expr.default != 0:
|
209 |
+
raise ValueError("Default for a numerical aggregate must be 0. "
|
210 |
+
f"Got {agg_expr.default}")
|
211 |
+
|
212 |
+
bos_space = bases.VectorSpaceWithBasis([bos_dir])
|
213 |
+
one_space = bases.VectorSpaceWithBasis([one_dir])
|
214 |
+
query_space = bases.VectorSpaceWithBasis(queries[nodes.OUTPUT_BASIS])
|
215 |
+
key_space = bases.VectorSpaceWithBasis(keys[nodes.OUTPUT_BASIS])
|
216 |
+
value_space = bases.VectorSpaceWithBasis(sop[nodes.OUTPUT_BASIS])
|
217 |
+
output_space = bases.VectorSpaceWithBasis(node[nodes.OUTPUT_BASIS])
|
218 |
+
|
219 |
+
# Argument order is different in craft / transformers than RASP selectors
|
220 |
+
def attn_basis_fn(query: bases.BasisDirection,
|
221 |
+
key: bases.BasisDirection) -> bool:
|
222 |
+
# It's okay to use the local variable sel_expr because this function is
|
223 |
+
# only used within the same loop iteration to create an attention head.
|
224 |
+
# pylint: disable=cell-var-from-loop
|
225 |
+
selector_basis_fn = _transform_fun_to_basis_fun(sel_expr.predicate)
|
226 |
+
return selector_basis_fn(key, query)
|
227 |
+
|
228 |
+
block = categorical_attn.categorical_attn(
|
229 |
+
query_space=query_space,
|
230 |
+
key_space=key_space,
|
231 |
+
value_space=value_space,
|
232 |
+
output_space=output_space,
|
233 |
+
bos_space=bos_space,
|
234 |
+
one_space=one_space,
|
235 |
+
attn_fn=attn_basis_fn,
|
236 |
+
default_output=output_space.null_vector(),
|
237 |
+
causal=causal,
|
238 |
+
always_attend_to_bos=False,
|
239 |
+
use_bos_for_default_output=True,
|
240 |
+
softmax_coldness=100)
|
241 |
+
elif isinstance(expr, rasp.SelectorWidth):
|
242 |
+
sel_expr = expr.selector
|
243 |
+
queries = graph.nodes[sel_expr.queries.label]
|
244 |
+
keys = graph.nodes[sel_expr.keys.label]
|
245 |
+
_check_selector_expression(expr, graph)
|
246 |
+
|
247 |
+
bos_space = bases.VectorSpaceWithBasis([bos_dir])
|
248 |
+
query_space = bases.VectorSpaceWithBasis(queries[nodes.OUTPUT_BASIS])
|
249 |
+
key_space = bases.VectorSpaceWithBasis(keys[nodes.OUTPUT_BASIS])
|
250 |
+
output_space = bases.VectorSpaceWithBasis(node[nodes.OUTPUT_BASIS])
|
251 |
+
|
252 |
+
# Argument order is different in craft / transformers than RASP selectors
|
253 |
+
def attn_basis_fn(query: bases.BasisDirection,
|
254 |
+
key: bases.BasisDirection) -> bool:
|
255 |
+
# It's okay to use the local variable sel_expr because this function is
|
256 |
+
# only used within the same loop iteration to create an attention head.
|
257 |
+
selector_basis_fn = _transform_fun_to_basis_fun(sel_expr.predicate) # pylint: disable=cell-var-from-loop
|
258 |
+
return selector_basis_fn(key, query)
|
259 |
+
|
260 |
+
block = selector_width.selector_width(
|
261 |
+
query_space=query_space,
|
262 |
+
key_space=key_space,
|
263 |
+
output_space=output_space,
|
264 |
+
bos_space=bos_space,
|
265 |
+
one_space=one_space,
|
266 |
+
attn_fn=attn_basis_fn,
|
267 |
+
out_value_set=node[nodes.VALUE_SET],
|
268 |
+
categorical_output=rasp.is_categorical(expr),
|
269 |
+
causal=False,
|
270 |
+
softmax_coldness=100,
|
271 |
+
mlp_large_number=mlp_exactness,
|
272 |
+
label=expr.label)
|
273 |
+
else:
|
274 |
+
raise NotImplementedError(f"Expression {expr} cannot be translated to "
|
275 |
+
"a model component.")
|
276 |
+
|
277 |
+
graph.nodes[node_id][nodes.MODEL_BLOCK] = block
|
compiler/expr_to_craft_graph_test.py
ADDED
@@ -0,0 +1,121 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2022 DeepMind Technologies Limited. All Rights Reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
# ==============================================================================
|
15 |
+
"""Tests for compiler.expr_to_craft_graph."""
|
16 |
+
|
17 |
+
from absl.testing import absltest
|
18 |
+
from absl.testing import parameterized
|
19 |
+
from tracr.compiler import basis_inference
|
20 |
+
from tracr.compiler import expr_to_craft_graph
|
21 |
+
from tracr.compiler import lib
|
22 |
+
from tracr.compiler import nodes
|
23 |
+
from tracr.compiler import rasp_to_graph
|
24 |
+
from tracr.craft import bases
|
25 |
+
from tracr.craft import transformers
|
26 |
+
from tracr.rasp import rasp
|
27 |
+
|
28 |
+
|
29 |
+
class ExprToCraftGraphTest(parameterized.TestCase):
|
30 |
+
|
31 |
+
def _check_block_types_are_correct(self, graph):
|
32 |
+
for _, node in graph.nodes.items():
|
33 |
+
expr = node[nodes.EXPR]
|
34 |
+
if isinstance(expr, rasp.SOp):
|
35 |
+
block = node[nodes.MODEL_BLOCK]
|
36 |
+
if isinstance(expr, (rasp.Map, rasp.SequenceMap)):
|
37 |
+
self.assertIsInstance(block, transformers.MLP)
|
38 |
+
elif isinstance(expr, rasp.Aggregate):
|
39 |
+
self.assertIsInstance(block, transformers.AttentionHead)
|
40 |
+
|
41 |
+
def _get_input_space_from_node(self, node):
|
42 |
+
block = node[nodes.MODEL_BLOCK]
|
43 |
+
if isinstance(block, transformers.MLP):
|
44 |
+
return block.fst.input_space
|
45 |
+
elif isinstance(block, transformers.AttentionHead):
|
46 |
+
return bases.join_vector_spaces(block.w_qk.left_space,
|
47 |
+
block.w_qk.right_space,
|
48 |
+
block.w_ov.input_space)
|
49 |
+
else:
|
50 |
+
return None
|
51 |
+
|
52 |
+
def _check_spaces_are_consistent(self, graph):
|
53 |
+
"""Check that for each edge the output is a subspace of the input."""
|
54 |
+
for u, v in graph.edges:
|
55 |
+
u_node, v_node = graph.nodes[u], graph.nodes[v]
|
56 |
+
if isinstance(u_node[nodes.EXPR], rasp.SOp) and isinstance(
|
57 |
+
v_node[nodes.EXPR], rasp.SOp):
|
58 |
+
u_out_basis = u_node[nodes.OUTPUT_BASIS]
|
59 |
+
u_out_space = bases.VectorSpaceWithBasis(u_out_basis)
|
60 |
+
v_in_space = self._get_input_space_from_node(v_node)
|
61 |
+
self.assertTrue(u_out_space.issubspace(v_in_space))
|
62 |
+
|
63 |
+
@parameterized.named_parameters(
|
64 |
+
dict(
|
65 |
+
testcase_name="single_map",
|
66 |
+
program=rasp.Map(lambda x: x + 1, rasp.tokens)),
|
67 |
+
dict(
|
68 |
+
testcase_name="single_sequence_map",
|
69 |
+
program=rasp.SequenceMap(lambda x, y: x + y, rasp.tokens,
|
70 |
+
rasp.indices)),
|
71 |
+
dict(
|
72 |
+
testcase_name="single_select_aggregate",
|
73 |
+
program=rasp.Aggregate(
|
74 |
+
rasp.Select(rasp.tokens, rasp.indices, rasp.Comparison.EQ),
|
75 |
+
rasp.tokens,
|
76 |
+
)),
|
77 |
+
dict(testcase_name="reverse", program=lib.make_reverse(rasp.tokens)),
|
78 |
+
dict(testcase_name="length", program=lib.make_length()))
|
79 |
+
def test_compiling_rasp_programs(self, program):
|
80 |
+
vocab = {0, 1, 2}
|
81 |
+
extracted = rasp_to_graph.extract_rasp_graph(program)
|
82 |
+
basis_inference.infer_bases(
|
83 |
+
extracted.graph,
|
84 |
+
extracted.sink,
|
85 |
+
vocab,
|
86 |
+
max_seq_len=3,
|
87 |
+
)
|
88 |
+
expr_to_craft_graph.add_craft_components_to_rasp_graph(extracted.graph)
|
89 |
+
self._check_block_types_are_correct(extracted.graph)
|
90 |
+
self._check_spaces_are_consistent(extracted.graph)
|
91 |
+
|
92 |
+
def test_add_craft_components_raises_value_error_if_called_before_basis_inference(
|
93 |
+
self):
|
94 |
+
program = rasp.categorical(rasp.Map(lambda x: x + 1, rasp.tokens))
|
95 |
+
extracted = rasp_to_graph.extract_rasp_graph(program)
|
96 |
+
|
97 |
+
with self.assertRaisesRegex(
|
98 |
+
ValueError,
|
99 |
+
r"^.*Craft components can only be added after basis inference.*$"):
|
100 |
+
expr_to_craft_graph.add_craft_components_to_rasp_graph(extracted.graph)
|
101 |
+
|
102 |
+
def test_add_craft_components_raises_value_error_if_called_twice(self):
|
103 |
+
vocab = {0, 1, 2}
|
104 |
+
program = rasp.categorical(rasp.Map(lambda x: x + 1, rasp.tokens))
|
105 |
+
extracted = rasp_to_graph.extract_rasp_graph(program)
|
106 |
+
|
107 |
+
basis_inference.infer_bases(
|
108 |
+
extracted.graph,
|
109 |
+
extracted.sink,
|
110 |
+
vocab,
|
111 |
+
max_seq_len=1,
|
112 |
+
)
|
113 |
+
|
114 |
+
expr_to_craft_graph.add_craft_components_to_rasp_graph(extracted.graph)
|
115 |
+
with self.assertRaisesRegex(
|
116 |
+
ValueError, r"^.*Input graph cannot have model blocks set already.*$"):
|
117 |
+
expr_to_craft_graph.add_craft_components_to_rasp_graph(extracted.graph)
|
118 |
+
|
119 |
+
|
120 |
+
if __name__ == "__main__":
|
121 |
+
absltest.main()
|
compiler/lib.py
ADDED
@@ -0,0 +1,371 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2022 DeepMind Technologies Limited. All Rights Reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
# ==============================================================================
|
15 |
+
"""RASP programs only using the subset of RASP supported by the compiler."""
|
16 |
+
|
17 |
+
from typing import Sequence
|
18 |
+
|
19 |
+
from tracr.rasp import rasp
|
20 |
+
|
21 |
+
### Programs that work only under non-causal evaluation.
|
22 |
+
|
23 |
+
|
24 |
+
def make_length() -> rasp.SOp:
|
25 |
+
"""Creates the `length` SOp using selector width primitive.
|
26 |
+
|
27 |
+
Example usage:
|
28 |
+
length = make_length()
|
29 |
+
length("abcdefg")
|
30 |
+
>> [7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0]
|
31 |
+
|
32 |
+
Returns:
|
33 |
+
length: SOp mapping an input to a sequence, where every element
|
34 |
+
is the length of that sequence.
|
35 |
+
"""
|
36 |
+
all_true_selector = rasp.Select(
|
37 |
+
rasp.tokens, rasp.tokens, rasp.Comparison.TRUE).named("all_true_selector")
|
38 |
+
return rasp.SelectorWidth(all_true_selector).named("length")
|
39 |
+
|
40 |
+
|
41 |
+
length = make_length()
|
42 |
+
|
43 |
+
|
44 |
+
def make_reverse(sop: rasp.SOp) -> rasp.SOp:
|
45 |
+
"""Create an SOp that reverses a sequence, using length primitive.
|
46 |
+
|
47 |
+
Example usage:
|
48 |
+
reverse = make_reverse(rasp.tokens)
|
49 |
+
reverse("Hello")
|
50 |
+
>> ['o', 'l', 'l', 'e', 'H']
|
51 |
+
|
52 |
+
Args:
|
53 |
+
sop: an SOp
|
54 |
+
|
55 |
+
Returns:
|
56 |
+
reverse : SOp that reverses the input sequence.
|
57 |
+
"""
|
58 |
+
opp_idx = (length - rasp.indices).named("opp_idx")
|
59 |
+
opp_idx = (opp_idx - 1).named("opp_idx-1")
|
60 |
+
reverse_selector = rasp.Select(rasp.indices, opp_idx,
|
61 |
+
rasp.Comparison.EQ).named("reverse_selector")
|
62 |
+
return rasp.Aggregate(reverse_selector, sop).named("reverse")
|
63 |
+
|
64 |
+
|
65 |
+
def make_pair_balance(sop: rasp.SOp, open_token: str,
|
66 |
+
close_token: str) -> rasp.SOp:
|
67 |
+
"""Return fraction of previous open tokens minus the fraction of close tokens.
|
68 |
+
|
69 |
+
(As implemented in the RASP paper.)
|
70 |
+
|
71 |
+
If the outputs are always non-negative and end in 0, that implies the input
|
72 |
+
has balanced parentheses.
|
73 |
+
|
74 |
+
Example usage:
|
75 |
+
num_l = make_pair_balance(rasp.tokens, "(", ")")
|
76 |
+
num_l("a()b(c))")
|
77 |
+
>> [0, 1/2, 0, 0, 1/5, 1/6, 0, -1/8]
|
78 |
+
|
79 |
+
Args:
|
80 |
+
sop: Input SOp.
|
81 |
+
open_token: Token that counts positive.
|
82 |
+
close_token: Token that counts negative.
|
83 |
+
|
84 |
+
Returns:
|
85 |
+
pair_balance: SOp mapping an input to a sequence, where every element
|
86 |
+
is the fraction of previous open tokens minus previous close tokens.
|
87 |
+
"""
|
88 |
+
bools_open = rasp.numerical(sop == open_token).named("bools_open")
|
89 |
+
opens = rasp.numerical(make_frac_prevs(bools_open)).named("opens")
|
90 |
+
|
91 |
+
bools_close = rasp.numerical(sop == close_token).named("bools_close")
|
92 |
+
closes = rasp.numerical(make_frac_prevs(bools_close)).named("closes")
|
93 |
+
|
94 |
+
pair_balance = rasp.numerical(rasp.LinearSequenceMap(opens, closes, 1, -1))
|
95 |
+
return pair_balance.named("pair_balance")
|
96 |
+
|
97 |
+
|
98 |
+
def make_shuffle_dyck(pairs: list[str]) -> rasp.SOp:
|
99 |
+
"""Returns 1 if a set of parentheses are balanced, 0 else.
|
100 |
+
|
101 |
+
(As implemented in the RASP paper.)
|
102 |
+
|
103 |
+
Example usage:
|
104 |
+
shuffle_dyck2 = make_shuffle_dyck(pairs=["()", "{}"])
|
105 |
+
shuffle_dyck2("({)}")
|
106 |
+
>> [1, 1, 1, 1]
|
107 |
+
shuffle_dyck2("(){)}")
|
108 |
+
>> [0, 0, 0, 0, 0]
|
109 |
+
|
110 |
+
Args:
|
111 |
+
pairs: List of pairs of open and close tokens that each should be balanced.
|
112 |
+
"""
|
113 |
+
assert len(pairs) >= 1
|
114 |
+
|
115 |
+
# Compute running balance of each type of parenthesis
|
116 |
+
balances = []
|
117 |
+
for pair in pairs:
|
118 |
+
assert len(pair) == 2
|
119 |
+
open_token, close_token = pair
|
120 |
+
balance = make_pair_balance(
|
121 |
+
rasp.tokens, open_token=open_token,
|
122 |
+
close_token=close_token).named(f"balance_{pair}")
|
123 |
+
balances.append(balance)
|
124 |
+
|
125 |
+
# Check if balances where negative anywhere -> parentheses not balanced
|
126 |
+
any_negative = balances[0] < 0
|
127 |
+
for balance in balances[1:]:
|
128 |
+
any_negative = any_negative | (balance < 0)
|
129 |
+
|
130 |
+
# Convert to numerical SOp
|
131 |
+
any_negative = rasp.numerical(rasp.Map(lambda x: x,
|
132 |
+
any_negative)).named("any_negative")
|
133 |
+
|
134 |
+
select_all = rasp.Select(rasp.indices, rasp.indices,
|
135 |
+
rasp.Comparison.TRUE).named("select_all")
|
136 |
+
has_neg = rasp.numerical(rasp.Aggregate(select_all, any_negative,
|
137 |
+
default=0)).named("has_neg")
|
138 |
+
|
139 |
+
# Check if all balances are 0 at the end -> closed all parentheses
|
140 |
+
all_zero = balances[0] == 0
|
141 |
+
for balance in balances[1:]:
|
142 |
+
all_zero = all_zero & (balance == 0)
|
143 |
+
|
144 |
+
select_last = rasp.Select(rasp.indices, length - 1,
|
145 |
+
rasp.Comparison.EQ).named("select_last")
|
146 |
+
last_zero = rasp.Aggregate(select_last, all_zero).named("last_zero")
|
147 |
+
|
148 |
+
not_has_neg = (~has_neg).named("not_has_neg")
|
149 |
+
return (last_zero & not_has_neg).named("shuffle_dyck")
|
150 |
+
|
151 |
+
|
152 |
+
def make_shuffle_dyck2() -> rasp.SOp:
|
153 |
+
return make_shuffle_dyck(pairs=["()", "{}"]).named("shuffle_dyck2")
|
154 |
+
|
155 |
+
|
156 |
+
def make_hist() -> rasp.SOp:
|
157 |
+
"""Returns the number of times each token occurs in the input.
|
158 |
+
|
159 |
+
(As implemented in the RASP paper.)
|
160 |
+
|
161 |
+
Example usage:
|
162 |
+
hist = make_hist()
|
163 |
+
hist("abac")
|
164 |
+
>> [2, 1, 2, 1]
|
165 |
+
"""
|
166 |
+
same_tok = rasp.Select(rasp.tokens, rasp.tokens,
|
167 |
+
rasp.Comparison.EQ).named("same_tok")
|
168 |
+
return rasp.SelectorWidth(same_tok).named("hist")
|
169 |
+
|
170 |
+
|
171 |
+
def make_sort_unique(vals: rasp.SOp, keys: rasp.SOp) -> rasp.SOp:
|
172 |
+
"""Returns vals sorted by < relation on keys.
|
173 |
+
|
174 |
+
Only supports unique keys.
|
175 |
+
|
176 |
+
Example usage:
|
177 |
+
sort = make_sort(rasp.tokens, rasp.tokens)
|
178 |
+
sort([2, 4, 3, 1])
|
179 |
+
>> [1, 2, 3, 4]
|
180 |
+
|
181 |
+
Args:
|
182 |
+
vals: Values to sort.
|
183 |
+
keys: Keys for sorting.
|
184 |
+
"""
|
185 |
+
smaller = rasp.Select(keys, keys, rasp.Comparison.LT).named("smaller")
|
186 |
+
target_pos = rasp.SelectorWidth(smaller).named("target_pos")
|
187 |
+
sel_new = rasp.Select(target_pos, rasp.indices, rasp.Comparison.EQ)
|
188 |
+
return rasp.Aggregate(sel_new, vals).named("sort")
|
189 |
+
|
190 |
+
|
191 |
+
def make_sort(vals: rasp.SOp, keys: rasp.SOp, *, max_seq_len: int,
|
192 |
+
min_key: float) -> rasp.SOp:
|
193 |
+
"""Returns vals sorted by < relation on keys, which don't need to be unique.
|
194 |
+
|
195 |
+
The implementation differs from the RASP paper, as it avoids using
|
196 |
+
compositions of selectors to break ties. Instead, it uses the arguments
|
197 |
+
max_seq_len and min_key to ensure the keys are unique.
|
198 |
+
|
199 |
+
Note that this approach only works for numerical keys.
|
200 |
+
|
201 |
+
Example usage:
|
202 |
+
sort = make_sort(rasp.tokens, rasp.tokens, 5, 1)
|
203 |
+
sort([2, 4, 3, 1])
|
204 |
+
>> [1, 2, 3, 4]
|
205 |
+
sort([2, 4, 1, 2])
|
206 |
+
>> [1, 2, 2, 4]
|
207 |
+
|
208 |
+
Args:
|
209 |
+
vals: Values to sort.
|
210 |
+
keys: Keys for sorting.
|
211 |
+
max_seq_len: Maximum sequence length (used to ensure keys are unique)
|
212 |
+
min_key: Minimum key value (used to ensure keys are unique)
|
213 |
+
|
214 |
+
Returns:
|
215 |
+
Output SOp of sort program.
|
216 |
+
"""
|
217 |
+
keys = rasp.SequenceMap(lambda x, i: x + min_key * i / max_seq_len, keys,
|
218 |
+
rasp.indices)
|
219 |
+
return make_sort_unique(vals, keys)
|
220 |
+
|
221 |
+
|
222 |
+
def make_sort_freq(max_seq_len: int) -> rasp.SOp:
|
223 |
+
"""Returns tokens sorted by the frequency they appear in the input.
|
224 |
+
|
225 |
+
Tokens the appear the same amount of times are output in the same order as in
|
226 |
+
the input.
|
227 |
+
|
228 |
+
Example usage:
|
229 |
+
sort = make_sort_freq(rasp.tokens, rasp.tokens, 5)
|
230 |
+
sort([2, 4, 2, 1])
|
231 |
+
>> [2, 2, 4, 1]
|
232 |
+
|
233 |
+
Args:
|
234 |
+
max_seq_len: Maximum sequence length (used to ensure keys are unique)
|
235 |
+
"""
|
236 |
+
hist = -1 * make_hist().named("hist")
|
237 |
+
return make_sort(
|
238 |
+
rasp.tokens, hist, max_seq_len=max_seq_len, min_key=1).named("sort_freq")
|
239 |
+
|
240 |
+
|
241 |
+
### Programs that work under both causal and regular evaluation.
|
242 |
+
|
243 |
+
|
244 |
+
def make_frac_prevs(bools: rasp.SOp) -> rasp.SOp:
|
245 |
+
"""Count the fraction of previous tokens where a specific condition was True.
|
246 |
+
|
247 |
+
(As implemented in the RASP paper.)
|
248 |
+
|
249 |
+
Example usage:
|
250 |
+
num_l = make_frac_prevs(rasp.tokens=="l")
|
251 |
+
num_l("hello")
|
252 |
+
>> [0, 0, 1/3, 1/2, 2/5]
|
253 |
+
|
254 |
+
Args:
|
255 |
+
bools: SOp mapping a sequence to a sequence of booleans.
|
256 |
+
|
257 |
+
Returns:
|
258 |
+
frac_prevs: SOp mapping an input to a sequence, where every element
|
259 |
+
is the fraction of previous "True" tokens.
|
260 |
+
"""
|
261 |
+
bools = rasp.numerical(bools)
|
262 |
+
prevs = rasp.Select(rasp.indices, rasp.indices, rasp.Comparison.LEQ)
|
263 |
+
return rasp.numerical(rasp.Aggregate(prevs, bools,
|
264 |
+
default=0)).named("frac_prevs")
|
265 |
+
|
266 |
+
|
267 |
+
def shift_by(offset: int, /, sop: rasp.SOp) -> rasp.SOp:
|
268 |
+
"""Returns the sop, shifted by `offset`, None-padded."""
|
269 |
+
select_off_by_offset = rasp.Select(rasp.indices, rasp.indices,
|
270 |
+
lambda k, q: q == k + offset)
|
271 |
+
out = rasp.Aggregate(select_off_by_offset, sop, default=None)
|
272 |
+
return out.named(f"shift_by({offset})")
|
273 |
+
|
274 |
+
|
275 |
+
def detect_pattern(sop: rasp.SOp, pattern: Sequence[rasp.Value]) -> rasp.SOp:
|
276 |
+
"""Returns an SOp which is True at the final element of the pattern.
|
277 |
+
|
278 |
+
The first len(pattern) - 1 elements of the output SOp are None-padded.
|
279 |
+
|
280 |
+
detect_pattern(tokens, "abc")("abcabc") == [None, None, T, F, F, T]
|
281 |
+
|
282 |
+
Args:
|
283 |
+
sop: the SOp in which to look for patterns.
|
284 |
+
pattern: a sequence of values to look for.
|
285 |
+
|
286 |
+
Returns:
|
287 |
+
a sop which detects the pattern.
|
288 |
+
"""
|
289 |
+
|
290 |
+
if len(pattern) < 1:
|
291 |
+
raise ValueError(f"Length of `pattern` must be at least 1. Got {pattern}")
|
292 |
+
|
293 |
+
# detectors[i] will be a boolean-valued SOp which is true at position j iff
|
294 |
+
# the i'th (from the end) element of the pattern was detected at position j-i.
|
295 |
+
detectors = []
|
296 |
+
for i, element in enumerate(reversed(pattern)):
|
297 |
+
detector = sop == element
|
298 |
+
if i != 0:
|
299 |
+
detector = shift_by(i, detector)
|
300 |
+
detectors.append(detector)
|
301 |
+
|
302 |
+
# All that's left is to take the AND over all detectors.
|
303 |
+
pattern_detected = detectors.pop()
|
304 |
+
while detectors:
|
305 |
+
pattern_detected = pattern_detected & detectors.pop()
|
306 |
+
|
307 |
+
return pattern_detected.named(f"detect_pattern({pattern})")
|
308 |
+
|
309 |
+
|
310 |
+
def make_count_less_freq(n: int) -> rasp.SOp:
|
311 |
+
"""Returns how many tokens appear fewer than n times in the input.
|
312 |
+
|
313 |
+
The output sequence contains this count in each position.
|
314 |
+
|
315 |
+
Example usage:
|
316 |
+
count_less_freq = make_count_less_freq(2)
|
317 |
+
count_less_freq(["a", "a", "a", "b", "b", "c"])
|
318 |
+
>> [3, 3, 3, 3, 3, 3]
|
319 |
+
count_less_freq(["a", "a", "c", "b", "b", "c"])
|
320 |
+
>> [6, 6, 6, 6, 6, 6]
|
321 |
+
|
322 |
+
Args:
|
323 |
+
n: Integer to compare token frequences to.
|
324 |
+
"""
|
325 |
+
hist = make_hist().named("hist")
|
326 |
+
select_less = rasp.Select(hist, hist,
|
327 |
+
lambda x, y: x <= n).named("select_less")
|
328 |
+
return rasp.SelectorWidth(select_less).named("count_less_freq")
|
329 |
+
|
330 |
+
|
331 |
+
def make_count(sop, token):
|
332 |
+
"""Returns the count of `token` in `sop`.
|
333 |
+
|
334 |
+
The output sequence contains this count in each position.
|
335 |
+
|
336 |
+
Example usage:
|
337 |
+
count = make_count(tokens, "a")
|
338 |
+
count(["a", "a", "a", "b", "b", "c"])
|
339 |
+
>> [3, 3, 3, 3, 3, 3]
|
340 |
+
count(["c", "a", "b", "c"])
|
341 |
+
>> [1, 1, 1, 1]
|
342 |
+
|
343 |
+
Args:
|
344 |
+
sop: Sop to count tokens in.
|
345 |
+
token: Token to count.
|
346 |
+
"""
|
347 |
+
return rasp.SelectorWidth(rasp.Select(
|
348 |
+
sop, sop, lambda k, q: k == token)).named(f"count_{token}")
|
349 |
+
|
350 |
+
|
351 |
+
def make_nary_sequencemap(f, *sops):
|
352 |
+
"""Returns an SOp that simulates an n-ary SequenceMap.
|
353 |
+
|
354 |
+
Uses multiple binary SequenceMaps to convert n SOps x_1, x_2, ..., x_n
|
355 |
+
into a single SOp arguments that takes n-tuples as value. The n-ary sequence
|
356 |
+
map implementing f is then a Map on this resulting SOp.
|
357 |
+
|
358 |
+
Note that the intermediate variables representing tuples of varying length
|
359 |
+
will be encoded categorically, and can become very high-dimensional. So,
|
360 |
+
using this function might lead to very large compiled models.
|
361 |
+
|
362 |
+
Args:
|
363 |
+
f: Function with n arguments.
|
364 |
+
*sops: Sequence of SOps, one for each argument of f.
|
365 |
+
"""
|
366 |
+
values, *sops = sops
|
367 |
+
for sop in sops:
|
368 |
+
# x is a single entry in the first iteration but a tuple in later iterations
|
369 |
+
values = rasp.SequenceMap(
|
370 |
+
lambda x, y: (*x, y) if isinstance(x, tuple) else (x, y), values, sop)
|
371 |
+
return rasp.Map(lambda args: f(*args), values)
|
compiler/lib_test.py
ADDED
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2022 DeepMind Technologies Limited. All Rights Reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
# ==============================================================================
|
15 |
+
"""Tests for compiler.lib."""
|
16 |
+
|
17 |
+
from absl.testing import absltest
|
18 |
+
from absl.testing import parameterized
|
19 |
+
from tracr.compiler import test_cases
|
20 |
+
from tracr.rasp import causal_eval
|
21 |
+
from tracr.rasp import rasp
|
22 |
+
|
23 |
+
|
24 |
+
class LibTest(parameterized.TestCase):
|
25 |
+
|
26 |
+
@parameterized.named_parameters(*test_cases.TEST_CASES)
|
27 |
+
def test_program_produces_expected_output(self, program, test_input,
|
28 |
+
expected_output, **kwargs):
|
29 |
+
del kwargs
|
30 |
+
self.assertEqual(rasp.evaluate(program, test_input), expected_output)
|
31 |
+
|
32 |
+
@parameterized.named_parameters(*test_cases.CAUSAL_TEST_CASES)
|
33 |
+
def test_causal_program_produces_expected_output(self, program, test_input,
|
34 |
+
expected_output, **kwargs):
|
35 |
+
del kwargs
|
36 |
+
self.assertEqual(causal_eval.evaluate(program, test_input), expected_output)
|
37 |
+
|
38 |
+
|
39 |
+
if __name__ == "__main__":
|
40 |
+
absltest.main()
|
compiler/nodes.py
ADDED
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2022 DeepMind Technologies Limited. All Rights Reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
# ==============================================================================
|
15 |
+
"""Documents the data stored in nodes after each compiler pass."""
|
16 |
+
|
17 |
+
from typing import Any
|
18 |
+
|
19 |
+
Node = dict[str, Any]
|
20 |
+
NodeID = str
|
21 |
+
|
22 |
+
# RASP -> Graph
|
23 |
+
ID = "ID" # unique ID of the node
|
24 |
+
EXPR = "EXPR" # the RASPExpr of the node
|
25 |
+
|
26 |
+
# Basis inference
|
27 |
+
# Note that only S-Op expressions will have these keys set.
|
28 |
+
VALUE_SET = "VALUE_SET" # possible values taken on by this SOp.
|
29 |
+
OUTPUT_BASIS = "OUTPUT_BASIS" # the corresponding named basis.
|
30 |
+
|
31 |
+
# RASP Graph -> Craft Graph
|
32 |
+
MODEL_BLOCK = "MODEL_BLOCK" # craft block representing a RASPExpr
|
compiler/rasp_to_craft_integration_test.py
ADDED
@@ -0,0 +1,254 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2022 DeepMind Technologies Limited. All Rights Reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
# ==============================================================================
|
15 |
+
"""Integration tests for the RASP -> craft stages of the compiler."""
|
16 |
+
|
17 |
+
import unittest
|
18 |
+
|
19 |
+
from absl.testing import absltest
|
20 |
+
from absl.testing import parameterized
|
21 |
+
import numpy as np
|
22 |
+
from tracr.compiler import basis_inference
|
23 |
+
from tracr.compiler import craft_graph_to_model
|
24 |
+
from tracr.compiler import expr_to_craft_graph
|
25 |
+
from tracr.compiler import nodes
|
26 |
+
from tracr.compiler import rasp_to_graph
|
27 |
+
from tracr.compiler import test_cases
|
28 |
+
from tracr.craft import bases
|
29 |
+
from tracr.craft import tests_common
|
30 |
+
from tracr.rasp import rasp
|
31 |
+
|
32 |
+
_BOS_DIRECTION = "rasp_to_transformer_integration_test_BOS"
|
33 |
+
_ONE_DIRECTION = "rasp_to_craft_integration_test_ONE"
|
34 |
+
|
35 |
+
|
36 |
+
def _make_input_space(vocab, max_seq_len):
|
37 |
+
tokens_space = bases.VectorSpaceWithBasis.from_values("tokens", vocab)
|
38 |
+
indices_space = bases.VectorSpaceWithBasis.from_values(
|
39 |
+
"indices", range(max_seq_len))
|
40 |
+
one_space = bases.VectorSpaceWithBasis.from_names([_ONE_DIRECTION])
|
41 |
+
bos_space = bases.VectorSpaceWithBasis.from_names([_BOS_DIRECTION])
|
42 |
+
input_space = bases.join_vector_spaces(tokens_space, indices_space, one_space,
|
43 |
+
bos_space)
|
44 |
+
|
45 |
+
return input_space
|
46 |
+
|
47 |
+
|
48 |
+
def _embed_input(input_seq, input_space):
|
49 |
+
bos_vec = input_space.vector_from_basis_direction(
|
50 |
+
bases.BasisDirection(_BOS_DIRECTION))
|
51 |
+
one_vec = input_space.vector_from_basis_direction(
|
52 |
+
bases.BasisDirection(_ONE_DIRECTION))
|
53 |
+
embedded_input = [bos_vec + one_vec]
|
54 |
+
for i, val in enumerate(input_seq):
|
55 |
+
i_vec = input_space.vector_from_basis_direction(
|
56 |
+
bases.BasisDirection("indices", i))
|
57 |
+
val_vec = input_space.vector_from_basis_direction(
|
58 |
+
bases.BasisDirection("tokens", val))
|
59 |
+
embedded_input.append(i_vec + val_vec + one_vec)
|
60 |
+
return bases.VectorInBasis.stack(embedded_input)
|
61 |
+
|
62 |
+
|
63 |
+
def _embed_output(output_seq, output_space, categorical_output):
|
64 |
+
embedded_output = []
|
65 |
+
output_label = output_space.basis[0].name
|
66 |
+
for x in output_seq:
|
67 |
+
if x is None:
|
68 |
+
out_vec = output_space.null_vector()
|
69 |
+
elif categorical_output:
|
70 |
+
out_vec = output_space.vector_from_basis_direction(
|
71 |
+
bases.BasisDirection(output_label, x))
|
72 |
+
else:
|
73 |
+
out_vec = x * output_space.vector_from_basis_direction(
|
74 |
+
output_space.basis[0])
|
75 |
+
embedded_output.append(out_vec)
|
76 |
+
return bases.VectorInBasis.stack(embedded_output)
|
77 |
+
|
78 |
+
|
79 |
+
class CompilerIntegrationTest(tests_common.VectorFnTestCase):
|
80 |
+
|
81 |
+
@parameterized.named_parameters(
|
82 |
+
dict(
|
83 |
+
testcase_name="map",
|
84 |
+
program=rasp.categorical(rasp.Map(lambda x: x + 1, rasp.tokens))),
|
85 |
+
dict(
|
86 |
+
testcase_name="sequence_map",
|
87 |
+
program=rasp.categorical(
|
88 |
+
rasp.SequenceMap(lambda x, y: x + y, rasp.tokens, rasp.indices))),
|
89 |
+
dict(
|
90 |
+
testcase_name="sequence_map_with_same_input",
|
91 |
+
program=rasp.categorical(
|
92 |
+
rasp.SequenceMap(lambda x, y: x + y, rasp.tokens, rasp.tokens))),
|
93 |
+
dict(
|
94 |
+
testcase_name="select_aggregate",
|
95 |
+
program=rasp.Aggregate(
|
96 |
+
rasp.Select(rasp.tokens, rasp.tokens, rasp.Comparison.EQ),
|
97 |
+
rasp.Map(lambda x: 1, rasp.tokens))))
|
98 |
+
def test_rasp_program_and_craft_model_produce_same_output(self, program):
|
99 |
+
vocab = {0, 1, 2}
|
100 |
+
max_seq_len = 3
|
101 |
+
|
102 |
+
extracted = rasp_to_graph.extract_rasp_graph(program)
|
103 |
+
basis_inference.infer_bases(
|
104 |
+
extracted.graph,
|
105 |
+
extracted.sink,
|
106 |
+
vocab,
|
107 |
+
max_seq_len=max_seq_len,
|
108 |
+
)
|
109 |
+
expr_to_craft_graph.add_craft_components_to_rasp_graph(
|
110 |
+
extracted.graph,
|
111 |
+
bos_dir=bases.BasisDirection(_BOS_DIRECTION),
|
112 |
+
one_dir=bases.BasisDirection(_ONE_DIRECTION),
|
113 |
+
)
|
114 |
+
model = craft_graph_to_model.craft_graph_to_model(extracted.graph,
|
115 |
+
extracted.sources)
|
116 |
+
input_space = _make_input_space(vocab, max_seq_len)
|
117 |
+
output_space = bases.VectorSpaceWithBasis(
|
118 |
+
extracted.sink[nodes.OUTPUT_BASIS])
|
119 |
+
|
120 |
+
for val in vocab:
|
121 |
+
test_input = _embed_input([val], input_space)
|
122 |
+
rasp_output = program([val])
|
123 |
+
expected_output = _embed_output(
|
124 |
+
output_seq=rasp_output,
|
125 |
+
output_space=output_space,
|
126 |
+
categorical_output=True)
|
127 |
+
test_output = model.apply(test_input).project(output_space)
|
128 |
+
self.assertVectorAllClose(
|
129 |
+
tests_common.strip_bos_token(test_output), expected_output)
|
130 |
+
|
131 |
+
@parameterized.named_parameters(*test_cases.TEST_CASES)
|
132 |
+
def test_compiled_models_produce_expected_output(self, program, vocab,
|
133 |
+
test_input, expected_output,
|
134 |
+
max_seq_len, **kwargs):
|
135 |
+
del kwargs
|
136 |
+
categorical_output = rasp.is_categorical(program)
|
137 |
+
|
138 |
+
extracted = rasp_to_graph.extract_rasp_graph(program)
|
139 |
+
basis_inference.infer_bases(
|
140 |
+
extracted.graph,
|
141 |
+
extracted.sink,
|
142 |
+
vocab,
|
143 |
+
max_seq_len=max_seq_len,
|
144 |
+
)
|
145 |
+
expr_to_craft_graph.add_craft_components_to_rasp_graph(
|
146 |
+
extracted.graph,
|
147 |
+
bos_dir=bases.BasisDirection(_BOS_DIRECTION),
|
148 |
+
one_dir=bases.BasisDirection(_ONE_DIRECTION),
|
149 |
+
)
|
150 |
+
model = craft_graph_to_model.craft_graph_to_model(extracted.graph,
|
151 |
+
extracted.sources)
|
152 |
+
input_space = _make_input_space(vocab, max_seq_len)
|
153 |
+
output_space = bases.VectorSpaceWithBasis(
|
154 |
+
extracted.sink[nodes.OUTPUT_BASIS])
|
155 |
+
if not categorical_output:
|
156 |
+
self.assertLen(output_space.basis, 1)
|
157 |
+
|
158 |
+
test_input_vector = _embed_input(test_input, input_space)
|
159 |
+
expected_output_vector = _embed_output(
|
160 |
+
output_seq=expected_output,
|
161 |
+
output_space=output_space,
|
162 |
+
categorical_output=categorical_output)
|
163 |
+
test_output = model.apply(test_input_vector).project(output_space)
|
164 |
+
self.assertVectorAllClose(
|
165 |
+
tests_common.strip_bos_token(test_output), expected_output_vector)
|
166 |
+
|
167 |
+
@unittest.expectedFailure
|
168 |
+
def test_setting_default_values_can_lead_to_wrong_outputs_in_compiled_model(
|
169 |
+
self, program):
|
170 |
+
# This is an example program in which setting a default value for aggregate
|
171 |
+
# writes a value to the bos token position, which interfers with a later
|
172 |
+
# aggregate operation causing the compiled model to have the wrong output.
|
173 |
+
|
174 |
+
vocab = {"a", "b"}
|
175 |
+
test_input = ["a"]
|
176 |
+
max_seq_len = 2
|
177 |
+
|
178 |
+
# RASP: [False, True]
|
179 |
+
# compiled: [False, False, True]
|
180 |
+
not_a = rasp.Map(lambda x: x != "a", rasp.tokens)
|
181 |
+
|
182 |
+
# RASP:
|
183 |
+
# [[True, False],
|
184 |
+
# [False, False]]
|
185 |
+
# compiled:
|
186 |
+
# [[False,True, False],
|
187 |
+
# [True, False, False]]
|
188 |
+
sel1 = rasp.Select(rasp.tokens, rasp.tokens,
|
189 |
+
lambda k, q: k == "a" and q == "a")
|
190 |
+
|
191 |
+
# RASP: [False, True]
|
192 |
+
# compiled: [True, False, True]
|
193 |
+
agg1 = rasp.Aggregate(sel1, not_a, default=True)
|
194 |
+
|
195 |
+
# RASP:
|
196 |
+
# [[False, True]
|
197 |
+
# [True, True]]
|
198 |
+
# compiled:
|
199 |
+
# [[True, False, False]
|
200 |
+
# [True, False, False]]
|
201 |
+
# because pre-softmax we get
|
202 |
+
# [[1.5, 1, 1]
|
203 |
+
# [1.5, 1, 1]]
|
204 |
+
# instead of
|
205 |
+
# [[0.5, 1, 1]
|
206 |
+
# [0.5, 1, 1]]
|
207 |
+
# Because agg1 = True is stored on the BOS token position
|
208 |
+
sel2 = rasp.Select(agg1, agg1, lambda k, q: k or q)
|
209 |
+
|
210 |
+
# RASP: [1, 0.5]
|
211 |
+
# compiled
|
212 |
+
# [1, 1, 1]
|
213 |
+
program = rasp.numerical(
|
214 |
+
rasp.Aggregate(sel2, rasp.numerical(not_a), default=1))
|
215 |
+
expected_output = [1, 0.5]
|
216 |
+
|
217 |
+
# RASP program gives the correct output
|
218 |
+
program_output = program(test_input)
|
219 |
+
np.testing.assert_allclose(program_output, expected_output)
|
220 |
+
|
221 |
+
extracted = rasp_to_graph.extract_rasp_graph(program)
|
222 |
+
basis_inference.infer_bases(
|
223 |
+
extracted.graph,
|
224 |
+
extracted.sink,
|
225 |
+
vocab,
|
226 |
+
max_seq_len=max_seq_len,
|
227 |
+
)
|
228 |
+
expr_to_craft_graph.add_craft_components_to_rasp_graph(
|
229 |
+
extracted.graph,
|
230 |
+
bos_dir=bases.BasisDirection(_BOS_DIRECTION),
|
231 |
+
one_dir=bases.BasisDirection(_ONE_DIRECTION),
|
232 |
+
)
|
233 |
+
model = craft_graph_to_model.craft_graph_to_model(extracted.graph,
|
234 |
+
extracted.sources)
|
235 |
+
|
236 |
+
input_space = _make_input_space(vocab, max_seq_len)
|
237 |
+
output_space = bases.VectorSpaceWithBasis(
|
238 |
+
extracted.sink[nodes.OUTPUT_BASIS])
|
239 |
+
|
240 |
+
test_input_vector = _embed_input(test_input, input_space)
|
241 |
+
expected_output_vector = _embed_output(
|
242 |
+
output_seq=expected_output,
|
243 |
+
output_space=output_space,
|
244 |
+
categorical_output=True)
|
245 |
+
compiled_model_output = model.apply(test_input_vector).project(output_space)
|
246 |
+
|
247 |
+
# Compiled craft model gives correct output
|
248 |
+
self.assertVectorAllClose(
|
249 |
+
tests_common.strip_bos_token(compiled_model_output),
|
250 |
+
expected_output_vector)
|
251 |
+
|
252 |
+
|
253 |
+
if __name__ == "__main__":
|
254 |
+
absltest.main()
|
compiler/rasp_to_graph.py
ADDED
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2022 DeepMind Technologies Limited. All Rights Reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
# ==============================================================================
|
15 |
+
"""Converting a RaspExpr to a graph."""
|
16 |
+
|
17 |
+
import dataclasses
|
18 |
+
import queue
|
19 |
+
|
20 |
+
import networkx as nx
|
21 |
+
from tracr.compiler import nodes
|
22 |
+
from tracr.rasp import rasp
|
23 |
+
|
24 |
+
Node = nodes.Node
|
25 |
+
NodeID = nodes.NodeID
|
26 |
+
|
27 |
+
|
28 |
+
@dataclasses.dataclass
|
29 |
+
class ExtractRaspGraphOutput:
|
30 |
+
graph: nx.DiGraph
|
31 |
+
sink: Node # the program's output.
|
32 |
+
sources: list[Node] # the primitive S-Ops.
|
33 |
+
|
34 |
+
|
35 |
+
def extract_rasp_graph(tip: rasp.SOp) -> ExtractRaspGraphOutput:
|
36 |
+
"""Converts a RASP program into a graph representation."""
|
37 |
+
expr_queue = queue.Queue()
|
38 |
+
graph = nx.DiGraph()
|
39 |
+
sources: list[NodeID] = []
|
40 |
+
|
41 |
+
def ensure_node(expr: rasp.RASPExpr) -> NodeID:
|
42 |
+
"""Finds or creates a graph node corresponding to expr; returns its ID."""
|
43 |
+
node_id = expr.label
|
44 |
+
if node_id not in graph:
|
45 |
+
graph.add_node(node_id, **{nodes.ID: node_id, nodes.EXPR: expr})
|
46 |
+
|
47 |
+
return node_id
|
48 |
+
|
49 |
+
# Breadth-first search over the RASP expression graph.
|
50 |
+
|
51 |
+
def visit_raspexpr(expr: rasp.RASPExpr):
|
52 |
+
parent_id = ensure_node(expr)
|
53 |
+
|
54 |
+
for child_expr in expr.children:
|
55 |
+
expr_queue.put(child_expr)
|
56 |
+
child_id = ensure_node(child_expr)
|
57 |
+
graph.add_edge(child_id, parent_id)
|
58 |
+
|
59 |
+
if not expr.children:
|
60 |
+
sources.append(graph.nodes[parent_id])
|
61 |
+
|
62 |
+
expr_queue.put(tip)
|
63 |
+
sink = graph.nodes[ensure_node(tip)]
|
64 |
+
while not expr_queue.empty():
|
65 |
+
visit_raspexpr(expr_queue.get())
|
66 |
+
|
67 |
+
return ExtractRaspGraphOutput(graph=graph, sink=sink, sources=sources)
|
compiler/rasp_to_graph_test.py
ADDED
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2022 DeepMind Technologies Limited. All Rights Reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
# ==============================================================================
|
15 |
+
"""Tests for compiler.rasp_to_graph."""
|
16 |
+
|
17 |
+
from absl.testing import absltest
|
18 |
+
from absl.testing import parameterized
|
19 |
+
from tracr.compiler import nodes
|
20 |
+
from tracr.compiler import rasp_to_graph
|
21 |
+
from tracr.rasp import rasp
|
22 |
+
|
23 |
+
|
24 |
+
class ExtractRaspGraphTest(parameterized.TestCase):
|
25 |
+
|
26 |
+
def test_primitives_have_no_edges(self):
|
27 |
+
tokens_graph = rasp_to_graph.extract_rasp_graph(rasp.tokens).graph
|
28 |
+
self.assertEmpty(tokens_graph.edges)
|
29 |
+
|
30 |
+
indices_graph = rasp_to_graph.extract_rasp_graph(rasp.indices).graph
|
31 |
+
self.assertEmpty(indices_graph.edges)
|
32 |
+
|
33 |
+
full_graph = rasp_to_graph.extract_rasp_graph(rasp.Full(1)).graph
|
34 |
+
self.assertEmpty(full_graph.edges)
|
35 |
+
|
36 |
+
def test_one_edge(self):
|
37 |
+
program = rasp.Map(lambda x: x + 1, rasp.tokens)
|
38 |
+
|
39 |
+
graph = rasp_to_graph.extract_rasp_graph(program).graph
|
40 |
+
|
41 |
+
self.assertLen(graph.edges, 1)
|
42 |
+
(u, v), = graph.edges
|
43 |
+
self.assertEqual(graph.nodes[u][nodes.EXPR], rasp.tokens)
|
44 |
+
self.assertEqual(graph.nodes[v][nodes.EXPR], program)
|
45 |
+
|
46 |
+
def test_aggregate(self):
|
47 |
+
program = rasp.Aggregate(
|
48 |
+
rasp.Select(rasp.tokens, rasp.indices, rasp.Comparison.EQ),
|
49 |
+
rasp.indices,
|
50 |
+
)
|
51 |
+
|
52 |
+
extracted = rasp_to_graph.extract_rasp_graph(program)
|
53 |
+
|
54 |
+
# Expected graph:
|
55 |
+
#
|
56 |
+
# indices \ --------
|
57 |
+
# \ \
|
58 |
+
# select -- program
|
59 |
+
# tokens /
|
60 |
+
|
61 |
+
self.assertLen(extracted.graph.edges, 4)
|
62 |
+
self.assertEqual(extracted.sink[nodes.EXPR], program)
|
63 |
+
for source in extracted.sources:
|
64 |
+
self.assertIn(
|
65 |
+
source[nodes.EXPR],
|
66 |
+
[rasp.tokens, rasp.indices],
|
67 |
+
)
|
68 |
+
|
69 |
+
|
70 |
+
if __name__ == "__main__":
|
71 |
+
absltest.main()
|
compiler/rasp_to_transformer_integration_test.py
ADDED
@@ -0,0 +1,214 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2022 DeepMind Technologies Limited. All Rights Reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
# ==============================================================================
|
15 |
+
"""Integration tests for the full RASP -> transformer compilation."""
|
16 |
+
|
17 |
+
from absl.testing import absltest
|
18 |
+
from absl.testing import parameterized
|
19 |
+
import jax
|
20 |
+
import numpy as np
|
21 |
+
|
22 |
+
from tracr.compiler import compiling
|
23 |
+
from tracr.compiler import lib
|
24 |
+
from tracr.compiler import test_cases
|
25 |
+
from tracr.craft import tests_common
|
26 |
+
from tracr.rasp import rasp
|
27 |
+
|
28 |
+
_COMPILER_BOS = "rasp_to_transformer_integration_test_BOS"
|
29 |
+
_COMPILER_PAD = "rasp_to_transformer_integration_test_PAD"
|
30 |
+
|
31 |
+
# Force float32 precision on TPU, which otherwise defaults to float16.
|
32 |
+
jax.config.update("jax_default_matmul_precision", "float32")
|
33 |
+
|
34 |
+
|
35 |
+
class CompilerIntegrationTest(tests_common.VectorFnTestCase):
|
36 |
+
|
37 |
+
def assertSequenceEqualWhenExpectedIsNotNone(self, actual_seq, expected_seq):
|
38 |
+
for actual, expected in zip(actual_seq, expected_seq):
|
39 |
+
if expected is not None and actual != expected:
|
40 |
+
self.fail(f"{actual_seq} does not match (ignoring Nones) "
|
41 |
+
f"{expected_seq=}")
|
42 |
+
|
43 |
+
@parameterized.named_parameters(
|
44 |
+
dict(
|
45 |
+
testcase_name="map",
|
46 |
+
program=rasp.Map(lambda x: x + 1, rasp.tokens)),
|
47 |
+
dict(
|
48 |
+
testcase_name="sequence_map",
|
49 |
+
program=rasp.SequenceMap(lambda x, y: x + y, rasp.tokens,
|
50 |
+
rasp.indices)),
|
51 |
+
dict(
|
52 |
+
testcase_name="sequence_map_with_same_input",
|
53 |
+
program=rasp.SequenceMap(lambda x, y: x + y, rasp.tokens,
|
54 |
+
rasp.indices)),
|
55 |
+
dict(
|
56 |
+
testcase_name="select_aggregate",
|
57 |
+
program=rasp.Aggregate(
|
58 |
+
rasp.Select(rasp.tokens, rasp.tokens, rasp.Comparison.EQ),
|
59 |
+
rasp.Map(lambda x: 1, rasp.tokens))))
|
60 |
+
def test_rasp_program_and_transformer_produce_same_output(self, program):
|
61 |
+
vocab = {0, 1, 2}
|
62 |
+
max_seq_len = 3
|
63 |
+
assembled_model = compiling.compile_rasp_to_model(
|
64 |
+
program, vocab, max_seq_len, compiler_bos=_COMPILER_BOS)
|
65 |
+
|
66 |
+
test_outputs = {}
|
67 |
+
rasp_outputs = {}
|
68 |
+
for val in vocab:
|
69 |
+
test_outputs[val] = assembled_model.apply([_COMPILER_BOS, val]).decoded[1]
|
70 |
+
rasp_outputs[val] = program([val])[0]
|
71 |
+
|
72 |
+
with self.subTest(val=0):
|
73 |
+
self.assertEqual(test_outputs[0], rasp_outputs[0])
|
74 |
+
with self.subTest(val=1):
|
75 |
+
self.assertEqual(test_outputs[1], rasp_outputs[1])
|
76 |
+
with self.subTest(val=2):
|
77 |
+
self.assertEqual(test_outputs[2], rasp_outputs[2])
|
78 |
+
|
79 |
+
@parameterized.named_parameters(*test_cases.TEST_CASES)
|
80 |
+
def test_compiled_models_produce_expected_output(self, program, vocab,
|
81 |
+
test_input, expected_output,
|
82 |
+
max_seq_len, **kwargs):
|
83 |
+
del kwargs
|
84 |
+
assembled_model = compiling.compile_rasp_to_model(
|
85 |
+
program, vocab, max_seq_len, compiler_bos=_COMPILER_BOS)
|
86 |
+
test_output = assembled_model.apply([_COMPILER_BOS] + test_input)
|
87 |
+
|
88 |
+
if isinstance(expected_output[0], (int, float)):
|
89 |
+
np.testing.assert_allclose(
|
90 |
+
test_output.decoded[1:], expected_output, atol=1e-7, rtol=0.005)
|
91 |
+
else:
|
92 |
+
self.assertSequenceEqualWhenExpectedIsNotNone(test_output.decoded[1:],
|
93 |
+
expected_output)
|
94 |
+
|
95 |
+
@parameterized.named_parameters(*test_cases.CAUSAL_TEST_CASES)
|
96 |
+
def test_compiled_causal_models_produce_expected_output(
|
97 |
+
self, program, vocab, test_input, expected_output, max_seq_len, **kwargs):
|
98 |
+
del kwargs
|
99 |
+
assembled_model = compiling.compile_rasp_to_model(
|
100 |
+
program,
|
101 |
+
vocab,
|
102 |
+
max_seq_len,
|
103 |
+
causal=True,
|
104 |
+
compiler_bos=_COMPILER_BOS,
|
105 |
+
compiler_pad=_COMPILER_PAD)
|
106 |
+
test_output = assembled_model.apply([_COMPILER_BOS] + test_input)
|
107 |
+
|
108 |
+
if isinstance(expected_output[0], (int, float)):
|
109 |
+
np.testing.assert_allclose(
|
110 |
+
test_output.decoded[1:], expected_output, atol=1e-7, rtol=0.005)
|
111 |
+
else:
|
112 |
+
self.assertSequenceEqualWhenExpectedIsNotNone(test_output.decoded[1:],
|
113 |
+
expected_output)
|
114 |
+
|
115 |
+
@parameterized.named_parameters(
|
116 |
+
dict(
|
117 |
+
testcase_name="reverse_1",
|
118 |
+
program=lib.make_reverse(rasp.tokens),
|
119 |
+
vocab={"a", "b", "c", "d"},
|
120 |
+
test_input=list("abcd"),
|
121 |
+
expected_output=list("dcba"),
|
122 |
+
max_seq_len=5),
|
123 |
+
dict(
|
124 |
+
testcase_name="reverse_2",
|
125 |
+
program=lib.make_reverse(rasp.tokens),
|
126 |
+
vocab={"a", "b", "c", "d"},
|
127 |
+
test_input=list("abc"),
|
128 |
+
expected_output=list("cba"),
|
129 |
+
max_seq_len=5),
|
130 |
+
dict(
|
131 |
+
testcase_name="reverse_3",
|
132 |
+
program=lib.make_reverse(rasp.tokens),
|
133 |
+
vocab={"a", "b", "c", "d"},
|
134 |
+
test_input=list("ad"),
|
135 |
+
expected_output=list("da"),
|
136 |
+
max_seq_len=5),
|
137 |
+
dict(
|
138 |
+
testcase_name="reverse_4",
|
139 |
+
program=lib.make_reverse(rasp.tokens),
|
140 |
+
vocab={"a", "b", "c", "d"},
|
141 |
+
test_input=["c"],
|
142 |
+
expected_output=["c"],
|
143 |
+
max_seq_len=5),
|
144 |
+
dict(
|
145 |
+
testcase_name="length_categorical_1",
|
146 |
+
program=rasp.categorical(lib.make_length()),
|
147 |
+
vocab={"a", "b", "c", "d"},
|
148 |
+
test_input=list("abc"),
|
149 |
+
expected_output=[3, 3, 3],
|
150 |
+
max_seq_len=5),
|
151 |
+
dict(
|
152 |
+
testcase_name="length_categorical_2",
|
153 |
+
program=rasp.categorical(lib.make_length()),
|
154 |
+
vocab={"a", "b", "c", "d"},
|
155 |
+
test_input=list("ad"),
|
156 |
+
expected_output=[2, 2],
|
157 |
+
max_seq_len=5),
|
158 |
+
dict(
|
159 |
+
testcase_name="length_categorical_3",
|
160 |
+
program=rasp.categorical(lib.make_length()),
|
161 |
+
vocab={"a", "b", "c", "d"},
|
162 |
+
test_input=["c"],
|
163 |
+
expected_output=[1],
|
164 |
+
max_seq_len=5),
|
165 |
+
dict(
|
166 |
+
testcase_name="length_numerical_1",
|
167 |
+
program=rasp.numerical(lib.make_length()),
|
168 |
+
vocab={"a", "b", "c", "d"},
|
169 |
+
test_input=list("abc"),
|
170 |
+
expected_output=[3, 3, 3],
|
171 |
+
max_seq_len=5),
|
172 |
+
dict(
|
173 |
+
testcase_name="length_numerical_2",
|
174 |
+
program=rasp.numerical(lib.make_length()),
|
175 |
+
vocab={"a", "b", "c", "d"},
|
176 |
+
test_input=list("ad"),
|
177 |
+
expected_output=[2, 2],
|
178 |
+
max_seq_len=5),
|
179 |
+
dict(
|
180 |
+
testcase_name="length_numerical_3",
|
181 |
+
program=rasp.numerical(lib.make_length()),
|
182 |
+
vocab={"a", "b", "c", "d"},
|
183 |
+
test_input=["c"],
|
184 |
+
expected_output=[1],
|
185 |
+
max_seq_len=5),
|
186 |
+
)
|
187 |
+
def test_compiled_models_produce_expected_output_with_padding(
|
188 |
+
self, program, vocab, test_input, expected_output, max_seq_len, **kwargs):
|
189 |
+
del kwargs
|
190 |
+
assembled_model = compiling.compile_rasp_to_model(
|
191 |
+
program,
|
192 |
+
vocab,
|
193 |
+
max_seq_len,
|
194 |
+
compiler_bos=_COMPILER_BOS,
|
195 |
+
compiler_pad=_COMPILER_PAD)
|
196 |
+
|
197 |
+
pad_len = (max_seq_len - len(test_input))
|
198 |
+
test_input = test_input + [_COMPILER_PAD] * pad_len
|
199 |
+
test_input = [_COMPILER_BOS] + test_input
|
200 |
+
test_output = assembled_model.apply(test_input)
|
201 |
+
output = test_output.decoded
|
202 |
+
output_len = len(output)
|
203 |
+
output_stripped = test_output.decoded[1:output_len - pad_len]
|
204 |
+
|
205 |
+
self.assertEqual(output[0], _COMPILER_BOS)
|
206 |
+
if isinstance(expected_output[0], (int, float)):
|
207 |
+
np.testing.assert_allclose(
|
208 |
+
output_stripped, expected_output, atol=1e-7, rtol=0.005)
|
209 |
+
else:
|
210 |
+
self.assertEqual(output_stripped, expected_output)
|
211 |
+
|
212 |
+
|
213 |
+
if __name__ == "__main__":
|
214 |
+
absltest.main()
|
compiler/test_cases.py
ADDED
@@ -0,0 +1,357 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2022 DeepMind Technologies Limited. All Rights Reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
# ==============================================================================
|
15 |
+
"""A set of RASP programs and input/output pairs used in integration tests."""
|
16 |
+
|
17 |
+
from tracr.compiler import lib
|
18 |
+
from tracr.rasp import rasp
|
19 |
+
|
20 |
+
UNIVERSAL_TEST_CASES = [
|
21 |
+
dict(
|
22 |
+
testcase_name="frac_prevs_1",
|
23 |
+
program=lib.make_frac_prevs(rasp.tokens == "l"),
|
24 |
+
vocab={"h", "e", "l", "o"},
|
25 |
+
test_input=list("hello"),
|
26 |
+
expected_output=[0.0, 0.0, 1 / 3, 1 / 2, 2 / 5],
|
27 |
+
max_seq_len=5),
|
28 |
+
dict(
|
29 |
+
testcase_name="frac_prevs_2",
|
30 |
+
program=lib.make_frac_prevs(rasp.tokens == "("),
|
31 |
+
vocab={"a", "b", "c", "(", ")"},
|
32 |
+
test_input=list("a()b(c))"),
|
33 |
+
expected_output=[0.0, 1 / 2, 1 / 3, 1 / 4, 2 / 5, 2 / 6, 2 / 7, 2 / 8],
|
34 |
+
max_seq_len=10),
|
35 |
+
dict(
|
36 |
+
testcase_name="frac_prevs_3",
|
37 |
+
program=lib.make_frac_prevs(rasp.tokens == ")"),
|
38 |
+
vocab={"a", "b", "c", "(", ")"},
|
39 |
+
test_input=list("a()b(c))"),
|
40 |
+
expected_output=[0.0, 0.0, 1 / 3, 1 / 4, 1 / 5, 1 / 6, 2 / 7, 3 / 8],
|
41 |
+
max_seq_len=10,
|
42 |
+
),
|
43 |
+
dict(
|
44 |
+
testcase_name="shift_by_one",
|
45 |
+
program=lib.shift_by(1, rasp.tokens),
|
46 |
+
vocab={"a", "b", "c", "d"},
|
47 |
+
test_input=list("abcd"),
|
48 |
+
expected_output=[None, "a", "b", "c"],
|
49 |
+
max_seq_len=5,
|
50 |
+
),
|
51 |
+
dict(
|
52 |
+
testcase_name="shift_by_two",
|
53 |
+
program=lib.shift_by(2, rasp.tokens),
|
54 |
+
vocab={"a", "b", "c", "d"},
|
55 |
+
test_input=list("abcd"),
|
56 |
+
expected_output=[None, None, "a", "b"],
|
57 |
+
max_seq_len=5,
|
58 |
+
),
|
59 |
+
dict(
|
60 |
+
testcase_name="detect_pattern_a",
|
61 |
+
program=lib.detect_pattern(rasp.tokens, "a"),
|
62 |
+
vocab={"a", "b", "c", "d"},
|
63 |
+
test_input=list("bacd"),
|
64 |
+
expected_output=[False, True, False, False],
|
65 |
+
max_seq_len=5,
|
66 |
+
),
|
67 |
+
dict(
|
68 |
+
testcase_name="detect_pattern_ab",
|
69 |
+
program=lib.detect_pattern(rasp.tokens, "ab"),
|
70 |
+
vocab={"a", "b"},
|
71 |
+
test_input=list("aaba"),
|
72 |
+
expected_output=[None, False, True, False],
|
73 |
+
max_seq_len=5,
|
74 |
+
),
|
75 |
+
dict(
|
76 |
+
testcase_name="detect_pattern_ab_2",
|
77 |
+
program=lib.detect_pattern(rasp.tokens, "ab"),
|
78 |
+
vocab={"a", "b"},
|
79 |
+
test_input=list("abaa"),
|
80 |
+
expected_output=[None, True, False, False],
|
81 |
+
max_seq_len=5,
|
82 |
+
),
|
83 |
+
dict(
|
84 |
+
testcase_name="detect_pattern_ab_3",
|
85 |
+
program=lib.detect_pattern(rasp.tokens, "ab"),
|
86 |
+
vocab={"a", "b"},
|
87 |
+
test_input=list("aaaa"),
|
88 |
+
expected_output=[None, False, False, False],
|
89 |
+
max_seq_len=5,
|
90 |
+
),
|
91 |
+
dict(
|
92 |
+
testcase_name="detect_pattern_abc",
|
93 |
+
program=lib.detect_pattern(rasp.tokens, "abc"),
|
94 |
+
vocab={"a", "b", "c"},
|
95 |
+
test_input=list("abcabc"),
|
96 |
+
expected_output=[None, None, True, False, False, True],
|
97 |
+
max_seq_len=6,
|
98 |
+
),
|
99 |
+
]
|
100 |
+
|
101 |
+
TEST_CASES = UNIVERSAL_TEST_CASES + [
|
102 |
+
dict(
|
103 |
+
testcase_name="reverse_1",
|
104 |
+
program=lib.make_reverse(rasp.tokens),
|
105 |
+
vocab={"a", "b", "c", "d"},
|
106 |
+
test_input=list("abcd"),
|
107 |
+
expected_output=list("dcba"),
|
108 |
+
max_seq_len=5),
|
109 |
+
dict(
|
110 |
+
testcase_name="reverse_2",
|
111 |
+
program=lib.make_reverse(rasp.tokens),
|
112 |
+
vocab={"a", "b", "c", "d"},
|
113 |
+
test_input=list("abc"),
|
114 |
+
expected_output=list("cba"),
|
115 |
+
max_seq_len=5),
|
116 |
+
dict(
|
117 |
+
testcase_name="reverse_3",
|
118 |
+
program=lib.make_reverse(rasp.tokens),
|
119 |
+
vocab={"a", "b", "c", "d"},
|
120 |
+
test_input=list("ad"),
|
121 |
+
expected_output=list("da"),
|
122 |
+
max_seq_len=5),
|
123 |
+
dict(
|
124 |
+
testcase_name="reverse_4",
|
125 |
+
program=lib.make_reverse(rasp.tokens),
|
126 |
+
vocab={"a", "b", "c", "d"},
|
127 |
+
test_input=["c"],
|
128 |
+
expected_output=["c"],
|
129 |
+
max_seq_len=5),
|
130 |
+
dict(
|
131 |
+
testcase_name="length_categorical_1",
|
132 |
+
program=rasp.categorical(lib.make_length()),
|
133 |
+
vocab={"a", "b", "c", "d"},
|
134 |
+
test_input=list("abc"),
|
135 |
+
expected_output=[3, 3, 3],
|
136 |
+
max_seq_len=3),
|
137 |
+
dict(
|
138 |
+
testcase_name="length_categorical_2",
|
139 |
+
program=rasp.categorical(lib.make_length()),
|
140 |
+
vocab={"a", "b", "c", "d"},
|
141 |
+
test_input=list("ad"),
|
142 |
+
expected_output=[2, 2],
|
143 |
+
max_seq_len=3),
|
144 |
+
dict(
|
145 |
+
testcase_name="length_categorical_3",
|
146 |
+
program=rasp.categorical(lib.make_length()),
|
147 |
+
vocab={"a", "b", "c", "d"},
|
148 |
+
test_input=["c"],
|
149 |
+
expected_output=[1],
|
150 |
+
max_seq_len=3),
|
151 |
+
dict(
|
152 |
+
testcase_name="length_numerical_1",
|
153 |
+
program=rasp.numerical(lib.make_length()),
|
154 |
+
vocab={"a", "b", "c", "d"},
|
155 |
+
test_input=list("abc"),
|
156 |
+
expected_output=[3, 3, 3],
|
157 |
+
max_seq_len=3),
|
158 |
+
dict(
|
159 |
+
testcase_name="length_numerical_2",
|
160 |
+
program=rasp.numerical(lib.make_length()),
|
161 |
+
vocab={"a", "b", "c", "d"},
|
162 |
+
test_input=list("ad"),
|
163 |
+
expected_output=[2, 2],
|
164 |
+
max_seq_len=3),
|
165 |
+
dict(
|
166 |
+
testcase_name="length_numerical_3",
|
167 |
+
program=rasp.numerical(lib.make_length()),
|
168 |
+
vocab={"a", "b", "c", "d"},
|
169 |
+
test_input=["c"],
|
170 |
+
expected_output=[1],
|
171 |
+
max_seq_len=3),
|
172 |
+
dict(
|
173 |
+
testcase_name="pair_balance_1",
|
174 |
+
program=lib.make_pair_balance(rasp.tokens, "(", ")"),
|
175 |
+
vocab={"a", "b", "c", "(", ")"},
|
176 |
+
test_input=list("a()b(c))"),
|
177 |
+
expected_output=[0.0, 1 / 2, 0.0, 0.0, 1 / 5, 1 / 6, 0.0, -1 / 8],
|
178 |
+
max_seq_len=10),
|
179 |
+
dict(
|
180 |
+
testcase_name="shuffle_dyck2_1",
|
181 |
+
program=lib.make_shuffle_dyck(pairs=["()", "{}"]),
|
182 |
+
vocab={"(", ")", "{", "}"},
|
183 |
+
test_input=list("({)}"),
|
184 |
+
expected_output=[1, 1, 1, 1],
|
185 |
+
max_seq_len=5),
|
186 |
+
dict(
|
187 |
+
testcase_name="shuffle_dyck2_2",
|
188 |
+
program=lib.make_shuffle_dyck(pairs=["()", "{}"]),
|
189 |
+
vocab={"(", ")", "{", "}"},
|
190 |
+
test_input=list("(){)}"),
|
191 |
+
expected_output=[0, 0, 0, 0, 0],
|
192 |
+
max_seq_len=5),
|
193 |
+
dict(
|
194 |
+
testcase_name="shuffle_dyck2_3",
|
195 |
+
program=lib.make_shuffle_dyck(pairs=["()", "{}"]),
|
196 |
+
vocab={"(", ")", "{", "}"},
|
197 |
+
test_input=list("{}("),
|
198 |
+
expected_output=[0, 0, 0],
|
199 |
+
max_seq_len=5),
|
200 |
+
dict(
|
201 |
+
testcase_name="shuffle_dyck3_1",
|
202 |
+
program=lib.make_shuffle_dyck(pairs=["()", "{}", "[]"]),
|
203 |
+
vocab={"(", ")", "{", "}", "[", "]"},
|
204 |
+
test_input=list("({)[}]"),
|
205 |
+
expected_output=[1, 1, 1, 1, 1, 1],
|
206 |
+
max_seq_len=6),
|
207 |
+
dict(
|
208 |
+
testcase_name="shuffle_dyck3_2",
|
209 |
+
program=lib.make_shuffle_dyck(pairs=["()", "{}", "[]"]),
|
210 |
+
vocab={"(", ")", "{", "}", "[", "]"},
|
211 |
+
test_input=list("(){)}"),
|
212 |
+
expected_output=[0, 0, 0, 0, 0],
|
213 |
+
max_seq_len=6),
|
214 |
+
dict(
|
215 |
+
testcase_name="shuffle_dyck3_3",
|
216 |
+
program=lib.make_shuffle_dyck(pairs=["()", "{}", "[]"]),
|
217 |
+
vocab={"(", ")", "{", "}", "[", "]"},
|
218 |
+
test_input=list("{}[(]"),
|
219 |
+
expected_output=[0, 0, 0, 0, 0],
|
220 |
+
max_seq_len=6),
|
221 |
+
dict(
|
222 |
+
testcase_name="hist",
|
223 |
+
program=lib.make_hist(),
|
224 |
+
vocab={"a", "b", "c", "d"},
|
225 |
+
test_input=list("abac"),
|
226 |
+
expected_output=[2, 1, 2, 1],
|
227 |
+
max_seq_len=5,
|
228 |
+
),
|
229 |
+
dict(
|
230 |
+
testcase_name="sort_unique_1",
|
231 |
+
program=lib.make_sort_unique(vals=rasp.tokens, keys=rasp.tokens),
|
232 |
+
vocab={1, 2, 3, 4},
|
233 |
+
test_input=[2, 4, 3, 1],
|
234 |
+
expected_output=[1, 2, 3, 4],
|
235 |
+
max_seq_len=5),
|
236 |
+
dict(
|
237 |
+
testcase_name="sort_unique_2",
|
238 |
+
program=lib.make_sort_unique(vals=rasp.tokens, keys=1 - rasp.indices),
|
239 |
+
vocab={"a", "b", "c", "d"},
|
240 |
+
test_input=list("abcd"),
|
241 |
+
expected_output=["d", "c", "b", "a"],
|
242 |
+
max_seq_len=5),
|
243 |
+
dict(
|
244 |
+
testcase_name="sort_1",
|
245 |
+
program=lib.make_sort(
|
246 |
+
vals=rasp.tokens, keys=rasp.tokens, max_seq_len=5, min_key=1),
|
247 |
+
vocab={1, 2, 3, 4},
|
248 |
+
test_input=[2, 4, 3, 1],
|
249 |
+
expected_output=[1, 2, 3, 4],
|
250 |
+
max_seq_len=5),
|
251 |
+
dict(
|
252 |
+
testcase_name="sort_2",
|
253 |
+
program=lib.make_sort(
|
254 |
+
vals=rasp.tokens, keys=1 - rasp.indices, max_seq_len=5, min_key=1),
|
255 |
+
vocab={"a", "b", "c", "d"},
|
256 |
+
test_input=list("abcd"),
|
257 |
+
expected_output=["d", "c", "b", "a"],
|
258 |
+
max_seq_len=5),
|
259 |
+
dict(
|
260 |
+
testcase_name="sort_3",
|
261 |
+
program=lib.make_sort(
|
262 |
+
vals=rasp.tokens, keys=rasp.tokens, max_seq_len=5, min_key=1),
|
263 |
+
vocab={1, 2, 3, 4},
|
264 |
+
test_input=[2, 4, 1, 2],
|
265 |
+
expected_output=[1, 2, 2, 4],
|
266 |
+
max_seq_len=5),
|
267 |
+
dict(
|
268 |
+
testcase_name="sort_freq",
|
269 |
+
program=lib.make_sort_freq(max_seq_len=5),
|
270 |
+
vocab={1, 2, 3, 4},
|
271 |
+
test_input=[2, 4, 2, 1],
|
272 |
+
expected_output=[2, 2, 4, 1],
|
273 |
+
max_seq_len=5),
|
274 |
+
dict(
|
275 |
+
testcase_name="make_count_less_freq_categorical_1",
|
276 |
+
program=lib.make_count_less_freq(n=2),
|
277 |
+
vocab={"a", "b", "c", "d"},
|
278 |
+
test_input=["a", "a", "a", "b", "b", "c"],
|
279 |
+
expected_output=[3, 3, 3, 3, 3, 3],
|
280 |
+
max_seq_len=6),
|
281 |
+
dict(
|
282 |
+
testcase_name="make_count_less_freq_categorical_2",
|
283 |
+
program=lib.make_count_less_freq(n=2),
|
284 |
+
vocab={"a", "b", "c", "d"},
|
285 |
+
test_input=["a", "a", "c", "b", "b", "c"],
|
286 |
+
expected_output=[6, 6, 6, 6, 6, 6],
|
287 |
+
max_seq_len=6),
|
288 |
+
dict(
|
289 |
+
testcase_name="make_count_less_freq_numerical_1",
|
290 |
+
program=rasp.numerical(lib.make_count_less_freq(n=2)),
|
291 |
+
vocab={"a", "b", "c", "d"},
|
292 |
+
test_input=["a", "a", "a", "b", "b", "c"],
|
293 |
+
expected_output=[3, 3, 3, 3, 3, 3],
|
294 |
+
max_seq_len=6),
|
295 |
+
dict(
|
296 |
+
testcase_name="make_count_less_freq_numerical_2",
|
297 |
+
program=rasp.numerical(lib.make_count_less_freq(n=2)),
|
298 |
+
vocab={"a", "b", "c", "d"},
|
299 |
+
test_input=["a", "a", "c", "b", "b", "c"],
|
300 |
+
expected_output=[6, 6, 6, 6, 6, 6],
|
301 |
+
max_seq_len=6),
|
302 |
+
dict(
|
303 |
+
testcase_name="make_count_1",
|
304 |
+
program=lib.make_count(rasp.tokens, "a"),
|
305 |
+
vocab={"a", "b", "c"},
|
306 |
+
test_input=["a", "a", "a", "b", "b", "c"],
|
307 |
+
expected_output=[3, 3, 3, 3, 3, 3],
|
308 |
+
max_seq_len=8,
|
309 |
+
),
|
310 |
+
dict(
|
311 |
+
testcase_name="make_count_2",
|
312 |
+
program=lib.make_count(rasp.tokens, "a"),
|
313 |
+
vocab={"a", "b", "c"},
|
314 |
+
test_input=["c", "a", "b", "c"],
|
315 |
+
expected_output=[1, 1, 1, 1],
|
316 |
+
max_seq_len=8,
|
317 |
+
),
|
318 |
+
dict(
|
319 |
+
testcase_name="make_count_3",
|
320 |
+
program=lib.make_count(rasp.tokens, "a"),
|
321 |
+
vocab={"a", "b", "c"},
|
322 |
+
test_input=["b", "b", "c"],
|
323 |
+
expected_output=[0, 0, 0],
|
324 |
+
max_seq_len=8,
|
325 |
+
),
|
326 |
+
dict(
|
327 |
+
testcase_name="make_nary_sequencemap_1",
|
328 |
+
program=lib.make_nary_sequencemap(
|
329 |
+
lambda x, y, z: x + y - z, rasp.tokens, rasp.tokens, rasp.indices),
|
330 |
+
vocab={1, 2, 3},
|
331 |
+
test_input=[1, 2, 3],
|
332 |
+
expected_output=[2, 3, 4],
|
333 |
+
max_seq_len=5,
|
334 |
+
),
|
335 |
+
dict(
|
336 |
+
testcase_name="make_nary_sequencemap_2",
|
337 |
+
program=lib.make_nary_sequencemap(
|
338 |
+
lambda x, y, z: x * y / z, rasp.indices, rasp.indices, rasp.tokens),
|
339 |
+
vocab={1, 2, 3},
|
340 |
+
test_input=[1, 2, 3],
|
341 |
+
expected_output=[0, 1 / 2, 4 / 3],
|
342 |
+
max_seq_len=3,
|
343 |
+
)
|
344 |
+
]
|
345 |
+
|
346 |
+
# make_nary_sequencemap(f, *sops)
|
347 |
+
|
348 |
+
CAUSAL_TEST_CASES = UNIVERSAL_TEST_CASES + [
|
349 |
+
dict(
|
350 |
+
testcase_name="selector_width",
|
351 |
+
program=rasp.SelectorWidth(
|
352 |
+
rasp.Select(rasp.tokens, rasp.tokens, rasp.Comparison.TRUE)),
|
353 |
+
vocab={"a", "b", "c", "d"},
|
354 |
+
test_input=list("abcd"),
|
355 |
+
expected_output=[1, 2, 3, 4],
|
356 |
+
max_seq_len=5),
|
357 |
+
]
|
craft/bases.py
ADDED
@@ -0,0 +1,247 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2022 DeepMind Technologies Limited. All Rights Reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
# ==============================================================================
|
15 |
+
"""Vectors and bases."""
|
16 |
+
|
17 |
+
import dataclasses
|
18 |
+
from typing import Sequence, Union, Optional, Iterable
|
19 |
+
|
20 |
+
import numpy as np
|
21 |
+
|
22 |
+
Name = Union[int, str]
|
23 |
+
Value = Union[int, float, bool, str, tuple]
|
24 |
+
|
25 |
+
|
26 |
+
@dataclasses.dataclass(frozen=True)
|
27 |
+
class BasisDirection:
|
28 |
+
"""Represents a basis direction (no magnitude) in a vector space.
|
29 |
+
|
30 |
+
Attributes:
|
31 |
+
name: a unique name for this direction.
|
32 |
+
value: used to hold a value one-hot-encoded by this direction. e.g.,
|
33 |
+
[BasisDirection("vs_1", True), BasisDirection("vs_1", False)] would be
|
34 |
+
basis directions of a subspace called "vs_1" which one-hot-encodes the
|
35 |
+
values True and False. If provided, considered part of the name for the
|
36 |
+
purpose of disambiguating directions.
|
37 |
+
"""
|
38 |
+
name: Name
|
39 |
+
value: Optional[Value] = None
|
40 |
+
|
41 |
+
def __str__(self):
|
42 |
+
if self.value is None:
|
43 |
+
return str(self.name)
|
44 |
+
return f"{self.name}:{self.value}"
|
45 |
+
|
46 |
+
def __lt__(self, other: "BasisDirection") -> bool:
|
47 |
+
try:
|
48 |
+
return (self.name, self.value) < (other.name, other.value)
|
49 |
+
except TypeError:
|
50 |
+
return str(self) < str(other)
|
51 |
+
|
52 |
+
|
53 |
+
@dataclasses.dataclass
|
54 |
+
class VectorInBasis:
|
55 |
+
"""A vector (or array of vectors) in a given basis.
|
56 |
+
|
57 |
+
When magnitudes are 1-d, this is a vector.
|
58 |
+
When magnitudes are (n+1)-d, this is an array of vectors,
|
59 |
+
where the -1th dimension is the basis dimension.
|
60 |
+
"""
|
61 |
+
basis_directions: Sequence[BasisDirection]
|
62 |
+
magnitudes: np.ndarray
|
63 |
+
|
64 |
+
def __post_init__(self):
|
65 |
+
"""Sort basis directions."""
|
66 |
+
if len(self.basis_directions) != self.magnitudes.shape[-1]:
|
67 |
+
raise ValueError(
|
68 |
+
"Last dimension of magnitudes must be the same as number "
|
69 |
+
f"of basis directions. Was {len(self.basis_directions)} "
|
70 |
+
f"and {self.magnitudes.shape[-1]}.")
|
71 |
+
|
72 |
+
sort_idx = np.argsort(self.basis_directions)
|
73 |
+
self.basis_directions = [self.basis_directions[i] for i in sort_idx]
|
74 |
+
self.magnitudes = np.take(self.magnitudes, sort_idx, -1)
|
75 |
+
|
76 |
+
def __add__(self, other: "VectorInBasis") -> "VectorInBasis":
|
77 |
+
if self.basis_directions != other.basis_directions:
|
78 |
+
raise TypeError(f"Adding incompatible bases: {self} + {other}")
|
79 |
+
magnitudes = self.magnitudes + other.magnitudes
|
80 |
+
return VectorInBasis(self.basis_directions, magnitudes)
|
81 |
+
|
82 |
+
def __radd__(self, other: "VectorInBasis") -> "VectorInBasis":
|
83 |
+
if self.basis_directions != other.basis_directions:
|
84 |
+
raise TypeError(f"Adding incompatible bases: {other} + {self}")
|
85 |
+
return self + other
|
86 |
+
|
87 |
+
def __sub__(self, other: "VectorInBasis") -> "VectorInBasis":
|
88 |
+
if self.basis_directions != other.basis_directions:
|
89 |
+
raise TypeError(f"Subtracting incompatible bases: {self} - {other}")
|
90 |
+
magnitudes = self.magnitudes - other.magnitudes
|
91 |
+
return VectorInBasis(self.basis_directions, magnitudes)
|
92 |
+
|
93 |
+
def __rsub__(self, other: "VectorInBasis") -> "VectorInBasis":
|
94 |
+
if self.basis_directions != other.basis_directions:
|
95 |
+
raise TypeError(f"Subtracting incompatible bases: {other} - {self}")
|
96 |
+
magnitudes = other.magnitudes - self.magnitudes
|
97 |
+
return VectorInBasis(self.basis_directions, magnitudes)
|
98 |
+
|
99 |
+
def __mul__(self, scalar: float) -> "VectorInBasis":
|
100 |
+
return VectorInBasis(self.basis_directions, self.magnitudes * scalar)
|
101 |
+
|
102 |
+
def __rmul__(self, scalar: float) -> "VectorInBasis":
|
103 |
+
return self * scalar
|
104 |
+
|
105 |
+
def __truediv__(self, scalar: float) -> "VectorInBasis":
|
106 |
+
return VectorInBasis(self.basis_directions, self.magnitudes / scalar)
|
107 |
+
|
108 |
+
def __neg__(self) -> "VectorInBasis":
|
109 |
+
return (-1) * self
|
110 |
+
|
111 |
+
def __eq__(self, other: "VectorInBasis") -> bool:
|
112 |
+
return ((self.basis_directions == other.basis_directions) and
|
113 |
+
(self.magnitudes.shape == other.magnitudes.shape) and
|
114 |
+
(np.all(self.magnitudes == other.magnitudes)))
|
115 |
+
|
116 |
+
@classmethod
|
117 |
+
def sum(cls, vectors: Sequence["VectorInBasis"]) -> "VectorInBasis":
|
118 |
+
return cls(vectors[0].basis_directions,
|
119 |
+
np.sum([x.magnitudes for x in vectors], axis=0))
|
120 |
+
|
121 |
+
@classmethod
|
122 |
+
def stack(cls,
|
123 |
+
vectors: Sequence["VectorInBasis"],
|
124 |
+
axis: int = 0) -> "VectorInBasis":
|
125 |
+
for v in vectors[1:]:
|
126 |
+
if v.basis_directions != vectors[0].basis_directions:
|
127 |
+
raise TypeError(f"Stacking incompatible bases: {vectors[0]} + {v}")
|
128 |
+
return cls(vectors[0].basis_directions,
|
129 |
+
np.stack([v.magnitudes for v in vectors], axis=axis))
|
130 |
+
|
131 |
+
def project(
|
132 |
+
self, basis: Union["VectorSpaceWithBasis", Sequence[BasisDirection]]
|
133 |
+
) -> "VectorInBasis":
|
134 |
+
"""Projects to the basis."""
|
135 |
+
if isinstance(basis, VectorSpaceWithBasis):
|
136 |
+
basis = basis.basis
|
137 |
+
components = []
|
138 |
+
for direction in basis:
|
139 |
+
if direction in self.basis_directions:
|
140 |
+
components.append(
|
141 |
+
self.magnitudes[..., self.basis_directions.index(direction)])
|
142 |
+
else:
|
143 |
+
components.append(np.zeros_like(self.magnitudes[..., 0]))
|
144 |
+
return VectorInBasis(list(basis), np.stack(components, axis=-1))
|
145 |
+
|
146 |
+
|
147 |
+
@dataclasses.dataclass
|
148 |
+
class VectorSpaceWithBasis:
|
149 |
+
"""A vector subspace in a given basis."""
|
150 |
+
basis: Sequence[BasisDirection]
|
151 |
+
|
152 |
+
def __post_init__(self):
|
153 |
+
"""Keep basis directions sorted."""
|
154 |
+
self.basis = sorted(self.basis)
|
155 |
+
|
156 |
+
@property
|
157 |
+
def num_dims(self) -> int:
|
158 |
+
return len(self.basis)
|
159 |
+
|
160 |
+
def __contains__(self, item: Union[VectorInBasis, BasisDirection]) -> bool:
|
161 |
+
if isinstance(item, BasisDirection):
|
162 |
+
return item in self.basis
|
163 |
+
|
164 |
+
return set(self.basis) == set(item.basis_directions)
|
165 |
+
|
166 |
+
def issubspace(self, other: "VectorSpaceWithBasis") -> bool:
|
167 |
+
return set(self.basis).issubset(set(other.basis))
|
168 |
+
|
169 |
+
def basis_vectors(self) -> Sequence[VectorInBasis]:
|
170 |
+
basis_vector_magnitudes = list(np.eye(self.num_dims))
|
171 |
+
return [VectorInBasis(self.basis, m) for m in basis_vector_magnitudes]
|
172 |
+
|
173 |
+
def vector_from_basis_direction(
|
174 |
+
self, basis_direction: BasisDirection) -> VectorInBasis:
|
175 |
+
i = self.basis.index(basis_direction)
|
176 |
+
return VectorInBasis(self.basis, np.eye(self.num_dims)[i])
|
177 |
+
|
178 |
+
def null_vector(self) -> VectorInBasis:
|
179 |
+
return VectorInBasis(self.basis, np.zeros(self.num_dims))
|
180 |
+
|
181 |
+
@classmethod
|
182 |
+
def from_names(cls, names: Sequence[Name]) -> "VectorSpaceWithBasis":
|
183 |
+
"""Creates a VectorSpace from a list of names for its basis directions."""
|
184 |
+
return cls([BasisDirection(n) for n in names])
|
185 |
+
|
186 |
+
@classmethod
|
187 |
+
def from_values(
|
188 |
+
cls,
|
189 |
+
name: Name,
|
190 |
+
values: Iterable[Value],
|
191 |
+
) -> "VectorSpaceWithBasis":
|
192 |
+
"""Creates a VectorSpace from a list of values for its basis directions."""
|
193 |
+
return cls([BasisDirection(name, v) for v in values])
|
194 |
+
|
195 |
+
|
196 |
+
def direct_sum(*vs: VectorSpaceWithBasis) -> VectorSpaceWithBasis:
|
197 |
+
"""Create a direct sum of the vector spaces.
|
198 |
+
|
199 |
+
Assumes the basis elements of all input vector spaces are
|
200 |
+
orthogonal to each other. Maintains the order of the bases.
|
201 |
+
|
202 |
+
Args:
|
203 |
+
*vs: the vector spaces to sum.
|
204 |
+
|
205 |
+
Returns:
|
206 |
+
the combined vector space.
|
207 |
+
|
208 |
+
Raises:
|
209 |
+
Value error in case of overlapping bases.
|
210 |
+
"""
|
211 |
+
# Take the union of all the bases:
|
212 |
+
total_basis = sum([v.basis for v in vs], [])
|
213 |
+
|
214 |
+
if len(total_basis) != len(set(total_basis)):
|
215 |
+
raise ValueError("Overlapping bases!")
|
216 |
+
|
217 |
+
return VectorSpaceWithBasis(total_basis)
|
218 |
+
|
219 |
+
|
220 |
+
def join_vector_spaces(*vs: VectorSpaceWithBasis) -> VectorSpaceWithBasis:
|
221 |
+
"""Joins a set of vector spaces allowing them to overlap.
|
222 |
+
|
223 |
+
Assumes the basis elements of all input vector spaces are
|
224 |
+
orthogonal to each other. Does not maintain the order of the bases but
|
225 |
+
sorts them.
|
226 |
+
|
227 |
+
Args:
|
228 |
+
*vs: the vector spaces to sum.
|
229 |
+
|
230 |
+
Returns:
|
231 |
+
the combined vector space.
|
232 |
+
"""
|
233 |
+
# Take the union of all the bases:
|
234 |
+
total_basis = list(set().union(*[set(v.basis) for v in vs]))
|
235 |
+
total_basis = sorted(total_basis)
|
236 |
+
return VectorSpaceWithBasis(total_basis)
|
237 |
+
|
238 |
+
|
239 |
+
def ensure_dims(
|
240 |
+
vs: VectorSpaceWithBasis,
|
241 |
+
num_dims: int,
|
242 |
+
name: str = "vector space",
|
243 |
+
) -> None:
|
244 |
+
"""Raises ValueError if vs has the wrong number of dimensions."""
|
245 |
+
if vs.num_dims != num_dims:
|
246 |
+
raise ValueError(f"{name} must have {num_dims=}, "
|
247 |
+
f"but got {vs.num_dims}: {vs.basis}")
|
craft/bases_test.py
ADDED
@@ -0,0 +1,158 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2022 DeepMind Technologies Limited. All Rights Reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
# ==============================================================================
|
15 |
+
"""Tests for bases."""
|
16 |
+
|
17 |
+
from absl.testing import absltest
|
18 |
+
import numpy as np
|
19 |
+
from tracr.craft import bases
|
20 |
+
from tracr.craft import tests_common
|
21 |
+
|
22 |
+
|
23 |
+
class VectorInBasisTest(tests_common.VectorFnTestCase):
|
24 |
+
|
25 |
+
def test_shape_mismatch_raises_value_error(self):
|
26 |
+
vs1 = bases.VectorSpaceWithBasis.from_names(["a", "b"])
|
27 |
+
regex = (r"^.*Last dimension of magnitudes must be the same as number of "
|
28 |
+
r"basis directions.*$")
|
29 |
+
with self.assertRaisesRegex(ValueError, regex):
|
30 |
+
bases.VectorInBasis(vs1.basis, np.array([1, 2, 3, 4]))
|
31 |
+
with self.assertRaisesRegex(ValueError, regex):
|
32 |
+
bases.VectorInBasis(vs1.basis, np.array([[0, 1, 2, 3], [1, 2, 3, 4]]))
|
33 |
+
|
34 |
+
def test_equal(self):
|
35 |
+
vs1 = bases.VectorSpaceWithBasis.from_names(["a", "b", "c", "d"])
|
36 |
+
v1 = bases.VectorInBasis(vs1.basis, np.array([1, 2, 3, 4]))
|
37 |
+
v2 = bases.VectorInBasis(vs1.basis, np.array([1, 2, 3, 4]))
|
38 |
+
self.assertEqual(v1, v2)
|
39 |
+
self.assertEqual(v2, v1)
|
40 |
+
v3 = bases.VectorInBasis(vs1.basis, np.array([[0, 1, 2, 3], [1, 2, 3, 4]]))
|
41 |
+
v4 = bases.VectorInBasis(vs1.basis, np.array([[0, 1, 2, 3], [1, 2, 3, 4]]))
|
42 |
+
self.assertEqual(v3, v4)
|
43 |
+
self.assertEqual(v4, v3)
|
44 |
+
v5 = bases.VectorInBasis(vs1.basis, np.array([1, 2, 3, 4]))
|
45 |
+
v6 = bases.VectorInBasis(vs1.basis, np.array([1, 1, 1, 1]))
|
46 |
+
self.assertNotEqual(v5, v6)
|
47 |
+
self.assertNotEqual(v6, v5)
|
48 |
+
v7 = bases.VectorInBasis(vs1.basis, np.array([1, 2, 3, 4]))
|
49 |
+
v8 = bases.VectorInBasis(vs1.basis, np.array([[1, 2, 3, 4], [1, 1, 1, 1]]))
|
50 |
+
self.assertNotEqual(v7, v8)
|
51 |
+
self.assertNotEqual(v8, v7)
|
52 |
+
vs2 = bases.VectorSpaceWithBasis.from_names(["e", "f", "g", "h"])
|
53 |
+
v9 = bases.VectorInBasis(vs1.basis, np.array([1, 2, 3, 4]))
|
54 |
+
v10 = bases.VectorInBasis(vs2.basis, np.array([1, 2, 3, 4]))
|
55 |
+
self.assertNotEqual(v9, v10)
|
56 |
+
self.assertNotEqual(v10, v9)
|
57 |
+
|
58 |
+
def test_dunders(self):
|
59 |
+
vs1 = bases.VectorSpaceWithBasis.from_names(["a", "b", "c"])
|
60 |
+
v = bases.VectorInBasis(vs1.basis, np.array([0, 1, 2]))
|
61 |
+
three = bases.VectorInBasis(vs1.basis, np.array([3, 3, 3]))
|
62 |
+
five = bases.VectorInBasis(vs1.basis, np.array([5, 5, 5]))
|
63 |
+
v_times_5 = bases.VectorInBasis(vs1.basis, np.array([0, 5, 10]))
|
64 |
+
self.assertEqual(5 * v, v_times_5)
|
65 |
+
self.assertEqual(v * 5, v_times_5)
|
66 |
+
self.assertEqual(5.0 * v, v_times_5)
|
67 |
+
self.assertEqual(v * 5.0, v_times_5)
|
68 |
+
v_by_2 = bases.VectorInBasis(vs1.basis, np.array([0, 0.5, 1]))
|
69 |
+
self.assertEqual(v / 2, v_by_2)
|
70 |
+
self.assertEqual(v / 2.0, v_by_2)
|
71 |
+
self.assertEqual(1 / 2 * v, v_by_2)
|
72 |
+
v_plus_3 = bases.VectorInBasis(vs1.basis, np.array([3, 4, 5]))
|
73 |
+
self.assertEqual(v + three, v_plus_3)
|
74 |
+
self.assertEqual(three + v, v_plus_3)
|
75 |
+
v_minus_5 = bases.VectorInBasis(vs1.basis, np.array([-5, -4, -3]))
|
76 |
+
self.assertEqual(v - five, v_minus_5)
|
77 |
+
minus_v = bases.VectorInBasis(vs1.basis, np.array([0, -1, -2]))
|
78 |
+
self.assertEqual(-v, minus_v)
|
79 |
+
|
80 |
+
|
81 |
+
class ProjectionTest(tests_common.VectorFnTestCase):
|
82 |
+
|
83 |
+
def test_direct_sum_produces_expected_result(self):
|
84 |
+
vs1 = bases.VectorSpaceWithBasis.from_names(["a", "b"])
|
85 |
+
vs2 = bases.VectorSpaceWithBasis.from_names(["d", "c"])
|
86 |
+
vs3 = bases.VectorSpaceWithBasis.from_names(["a", "b", "d", "c"])
|
87 |
+
self.assertEqual(bases.direct_sum(vs1, vs2), vs3)
|
88 |
+
|
89 |
+
def test_join_vector_spaces_produces_expected_result(self):
|
90 |
+
vs1 = bases.VectorSpaceWithBasis.from_names(["a", "b"])
|
91 |
+
vs2 = bases.VectorSpaceWithBasis.from_names(["d", "c"])
|
92 |
+
vs3 = bases.VectorSpaceWithBasis.from_names(["a", "b", "c", "d"])
|
93 |
+
self.assertEqual(bases.join_vector_spaces(vs1, vs2), vs3)
|
94 |
+
|
95 |
+
vs1 = bases.VectorSpaceWithBasis.from_names(["a", "b"])
|
96 |
+
vs2 = bases.VectorSpaceWithBasis.from_names(["b", "d", "c"])
|
97 |
+
vs3 = bases.VectorSpaceWithBasis.from_names(["a", "b", "c", "d"])
|
98 |
+
self.assertEqual(bases.join_vector_spaces(vs1, vs2), vs3)
|
99 |
+
|
100 |
+
def test_compare_vectors_with_differently_ordered_basis_vectors(self):
|
101 |
+
basis1 = ["a", "b", "c", "d"]
|
102 |
+
basis1 = [bases.BasisDirection(x) for x in basis1]
|
103 |
+
basis2 = ["b", "d", "a", "c"]
|
104 |
+
basis2 = [bases.BasisDirection(x) for x in basis2]
|
105 |
+
vs1 = bases.VectorSpaceWithBasis(basis1)
|
106 |
+
vs2 = bases.VectorSpaceWithBasis(basis2)
|
107 |
+
v1 = bases.VectorInBasis(basis1, np.array([1, 2, 3, 4]))
|
108 |
+
v2 = bases.VectorInBasis(basis2, np.array([2, 4, 1, 3]))
|
109 |
+
self.assertEqual(v1, v2)
|
110 |
+
self.assertEqual(v1 - v2, vs1.null_vector())
|
111 |
+
self.assertEqual(v1 - v2, vs2.null_vector())
|
112 |
+
self.assertEqual(v1 + v2, 2 * v2)
|
113 |
+
self.assertIn(v1, vs1)
|
114 |
+
self.assertIn(v1, vs2)
|
115 |
+
self.assertIn(v2, vs1)
|
116 |
+
self.assertIn(v2, vs2)
|
117 |
+
|
118 |
+
def test_compare_vector_arrays_with_differently_ordered_basis_vectors(self):
|
119 |
+
basis1 = ["a", "b", "c", "d"]
|
120 |
+
basis1 = [bases.BasisDirection(x) for x in basis1]
|
121 |
+
basis2 = ["b", "d", "a", "c"]
|
122 |
+
basis2 = [bases.BasisDirection(x) for x in basis2]
|
123 |
+
vs1 = bases.VectorSpaceWithBasis(basis1)
|
124 |
+
vs2 = bases.VectorSpaceWithBasis(basis2)
|
125 |
+
v1 = bases.VectorInBasis(basis1, np.array([[1, 2, 3, 4], [5, 6, 7, 8]]))
|
126 |
+
v2 = bases.VectorInBasis(basis2, np.array([[2, 4, 1, 3], [6, 8, 5, 7]]))
|
127 |
+
null_vec = bases.VectorInBasis.stack([vs1.null_vector(), vs2.null_vector()])
|
128 |
+
self.assertEqual(v1, v2)
|
129 |
+
self.assertEqual(v1 - v2, null_vec)
|
130 |
+
self.assertEqual(v1 + v2, 2 * v2)
|
131 |
+
self.assertIn(v1, vs1)
|
132 |
+
self.assertIn(v1, vs2)
|
133 |
+
self.assertIn(v2, vs1)
|
134 |
+
self.assertIn(v2, vs2)
|
135 |
+
|
136 |
+
def test_projection_to_larger_space(self):
|
137 |
+
vs1 = bases.VectorSpaceWithBasis.from_names(["a", "b"])
|
138 |
+
vs2 = bases.VectorSpaceWithBasis.from_names(["a", "b", "c", "d"])
|
139 |
+
a1, b1 = vs1.basis_vectors()
|
140 |
+
a2, b2, _, _ = vs2.basis_vectors()
|
141 |
+
|
142 |
+
self.assertEqual(a1.project(vs2), a2)
|
143 |
+
self.assertEqual(b1.project(vs2), b2)
|
144 |
+
|
145 |
+
def test_projection_to_smaller_space(self):
|
146 |
+
vs1 = bases.VectorSpaceWithBasis.from_names(["a", "b", "c", "d"])
|
147 |
+
vs2 = bases.VectorSpaceWithBasis.from_names(["a", "b"])
|
148 |
+
a1, b1, c1, d1 = vs1.basis_vectors()
|
149 |
+
a2, b2 = vs2.basis_vectors()
|
150 |
+
|
151 |
+
self.assertEqual(a1.project(vs2), a2)
|
152 |
+
self.assertEqual(b1.project(vs2), b2)
|
153 |
+
self.assertEqual(c1.project(vs2), vs2.null_vector())
|
154 |
+
self.assertEqual(d1.project(vs2), vs2.null_vector())
|
155 |
+
|
156 |
+
|
157 |
+
if __name__ == "__main__":
|
158 |
+
absltest.main()
|
craft/chamber/categorical_attn.py
ADDED
@@ -0,0 +1,167 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2022 DeepMind Technologies Limited. All Rights Reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
# ==============================================================================
|
15 |
+
"""Attention head for categorical inputs."""
|
16 |
+
|
17 |
+
from typing import Optional, Protocol
|
18 |
+
|
19 |
+
from tracr.craft import bases
|
20 |
+
from tracr.craft import transformers
|
21 |
+
from tracr.craft import vectorspace_fns
|
22 |
+
|
23 |
+
|
24 |
+
class QueryKeyToAttnLogit(Protocol):
|
25 |
+
|
26 |
+
def __call__(self, query: bases.BasisDirection,
|
27 |
+
key: bases.BasisDirection) -> bool:
|
28 |
+
pass
|
29 |
+
|
30 |
+
|
31 |
+
def categorical_attn(
|
32 |
+
query_space: bases.VectorSpaceWithBasis,
|
33 |
+
key_space: bases.VectorSpaceWithBasis,
|
34 |
+
value_space: bases.VectorSpaceWithBasis,
|
35 |
+
output_space: bases.VectorSpaceWithBasis,
|
36 |
+
bos_space: bases.VectorSpaceWithBasis,
|
37 |
+
one_space: bases.VectorSpaceWithBasis,
|
38 |
+
attn_fn: QueryKeyToAttnLogit,
|
39 |
+
default_output: Optional[bases.VectorInBasis] = None,
|
40 |
+
causal: bool = False,
|
41 |
+
always_attend_to_bos: bool = False,
|
42 |
+
use_bos_for_default_output: bool = True,
|
43 |
+
softmax_coldness: float = 100.,
|
44 |
+
) -> transformers.AttentionHead:
|
45 |
+
"""Returns an attention head for categorical inputs.
|
46 |
+
|
47 |
+
Assumes the existence of a beginning of sequence token and attends to it
|
48 |
+
always with strength 0.5*softmax_coldness. This allows to implement an
|
49 |
+
arbitrary default value for rows in the attention pattern that are all-zero.
|
50 |
+
|
51 |
+
Attends to the BOS token if all other key-query pairs have zero attention.
|
52 |
+
Hence, the first value in the value sequence will be the default output for
|
53 |
+
such cases.
|
54 |
+
|
55 |
+
Args:
|
56 |
+
query_space: Vector space containing (categorical) query input.
|
57 |
+
key_space: Vector space containing (categorical) key input.
|
58 |
+
value_space: Vector space containing (numerical) value input.
|
59 |
+
output_space: Vector space which will contain (numerical) output.
|
60 |
+
bos_space: 1-d space used to identify the beginning of sequence token.
|
61 |
+
one_space: 1-d space which contains 1 at every position.
|
62 |
+
attn_fn: A selector function f(query, key) operating on the query/key basis
|
63 |
+
directions that defines the attention pattern.
|
64 |
+
default_output: Output to return if attention pattern is all zero.
|
65 |
+
causal: If True, use masked attention.
|
66 |
+
always_attend_to_bos: If True, always attend to the BOS token. If False,
|
67 |
+
only attend to BOS when attending to nothing else.
|
68 |
+
use_bos_for_default_output: If True, assume BOS is not in the value space
|
69 |
+
and output a default value when attending to BOS. If False, assume BOS is
|
70 |
+
in the value space, and map it to the output space like any other token.
|
71 |
+
softmax_coldness: The inverse temperature of the softmax. Default value is
|
72 |
+
high which makes the attention close to a hard maximum.
|
73 |
+
"""
|
74 |
+
bases.ensure_dims(bos_space, num_dims=1, name="bos_space")
|
75 |
+
bases.ensure_dims(one_space, num_dims=1, name="one_space")
|
76 |
+
bos_direction = bos_space.basis[0]
|
77 |
+
one_direction = one_space.basis[0]
|
78 |
+
|
79 |
+
# Add bos direction to query, key, and value spaces in case it is missing
|
80 |
+
query_space = bases.join_vector_spaces(query_space, bos_space, one_space)
|
81 |
+
key_space = bases.join_vector_spaces(key_space, bos_space)
|
82 |
+
value_space = bases.join_vector_spaces(value_space, bos_space)
|
83 |
+
|
84 |
+
if always_attend_to_bos:
|
85 |
+
value_basis = value_space.basis
|
86 |
+
else:
|
87 |
+
value_basis = [v for v in value_space.basis if v != bos_direction]
|
88 |
+
assert len(value_basis) == output_space.num_dims
|
89 |
+
value_to_output = dict(zip(value_basis, output_space.basis))
|
90 |
+
|
91 |
+
if default_output is None:
|
92 |
+
default_output = output_space.null_vector()
|
93 |
+
assert default_output in output_space
|
94 |
+
|
95 |
+
def qk_fun(query: bases.BasisDirection, key: bases.BasisDirection) -> float:
|
96 |
+
|
97 |
+
# We want to enforce the following property on our attention patterns:
|
98 |
+
# - if nothing else is attended to, attend to the BOS token.
|
99 |
+
# - otherwise, don't attend to the BOS token.
|
100 |
+
#
|
101 |
+
# We assume that the BOS position always only contains the vector bos + one,
|
102 |
+
# and that any other position has bos coefficient 0.
|
103 |
+
#
|
104 |
+
# We do this as follows:
|
105 |
+
# Let Q and K be subspaces of V containing the query and key vectors,
|
106 |
+
# both disjoint with the BOS space {bos} or the one space {one}.
|
107 |
+
# Suppose we have an attn_fn which defines a bilinear W_QK: V x V -> ℝ,
|
108 |
+
# s.t. W_QK(q, k) = 0 whenever either q or k are bos or one.
|
109 |
+
#
|
110 |
+
# Then define W_new: V x V -> ℝ st:
|
111 |
+
# W_new(one, bos) = 0.5, otherwise 0.
|
112 |
+
#
|
113 |
+
# Now set W_QK' = W_QK + W_new.
|
114 |
+
#
|
115 |
+
# To evaluate the attention to the BOS position:
|
116 |
+
# W_QK'(q, bos + one)
|
117 |
+
# = W_QK'(q, bos) + W_QK'(q, one)
|
118 |
+
# = W_QK(q, bos) + W_QK(q, one) + W_new(q, bos) + W_new(q, one)
|
119 |
+
# = 0 + 0 + W_new(q, bos) + W_new(q, one)
|
120 |
+
# = W_new(q, bos) + W_new(q, one)
|
121 |
+
# = W_new(q' + one, bos) + W_new(q' + one, one) where q = one + q'
|
122 |
+
# = W_new(q', bos) + W_new(one, bos) + W_new(q', one) + W_new(one, one)
|
123 |
+
# = 0 + 0.5 + 0 + 0
|
124 |
+
# = 0.5
|
125 |
+
#
|
126 |
+
# To evaluate the attention to a non-BOS position:
|
127 |
+
# W_QK'(0 * bos + q, 0 * bos + k) # s.t. q ∈ Q+{one}, k ∈ K+{one}
|
128 |
+
# = 0*W_QK'(bos, 0*bos + k) + W_QK'(q, 0*bos + k)
|
129 |
+
# = W_QK'(q, 0*bos + k)
|
130 |
+
# = 0*W_QK'(q, bos) + W_QK'(q, k)
|
131 |
+
# = W_QK'(q, k)
|
132 |
+
# = W_QK(q, k) since W_QK' = W_QK on inputs not containing bos.
|
133 |
+
# = W_QK(q', k') since W_QK(x, y) = 0 whenever x or y are one.
|
134 |
+
#
|
135 |
+
# Since W_QK(q, k) takes values in 0, 1, a sufficiently high softmax
|
136 |
+
# coldness will give us the desired property. QED
|
137 |
+
#
|
138 |
+
# The following implements this idea.
|
139 |
+
# By replacing 0.5 with 1, we can instead enforce a different property: that
|
140 |
+
# the BOS token is always attended to in addition to whatever else.
|
141 |
+
|
142 |
+
if key == bos_direction and query == one_direction:
|
143 |
+
c = 1. if always_attend_to_bos else 0.5
|
144 |
+
return c * softmax_coldness
|
145 |
+
elif {key, query}.intersection({one_direction, bos_direction}):
|
146 |
+
return 0
|
147 |
+
|
148 |
+
return softmax_coldness * attn_fn(query, key)
|
149 |
+
|
150 |
+
w_qk = vectorspace_fns.ScalarBilinear.from_action(
|
151 |
+
query_space,
|
152 |
+
key_space,
|
153 |
+
qk_fun,
|
154 |
+
)
|
155 |
+
|
156 |
+
def ov_fun(input_dir: bases.BasisDirection) -> bases.VectorInBasis:
|
157 |
+
if use_bos_for_default_output and input_dir == bos_direction:
|
158 |
+
return default_output
|
159 |
+
return output_space.vector_from_basis_direction(value_to_output[input_dir])
|
160 |
+
|
161 |
+
w_ov = vectorspace_fns.Linear.from_action(
|
162 |
+
value_space,
|
163 |
+
output_space,
|
164 |
+
ov_fun,
|
165 |
+
)
|
166 |
+
|
167 |
+
return transformers.AttentionHead(w_qk, w_ov, causal=causal)
|
craft/chamber/categorical_attn_test.py
ADDED
@@ -0,0 +1,229 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2022 DeepMind Technologies Limited. All Rights Reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
# ==============================================================================
|
15 |
+
"""Tests for chamber.categorical_attn."""
|
16 |
+
|
17 |
+
from absl.testing import absltest
|
18 |
+
from absl.testing import parameterized
|
19 |
+
import numpy as np
|
20 |
+
from tracr.craft import bases
|
21 |
+
from tracr.craft import tests_common
|
22 |
+
from tracr.craft.chamber import categorical_attn
|
23 |
+
|
24 |
+
|
25 |
+
class CategoricalAttnTest(tests_common.VectorFnTestCase):
|
26 |
+
|
27 |
+
@parameterized.parameters([
|
28 |
+
dict(causal=False, input_seq=[1, 2, 3, 4, 5], result_seq=[3, 3, 3, 3, 3]),
|
29 |
+
dict(
|
30 |
+
causal=True,
|
31 |
+
input_seq=[1, 2, 3, 4, 5],
|
32 |
+
result_seq=[1, 1.5, 2, 2.5, 3]),
|
33 |
+
dict(causal=False, input_seq=[10], result_seq=[10]),
|
34 |
+
dict(causal=True, input_seq=[10], result_seq=[10]),
|
35 |
+
dict(causal=False, input_seq=[-1, 0, 1], result_seq=[0, 0, 0]),
|
36 |
+
dict(causal=True, input_seq=[-1, 0, 1], result_seq=[-1, -0.5, 0]),
|
37 |
+
])
|
38 |
+
def test_categorical_attn_can_implement_select_all(self, causal, input_seq,
|
39 |
+
result_seq):
|
40 |
+
vocab = range(-20, 20)
|
41 |
+
input_space = bases.VectorSpaceWithBasis.from_values("input", vocab)
|
42 |
+
|
43 |
+
output_dir = bases.BasisDirection("output")
|
44 |
+
output_space = bases.VectorSpaceWithBasis([output_dir])
|
45 |
+
output_vec = output_space.vector_from_basis_direction(output_dir)
|
46 |
+
|
47 |
+
bos_dir = bases.BasisDirection("bos_dimension")
|
48 |
+
bos_space = bases.VectorSpaceWithBasis([bos_dir])
|
49 |
+
|
50 |
+
one_dir = bases.BasisDirection("one")
|
51 |
+
one_space = bases.VectorSpaceWithBasis([one_dir])
|
52 |
+
|
53 |
+
value_dir = bases.BasisDirection("value")
|
54 |
+
value_space = bases.VectorSpaceWithBasis([value_dir])
|
55 |
+
|
56 |
+
input_space = bases.join_vector_spaces(input_space, bos_space, one_space)
|
57 |
+
value_space = bases.join_vector_spaces(value_space, bos_space)
|
58 |
+
residual_space = bases.join_vector_spaces(input_space, value_space,
|
59 |
+
output_space)
|
60 |
+
one_vec = residual_space.vector_from_basis_direction(one_dir)
|
61 |
+
bos_vec = residual_space.vector_from_basis_direction(bos_dir)
|
62 |
+
value_vec = residual_space.vector_from_basis_direction(value_dir)
|
63 |
+
|
64 |
+
attn = categorical_attn.categorical_attn(
|
65 |
+
key_space=input_space,
|
66 |
+
query_space=input_space,
|
67 |
+
value_space=value_space,
|
68 |
+
output_space=output_space,
|
69 |
+
bos_space=bos_space,
|
70 |
+
one_space=one_space,
|
71 |
+
attn_fn=lambda x, y: True,
|
72 |
+
causal=causal)
|
73 |
+
|
74 |
+
test_inputs = [bos_vec + one_vec]
|
75 |
+
for x in input_seq:
|
76 |
+
test_inputs.append(
|
77 |
+
residual_space.vector_from_basis_direction(
|
78 |
+
bases.BasisDirection("input", x)) + x * value_vec)
|
79 |
+
test_inputs = bases.VectorInBasis.stack(test_inputs)
|
80 |
+
|
81 |
+
# Expect the average of all (previous) tokens
|
82 |
+
expected_results = [x * output_vec for x in result_seq]
|
83 |
+
expected_results = bases.VectorInBasis.stack(expected_results)
|
84 |
+
|
85 |
+
test_outputs = attn.apply(test_inputs).project(output_space)
|
86 |
+
|
87 |
+
self.assertVectorAllClose(
|
88 |
+
tests_common.strip_bos_token(test_outputs), expected_results)
|
89 |
+
|
90 |
+
@parameterized.parameters([
|
91 |
+
dict(causal=False, input_seq=[1, 2, 3, 4, 5], default=0),
|
92 |
+
dict(causal=True, input_seq=[1, 2, 3, 4, 5], default=1),
|
93 |
+
dict(causal=False, input_seq=[10], default=2),
|
94 |
+
dict(causal=True, input_seq=[10], default=-3),
|
95 |
+
dict(causal=False, input_seq=[-1, 0, 1], default=-2),
|
96 |
+
dict(causal=True, input_seq=[-1, 0, 1], default=-1),
|
97 |
+
])
|
98 |
+
def test_categorical_attn_can_implement_select_none(self, causal, input_seq,
|
99 |
+
default):
|
100 |
+
vocab = range(-20, 20)
|
101 |
+
input_space = bases.VectorSpaceWithBasis.from_values("input", vocab)
|
102 |
+
|
103 |
+
output_dir = bases.BasisDirection("output")
|
104 |
+
output_space = bases.VectorSpaceWithBasis([output_dir])
|
105 |
+
default_vec = default * output_space.vector_from_basis_direction(output_dir)
|
106 |
+
|
107 |
+
bos_dir = bases.BasisDirection("bos_dimension")
|
108 |
+
bos_space = bases.VectorSpaceWithBasis([bos_dir])
|
109 |
+
|
110 |
+
one_dir = bases.BasisDirection("one")
|
111 |
+
one_space = bases.VectorSpaceWithBasis([one_dir])
|
112 |
+
|
113 |
+
value_dir = bases.BasisDirection("value")
|
114 |
+
value_space = bases.VectorSpaceWithBasis([value_dir])
|
115 |
+
|
116 |
+
input_space = bases.join_vector_spaces(input_space, bos_space, one_space)
|
117 |
+
value_space = bases.join_vector_spaces(value_space, bos_space)
|
118 |
+
residual_space = bases.join_vector_spaces(input_space, value_space,
|
119 |
+
output_space)
|
120 |
+
value_vec = residual_space.vector_from_basis_direction(value_dir)
|
121 |
+
bos_vec = residual_space.vector_from_basis_direction(bos_dir)
|
122 |
+
one_vec = residual_space.vector_from_basis_direction(one_dir)
|
123 |
+
|
124 |
+
attn = categorical_attn.categorical_attn(
|
125 |
+
key_space=input_space,
|
126 |
+
query_space=input_space,
|
127 |
+
value_space=value_space,
|
128 |
+
output_space=output_space,
|
129 |
+
bos_space=bos_space,
|
130 |
+
one_space=one_space,
|
131 |
+
attn_fn=lambda x, y: False,
|
132 |
+
default_output=default_vec,
|
133 |
+
causal=causal,
|
134 |
+
always_attend_to_bos=False,
|
135 |
+
use_bos_for_default_output=True)
|
136 |
+
|
137 |
+
def make_input(x):
|
138 |
+
return (one_vec + x * value_vec +
|
139 |
+
residual_space.vector_from_basis_direction(
|
140 |
+
bases.BasisDirection("input", x)))
|
141 |
+
|
142 |
+
test_inputs = bases.VectorInBasis.stack([bos_vec + one_vec] +
|
143 |
+
[make_input(x) for x in input_seq])
|
144 |
+
|
145 |
+
# Expect the default value
|
146 |
+
expected_results = [default_vec for x in input_seq]
|
147 |
+
expected_results = bases.VectorInBasis.stack(expected_results)
|
148 |
+
|
149 |
+
test_outputs = attn.apply(test_inputs).project(output_space)
|
150 |
+
|
151 |
+
self.assertVectorAllClose(
|
152 |
+
tests_common.strip_bos_token(test_outputs), expected_results)
|
153 |
+
|
154 |
+
@parameterized.parameters([
|
155 |
+
dict(num_counts=5, input_seq=[1, 4, 3, 2], n=1, result=[4, 3, 2, 1]),
|
156 |
+
dict(num_counts=10, input_seq=[5, 8, 9, 2], n=3, result=[2, 5, 8, 9])
|
157 |
+
])
|
158 |
+
def test_categorical_attn_can_implement_shift_by_n(self, num_counts,
|
159 |
+
input_seq, n, result):
|
160 |
+
query_prefix = "prefix1"
|
161 |
+
key_prefix = "prefix2"
|
162 |
+
agg_input_prefix = "prefix3"
|
163 |
+
output_prefix = "prefix4"
|
164 |
+
|
165 |
+
bos_direction = bases.BasisDirection("bos")
|
166 |
+
one_direction = bases.BasisDirection("one")
|
167 |
+
query_space = bases.VectorSpaceWithBasis.from_values(
|
168 |
+
query_prefix, range(num_counts))
|
169 |
+
key_space = bases.VectorSpaceWithBasis.from_values(key_prefix,
|
170 |
+
range(num_counts))
|
171 |
+
bos_space = bases.VectorSpaceWithBasis([bos_direction])
|
172 |
+
one_space = bases.VectorSpaceWithBasis([one_direction])
|
173 |
+
key_space = bases.join_vector_spaces(key_space, bos_space)
|
174 |
+
|
175 |
+
agg_input_space = bases.VectorSpaceWithBasis.from_values(
|
176 |
+
agg_input_prefix, range(num_counts))
|
177 |
+
agg_input_space = bases.join_vector_spaces(agg_input_space, bos_space)
|
178 |
+
output_space = bases.VectorSpaceWithBasis.from_values(
|
179 |
+
output_prefix, range(num_counts))
|
180 |
+
|
181 |
+
attn = categorical_attn.categorical_attn(
|
182 |
+
query_space=query_space,
|
183 |
+
key_space=key_space,
|
184 |
+
value_space=agg_input_space,
|
185 |
+
output_space=output_space,
|
186 |
+
bos_space=bos_space,
|
187 |
+
one_space=one_space,
|
188 |
+
attn_fn=lambda q, k: q.value == k.value,
|
189 |
+
default_output=None,
|
190 |
+
always_attend_to_bos=False,
|
191 |
+
use_bos_for_default_output=True,
|
192 |
+
causal=False)
|
193 |
+
|
194 |
+
residual_space = bases.join_vector_spaces(key_space, query_space,
|
195 |
+
agg_input_space, output_space,
|
196 |
+
one_space)
|
197 |
+
|
198 |
+
seq_len = len(input_seq)
|
199 |
+
query_seq = np.arange(n, seq_len + n) % seq_len
|
200 |
+
key_seq = np.arange(seq_len)
|
201 |
+
|
202 |
+
bos_vec = residual_space.vector_from_basis_direction(bos_direction)
|
203 |
+
one_vec = residual_space.vector_from_basis_direction(one_direction)
|
204 |
+
|
205 |
+
test_inputs = [bos_vec + one_vec]
|
206 |
+
expected_results = []
|
207 |
+
for i in range(seq_len):
|
208 |
+
test_inputs.append(
|
209 |
+
residual_space.vector_from_basis_direction(
|
210 |
+
bases.BasisDirection(query_prefix, query_seq[i])) +
|
211 |
+
residual_space.vector_from_basis_direction(
|
212 |
+
bases.BasisDirection(key_prefix, key_seq[i])) +
|
213 |
+
residual_space.vector_from_basis_direction(
|
214 |
+
bases.BasisDirection(agg_input_prefix, input_seq[i])))
|
215 |
+
expected_results.append(
|
216 |
+
residual_space.vector_from_basis_direction(
|
217 |
+
bases.BasisDirection(output_prefix, result[i])))
|
218 |
+
|
219 |
+
test_inputs = bases.VectorInBasis.stack(test_inputs)
|
220 |
+
expected_results = bases.VectorInBasis.stack(expected_results)
|
221 |
+
|
222 |
+
test_outputs = attn.apply(test_inputs)
|
223 |
+
|
224 |
+
self.assertVectorAllClose(
|
225 |
+
tests_common.strip_bos_token(test_outputs), expected_results)
|
226 |
+
|
227 |
+
|
228 |
+
if __name__ == "__main__":
|
229 |
+
absltest.main()
|
craft/chamber/categorical_mlp.py
ADDED
@@ -0,0 +1,168 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2022 DeepMind Technologies Limited. All Rights Reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
# ==============================================================================
|
15 |
+
"""MLP to compute basic linear functions of one-hot encoded integers."""
|
16 |
+
|
17 |
+
from typing import Callable
|
18 |
+
|
19 |
+
import numpy as np
|
20 |
+
|
21 |
+
from tracr.craft import bases
|
22 |
+
from tracr.craft import transformers
|
23 |
+
from tracr.craft import vectorspace_fns
|
24 |
+
|
25 |
+
_ONE_SPACE = bases.VectorSpaceWithBasis.from_names(["one"])
|
26 |
+
|
27 |
+
|
28 |
+
def map_categorical_mlp(
|
29 |
+
input_space: bases.VectorSpaceWithBasis,
|
30 |
+
output_space: bases.VectorSpaceWithBasis,
|
31 |
+
operation: Callable[[bases.BasisDirection], bases.BasisDirection],
|
32 |
+
) -> transformers.MLP:
|
33 |
+
"""Returns an MLP that encodes any categorical function of a single variable f(x).
|
34 |
+
|
35 |
+
The hidden layer is the identity and output combines this with a lookup table
|
36 |
+
output_k = sum(f(i)*input_i for all i in input space)
|
37 |
+
|
38 |
+
Args:
|
39 |
+
input_space: space containing the input x.
|
40 |
+
output_space: space containing possible outputs.
|
41 |
+
operation: A function operating on basis directions.
|
42 |
+
"""
|
43 |
+
|
44 |
+
def operation_fn(direction):
|
45 |
+
if direction in input_space:
|
46 |
+
output_direction = operation(direction)
|
47 |
+
if output_direction in output_space:
|
48 |
+
return output_space.vector_from_basis_direction(output_direction)
|
49 |
+
return output_space.null_vector()
|
50 |
+
|
51 |
+
first_layer = vectorspace_fns.Linear.from_action(input_space, output_space,
|
52 |
+
operation_fn)
|
53 |
+
|
54 |
+
second_layer = vectorspace_fns.project(output_space, output_space)
|
55 |
+
|
56 |
+
return transformers.MLP(first_layer, second_layer)
|
57 |
+
|
58 |
+
|
59 |
+
def map_categorical_to_numerical_mlp(
|
60 |
+
input_space: bases.VectorSpaceWithBasis,
|
61 |
+
output_space: bases.VectorSpaceWithBasis,
|
62 |
+
operation: Callable[[bases.Value], float],
|
63 |
+
) -> transformers.MLP:
|
64 |
+
"""Returns an MLP to compute f(x) from a categorical to a numerical variable.
|
65 |
+
|
66 |
+
The hidden layer is the identity and output combines this with a lookup table
|
67 |
+
output = sum(f(i)*input_i for all i in input space)
|
68 |
+
|
69 |
+
Args:
|
70 |
+
input_space: Vector space containing the input x.
|
71 |
+
output_space: Vector space to write the numerical output to.
|
72 |
+
operation: A function operating on basis directions.
|
73 |
+
"""
|
74 |
+
bases.ensure_dims(output_space, num_dims=1, name="output_space")
|
75 |
+
out_vec = output_space.vector_from_basis_direction(output_space.basis[0])
|
76 |
+
|
77 |
+
def operation_fn(direction):
|
78 |
+
if direction in input_space:
|
79 |
+
return operation(direction.value) * out_vec
|
80 |
+
return output_space.null_vector()
|
81 |
+
|
82 |
+
first_layer = vectorspace_fns.Linear.from_action(input_space, output_space,
|
83 |
+
operation_fn)
|
84 |
+
|
85 |
+
second_layer = vectorspace_fns.project(output_space, output_space)
|
86 |
+
|
87 |
+
return transformers.MLP(first_layer, second_layer)
|
88 |
+
|
89 |
+
|
90 |
+
def sequence_map_categorical_mlp(
|
91 |
+
input1_space: bases.VectorSpaceWithBasis,
|
92 |
+
input2_space: bases.VectorSpaceWithBasis,
|
93 |
+
output_space: bases.VectorSpaceWithBasis,
|
94 |
+
operation: Callable[[bases.BasisDirection, bases.BasisDirection],
|
95 |
+
bases.BasisDirection],
|
96 |
+
one_space: bases.VectorSpaceWithBasis = _ONE_SPACE,
|
97 |
+
hidden_name: bases.Name = "__hidden__",
|
98 |
+
) -> transformers.MLP:
|
99 |
+
"""Returns an MLP that encodes a categorical function of two variables f(x, y).
|
100 |
+
|
101 |
+
The hidden layer of the MLP computes the logical and of all input directions
|
102 |
+
hidden_i_j = ReLU(x_i+x_j-1)
|
103 |
+
|
104 |
+
And the output combines this with a lookup table
|
105 |
+
output_k = sum(f(i, j)*hidden_i_j for all i,j in input space)
|
106 |
+
|
107 |
+
Args:
|
108 |
+
input1_space: Vector space containing the input x.
|
109 |
+
input2_space: Vector space containing the input y.
|
110 |
+
output_space: Vector space to write outputs to.
|
111 |
+
operation: A function operating on basis directions.
|
112 |
+
one_space: a reserved 1-d space that always contains a 1.
|
113 |
+
hidden_name: Name for hidden dimensions.
|
114 |
+
"""
|
115 |
+
bases.ensure_dims(one_space, num_dims=1, name="one_space")
|
116 |
+
|
117 |
+
if not set(input1_space.basis).isdisjoint(input2_space.basis):
|
118 |
+
raise ValueError("Input spaces to a SequenceMap must be disjoint. "
|
119 |
+
"If input spaces are the same, use Map instead!")
|
120 |
+
|
121 |
+
input_space = bases.direct_sum(input1_space, input2_space, one_space)
|
122 |
+
|
123 |
+
def to_hidden(x, y):
|
124 |
+
return bases.BasisDirection(hidden_name, (x.name, x.value, y.name, y.value))
|
125 |
+
|
126 |
+
def from_hidden(h):
|
127 |
+
x_name, x_value, y_name, y_value = h.value
|
128 |
+
x_dir = bases.BasisDirection(x_name, x_value)
|
129 |
+
y_dir = bases.BasisDirection(y_name, y_value)
|
130 |
+
return x_dir, y_dir
|
131 |
+
|
132 |
+
hidden_dir = []
|
133 |
+
for dir1 in input1_space.basis:
|
134 |
+
for dir2 in input2_space.basis:
|
135 |
+
hidden_dir.append(to_hidden(dir1, dir2))
|
136 |
+
hidden_space = bases.VectorSpaceWithBasis(hidden_dir)
|
137 |
+
|
138 |
+
def logical_and(direction):
|
139 |
+
if direction in one_space:
|
140 |
+
out = bases.VectorInBasis(hidden_space.basis,
|
141 |
+
-np.ones(hidden_space.num_dims))
|
142 |
+
elif direction in input1_space:
|
143 |
+
dir1 = direction
|
144 |
+
out = hidden_space.null_vector()
|
145 |
+
for dir2 in input2_space.basis:
|
146 |
+
out += hidden_space.vector_from_basis_direction(to_hidden(dir1, dir2))
|
147 |
+
else:
|
148 |
+
dir2 = direction
|
149 |
+
out = hidden_space.null_vector()
|
150 |
+
for dir1 in input1_space.basis:
|
151 |
+
out += hidden_space.vector_from_basis_direction(to_hidden(dir1, dir2))
|
152 |
+
return out
|
153 |
+
|
154 |
+
first_layer = vectorspace_fns.Linear.from_action(input_space, hidden_space,
|
155 |
+
logical_and)
|
156 |
+
|
157 |
+
def operation_fn(direction):
|
158 |
+
dir1, dir2 = from_hidden(direction)
|
159 |
+
output_direction = operation(dir1, dir2)
|
160 |
+
if output_direction in output_space:
|
161 |
+
return output_space.vector_from_basis_direction(output_direction)
|
162 |
+
else:
|
163 |
+
return output_space.null_vector()
|
164 |
+
|
165 |
+
second_layer = vectorspace_fns.Linear.from_action(hidden_space, output_space,
|
166 |
+
operation_fn)
|
167 |
+
|
168 |
+
return transformers.MLP(first_layer, second_layer)
|
craft/chamber/categorical_mlp_test.py
ADDED
@@ -0,0 +1,164 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2022 DeepMind Technologies Limited. All Rights Reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
# ==============================================================================
|
15 |
+
"""Tests for chamber.categorical_mlp."""
|
16 |
+
|
17 |
+
import math
|
18 |
+
from absl.testing import absltest
|
19 |
+
from absl.testing import parameterized
|
20 |
+
|
21 |
+
from tracr.craft import bases
|
22 |
+
from tracr.craft import tests_common
|
23 |
+
from tracr.craft.chamber import categorical_mlp
|
24 |
+
|
25 |
+
|
26 |
+
class CategoricalInputMlpTest(tests_common.VectorFnTestCase):
|
27 |
+
|
28 |
+
@parameterized.parameters([
|
29 |
+
dict(num_counts=4, x=1, y=2, fun=lambda x, y: x + y, result=3),
|
30 |
+
dict(num_counts=4, x=1, y=0, fun=lambda x, y: x + y + 1, result=2),
|
31 |
+
dict(num_counts=5, x=2, y=1, fun=math.pow, result=2),
|
32 |
+
dict(num_counts=5, x=2, y=2, fun=math.pow, result=4),
|
33 |
+
])
|
34 |
+
def test_seq_map_categorical_mlp_produces_expected_outcome(
|
35 |
+
self, num_counts, x, y, fun, result):
|
36 |
+
input1_name = "in1"
|
37 |
+
input2_name = "in2"
|
38 |
+
output_name = "out"
|
39 |
+
one_name = "one_dimension"
|
40 |
+
|
41 |
+
in1_space = bases.VectorSpaceWithBasis.from_values(input1_name,
|
42 |
+
range(num_counts + 1))
|
43 |
+
in2_space = bases.VectorSpaceWithBasis.from_values(input2_name,
|
44 |
+
range(num_counts + 1))
|
45 |
+
out_space = bases.VectorSpaceWithBasis.from_values(output_name,
|
46 |
+
range(num_counts + 1))
|
47 |
+
|
48 |
+
def operation(in1, in2):
|
49 |
+
out_val = fun(int(in1.value), int(in2.value))
|
50 |
+
return bases.BasisDirection(output_name, out_val)
|
51 |
+
|
52 |
+
mlp = categorical_mlp.sequence_map_categorical_mlp(
|
53 |
+
input1_space=in1_space,
|
54 |
+
input2_space=in2_space,
|
55 |
+
output_space=out_space,
|
56 |
+
operation=operation,
|
57 |
+
one_space=bases.VectorSpaceWithBasis.from_names([one_name]))
|
58 |
+
|
59 |
+
test_inputs = (
|
60 |
+
mlp.residual_space.vector_from_basis_direction(
|
61 |
+
bases.BasisDirection(one_name)) +
|
62 |
+
mlp.residual_space.vector_from_basis_direction(
|
63 |
+
bases.BasisDirection(input1_name, x)) +
|
64 |
+
mlp.residual_space.vector_from_basis_direction(
|
65 |
+
bases.BasisDirection(input2_name, y)))
|
66 |
+
|
67 |
+
expected_results = mlp.residual_space.vector_from_basis_direction(
|
68 |
+
bases.BasisDirection(output_name, result))
|
69 |
+
|
70 |
+
test_outputs = mlp.apply(test_inputs)
|
71 |
+
|
72 |
+
self.assertVectorAllClose(test_outputs, expected_results)
|
73 |
+
|
74 |
+
def test_seq_map_categorical_mlp_raises_error_with_overlapping_inputs(self):
|
75 |
+
input_name = "in"
|
76 |
+
output_name = "out"
|
77 |
+
one_name = "one_dimension"
|
78 |
+
|
79 |
+
in1_space = bases.VectorSpaceWithBasis.from_values(input_name, range(5))
|
80 |
+
in2_space = bases.VectorSpaceWithBasis.from_values(input_name, range(3, 10))
|
81 |
+
out_space = bases.VectorSpaceWithBasis.from_values(output_name, range(5))
|
82 |
+
|
83 |
+
with self.assertRaisesRegex(
|
84 |
+
ValueError, r".*Input spaces to a SequenceMap must be disjoint.*"):
|
85 |
+
categorical_mlp.sequence_map_categorical_mlp(
|
86 |
+
input1_space=in1_space,
|
87 |
+
input2_space=in1_space,
|
88 |
+
output_space=out_space,
|
89 |
+
operation=lambda x, y: bases.BasisDirection(output_name, 0),
|
90 |
+
one_space=bases.VectorSpaceWithBasis.from_names([one_name]))
|
91 |
+
|
92 |
+
with self.assertRaisesRegex(
|
93 |
+
ValueError, r".*Input spaces to a SequenceMap must be disjoint.*"):
|
94 |
+
categorical_mlp.sequence_map_categorical_mlp(
|
95 |
+
input1_space=in1_space,
|
96 |
+
input2_space=in2_space,
|
97 |
+
output_space=out_space,
|
98 |
+
operation=lambda x, y: bases.BasisDirection(output_name, 0),
|
99 |
+
one_space=bases.VectorSpaceWithBasis.from_names([one_name]))
|
100 |
+
|
101 |
+
@parameterized.parameters([
|
102 |
+
dict(num_counts=5, x=2, fun=lambda x: x, result=2),
|
103 |
+
dict(num_counts=5, x=2, fun=lambda x: math.pow(x, int(2)), result=4),
|
104 |
+
dict(num_counts=5, x=-2, fun=lambda x: math.pow(x, int(2)), result=4),
|
105 |
+
dict(num_counts=5, x=-1, fun=lambda x: math.pow(x, int(3)), result=-1),
|
106 |
+
])
|
107 |
+
def test_map_categorical_mlp_produces_expected_outcome_computing_powers(
|
108 |
+
self, num_counts, x, fun, result):
|
109 |
+
input_name = "in"
|
110 |
+
output_name = "out"
|
111 |
+
|
112 |
+
in_space = bases.VectorSpaceWithBasis.from_values(
|
113 |
+
input_name, range(-num_counts, num_counts + 1))
|
114 |
+
out_space = bases.VectorSpaceWithBasis.from_values(
|
115 |
+
output_name, range(-num_counts, num_counts + 1))
|
116 |
+
|
117 |
+
def operation(direction):
|
118 |
+
out_val = fun(int(direction.value))
|
119 |
+
return bases.BasisDirection(output_name, out_val)
|
120 |
+
|
121 |
+
mlp = categorical_mlp.map_categorical_mlp(
|
122 |
+
input_space=in_space, output_space=out_space, operation=operation)
|
123 |
+
|
124 |
+
test_inputs = mlp.residual_space.vector_from_basis_direction(
|
125 |
+
bases.BasisDirection(input_name, x))
|
126 |
+
|
127 |
+
expected_results = mlp.residual_space.vector_from_basis_direction(
|
128 |
+
bases.BasisDirection(output_name, result))
|
129 |
+
|
130 |
+
test_outputs = mlp.apply(test_inputs)
|
131 |
+
|
132 |
+
self.assertVectorAllClose(test_outputs, expected_results)
|
133 |
+
|
134 |
+
@parameterized.parameters([
|
135 |
+
dict(x=2, fun=lambda x: x, result=2),
|
136 |
+
dict(x=2, fun=lambda x: math.pow(x, int(2)), result=4),
|
137 |
+
dict(x=1, fun=lambda x: 1 / (x + 1), result=0.5),
|
138 |
+
dict(x=3, fun=lambda x: 1 / (x + 1), result=0.25),
|
139 |
+
])
|
140 |
+
def test_map_categorical_to_numerical_mlp_produces_expected_outcome(
|
141 |
+
self, x, fun, result):
|
142 |
+
|
143 |
+
in_space = bases.VectorSpaceWithBasis.from_values("in", range(6))
|
144 |
+
out_space = bases.VectorSpaceWithBasis.from_names(["out"])
|
145 |
+
|
146 |
+
mlp = categorical_mlp.map_categorical_to_numerical_mlp(
|
147 |
+
input_space=in_space,
|
148 |
+
output_space=out_space,
|
149 |
+
operation=fun,
|
150 |
+
)
|
151 |
+
|
152 |
+
test_inputs = mlp.residual_space.vector_from_basis_direction(
|
153 |
+
bases.BasisDirection("in", x))
|
154 |
+
|
155 |
+
expected_results = result * mlp.residual_space.vector_from_basis_direction(
|
156 |
+
bases.BasisDirection("out"))
|
157 |
+
|
158 |
+
test_outputs = mlp.apply(test_inputs)
|
159 |
+
|
160 |
+
self.assertVectorAllClose(test_outputs, expected_results)
|
161 |
+
|
162 |
+
|
163 |
+
if __name__ == "__main__":
|
164 |
+
absltest.main()
|
craft/chamber/numerical_mlp.py
ADDED
@@ -0,0 +1,334 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2022 DeepMind Technologies Limited. All Rights Reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
# ==============================================================================
|
15 |
+
"""MLPs to compute arbitrary numerical functions by discretising."""
|
16 |
+
|
17 |
+
import dataclasses
|
18 |
+
|
19 |
+
from typing import Callable, Iterable
|
20 |
+
|
21 |
+
from tracr.craft import bases
|
22 |
+
from tracr.craft import transformers
|
23 |
+
from tracr.craft import vectorspace_fns
|
24 |
+
from tracr.utils import errors
|
25 |
+
|
26 |
+
|
27 |
+
@dataclasses.dataclass
|
28 |
+
class DiscretisingLayerMaterials:
|
29 |
+
"""Provides components for a hidden layer that discretises the input.
|
30 |
+
|
31 |
+
Attributes:
|
32 |
+
action: Function acting on basis directions that defines the computation.
|
33 |
+
hidden_space: Vector space of the hidden representation of the layer.
|
34 |
+
output_values: Set of output values that correspond to the discretisation.
|
35 |
+
"""
|
36 |
+
action: Callable[[bases.BasisDirection], bases.VectorInBasis]
|
37 |
+
hidden_space: bases.VectorSpaceWithBasis
|
38 |
+
output_values: list[float]
|
39 |
+
|
40 |
+
|
41 |
+
def _get_discretising_layer(input_value_set: Iterable[float],
|
42 |
+
f: Callable[[float],
|
43 |
+
float], hidden_name: bases.Name,
|
44 |
+
one_direction: bases.BasisDirection,
|
45 |
+
large_number: float) -> DiscretisingLayerMaterials:
|
46 |
+
"""Creates a hidden layer that discretises the input of f(x) into a value set.
|
47 |
+
|
48 |
+
The input is split up into a distinct region around each value in
|
49 |
+
`input_value_set`:
|
50 |
+
|
51 |
+
elements of value set: v0 | v1 | v2 | v3 | v4 | ...
|
52 |
+
thresholds: t0 t1 t2 t3 t4
|
53 |
+
|
54 |
+
The hidden layer has two activations per threshold:
|
55 |
+
hidden_k_1 = ReLU(L * (x - threshold[k]) + 1)
|
56 |
+
hidden_k_2 = ReLU(L * (x - threshold[k]))
|
57 |
+
|
58 |
+
Note that hidden_k_1 - hidden_k_2 is:
|
59 |
+
1 if x >= threshold[k] + 1/L
|
60 |
+
0 if x <= threshold[k]
|
61 |
+
between 0 and 1 if threshold[k] < x < threshold[k] + 1/L
|
62 |
+
|
63 |
+
So as long as we choose L a big enough number, we have
|
64 |
+
hidden_k_1 - hidden_k_2 = 1 if x >= threshold[k].
|
65 |
+
i.e. we know in which region the input value is.
|
66 |
+
|
67 |
+
Args:
|
68 |
+
input_value_set: Set of discrete input values.
|
69 |
+
f: Function to approximate.
|
70 |
+
hidden_name: Name for hidden dimensions.
|
71 |
+
one_direction: Auxiliary dimension that must contain 1 in the input.
|
72 |
+
large_number: Large number L that determines accuracy of the computation.
|
73 |
+
|
74 |
+
Returns:
|
75 |
+
DiscretisingLayerMaterials containing all components for the layer.
|
76 |
+
"""
|
77 |
+
output_values, sorted_values = [], []
|
78 |
+
for x in sorted(input_value_set):
|
79 |
+
res = errors.ignoring_arithmetic_errors(f)(x)
|
80 |
+
if res is not None:
|
81 |
+
output_values.append(res)
|
82 |
+
sorted_values.append(x)
|
83 |
+
|
84 |
+
num_vals = len(sorted_values)
|
85 |
+
value_thresholds = [
|
86 |
+
(sorted_values[i] + sorted_values[i + 1]) / 2 for i in range(num_vals - 1)
|
87 |
+
]
|
88 |
+
|
89 |
+
hidden_directions = [bases.BasisDirection(f"{hidden_name}start")]
|
90 |
+
for k in range(1, num_vals):
|
91 |
+
dir0 = bases.BasisDirection(hidden_name, (k, 0))
|
92 |
+
dir1 = bases.BasisDirection(hidden_name, (k, 1))
|
93 |
+
hidden_directions.extend([dir0, dir1])
|
94 |
+
hidden_space = bases.VectorSpaceWithBasis(hidden_directions)
|
95 |
+
|
96 |
+
def action(direction: bases.BasisDirection) -> bases.VectorInBasis:
|
97 |
+
# hidden_k_0 = ReLU(L * (x - threshold[k]) + 1)
|
98 |
+
# hidden_k_1 = ReLU(L * (x - threshold[k]))
|
99 |
+
if direction == one_direction:
|
100 |
+
hidden = hidden_space.vector_from_basis_direction(
|
101 |
+
bases.BasisDirection(f"{hidden_name}start"))
|
102 |
+
else:
|
103 |
+
hidden = hidden_space.null_vector()
|
104 |
+
for k in range(1, num_vals):
|
105 |
+
vec0 = hidden_space.vector_from_basis_direction(
|
106 |
+
bases.BasisDirection(hidden_name, (k, 0)))
|
107 |
+
vec1 = hidden_space.vector_from_basis_direction(
|
108 |
+
bases.BasisDirection(hidden_name, (k, 1)))
|
109 |
+
if direction == one_direction:
|
110 |
+
hidden += (1 - large_number * value_thresholds[k - 1]) * vec0
|
111 |
+
hidden -= large_number * value_thresholds[k - 1] * vec1
|
112 |
+
else:
|
113 |
+
hidden += large_number * vec0 + large_number * vec1
|
114 |
+
return hidden
|
115 |
+
|
116 |
+
return DiscretisingLayerMaterials(
|
117 |
+
action=action, hidden_space=hidden_space, output_values=output_values)
|
118 |
+
|
119 |
+
|
120 |
+
def map_numerical_mlp(
|
121 |
+
f: Callable[[float], float],
|
122 |
+
input_space: bases.VectorSpaceWithBasis,
|
123 |
+
output_space: bases.VectorSpaceWithBasis,
|
124 |
+
input_value_set: Iterable[float],
|
125 |
+
one_space: bases.VectorSpaceWithBasis,
|
126 |
+
large_number: float = 100,
|
127 |
+
hidden_name: bases.Name = "__hidden__",
|
128 |
+
) -> transformers.MLP:
|
129 |
+
"""Returns an MLP that encodes any function of a single variable f(x).
|
130 |
+
|
131 |
+
This is implemented by discretising the input according to input_value_set
|
132 |
+
and defining thresholds that determine which part of the input range will
|
133 |
+
is allocated to which value in input_value_set.
|
134 |
+
|
135 |
+
elements of value set: v0 | v1 | v2 | v3 | v4 | ...
|
136 |
+
thresholds: t0 t1 t2 t3 t4
|
137 |
+
|
138 |
+
The MLP computes two hidden activations per threshold:
|
139 |
+
hidden_k_0 = ReLU(L * (x - threshold[k]) + 1)
|
140 |
+
hidden_k_1 = ReLU(L * (x - threshold[k]))
|
141 |
+
|
142 |
+
Note that hidden_k_1 - hidden_k_2 is:
|
143 |
+
1 if x >= threshold[k] + 1/L
|
144 |
+
0 if x <= threshold[k]
|
145 |
+
between 0 and 1 if threshold[k] < x < threshold[k] + 1/L
|
146 |
+
|
147 |
+
So as long as we choose L a big enough number, we have
|
148 |
+
hidden_k_0 - hidden_k_1 = 1 if x >= threshold[k].
|
149 |
+
|
150 |
+
The MLP then computes the output as:
|
151 |
+
output = f(input[0]) +
|
152 |
+
sum((hidden_k_0 - hidden_k_1) * (f(input[k]) - f(input[k-1]))
|
153 |
+
for all k=0,1,...)
|
154 |
+
|
155 |
+
This sum will be (by a telescoping sums argument)
|
156 |
+
f(input[0]) if x <= threshold[0]
|
157 |
+
f(input[k]) if threshold[k-1] < x <= threshold[k] for some other k
|
158 |
+
f(input[-1]) if x > threshold[-1]
|
159 |
+
which approximates f() up to an accuracy given by input_value_set and L.
|
160 |
+
|
161 |
+
Args:
|
162 |
+
f: Function to approximate.
|
163 |
+
input_space: 1-d vector space that encodes the input x.
|
164 |
+
output_space: 1-d vector space to write the output to.
|
165 |
+
input_value_set: Set of values the input can take.
|
166 |
+
one_space: Auxiliary 1-d vector space that must contain 1 in the input.
|
167 |
+
large_number: Large number L that determines accuracy of the computation.
|
168 |
+
Note that too large values of L can lead to numerical issues, particularly
|
169 |
+
during inference on GPU/TPU.
|
170 |
+
hidden_name: Name for hidden dimensions.
|
171 |
+
"""
|
172 |
+
bases.ensure_dims(input_space, num_dims=1, name="input_space")
|
173 |
+
bases.ensure_dims(output_space, num_dims=1, name="output_space")
|
174 |
+
bases.ensure_dims(one_space, num_dims=1, name="one_space")
|
175 |
+
|
176 |
+
input_space = bases.join_vector_spaces(input_space, one_space)
|
177 |
+
out_vec = output_space.vector_from_basis_direction(output_space.basis[0])
|
178 |
+
|
179 |
+
discretising_layer = _get_discretising_layer(
|
180 |
+
input_value_set=input_value_set,
|
181 |
+
f=f,
|
182 |
+
hidden_name=hidden_name,
|
183 |
+
one_direction=one_space.basis[0],
|
184 |
+
large_number=large_number)
|
185 |
+
first_layer = vectorspace_fns.Linear.from_action(
|
186 |
+
input_space, discretising_layer.hidden_space, discretising_layer.action)
|
187 |
+
|
188 |
+
def second_layer_action(
|
189 |
+
direction: bases.BasisDirection) -> bases.VectorInBasis:
|
190 |
+
# output = sum(
|
191 |
+
# (hidden_k_0 - hidden_k_1) * (f(input[k]) - f(input[k-1]))
|
192 |
+
# for all k)
|
193 |
+
if direction.name == f"{hidden_name}start":
|
194 |
+
return discretising_layer.output_values[0] * out_vec
|
195 |
+
k, i = direction.value
|
196 |
+
# add hidden_k_0 and subtract hidden_k_1
|
197 |
+
sign = {0: 1, 1: -1}[i]
|
198 |
+
return sign * (discretising_layer.output_values[k] -
|
199 |
+
discretising_layer.output_values[k - 1]) * out_vec
|
200 |
+
|
201 |
+
second_layer = vectorspace_fns.Linear.from_action(
|
202 |
+
discretising_layer.hidden_space, output_space, second_layer_action)
|
203 |
+
|
204 |
+
return transformers.MLP(first_layer, second_layer)
|
205 |
+
|
206 |
+
|
207 |
+
def map_numerical_to_categorical_mlp(
|
208 |
+
f: Callable[[float], float],
|
209 |
+
input_space: bases.VectorSpaceWithBasis,
|
210 |
+
output_space: bases.VectorSpaceWithBasis,
|
211 |
+
input_value_set: Iterable[float],
|
212 |
+
one_space: bases.VectorSpaceWithBasis,
|
213 |
+
large_number: float = 100,
|
214 |
+
hidden_name: bases.Name = "__hidden__",
|
215 |
+
) -> transformers.MLP:
|
216 |
+
"""Returns an MLP to compute f(x) from a numerical to a categorical variable.
|
217 |
+
|
218 |
+
Uses a set of possible output values, and rounds f(x) to the closest value
|
219 |
+
in this set to create a categorical output variable.
|
220 |
+
|
221 |
+
The output is discretised the same way as in `map_numerical_mlp`.
|
222 |
+
|
223 |
+
Args:
|
224 |
+
f: Function to approximate.
|
225 |
+
input_space: 1-d vector space that encodes the input x.
|
226 |
+
output_space: n-d vector space to write categorical output to. The output
|
227 |
+
directions need to encode the possible output values.
|
228 |
+
input_value_set: Set of values the input can take.
|
229 |
+
one_space: Auxiliary 1-d space that must contain 1 in the input.
|
230 |
+
large_number: Large number L that determines accuracy of the computation.
|
231 |
+
hidden_name: Name for hidden dimensions.
|
232 |
+
"""
|
233 |
+
bases.ensure_dims(input_space, num_dims=1, name="input_space")
|
234 |
+
bases.ensure_dims(one_space, num_dims=1, name="one_space")
|
235 |
+
|
236 |
+
input_space = bases.join_vector_spaces(input_space, one_space)
|
237 |
+
|
238 |
+
vec_by_out_val = dict()
|
239 |
+
for d in output_space.basis:
|
240 |
+
# TODO(b/255937603): Do a similar assert in other places where we expect
|
241 |
+
# categorical basis directions to encode values.
|
242 |
+
assert d.value is not None, ("output directions need to encode "
|
243 |
+
"possible output values")
|
244 |
+
vec_by_out_val[d.value] = output_space.vector_from_basis_direction(d)
|
245 |
+
|
246 |
+
discretising_layer = _get_discretising_layer(
|
247 |
+
input_value_set=input_value_set,
|
248 |
+
f=f,
|
249 |
+
hidden_name=hidden_name,
|
250 |
+
one_direction=one_space.basis[0],
|
251 |
+
large_number=large_number)
|
252 |
+
|
253 |
+
assert set(discretising_layer.output_values).issubset(
|
254 |
+
set(vec_by_out_val.keys()))
|
255 |
+
|
256 |
+
first_layer = vectorspace_fns.Linear.from_action(
|
257 |
+
input_space, discretising_layer.hidden_space, discretising_layer.action)
|
258 |
+
|
259 |
+
def second_layer_action(
|
260 |
+
direction: bases.BasisDirection) -> bases.VectorInBasis:
|
261 |
+
"""Computes output value and returns corresponding output direction."""
|
262 |
+
if direction.name == f"{hidden_name}start":
|
263 |
+
return vec_by_out_val[discretising_layer.output_values[0]]
|
264 |
+
else:
|
265 |
+
k, i = direction.value
|
266 |
+
# add hidden_k_0 and subtract hidden_k_1
|
267 |
+
sign = {0: 1, 1: -1}[i]
|
268 |
+
out_k = discretising_layer.output_values[k]
|
269 |
+
out_k_m_1 = discretising_layer.output_values[k - 1]
|
270 |
+
return sign * (vec_by_out_val[out_k] - vec_by_out_val[out_k_m_1])
|
271 |
+
|
272 |
+
second_layer = vectorspace_fns.Linear.from_action(
|
273 |
+
discretising_layer.hidden_space, output_space, second_layer_action)
|
274 |
+
|
275 |
+
return transformers.MLP(first_layer, second_layer)
|
276 |
+
|
277 |
+
|
278 |
+
def linear_sequence_map_numerical_mlp(
|
279 |
+
input1_basis_direction: bases.BasisDirection,
|
280 |
+
input2_basis_direction: bases.BasisDirection,
|
281 |
+
output_basis_direction: bases.BasisDirection,
|
282 |
+
input1_factor: float,
|
283 |
+
input2_factor: float,
|
284 |
+
hidden_name: bases.Name = "__hidden__",
|
285 |
+
) -> transformers.MLP:
|
286 |
+
"""Returns an MLP that encodes a linear function f(x, y) = a*x + b*y.
|
287 |
+
|
288 |
+
Args:
|
289 |
+
input1_basis_direction: Basis direction that encodes the input x.
|
290 |
+
input2_basis_direction: Basis direction that encodes the input y.
|
291 |
+
output_basis_direction: Basis direction to write the output to.
|
292 |
+
input1_factor: Linear factor a for input x.
|
293 |
+
input2_factor: Linear factor a for input y.
|
294 |
+
hidden_name: Name for hidden dimensions.
|
295 |
+
"""
|
296 |
+
input_space = bases.VectorSpaceWithBasis(
|
297 |
+
[input1_basis_direction, input2_basis_direction])
|
298 |
+
output_space = bases.VectorSpaceWithBasis([output_basis_direction])
|
299 |
+
out_vec = output_space.vector_from_basis_direction(output_basis_direction)
|
300 |
+
|
301 |
+
hidden_directions = [
|
302 |
+
bases.BasisDirection(f"{hidden_name}x", 1),
|
303 |
+
bases.BasisDirection(f"{hidden_name}x", -1),
|
304 |
+
bases.BasisDirection(f"{hidden_name}y", 1),
|
305 |
+
bases.BasisDirection(f"{hidden_name}y", -1)
|
306 |
+
]
|
307 |
+
hidden_space = bases.VectorSpaceWithBasis(hidden_directions)
|
308 |
+
x_pos_vec, x_neg_vec, y_pos_vec, y_neg_vec = (
|
309 |
+
hidden_space.vector_from_basis_direction(d) for d in hidden_directions)
|
310 |
+
|
311 |
+
def first_layer_action(
|
312 |
+
direction: bases.BasisDirection) -> bases.VectorInBasis:
|
313 |
+
output = hidden_space.null_vector()
|
314 |
+
if direction == input1_basis_direction:
|
315 |
+
output += x_pos_vec - x_neg_vec
|
316 |
+
if direction == input2_basis_direction:
|
317 |
+
output += y_pos_vec - y_neg_vec
|
318 |
+
return output
|
319 |
+
|
320 |
+
first_layer = vectorspace_fns.Linear.from_action(input_space, hidden_space,
|
321 |
+
first_layer_action)
|
322 |
+
|
323 |
+
def second_layer_action(
|
324 |
+
direction: bases.BasisDirection) -> bases.VectorInBasis:
|
325 |
+
if direction.name == f"{hidden_name}x":
|
326 |
+
return input1_factor * direction.value * out_vec
|
327 |
+
if direction.name == f"{hidden_name}y":
|
328 |
+
return input2_factor * direction.value * out_vec
|
329 |
+
return output_space.null_vector()
|
330 |
+
|
331 |
+
second_layer = vectorspace_fns.Linear.from_action(hidden_space, output_space,
|
332 |
+
second_layer_action)
|
333 |
+
|
334 |
+
return transformers.MLP(first_layer, second_layer)
|
craft/chamber/numerical_mlp_test.py
ADDED
@@ -0,0 +1,233 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2022 DeepMind Technologies Limited. All Rights Reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
# ==============================================================================
|
15 |
+
"""Tests for chamber.numerical_mlp."""
|
16 |
+
|
17 |
+
from absl.testing import absltest
|
18 |
+
from absl.testing import parameterized
|
19 |
+
import numpy as np
|
20 |
+
from tracr.craft import bases
|
21 |
+
from tracr.craft import tests_common
|
22 |
+
from tracr.craft.chamber import numerical_mlp
|
23 |
+
from tracr.utils import errors
|
24 |
+
|
25 |
+
|
26 |
+
class NumericalMlpTest(tests_common.VectorFnTestCase):
|
27 |
+
|
28 |
+
@parameterized.parameters([
|
29 |
+
dict(
|
30 |
+
in_value_set={-2, -2, -1, 0, 1, 2, 3},
|
31 |
+
x=2,
|
32 |
+
function=lambda x: x,
|
33 |
+
result=2),
|
34 |
+
dict(
|
35 |
+
in_value_set={-2, -2, -1, 0, 1, 2, 3},
|
36 |
+
x=2,
|
37 |
+
function=lambda x: x**2,
|
38 |
+
result=4),
|
39 |
+
dict(
|
40 |
+
in_value_set={-2, -2, -1, 0, 1, 2, 3},
|
41 |
+
x=2,
|
42 |
+
function=lambda x: x**3,
|
43 |
+
result=8),
|
44 |
+
dict(
|
45 |
+
in_value_set={-2, -2, -1, 0, 1, 2, 3},
|
46 |
+
x=-2,
|
47 |
+
function=lambda x: x,
|
48 |
+
result=-2),
|
49 |
+
dict(
|
50 |
+
in_value_set={-2, -2, -1, 0, 1, 2, 3},
|
51 |
+
x=-2,
|
52 |
+
function=lambda x: x**2,
|
53 |
+
result=4),
|
54 |
+
dict(
|
55 |
+
in_value_set={-2, -2, -1, 0, 1, 2, 3},
|
56 |
+
x=-2,
|
57 |
+
function=lambda x: x**3,
|
58 |
+
result=-8),
|
59 |
+
])
|
60 |
+
def test_map_numerical_mlp_produces_expected_outcome(self, in_value_set, x,
|
61 |
+
function, result):
|
62 |
+
|
63 |
+
input_dir = bases.BasisDirection("input")
|
64 |
+
output_dir = bases.BasisDirection("output")
|
65 |
+
one_dir = bases.BasisDirection("one")
|
66 |
+
input_space = bases.VectorSpaceWithBasis([input_dir])
|
67 |
+
output_space = bases.VectorSpaceWithBasis([output_dir])
|
68 |
+
one_space = bases.VectorSpaceWithBasis([one_dir])
|
69 |
+
|
70 |
+
mlp = numerical_mlp.map_numerical_mlp(
|
71 |
+
f=function,
|
72 |
+
input_space=input_space,
|
73 |
+
output_space=output_space,
|
74 |
+
one_space=one_space,
|
75 |
+
input_value_set=in_value_set,
|
76 |
+
)
|
77 |
+
|
78 |
+
test_inputs = bases.VectorInBasis(
|
79 |
+
basis_directions=[input_dir, output_dir, one_dir],
|
80 |
+
magnitudes=np.array([x, 0, 1]))
|
81 |
+
|
82 |
+
expected_results = bases.VectorInBasis(
|
83 |
+
basis_directions=[input_dir, output_dir, one_dir],
|
84 |
+
magnitudes=np.array([0, result, 0]))
|
85 |
+
|
86 |
+
test_outputs = mlp.apply(test_inputs)
|
87 |
+
|
88 |
+
self.assertVectorAllClose(test_outputs, expected_results)
|
89 |
+
|
90 |
+
@parameterized.parameters([
|
91 |
+
dict(in_value_set={0, 1, 2, 3}, x=1, function=lambda x: 1 / x, result=1),
|
92 |
+
dict(
|
93 |
+
in_value_set={0, 1, 2, 3}, x=2, function=lambda x: 1 / x, result=0.5),
|
94 |
+
dict(
|
95 |
+
in_value_set={0, 1, 2, 3},
|
96 |
+
x=3,
|
97 |
+
function=lambda x: 1 / x,
|
98 |
+
result=1 / 3),
|
99 |
+
])
|
100 |
+
def test_map_numerical_mlp_logs_warning_and_produces_expected_outcome(
|
101 |
+
self, in_value_set, x, function, result):
|
102 |
+
|
103 |
+
input_dir = bases.BasisDirection("input")
|
104 |
+
output_dir = bases.BasisDirection("output")
|
105 |
+
one_dir = bases.BasisDirection("one")
|
106 |
+
input_space = bases.VectorSpaceWithBasis([input_dir])
|
107 |
+
output_space = bases.VectorSpaceWithBasis([output_dir])
|
108 |
+
one_space = bases.VectorSpaceWithBasis([one_dir])
|
109 |
+
|
110 |
+
with self.assertLogs(level="WARNING"):
|
111 |
+
mlp = numerical_mlp.map_numerical_mlp(
|
112 |
+
f=function,
|
113 |
+
input_space=input_space,
|
114 |
+
output_space=output_space,
|
115 |
+
one_space=one_space,
|
116 |
+
input_value_set=in_value_set,
|
117 |
+
)
|
118 |
+
|
119 |
+
test_inputs = bases.VectorInBasis(
|
120 |
+
basis_directions=[input_dir, output_dir, one_dir],
|
121 |
+
magnitudes=np.array([x, 0, 1]))
|
122 |
+
|
123 |
+
expected_results = bases.VectorInBasis(
|
124 |
+
basis_directions=[input_dir, output_dir, one_dir],
|
125 |
+
magnitudes=np.array([0, result, 0]))
|
126 |
+
|
127 |
+
test_outputs = mlp.apply(test_inputs)
|
128 |
+
|
129 |
+
self.assertVectorAllClose(test_outputs, expected_results)
|
130 |
+
|
131 |
+
@parameterized.parameters([
|
132 |
+
dict(in_value_set={0, 1, 2, 3}, x=1, function=lambda x: 1 / x, result=1),
|
133 |
+
dict(
|
134 |
+
in_value_set={0, 1, 2, 3}, x=2, function=lambda x: 1 / x, result=0.5),
|
135 |
+
dict(
|
136 |
+
in_value_set={0, 1, 2, 3},
|
137 |
+
x=3,
|
138 |
+
function=lambda x: 1 / x,
|
139 |
+
result=1 / 3),
|
140 |
+
])
|
141 |
+
def test_map_numerical_to_categorical_mlp_logs_warning_and_produces_expected_outcome(
|
142 |
+
self, in_value_set, x, function, result):
|
143 |
+
|
144 |
+
f_ign = errors.ignoring_arithmetic_errors(function)
|
145 |
+
out_value_set = {f_ign(x) for x in in_value_set if f_ign(x) is not None}
|
146 |
+
|
147 |
+
in_space = bases.VectorSpaceWithBasis.from_names(["input"])
|
148 |
+
out_space = bases.VectorSpaceWithBasis.from_values("output", out_value_set)
|
149 |
+
one_space = bases.VectorSpaceWithBasis.from_names(["one"])
|
150 |
+
|
151 |
+
residual_space = bases.join_vector_spaces(in_space, one_space, out_space)
|
152 |
+
in_vec = residual_space.vector_from_basis_direction(in_space.basis[0])
|
153 |
+
one_vec = residual_space.vector_from_basis_direction(one_space.basis[0])
|
154 |
+
|
155 |
+
with self.assertLogs(level="WARNING"):
|
156 |
+
mlp = numerical_mlp.map_numerical_to_categorical_mlp(
|
157 |
+
f=function,
|
158 |
+
input_space=in_space,
|
159 |
+
output_space=out_space,
|
160 |
+
input_value_set=in_value_set,
|
161 |
+
one_space=one_space,
|
162 |
+
)
|
163 |
+
|
164 |
+
test_inputs = x * in_vec + one_vec
|
165 |
+
expected_results = out_space.vector_from_basis_direction(
|
166 |
+
bases.BasisDirection("output", result))
|
167 |
+
test_outputs = mlp.apply(test_inputs).project(out_space)
|
168 |
+
self.assertVectorAllClose(test_outputs, expected_results)
|
169 |
+
|
170 |
+
@parameterized.parameters([
|
171 |
+
dict(x_factor=1, y_factor=2, x=1, y=1, result=3),
|
172 |
+
dict(x_factor=1, y_factor=2, x=1, y=-1, result=-1),
|
173 |
+
dict(x_factor=1, y_factor=-1, x=1, y=1, result=0),
|
174 |
+
dict(x_factor=1, y_factor=1, x=3, y=5, result=8),
|
175 |
+
dict(x_factor=-2, y_factor=-0.5, x=4, y=1, result=-8.5),
|
176 |
+
])
|
177 |
+
def test_linear_sequence_map_produces_expected_result(self, x_factor,
|
178 |
+
y_factor, x, y, result):
|
179 |
+
|
180 |
+
input1_dir = bases.BasisDirection("input1")
|
181 |
+
input2_dir = bases.BasisDirection("input2")
|
182 |
+
output_dir = bases.BasisDirection("output")
|
183 |
+
|
184 |
+
mlp = numerical_mlp.linear_sequence_map_numerical_mlp(
|
185 |
+
input1_basis_direction=input1_dir,
|
186 |
+
input2_basis_direction=input2_dir,
|
187 |
+
output_basis_direction=output_dir,
|
188 |
+
input1_factor=x_factor,
|
189 |
+
input2_factor=y_factor)
|
190 |
+
|
191 |
+
test_inputs = bases.VectorInBasis(
|
192 |
+
basis_directions=[input1_dir, input2_dir, output_dir],
|
193 |
+
magnitudes=np.array([x, y, 0]))
|
194 |
+
|
195 |
+
expected_results = bases.VectorInBasis(
|
196 |
+
basis_directions=[input1_dir, input2_dir, output_dir],
|
197 |
+
magnitudes=np.array([0, 0, result]))
|
198 |
+
|
199 |
+
test_outputs = mlp.apply(test_inputs)
|
200 |
+
|
201 |
+
self.assertVectorAllClose(test_outputs, expected_results)
|
202 |
+
|
203 |
+
@parameterized.parameters([
|
204 |
+
dict(x_factor=1, y_factor=2, x=1, result=3),
|
205 |
+
dict(x_factor=1, y_factor=-1, x=1, result=0),
|
206 |
+
])
|
207 |
+
def test_linear_sequence_map_produces_expected_result_with_same_inputs(
|
208 |
+
self, x_factor, y_factor, x, result):
|
209 |
+
|
210 |
+
input_dir = bases.BasisDirection("input")
|
211 |
+
output_dir = bases.BasisDirection("output")
|
212 |
+
|
213 |
+
mlp = numerical_mlp.linear_sequence_map_numerical_mlp(
|
214 |
+
input1_basis_direction=input_dir,
|
215 |
+
input2_basis_direction=input_dir,
|
216 |
+
output_basis_direction=output_dir,
|
217 |
+
input1_factor=x_factor,
|
218 |
+
input2_factor=y_factor)
|
219 |
+
|
220 |
+
test_inputs = bases.VectorInBasis(
|
221 |
+
basis_directions=[input_dir, output_dir], magnitudes=np.array([x, 0]))
|
222 |
+
|
223 |
+
expected_results = bases.VectorInBasis(
|
224 |
+
basis_directions=[input_dir, output_dir],
|
225 |
+
magnitudes=np.array([0, result]))
|
226 |
+
|
227 |
+
test_outputs = mlp.apply(test_inputs)
|
228 |
+
|
229 |
+
self.assertVectorAllClose(test_outputs, expected_results)
|
230 |
+
|
231 |
+
|
232 |
+
if __name__ == "__main__":
|
233 |
+
absltest.main()
|
craft/chamber/selector_width.py
ADDED
@@ -0,0 +1,144 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2022 DeepMind Technologies Limited. All Rights Reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
# ==============================================================================
|
15 |
+
"""SelectorWidth component consisting of an attention head and an MLP."""
|
16 |
+
|
17 |
+
from typing import Iterable
|
18 |
+
from tracr.craft import bases
|
19 |
+
from tracr.craft import transformers
|
20 |
+
from tracr.craft import vectorspace_fns
|
21 |
+
from tracr.craft.chamber import categorical_attn
|
22 |
+
from tracr.craft.chamber import numerical_mlp
|
23 |
+
|
24 |
+
|
25 |
+
def selector_width(
|
26 |
+
query_space: bases.VectorSpaceWithBasis,
|
27 |
+
key_space: bases.VectorSpaceWithBasis,
|
28 |
+
output_space: bases.VectorSpaceWithBasis,
|
29 |
+
bos_space: bases.VectorSpaceWithBasis,
|
30 |
+
one_space: bases.VectorSpaceWithBasis,
|
31 |
+
attn_fn: categorical_attn.QueryKeyToAttnLogit,
|
32 |
+
out_value_set: Iterable[float],
|
33 |
+
categorical_output: bool,
|
34 |
+
causal: bool = False,
|
35 |
+
softmax_coldness: float = 100.,
|
36 |
+
mlp_large_number: float = 100.,
|
37 |
+
label: str = "",
|
38 |
+
) -> transformers.SeriesWithResiduals:
|
39 |
+
"""Returns a craft block implementing RASP's SelectorWidth primitive.
|
40 |
+
|
41 |
+
The block consists of one attention head and one MLP.
|
42 |
+
|
43 |
+
The attention head implements the attention pattern (attn_fn or key=bos) and
|
44 |
+
aggregates the bos dimension over this pattern. The output of this will be
|
45 |
+
1/(d+1) in every position, where d is the "width" of the attention pattern,
|
46 |
+
i.e. the number of 1s in a row.
|
47 |
+
|
48 |
+
The MLP then computes d from the previous output in all positions except for
|
49 |
+
the first BOS position. In the BOS position the MLP removes the output of the
|
50 |
+
attention head, to ensure it only contains the encoding of the BOS token
|
51 |
+
which is expected by all other model components.
|
52 |
+
|
53 |
+
Args:
|
54 |
+
query_space: Vector space containing (categorical) query input.
|
55 |
+
key_space: Vector space containing (categorical) key input.
|
56 |
+
output_space: Vector space which will contain (numerical or categorical)
|
57 |
+
output.
|
58 |
+
bos_space: 1-d space used to identify the beginning of sequence token.
|
59 |
+
one_space: Auxiliary 1-d vector space that must contain 1 in the input.
|
60 |
+
attn_fn: A selector function f(query, key) operating on the query/key basis
|
61 |
+
directions that defines the attention pattern to compute the width of.
|
62 |
+
out_value_set: Set of possible output values of this SelectorWidth.
|
63 |
+
categorical_output: If True, encode the output as a categorical variable.
|
64 |
+
causal: If True, use masked attention.
|
65 |
+
softmax_coldness: The inverse temperature of the softmax. Default value is
|
66 |
+
high which makes the attention close to a hard maximum.
|
67 |
+
mlp_large_number: A larger number makes the MLP more accurate.
|
68 |
+
label: A name for this block, used to label auxiliary dimensions.
|
69 |
+
"""
|
70 |
+
assert output_space.num_dims == 1 or categorical_output
|
71 |
+
|
72 |
+
attn_out_dir = bases.BasisDirection(f"{label}_selector_width_attn_output")
|
73 |
+
attn_out_space = bases.VectorSpaceWithBasis([attn_out_dir])
|
74 |
+
attn_out_vec = attn_out_space.vector_from_basis_direction(attn_out_dir)
|
75 |
+
|
76 |
+
attn = categorical_attn.categorical_attn(
|
77 |
+
query_space=query_space,
|
78 |
+
key_space=key_space,
|
79 |
+
value_space=bos_space,
|
80 |
+
output_space=attn_out_space,
|
81 |
+
bos_space=bos_space,
|
82 |
+
one_space=one_space,
|
83 |
+
attn_fn=attn_fn,
|
84 |
+
default_output=attn_out_space.null_vector(),
|
85 |
+
causal=causal,
|
86 |
+
always_attend_to_bos=True,
|
87 |
+
use_bos_for_default_output=False,
|
88 |
+
softmax_coldness=softmax_coldness)
|
89 |
+
|
90 |
+
fun = lambda x: (1 / x) - 1
|
91 |
+
in_value_set = {1 / (x + 1) for x in out_value_set}
|
92 |
+
if categorical_output:
|
93 |
+
mlp = numerical_mlp.map_numerical_to_categorical_mlp(
|
94 |
+
f=fun,
|
95 |
+
input_space=attn_out_space,
|
96 |
+
output_space=output_space,
|
97 |
+
input_value_set=in_value_set,
|
98 |
+
one_space=one_space,
|
99 |
+
hidden_name=f"_hidden_{label}_",
|
100 |
+
large_number=mlp_large_number)
|
101 |
+
else:
|
102 |
+
mlp = numerical_mlp.map_numerical_mlp(
|
103 |
+
f=fun,
|
104 |
+
input_space=attn_out_space,
|
105 |
+
output_space=output_space,
|
106 |
+
input_value_set=in_value_set,
|
107 |
+
one_space=one_space,
|
108 |
+
hidden_name=f"_hidden_{label}_",
|
109 |
+
large_number=mlp_large_number)
|
110 |
+
|
111 |
+
# This implementation of selector width writes at each position including
|
112 |
+
# the BOS. To ensure that the BOS token position does not contain
|
113 |
+
# additional values, we add an mlp to subtract the output of both layers.
|
114 |
+
clean_bos_out_space = bases.join_vector_spaces(attn_out_space, output_space)
|
115 |
+
vec_to_subtract_from_bos = attn_out_vec.project(clean_bos_out_space)
|
116 |
+
|
117 |
+
if categorical_output:
|
118 |
+
# Add the one-hot encoding of the zero value to the vector
|
119 |
+
# which will get scrubbed from the BOS position.
|
120 |
+
zero_dir = [d for d in output_space.basis if d.value == 0][0]
|
121 |
+
zero_vec = clean_bos_out_space.vector_from_basis_direction(zero_dir)
|
122 |
+
vec_to_subtract_from_bos += zero_vec
|
123 |
+
|
124 |
+
# Construct an MLP that subtracts vec_to_subtract_from_bos * bos
|
125 |
+
# from the residual stream which is vec_to_subtract_from_bos in the
|
126 |
+
# bos position and 0 else. vec_to_subtract_from_bos contains what the
|
127 |
+
# attention head writes to the bos position.
|
128 |
+
|
129 |
+
hidden_dir = bases.BasisDirection("_hidden_clean_bos_")
|
130 |
+
hidden_space = bases.VectorSpaceWithBasis([hidden_dir])
|
131 |
+
hidden_vec = hidden_space.vector_from_basis_direction(hidden_dir)
|
132 |
+
|
133 |
+
# It's okay to use the local variables because they are only used within
|
134 |
+
# the same loop iteration to create the MLP.
|
135 |
+
# pylint: disable=cell-var-from-loop
|
136 |
+
first_layer = vectorspace_fns.Linear.from_action(bos_space, hidden_space,
|
137 |
+
lambda x: hidden_vec)
|
138 |
+
second_layer = vectorspace_fns.Linear.from_action(
|
139 |
+
hidden_space, clean_bos_out_space, lambda x: -vec_to_subtract_from_bos)
|
140 |
+
# pylint: enable=cell-var-from-loop
|
141 |
+
clean_bos_mlp = transformers.MLP(first_layer, second_layer)
|
142 |
+
|
143 |
+
mlp = transformers.MLP.combine_in_parallel([mlp, clean_bos_mlp])
|
144 |
+
return transformers.SeriesWithResiduals([attn, mlp])
|
craft/chamber/selector_width_test.py
ADDED
@@ -0,0 +1,155 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2022 DeepMind Technologies Limited. All Rights Reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
# ==============================================================================
|
15 |
+
"""Tests for selector_width."""
|
16 |
+
|
17 |
+
from absl.testing import absltest
|
18 |
+
from absl.testing import parameterized
|
19 |
+
from tracr.craft import bases
|
20 |
+
from tracr.craft import tests_common
|
21 |
+
from tracr.craft.chamber import selector_width
|
22 |
+
|
23 |
+
|
24 |
+
class SelectorWidthTest(tests_common.VectorFnTestCase):
|
25 |
+
|
26 |
+
@parameterized.product(
|
27 |
+
causal=[False, True],
|
28 |
+
categorical_output=[False, True],
|
29 |
+
input_seq=[[1, 2, 3, 4, 5], [-1, 0, 1], [10]])
|
30 |
+
def test_selector_width_of_select_all_is_length(self, causal,
|
31 |
+
categorical_output,
|
32 |
+
input_seq):
|
33 |
+
vocab = range(-20, 20)
|
34 |
+
input_space = bases.VectorSpaceWithBasis.from_values("input", vocab)
|
35 |
+
|
36 |
+
if categorical_output:
|
37 |
+
output_space = bases.VectorSpaceWithBasis.from_values("output", range(10))
|
38 |
+
else:
|
39 |
+
output_space = bases.VectorSpaceWithBasis(
|
40 |
+
[bases.BasisDirection("output")])
|
41 |
+
|
42 |
+
bos_dir = bases.BasisDirection("bos_dimension")
|
43 |
+
bos_space = bases.VectorSpaceWithBasis([bos_dir])
|
44 |
+
|
45 |
+
one_dir = bases.BasisDirection("one_dimension")
|
46 |
+
one_space = bases.VectorSpaceWithBasis([one_dir])
|
47 |
+
|
48 |
+
input_space = bases.join_vector_spaces(input_space, bos_space, one_space)
|
49 |
+
residual_space = bases.join_vector_spaces(input_space, output_space)
|
50 |
+
bos_vec = residual_space.vector_from_basis_direction(bos_dir)
|
51 |
+
one_vec = residual_space.vector_from_basis_direction(one_dir)
|
52 |
+
|
53 |
+
block = selector_width.selector_width(
|
54 |
+
query_space=input_space,
|
55 |
+
key_space=input_space,
|
56 |
+
output_space=output_space,
|
57 |
+
bos_space=bos_space,
|
58 |
+
one_space=one_space,
|
59 |
+
attn_fn=lambda x, y: True,
|
60 |
+
out_value_set=set(range(len(input_seq) + 1)),
|
61 |
+
categorical_output=categorical_output,
|
62 |
+
causal=causal,
|
63 |
+
label="select_all")
|
64 |
+
|
65 |
+
test_inputs = [bos_vec + one_vec]
|
66 |
+
for x in input_seq:
|
67 |
+
test_inputs.append(
|
68 |
+
residual_space.vector_from_basis_direction(
|
69 |
+
bases.BasisDirection("input", x)) + one_vec)
|
70 |
+
test_inputs = bases.VectorInBasis.stack(test_inputs)
|
71 |
+
|
72 |
+
# Expect length of the input sequence
|
73 |
+
if causal:
|
74 |
+
expected_results = list(range(1, len(input_seq) + 1))
|
75 |
+
else:
|
76 |
+
expected_results = [len(input_seq) for _ in input_seq]
|
77 |
+
|
78 |
+
if categorical_output:
|
79 |
+
expected_results = [
|
80 |
+
output_space.vector_from_basis_direction(
|
81 |
+
bases.BasisDirection("output", x)) for x in expected_results
|
82 |
+
]
|
83 |
+
else:
|
84 |
+
output_vec = output_space.vector_from_basis_direction(
|
85 |
+
bases.BasisDirection("output"))
|
86 |
+
expected_results = [x * output_vec for x in expected_results]
|
87 |
+
|
88 |
+
expected_results = bases.VectorInBasis.stack(expected_results)
|
89 |
+
|
90 |
+
test_outputs = block.apply(test_inputs).project(output_space)
|
91 |
+
self.assertVectorAllClose(
|
92 |
+
tests_common.strip_bos_token(test_outputs), expected_results)
|
93 |
+
|
94 |
+
@parameterized.product(
|
95 |
+
causal=[False, True],
|
96 |
+
categorical_output=[False, True],
|
97 |
+
input_seq=[[1, 2, 3, 4, 5], [-1, 0, 1], [10]])
|
98 |
+
def test_selector_width_of_select_none_is_zero(self, causal,
|
99 |
+
categorical_output, input_seq):
|
100 |
+
vocab = range(-20, 20)
|
101 |
+
input_space = bases.VectorSpaceWithBasis.from_values("input", vocab)
|
102 |
+
|
103 |
+
if categorical_output:
|
104 |
+
output_space = bases.VectorSpaceWithBasis.from_values("output", range(10))
|
105 |
+
else:
|
106 |
+
output_space = bases.VectorSpaceWithBasis(
|
107 |
+
[bases.BasisDirection("output")])
|
108 |
+
|
109 |
+
bos_dir = bases.BasisDirection("bos_dimension")
|
110 |
+
bos_space = bases.VectorSpaceWithBasis([bos_dir])
|
111 |
+
|
112 |
+
one_dir = bases.BasisDirection("one_dimension")
|
113 |
+
one_space = bases.VectorSpaceWithBasis([one_dir])
|
114 |
+
|
115 |
+
input_space = bases.join_vector_spaces(input_space, bos_space, one_space)
|
116 |
+
residual_space = bases.join_vector_spaces(input_space, output_space)
|
117 |
+
bos_vec = residual_space.vector_from_basis_direction(bos_dir)
|
118 |
+
one_vec = residual_space.vector_from_basis_direction(one_dir)
|
119 |
+
|
120 |
+
block = selector_width.selector_width(
|
121 |
+
query_space=input_space,
|
122 |
+
key_space=input_space,
|
123 |
+
output_space=output_space,
|
124 |
+
bos_space=bos_space,
|
125 |
+
one_space=one_space,
|
126 |
+
attn_fn=lambda x, y: False,
|
127 |
+
out_value_set=set(range(len(input_seq) + 1)),
|
128 |
+
categorical_output=categorical_output,
|
129 |
+
causal=causal,
|
130 |
+
label="select_all")
|
131 |
+
|
132 |
+
test_inputs = [bos_vec + one_vec]
|
133 |
+
for x in input_seq:
|
134 |
+
test_inputs.append(
|
135 |
+
residual_space.vector_from_basis_direction(
|
136 |
+
bases.BasisDirection("input", x)) + one_vec)
|
137 |
+
test_inputs = bases.VectorInBasis.stack(test_inputs)
|
138 |
+
|
139 |
+
# Expect zero output
|
140 |
+
if categorical_output:
|
141 |
+
expected_results = [
|
142 |
+
output_space.vector_from_basis_direction(
|
143 |
+
bases.BasisDirection("output", 0)) for _ in input_seq
|
144 |
+
]
|
145 |
+
else:
|
146 |
+
expected_results = [output_space.null_vector() for _ in input_seq]
|
147 |
+
expected_results = bases.VectorInBasis.stack(expected_results)
|
148 |
+
|
149 |
+
test_outputs = block.apply(test_inputs).project(output_space)
|
150 |
+
self.assertVectorAllClose(
|
151 |
+
tests_common.strip_bos_token(test_outputs), expected_results)
|
152 |
+
|
153 |
+
|
154 |
+
if __name__ == "__main__":
|
155 |
+
absltest.main()
|
craft/tests_common.py
ADDED
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2022 DeepMind Technologies Limited. All Rights Reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
# ==============================================================================
|
15 |
+
"""Helper functions for tests."""
|
16 |
+
|
17 |
+
from absl.testing import parameterized
|
18 |
+
import numpy as np
|
19 |
+
from tracr.craft import bases
|
20 |
+
|
21 |
+
|
22 |
+
def strip_bos_token(vector: bases.VectorInBasis) -> bases.VectorInBasis:
|
23 |
+
"""Removes BOS token of a vector."""
|
24 |
+
return bases.VectorInBasis(vector.basis_directions, vector.magnitudes[1:])
|
25 |
+
|
26 |
+
|
27 |
+
class VectorFnTestCase(parameterized.TestCase):
|
28 |
+
"""Asserts for vectors."""
|
29 |
+
|
30 |
+
def assertVectorAllClose(self, v1: bases.VectorInBasis,
|
31 |
+
v2: bases.VectorInBasis):
|
32 |
+
self.assertEqual(v1.basis_directions, v2.basis_directions)
|
33 |
+
np.testing.assert_allclose(v1.magnitudes, v2.magnitudes, atol=1e-7)
|
craft/transformers.py
ADDED
@@ -0,0 +1,197 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2022 DeepMind Technologies Limited. All Rights Reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
# ==============================================================================
|
15 |
+
"""Pieces for making transformers."""
|
16 |
+
|
17 |
+
import abc
|
18 |
+
import dataclasses
|
19 |
+
from typing import Iterable, Optional, Sequence, Union
|
20 |
+
|
21 |
+
import numpy as np
|
22 |
+
|
23 |
+
from tracr.craft import bases
|
24 |
+
from tracr.craft import vectorspace_fns
|
25 |
+
|
26 |
+
project = vectorspace_fns.project
|
27 |
+
|
28 |
+
|
29 |
+
def _np_softmax(x, axis=-1):
|
30 |
+
x_max = np.max(x, axis=axis, keepdims=True)
|
31 |
+
return np.exp(x - x_max) / np.sum(np.exp(x - x_max), axis=axis, keepdims=True)
|
32 |
+
|
33 |
+
|
34 |
+
def _np_relu(x):
|
35 |
+
return np.where(x > 0, x, 0)
|
36 |
+
|
37 |
+
|
38 |
+
def relu(x: bases.VectorInBasis) -> bases.VectorInBasis:
|
39 |
+
return bases.VectorInBasis(x.basis_directions, _np_relu(x.magnitudes))
|
40 |
+
|
41 |
+
|
42 |
+
class Block(abc.ABC):
|
43 |
+
"""Transformer block, acting on a sequence of vector space elements.
|
44 |
+
|
45 |
+
Attributes:
|
46 |
+
residual_space: Vector space that contains all subspaces the Block interacts
|
47 |
+
with. This can be either the full residual space of a model or a subspace.
|
48 |
+
"""
|
49 |
+
residual_space: bases.VectorSpaceWithBasis
|
50 |
+
|
51 |
+
@abc.abstractmethod
|
52 |
+
def apply(self, x: bases.VectorInBasis) -> bases.VectorInBasis:
|
53 |
+
"""Applies self to an input."""
|
54 |
+
|
55 |
+
|
56 |
+
@dataclasses.dataclass
|
57 |
+
class AttentionHead(Block):
|
58 |
+
"""A transformer attention head."""
|
59 |
+
w_qk: vectorspace_fns.ScalarBilinear
|
60 |
+
w_ov: vectorspace_fns.Linear
|
61 |
+
residual_space: Optional[bases.VectorSpaceWithBasis] = None
|
62 |
+
causal: bool = False
|
63 |
+
|
64 |
+
def __post_init__(self):
|
65 |
+
"""Infer residual stream and typecheck subspaces."""
|
66 |
+
if self.residual_space is None:
|
67 |
+
self.residual_space = bases.join_vector_spaces(self.w_qk.left_space,
|
68 |
+
self.w_qk.right_space,
|
69 |
+
self.w_ov.input_space,
|
70 |
+
self.w_ov.output_space)
|
71 |
+
|
72 |
+
assert self.w_qk.left_space.issubspace(self.residual_space)
|
73 |
+
assert self.w_qk.right_space.issubspace(self.residual_space)
|
74 |
+
assert self.w_ov.input_space.issubspace(self.residual_space)
|
75 |
+
assert self.w_ov.output_space.issubspace(self.residual_space)
|
76 |
+
|
77 |
+
def apply(self, x: bases.VectorInBasis) -> bases.VectorInBasis:
|
78 |
+
assert x in self.residual_space
|
79 |
+
# seq_len x query_space
|
80 |
+
queries = x.project(self.w_qk.left_space)
|
81 |
+
# seq_len x key_space
|
82 |
+
keys = x.project(self.w_qk.right_space)
|
83 |
+
|
84 |
+
attn_matrix = queries.magnitudes @ self.w_qk.matrix @ keys.magnitudes.T
|
85 |
+
|
86 |
+
if self.causal:
|
87 |
+
# The 1 gives us the matrix above the diagonal.
|
88 |
+
mask = np.triu(np.full_like(attn_matrix, -np.inf), 1)
|
89 |
+
attn_matrix = attn_matrix + mask
|
90 |
+
|
91 |
+
attn_weights = _np_softmax(attn_matrix) # seq_len_from, seq_len_to
|
92 |
+
values = self.w_ov_residual(x).magnitudes # seq_len_to, d_model
|
93 |
+
|
94 |
+
magnitudes = attn_weights @ values # seq_len_from, d_model
|
95 |
+
return bases.VectorInBasis(sorted(self.residual_space.basis), magnitudes)
|
96 |
+
|
97 |
+
def w_ov_residual(self, x: bases.VectorInBasis) -> bases.VectorInBasis:
|
98 |
+
"""Wov but acting on the residual space."""
|
99 |
+
x = project(self.residual_space, self.w_ov.input_space)(x)
|
100 |
+
out = self.w_ov(x)
|
101 |
+
return project(self.w_ov.output_space, self.residual_space)(out)
|
102 |
+
|
103 |
+
@property
|
104 |
+
def num_heads(self) -> int:
|
105 |
+
return 1
|
106 |
+
|
107 |
+
def as_multi(self) -> "MultiAttentionHead":
|
108 |
+
return MultiAttentionHead([self])
|
109 |
+
|
110 |
+
|
111 |
+
@dataclasses.dataclass
|
112 |
+
class MultiAttentionHead(Block):
|
113 |
+
"""Applies attention heads in parallel."""
|
114 |
+
sub_blocks: list[Union[AttentionHead, "MultiAttentionHead"]]
|
115 |
+
|
116 |
+
def __post_init__(self):
|
117 |
+
spaces = [block.residual_space for block in self.sub_blocks]
|
118 |
+
self.residual_space, *others = spaces
|
119 |
+
assert all(s == self.residual_space for s in others)
|
120 |
+
|
121 |
+
def apply(self, x: bases.VectorInBasis) -> bases.VectorInBasis:
|
122 |
+
# each element is seq_len x embedding
|
123 |
+
outs = [block.apply(x) for block in self.sub_blocks]
|
124 |
+
return bases.VectorInBasis.sum(outs) # seq_len x embedding
|
125 |
+
|
126 |
+
@property
|
127 |
+
def num_heads(self) -> int:
|
128 |
+
return sum(sub_block.num_heads for sub_block in self.sub_blocks)
|
129 |
+
|
130 |
+
def heads(self) -> Iterable[AttentionHead]:
|
131 |
+
for sub_block in self.sub_blocks:
|
132 |
+
if isinstance(sub_block, AttentionHead):
|
133 |
+
yield sub_block
|
134 |
+
elif isinstance(sub_block, MultiAttentionHead):
|
135 |
+
yield from sub_block.heads()
|
136 |
+
else:
|
137 |
+
raise NotImplementedError()
|
138 |
+
|
139 |
+
def as_multi(self) -> "MultiAttentionHead":
|
140 |
+
return self
|
141 |
+
|
142 |
+
|
143 |
+
@dataclasses.dataclass
|
144 |
+
class MLP(Block):
|
145 |
+
"""A transformer MLP block."""
|
146 |
+
fst: vectorspace_fns.Linear
|
147 |
+
snd: vectorspace_fns.Linear
|
148 |
+
residual_space: Optional[bases.VectorSpaceWithBasis] = None
|
149 |
+
|
150 |
+
def __post_init__(self):
|
151 |
+
"""Typecheck subspaces."""
|
152 |
+
if self.residual_space is None:
|
153 |
+
self.residual_space = bases.join_vector_spaces(self.fst.input_space,
|
154 |
+
self.snd.output_space)
|
155 |
+
|
156 |
+
assert self.fst.output_space == self.snd.input_space
|
157 |
+
assert self.fst.input_space.issubspace(self.residual_space)
|
158 |
+
assert self.snd.output_space.issubspace(self.residual_space)
|
159 |
+
|
160 |
+
def apply(self, x: bases.VectorInBasis) -> bases.VectorInBasis:
|
161 |
+
assert x in self.residual_space
|
162 |
+
|
163 |
+
x = project(self.residual_space, self.fst.input_space)(x)
|
164 |
+
hidden = self.fst(x)
|
165 |
+
hidden = relu(hidden)
|
166 |
+
out = self.snd(hidden)
|
167 |
+
return project(self.snd.output_space, self.residual_space)(out)
|
168 |
+
|
169 |
+
@classmethod
|
170 |
+
def combine_in_parallel(cls, mlps: Sequence["MLP"]) -> "MLP":
|
171 |
+
fst = vectorspace_fns.Linear.combine_in_parallel(
|
172 |
+
[block.fst for block in mlps])
|
173 |
+
snd = vectorspace_fns.Linear.combine_in_parallel(
|
174 |
+
[block.snd for block in mlps])
|
175 |
+
return cls(fst=fst, snd=snd, residual_space=None)
|
176 |
+
|
177 |
+
|
178 |
+
# Block that fits into a half-layer, without residual connections.
|
179 |
+
HalfLayerBlock = Union[MLP, AttentionHead, MultiAttentionHead]
|
180 |
+
|
181 |
+
|
182 |
+
@dataclasses.dataclass
|
183 |
+
class SeriesWithResiduals(Block):
|
184 |
+
"""A series of blocks with residual connections."""
|
185 |
+
blocks: list[HalfLayerBlock]
|
186 |
+
|
187 |
+
def __post_init__(self):
|
188 |
+
spaces = [block.residual_space for block in self.blocks]
|
189 |
+
self.residual_space = bases.join_vector_spaces(*spaces)
|
190 |
+
|
191 |
+
def apply(self, x: bases.VectorInBasis) -> bases.VectorInBasis:
|
192 |
+
x = x.project(self.residual_space)
|
193 |
+
for block in self.blocks:
|
194 |
+
x_in = x.project(block.residual_space)
|
195 |
+
x_out = block.apply(x_in).project(self.residual_space)
|
196 |
+
x = x + x_out
|
197 |
+
return x
|
craft/transformers_test.py
ADDED
@@ -0,0 +1,160 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2022 DeepMind Technologies Limited. All Rights Reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
# ==============================================================================
|
15 |
+
"""Tests for transformers."""
|
16 |
+
|
17 |
+
from absl.testing import absltest
|
18 |
+
from absl.testing import parameterized
|
19 |
+
import numpy as np
|
20 |
+
from tracr.craft import bases
|
21 |
+
from tracr.craft import tests_common
|
22 |
+
from tracr.craft import transformers
|
23 |
+
from tracr.craft import vectorspace_fns as vs_fns
|
24 |
+
|
25 |
+
# This makes it easier to use comments to annotate dimensions in arrays
|
26 |
+
# pylint: disable=g-no-space-after-comment
|
27 |
+
|
28 |
+
|
29 |
+
class AttentionHeadTest(tests_common.VectorFnTestCase):
|
30 |
+
|
31 |
+
@parameterized.parameters([
|
32 |
+
dict(with_residual_stream=False),
|
33 |
+
dict(with_residual_stream=True),
|
34 |
+
])
|
35 |
+
def test_attention_head(self, with_residual_stream):
|
36 |
+
i = bases.VectorSpaceWithBasis.from_values("i", [1, 2])
|
37 |
+
o = bases.VectorSpaceWithBasis.from_values("o", [1, 2])
|
38 |
+
q = bases.VectorSpaceWithBasis.from_values("q", [1, 2])
|
39 |
+
k = bases.VectorSpaceWithBasis.from_values("p", [1, 2])
|
40 |
+
rs = bases.direct_sum(i, o, q, k)
|
41 |
+
|
42 |
+
seq = bases.VectorInBasis(
|
43 |
+
rs.basis,
|
44 |
+
np.array([
|
45 |
+
#i1 i2 o1 o2 q1 q2 p1 p2
|
46 |
+
[1, 0, 0, 0, 1, 0, 1, 0],
|
47 |
+
[0, 1, 0, 0, 0, 1, 0, 1],
|
48 |
+
]))
|
49 |
+
|
50 |
+
head = transformers.AttentionHead(
|
51 |
+
w_qk=vs_fns.ScalarBilinear(q, k,
|
52 |
+
np.eye(2) * 100),
|
53 |
+
w_ov=vs_fns.Linear(i, o, np.eye(2)),
|
54 |
+
residual_space=rs if with_residual_stream else None,
|
55 |
+
causal=False,
|
56 |
+
)
|
57 |
+
|
58 |
+
self.assertVectorAllClose(
|
59 |
+
head.apply(seq),
|
60 |
+
bases.VectorInBasis(
|
61 |
+
rs.basis,
|
62 |
+
np.array([
|
63 |
+
#i1 i2 o1 o2 q1 q2 p1 p2
|
64 |
+
[0, 0, 1, 0, 0, 0, 0, 0],
|
65 |
+
[0, 0, 0, 1, 0, 0, 0, 0],
|
66 |
+
])),
|
67 |
+
)
|
68 |
+
|
69 |
+
|
70 |
+
class MLPTest(tests_common.VectorFnTestCase):
|
71 |
+
|
72 |
+
@parameterized.parameters([
|
73 |
+
dict(with_residual_stream=False, same_in_out=False),
|
74 |
+
dict(with_residual_stream=False, same_in_out=True),
|
75 |
+
dict(with_residual_stream=True, same_in_out=False),
|
76 |
+
dict(with_residual_stream=True, same_in_out=True),
|
77 |
+
])
|
78 |
+
def test_mlp(self, with_residual_stream, same_in_out):
|
79 |
+
i = bases.VectorSpaceWithBasis.from_values("i", [1, 2])
|
80 |
+
if same_in_out:
|
81 |
+
o, rs = i, i
|
82 |
+
expected_result = np.array([
|
83 |
+
#o1 o2
|
84 |
+
[1, 0],
|
85 |
+
[0, 1],
|
86 |
+
])
|
87 |
+
else:
|
88 |
+
o = bases.VectorSpaceWithBasis.from_values("o", [1, 2])
|
89 |
+
rs = bases.direct_sum(i, o)
|
90 |
+
expected_result = np.array([
|
91 |
+
#i1 i2 o1 o2
|
92 |
+
[0, 0, 1, 0],
|
93 |
+
[0, 0, 0, 1],
|
94 |
+
])
|
95 |
+
h = bases.VectorSpaceWithBasis.from_values("p", [1, 2])
|
96 |
+
|
97 |
+
seq = bases.VectorInBasis(
|
98 |
+
i.basis,
|
99 |
+
np.array([
|
100 |
+
#i1 i2
|
101 |
+
[1, -1],
|
102 |
+
[-1, 1],
|
103 |
+
])).project(rs)
|
104 |
+
|
105 |
+
mlp = transformers.MLP(
|
106 |
+
fst=vs_fns.Linear(i, h, np.eye(2)),
|
107 |
+
snd=vs_fns.Linear(h, o, np.eye(2)),
|
108 |
+
residual_space=rs if with_residual_stream else None,
|
109 |
+
)
|
110 |
+
|
111 |
+
self.assertEqual(
|
112 |
+
mlp.apply(seq),
|
113 |
+
bases.VectorInBasis(rs.basis, expected_result),
|
114 |
+
)
|
115 |
+
|
116 |
+
def test_combining_mlps(self):
|
117 |
+
in12 = bases.VectorSpaceWithBasis.from_values("in", [1, 2])
|
118 |
+
in34 = bases.VectorSpaceWithBasis.from_values("in", [3, 4])
|
119 |
+
out12 = bases.VectorSpaceWithBasis.from_values("out", [1, 2])
|
120 |
+
residual_space = bases.join_vector_spaces(in12, in34, out12)
|
121 |
+
|
122 |
+
h1 = bases.VectorSpaceWithBasis.from_values("h", [1])
|
123 |
+
h2 = bases.VectorSpaceWithBasis.from_values("h", [2])
|
124 |
+
|
125 |
+
# MLP1 maps in2 -> h1 -> out1
|
126 |
+
mlp1 = transformers.MLP(
|
127 |
+
fst=vs_fns.Linear(in12, h1, np.array([[0], [1]])),
|
128 |
+
snd=vs_fns.Linear(h1, out12, np.array([[1, 0]])))
|
129 |
+
|
130 |
+
# MLP2 maps in3 -> h2 -> out2
|
131 |
+
mlp2 = transformers.MLP(
|
132 |
+
fst=vs_fns.Linear(in34, h2, np.array([[1], [0]])),
|
133 |
+
snd=vs_fns.Linear(h2, out12, np.array([[0, 1]])))
|
134 |
+
|
135 |
+
mlp = transformers.MLP.combine_in_parallel([mlp1, mlp2])
|
136 |
+
|
137 |
+
seq = bases.VectorInBasis(
|
138 |
+
bases.direct_sum(in12, in34).basis,
|
139 |
+
np.array([
|
140 |
+
#i1 i2 i3 i4
|
141 |
+
[1, 2, 0, 0],
|
142 |
+
[0, 2, 3, 4],
|
143 |
+
])).project(residual_space)
|
144 |
+
|
145 |
+
expected_result = bases.VectorInBasis(
|
146 |
+
out12.basis,
|
147 |
+
np.array([
|
148 |
+
#o1 o2
|
149 |
+
[2, 0],
|
150 |
+
[2, 3],
|
151 |
+
]))
|
152 |
+
|
153 |
+
self.assertEqual(
|
154 |
+
mlp.apply(seq).project(out12),
|
155 |
+
expected_result,
|
156 |
+
)
|
157 |
+
|
158 |
+
|
159 |
+
if __name__ == "__main__":
|
160 |
+
absltest.main()
|
craft/vectorspace_fns.py
ADDED
@@ -0,0 +1,162 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2022 DeepMind Technologies Limited. All Rights Reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
# ==============================================================================
|
15 |
+
"""Functions on vector spaces."""
|
16 |
+
|
17 |
+
import abc
|
18 |
+
import dataclasses
|
19 |
+
from typing import Callable, Sequence
|
20 |
+
|
21 |
+
import numpy as np
|
22 |
+
|
23 |
+
from tracr.craft import bases
|
24 |
+
|
25 |
+
VectorSpaceWithBasis = bases.VectorSpaceWithBasis
|
26 |
+
VectorInBasis = bases.VectorInBasis
|
27 |
+
BasisDirection = bases.BasisDirection
|
28 |
+
|
29 |
+
|
30 |
+
class VectorFunction(abc.ABC):
|
31 |
+
"""A function that acts on vectors."""
|
32 |
+
|
33 |
+
input_space: VectorSpaceWithBasis
|
34 |
+
output_space: VectorSpaceWithBasis
|
35 |
+
|
36 |
+
@abc.abstractmethod
|
37 |
+
def __call__(self, x: VectorInBasis) -> VectorInBasis:
|
38 |
+
"""Evaluates the function."""
|
39 |
+
|
40 |
+
|
41 |
+
class Linear(VectorFunction):
|
42 |
+
"""A linear function."""
|
43 |
+
|
44 |
+
def __init__(
|
45 |
+
self,
|
46 |
+
input_space: VectorSpaceWithBasis,
|
47 |
+
output_space: VectorSpaceWithBasis,
|
48 |
+
matrix: np.ndarray,
|
49 |
+
):
|
50 |
+
"""Initialises.
|
51 |
+
|
52 |
+
Args:
|
53 |
+
input_space: The input vector space.
|
54 |
+
output_space: The output vector space.
|
55 |
+
matrix: a [input, output] matrix acting in a (sorted) basis.
|
56 |
+
"""
|
57 |
+
self.input_space = input_space
|
58 |
+
self.output_space = output_space
|
59 |
+
self.matrix = matrix
|
60 |
+
|
61 |
+
def __post_init__(self) -> None:
|
62 |
+
output_size, input_size = self.matrix.shape
|
63 |
+
assert input_size == self.input_space.num_dims
|
64 |
+
assert output_size == self.output_space.num_dims
|
65 |
+
|
66 |
+
def __call__(self, x: VectorInBasis) -> VectorInBasis:
|
67 |
+
if x not in self.input_space:
|
68 |
+
raise TypeError(f"{x=} not in {self.input_space=}.")
|
69 |
+
return VectorInBasis(
|
70 |
+
basis_directions=sorted(self.output_space.basis),
|
71 |
+
magnitudes=x.magnitudes @ self.matrix,
|
72 |
+
)
|
73 |
+
|
74 |
+
@classmethod
|
75 |
+
def from_action(
|
76 |
+
cls,
|
77 |
+
input_space: VectorSpaceWithBasis,
|
78 |
+
output_space: VectorSpaceWithBasis,
|
79 |
+
action: Callable[[BasisDirection], VectorInBasis],
|
80 |
+
) -> "Linear":
|
81 |
+
"""from_action(i, o)(action) creates a Linear."""
|
82 |
+
|
83 |
+
matrix = np.zeros((input_space.num_dims, output_space.num_dims))
|
84 |
+
for i, direction in enumerate(input_space.basis):
|
85 |
+
out_vector = action(direction)
|
86 |
+
if out_vector not in output_space:
|
87 |
+
raise TypeError(f"image of {direction} from {input_space=} "
|
88 |
+
f"is not in {output_space=}")
|
89 |
+
matrix[i, :] = out_vector.magnitudes
|
90 |
+
|
91 |
+
return Linear(input_space, output_space, matrix)
|
92 |
+
|
93 |
+
@classmethod
|
94 |
+
def combine_in_parallel(cls, fns: Sequence["Linear"]) -> "Linear":
|
95 |
+
"""Combines multiple parallel linear functions into a single one."""
|
96 |
+
joint_input_space = bases.join_vector_spaces(
|
97 |
+
*[fn.input_space for fn in fns])
|
98 |
+
joint_output_space = bases.join_vector_spaces(
|
99 |
+
*[fn.output_space for fn in fns])
|
100 |
+
|
101 |
+
def action(x: bases.BasisDirection) -> bases.VectorInBasis:
|
102 |
+
out = joint_output_space.null_vector()
|
103 |
+
for fn in fns:
|
104 |
+
if x in fn.input_space:
|
105 |
+
x_vec = fn.input_space.vector_from_basis_direction(x)
|
106 |
+
out += fn(x_vec).project(joint_output_space)
|
107 |
+
return out
|
108 |
+
|
109 |
+
return cls.from_action(joint_input_space, joint_output_space, action)
|
110 |
+
|
111 |
+
|
112 |
+
def project(
|
113 |
+
from_space: VectorSpaceWithBasis,
|
114 |
+
to_space: VectorSpaceWithBasis,
|
115 |
+
) -> Linear:
|
116 |
+
"""Creates a projection."""
|
117 |
+
|
118 |
+
def action(direction: bases.BasisDirection) -> VectorInBasis:
|
119 |
+
if direction in to_space:
|
120 |
+
return to_space.vector_from_basis_direction(direction)
|
121 |
+
else:
|
122 |
+
return to_space.null_vector()
|
123 |
+
|
124 |
+
return Linear.from_action(from_space, to_space, action=action)
|
125 |
+
|
126 |
+
|
127 |
+
@dataclasses.dataclass
|
128 |
+
class ScalarBilinear:
|
129 |
+
"""A scalar-valued bilinear operator."""
|
130 |
+
left_space: VectorSpaceWithBasis
|
131 |
+
right_space: VectorSpaceWithBasis
|
132 |
+
matrix: np.ndarray
|
133 |
+
|
134 |
+
def __post_init__(self):
|
135 |
+
"""Ensure matrix acts in sorted bases and typecheck sizes."""
|
136 |
+
left_size, right_size = self.matrix.shape
|
137 |
+
assert left_size == self.left_space.num_dims
|
138 |
+
assert right_size == self.right_space.num_dims
|
139 |
+
|
140 |
+
def __call__(self, x: VectorInBasis, y: VectorInBasis) -> float:
|
141 |
+
"""Describes the action of the operator on vectors."""
|
142 |
+
if x not in self.left_space:
|
143 |
+
raise TypeError(f"{x=} not in {self.left_space=}.")
|
144 |
+
if y not in self.right_space:
|
145 |
+
raise TypeError(f"{y=} not in {self.right_space=}.")
|
146 |
+
return (x.magnitudes.T @ self.matrix @ y.magnitudes).item()
|
147 |
+
|
148 |
+
@classmethod
|
149 |
+
def from_action(
|
150 |
+
cls,
|
151 |
+
left_space: VectorSpaceWithBasis,
|
152 |
+
right_space: VectorSpaceWithBasis,
|
153 |
+
action: Callable[[BasisDirection, BasisDirection], float],
|
154 |
+
) -> "ScalarBilinear":
|
155 |
+
"""from_action(l, r)(action) creates a ScalarBilinear."""
|
156 |
+
|
157 |
+
matrix = np.zeros((left_space.num_dims, right_space.num_dims))
|
158 |
+
for i, left_direction in enumerate(left_space.basis):
|
159 |
+
for j, right_direction in enumerate(right_space.basis):
|
160 |
+
matrix[i, j] = action(left_direction, right_direction)
|
161 |
+
|
162 |
+
return ScalarBilinear(left_space, right_space, matrix)
|
craft/vectorspace_fns_test.py
ADDED
@@ -0,0 +1,166 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2022 DeepMind Technologies Limited. All Rights Reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
# ==============================================================================
|
15 |
+
"""Tests for vectorspace_fns."""
|
16 |
+
|
17 |
+
from absl.testing import absltest
|
18 |
+
from absl.testing import parameterized
|
19 |
+
import numpy as np
|
20 |
+
from tracr.craft import bases
|
21 |
+
from tracr.craft import tests_common
|
22 |
+
from tracr.craft import vectorspace_fns as vs_fns
|
23 |
+
|
24 |
+
|
25 |
+
class LinearTest(tests_common.VectorFnTestCase):
|
26 |
+
|
27 |
+
def test_identity_from_matrix(self):
|
28 |
+
vs = bases.VectorSpaceWithBasis.from_names(["a", "b", "c"])
|
29 |
+
f = vs_fns.Linear(vs, vs, np.eye(3))
|
30 |
+
for v in vs.basis_vectors():
|
31 |
+
self.assertEqual(f(v), v)
|
32 |
+
|
33 |
+
def test_identity_from_action(self):
|
34 |
+
vs = bases.VectorSpaceWithBasis.from_names(["a", "b", "c"])
|
35 |
+
f = vs_fns.Linear.from_action(vs, vs, vs.vector_from_basis_direction)
|
36 |
+
for v in vs.basis_vectors():
|
37 |
+
self.assertEqual(f(v), v)
|
38 |
+
|
39 |
+
def test_nonidentiy(self):
|
40 |
+
vs = bases.VectorSpaceWithBasis.from_names(["a", "b"])
|
41 |
+
a = vs.vector_from_basis_direction(bases.BasisDirection("a"))
|
42 |
+
b = vs.vector_from_basis_direction(bases.BasisDirection("b"))
|
43 |
+
|
44 |
+
f = vs_fns.Linear(vs, vs, np.array([[0.3, 0.7], [0.2, 0.1]]))
|
45 |
+
|
46 |
+
self.assertEqual(
|
47 |
+
f(a), bases.VectorInBasis(vs.basis, np.array([0.3, 0.7])))
|
48 |
+
self.assertEqual(
|
49 |
+
f(b), bases.VectorInBasis(vs.basis, np.array([0.2, 0.1])))
|
50 |
+
|
51 |
+
def test_different_vector_spaces(self):
|
52 |
+
vs1 = bases.VectorSpaceWithBasis.from_names(["a", "b"])
|
53 |
+
vs2 = bases.VectorSpaceWithBasis.from_names(["c", "d"])
|
54 |
+
a, b = vs1.basis_vectors()
|
55 |
+
c, d = vs2.basis_vectors()
|
56 |
+
|
57 |
+
f = vs_fns.Linear(vs1, vs2, np.eye(2))
|
58 |
+
|
59 |
+
self.assertEqual(f(a), c)
|
60 |
+
self.assertEqual(f(b), d)
|
61 |
+
|
62 |
+
def test_combining_linear_functions_with_different_input(self):
|
63 |
+
vs1 = bases.VectorSpaceWithBasis.from_names(["a", "b"])
|
64 |
+
vs2 = bases.VectorSpaceWithBasis.from_names(["c", "d"])
|
65 |
+
vs = bases.direct_sum(vs1, vs2)
|
66 |
+
a = vs.vector_from_basis_direction(bases.BasisDirection("a"))
|
67 |
+
b = vs.vector_from_basis_direction(bases.BasisDirection("b"))
|
68 |
+
c = vs.vector_from_basis_direction(bases.BasisDirection("c"))
|
69 |
+
d = vs.vector_from_basis_direction(bases.BasisDirection("d"))
|
70 |
+
|
71 |
+
f1 = vs_fns.Linear(vs1, vs1, np.array([[0, 1], [1, 0]]))
|
72 |
+
f2 = vs_fns.Linear(vs2, vs2, np.array([[1, 0], [0, 0]]))
|
73 |
+
f3 = vs_fns.Linear.combine_in_parallel([f1, f2])
|
74 |
+
|
75 |
+
self.assertEqual(
|
76 |
+
f3(a), bases.VectorInBasis(vs.basis, np.array([0, 1, 0, 0])))
|
77 |
+
self.assertEqual(
|
78 |
+
f3(b), bases.VectorInBasis(vs.basis, np.array([1, 0, 0, 0])))
|
79 |
+
self.assertEqual(
|
80 |
+
f3(c), bases.VectorInBasis(vs.basis, np.array([0, 0, 1, 0])))
|
81 |
+
self.assertEqual(
|
82 |
+
f3(d), bases.VectorInBasis(vs.basis, np.array([0, 0, 0, 0])))
|
83 |
+
|
84 |
+
def test_combining_linear_functions_with_same_input(self):
|
85 |
+
vs = bases.VectorSpaceWithBasis.from_names(["a", "b"])
|
86 |
+
a = vs.vector_from_basis_direction(bases.BasisDirection("a"))
|
87 |
+
b = vs.vector_from_basis_direction(bases.BasisDirection("b"))
|
88 |
+
|
89 |
+
f1 = vs_fns.Linear(vs, vs, np.array([[0, 1], [1, 0]]))
|
90 |
+
f2 = vs_fns.Linear(vs, vs, np.array([[1, 0], [0, 0]]))
|
91 |
+
f3 = vs_fns.Linear.combine_in_parallel([f1, f2])
|
92 |
+
|
93 |
+
self.assertEqual(
|
94 |
+
f3(a), bases.VectorInBasis(vs.basis, np.array([1, 1])))
|
95 |
+
self.assertEqual(
|
96 |
+
f3(b), bases.VectorInBasis(vs.basis, np.array([1, 0])))
|
97 |
+
self.assertEqual(f3(a), f1(a) + f2(a))
|
98 |
+
self.assertEqual(f3(b), f1(b) + f2(b))
|
99 |
+
|
100 |
+
|
101 |
+
class ProjectionTest(tests_common.VectorFnTestCase):
|
102 |
+
|
103 |
+
def test_projection_to_larger_space(self):
|
104 |
+
vs1 = bases.VectorSpaceWithBasis.from_names(["a", "b"])
|
105 |
+
vs2 = bases.VectorSpaceWithBasis.from_names(["a", "b", "c", "d"])
|
106 |
+
a1, b1 = vs1.basis_vectors()
|
107 |
+
a2, b2, _, _ = vs2.basis_vectors()
|
108 |
+
|
109 |
+
f = vs_fns.project(vs1, vs2)
|
110 |
+
|
111 |
+
self.assertEqual(f(a1), a2)
|
112 |
+
self.assertEqual(f(b1), b2)
|
113 |
+
|
114 |
+
def test_projection_to_smaller_space(self):
|
115 |
+
vs1 = bases.VectorSpaceWithBasis.from_names(["a", "b", "c", "d"])
|
116 |
+
vs2 = bases.VectorSpaceWithBasis.from_names(["a", "b"])
|
117 |
+
a1, b1, c1, d1 = vs1.basis_vectors()
|
118 |
+
a2, b2 = vs2.basis_vectors()
|
119 |
+
|
120 |
+
f = vs_fns.project(vs1, vs2)
|
121 |
+
|
122 |
+
self.assertEqual(f(a1), a2)
|
123 |
+
self.assertEqual(f(b1), b2)
|
124 |
+
self.assertEqual(f(c1), vs2.null_vector())
|
125 |
+
self.assertEqual(f(d1), vs2.null_vector())
|
126 |
+
|
127 |
+
|
128 |
+
class ScalarBilinearTest(parameterized.TestCase):
|
129 |
+
|
130 |
+
def test_identity_matrix(self):
|
131 |
+
vs = bases.VectorSpaceWithBasis.from_names(["a", "b"])
|
132 |
+
a, b = vs.basis_vectors()
|
133 |
+
|
134 |
+
f = vs_fns.ScalarBilinear(vs, vs, np.eye(2))
|
135 |
+
|
136 |
+
self.assertEqual(f(a, a), 1)
|
137 |
+
self.assertEqual(f(a, b), 0)
|
138 |
+
self.assertEqual(f(b, a), 0)
|
139 |
+
self.assertEqual(f(b, b), 1)
|
140 |
+
|
141 |
+
def test_identity_from_action(self):
|
142 |
+
vs = bases.VectorSpaceWithBasis.from_names(["a", "b"])
|
143 |
+
a, b = vs.basis_vectors()
|
144 |
+
|
145 |
+
f = vs_fns.ScalarBilinear.from_action(vs, vs, lambda x, y: int(x == y))
|
146 |
+
|
147 |
+
self.assertEqual(f(a, a), 1)
|
148 |
+
self.assertEqual(f(a, b), 0)
|
149 |
+
self.assertEqual(f(b, a), 0)
|
150 |
+
self.assertEqual(f(b, b), 1)
|
151 |
+
|
152 |
+
def test_non_identity(self):
|
153 |
+
vs = bases.VectorSpaceWithBasis.from_names(["a", "b"])
|
154 |
+
a, b = vs.basis_vectors()
|
155 |
+
|
156 |
+
f = vs_fns.ScalarBilinear.from_action(vs, vs,
|
157 |
+
lambda x, y: int(x.name == "a"))
|
158 |
+
|
159 |
+
self.assertEqual(f(a, a), 1)
|
160 |
+
self.assertEqual(f(a, b), 1)
|
161 |
+
self.assertEqual(f(b, a), 0)
|
162 |
+
self.assertEqual(f(b, b), 0)
|
163 |
+
|
164 |
+
|
165 |
+
if __name__ == "__main__":
|
166 |
+
absltest.main()
|
examples/Visualize_Tracr_Models.ipynb
ADDED
@@ -0,0 +1,262 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "markdown",
|
5 |
+
"metadata": {
|
6 |
+
"id": "99FBiGH7bsfn"
|
7 |
+
},
|
8 |
+
"source": [
|
9 |
+
"# Compiling \u0026 Visualizing Tracr Models\n",
|
10 |
+
"\n",
|
11 |
+
"This notebook demonstrates how to compile a tracr model and provides some tools visualize the model's residual stream or layer outputs for a given input sequence."
|
12 |
+
]
|
13 |
+
},
|
14 |
+
{
|
15 |
+
"cell_type": "code",
|
16 |
+
"execution_count": null,
|
17 |
+
"metadata": {
|
18 |
+
"id": "qm-PM1PEawCx"
|
19 |
+
},
|
20 |
+
"outputs": [],
|
21 |
+
"source": [
|
22 |
+
"#@title Imports\n",
|
23 |
+
"import jax\n",
|
24 |
+
"import numpy as np\n",
|
25 |
+
"import matplotlib.pyplot as plt\n",
|
26 |
+
"\n",
|
27 |
+
"# The default of float16 can lead to discrepancies between outputs of\n",
|
28 |
+
"# the compiled model and the RASP program.\n",
|
29 |
+
"jax.config.update('jax_default_matmul_precision', 'float32')\n",
|
30 |
+
"\n",
|
31 |
+
"from tracr.compiler import compiling\n",
|
32 |
+
"from tracr.compiler import lib\n",
|
33 |
+
"from tracr.rasp import rasp"
|
34 |
+
]
|
35 |
+
},
|
36 |
+
{
|
37 |
+
"cell_type": "code",
|
38 |
+
"execution_count": null,
|
39 |
+
"metadata": {
|
40 |
+
"id": "HtOAc_yWawFR"
|
41 |
+
},
|
42 |
+
"outputs": [],
|
43 |
+
"source": [
|
44 |
+
"#@title Plotting functions\n",
|
45 |
+
"def tidy_label(label, value_width=5):\n",
|
46 |
+
" if ':' in label:\n",
|
47 |
+
" label, value = label.split(':')\n",
|
48 |
+
" else:\n",
|
49 |
+
" value = ''\n",
|
50 |
+
" return label + f\":{value:\u003e{value_width}}\"\n",
|
51 |
+
"\n",
|
52 |
+
"\n",
|
53 |
+
"def add_residual_ticks(model, value_width=5, x=False, y=True):\n",
|
54 |
+
" if y:\n",
|
55 |
+
" plt.yticks(\n",
|
56 |
+
" np.arange(len(model.residual_labels))+0.5, \n",
|
57 |
+
" [tidy_label(l, value_width=value_width)\n",
|
58 |
+
" for l in model.residual_labels], \n",
|
59 |
+
" family='monospace',\n",
|
60 |
+
" fontsize=20,\n",
|
61 |
+
" )\n",
|
62 |
+
" if x:\n",
|
63 |
+
" plt.xticks(\n",
|
64 |
+
" np.arange(len(model.residual_labels))+0.5, \n",
|
65 |
+
" [tidy_label(l, value_width=value_width)\n",
|
66 |
+
" for l in model.residual_labels], \n",
|
67 |
+
" family='monospace',\n",
|
68 |
+
" rotation=90,\n",
|
69 |
+
" fontsize=20,\n",
|
70 |
+
" )\n",
|
71 |
+
"\n",
|
72 |
+
"\n",
|
73 |
+
"def plot_computation_trace(model,\n",
|
74 |
+
" input_labels,\n",
|
75 |
+
" residuals_or_outputs,\n",
|
76 |
+
" add_input_layer=False,\n",
|
77 |
+
" figsize=(12, 9)):\n",
|
78 |
+
" fig, axes = plt.subplots(nrows=1, ncols=len(residuals_or_outputs), figsize=figsize, sharey=True)\n",
|
79 |
+
" value_width = max(map(len, map(str, input_labels))) + 1\n",
|
80 |
+
"\n",
|
81 |
+
" for i, (layer, ax) in enumerate(zip(residuals_or_outputs, axes)):\n",
|
82 |
+
" plt.sca(ax)\n",
|
83 |
+
" plt.pcolormesh(layer[0].T, vmin=0, vmax=1)\n",
|
84 |
+
" if i == 0:\n",
|
85 |
+
" add_residual_ticks(model, value_width=value_width)\n",
|
86 |
+
" plt.xticks(\n",
|
87 |
+
" np.arange(len(input_labels))+0.5,\n",
|
88 |
+
" input_labels,\n",
|
89 |
+
" rotation=90,\n",
|
90 |
+
" fontsize=20,\n",
|
91 |
+
" )\n",
|
92 |
+
" if add_input_layer and i == 0:\n",
|
93 |
+
" title = 'Input'\n",
|
94 |
+
" else:\n",
|
95 |
+
" layer_no = i - 1 if add_input_layer else i\n",
|
96 |
+
" layer_type = 'Attn' if layer_no % 2 == 0 else 'MLP'\n",
|
97 |
+
" title = f'{layer_type} {layer_no // 2 + 1}'\n",
|
98 |
+
" plt.title(title, fontsize=20)\n",
|
99 |
+
"\n",
|
100 |
+
"\n",
|
101 |
+
"def plot_residuals_and_input(model, inputs, figsize=(12, 9)):\n",
|
102 |
+
" \"\"\"Applies model to inputs, and plots the residual stream at each layer.\"\"\"\n",
|
103 |
+
" model_out = assembled_model.apply(inputs)\n",
|
104 |
+
" residuals = np.concatenate([model_out.input_embeddings[None, ...],\n",
|
105 |
+
" model_out.residuals], axis=0)\n",
|
106 |
+
" plot_computation_trace(\n",
|
107 |
+
" model=model,\n",
|
108 |
+
" input_labels=inputs,\n",
|
109 |
+
" residuals_or_outputs=residuals,\n",
|
110 |
+
" add_input_layer=True,\n",
|
111 |
+
" figsize=figsize)\n",
|
112 |
+
"\n",
|
113 |
+
"\n",
|
114 |
+
"def plot_layer_outputs(model, inputs, figsize=(12, 9)):\n",
|
115 |
+
" \"\"\"Applies model to inputs, and plots the outputs of each layer.\"\"\"\n",
|
116 |
+
" model_out = assembled_model.apply(inputs)\n",
|
117 |
+
" plot_computation_trace(\n",
|
118 |
+
" model=model,\n",
|
119 |
+
" input_labels=inputs,\n",
|
120 |
+
" residuals_or_outputs=model_out.layer_outputs,\n",
|
121 |
+
" add_input_layer=False,\n",
|
122 |
+
" figsize=figsize)\n"
|
123 |
+
]
|
124 |
+
},
|
125 |
+
{
|
126 |
+
"cell_type": "code",
|
127 |
+
"execution_count": null,
|
128 |
+
"metadata": {
|
129 |
+
"cellView": "form",
|
130 |
+
"id": "8hV0nv_ISmhM"
|
131 |
+
},
|
132 |
+
"outputs": [],
|
133 |
+
"source": [
|
134 |
+
"#@title Define RASP programs\n",
|
135 |
+
"def get_program(program_name, max_seq_len):\n",
|
136 |
+
" \"\"\"Returns RASP program and corresponding token vocabulary.\"\"\"\n",
|
137 |
+
" if program_name == \"length\":\n",
|
138 |
+
" vocab = {\"a\", \"b\", \"c\", \"d\"}\n",
|
139 |
+
" program = lib.make_length()\n",
|
140 |
+
" elif program_name == \"frac_prevs\":\n",
|
141 |
+
" vocab = {\"a\", \"b\", \"c\", \"x\"}\n",
|
142 |
+
" program = lib.make_frac_prevs((rasp.tokens == \"x\").named(\"is_x\"))\n",
|
143 |
+
" elif program_name == \"dyck-2\":\n",
|
144 |
+
" vocab = {\"(\", \")\", \"{\", \"}\"}\n",
|
145 |
+
" program = lib.make_shuffle_dyck(pairs=[\"()\", \"{}\"])\n",
|
146 |
+
" elif program_name == \"dyck-3\":\n",
|
147 |
+
" vocab = {\"(\", \")\", \"{\", \"}\", \"[\", \"]\"}\n",
|
148 |
+
" program = lib.make_shuffle_dyck(pairs=[\"()\", \"{}\", \"[]\"])\n",
|
149 |
+
" elif program_name == \"sort\":\n",
|
150 |
+
" vocab = {1, 2, 3, 4, 5}\n",
|
151 |
+
" program = lib.make_sort(\n",
|
152 |
+
" rasp.tokens, rasp.tokens, max_seq_len=max_seq_len, min_key=1)\n",
|
153 |
+
" elif program_name == \"sort_unique\":\n",
|
154 |
+
" vocab = {1, 2, 3, 4, 5}\n",
|
155 |
+
" program = lib.make_sort_unique(rasp.tokens, rasp.tokens)\n",
|
156 |
+
" elif program_name == \"hist\":\n",
|
157 |
+
" vocab = {\"a\", \"b\", \"c\", \"d\"}\n",
|
158 |
+
" program = lib.make_hist()\n",
|
159 |
+
" elif program_name == \"sort_freq\":\n",
|
160 |
+
" vocab = {\"a\", \"b\", \"c\", \"d\"}\n",
|
161 |
+
" program = lib.make_sort_freq(max_seq_len=max_seq_len)\n",
|
162 |
+
" elif program_name == \"pair_balance\":\n",
|
163 |
+
" vocab = {\"(\", \")\"}\n",
|
164 |
+
" program = lib.make_pair_balance(\n",
|
165 |
+
" sop=rasp.tokens, open_token=\"(\", close_token=\")\")\n",
|
166 |
+
" else:\n",
|
167 |
+
" raise NotImplementedError(f\"Program {program_name} not implemented.\")\n",
|
168 |
+
" return program, vocab"
|
169 |
+
]
|
170 |
+
},
|
171 |
+
{
|
172 |
+
"cell_type": "code",
|
173 |
+
"execution_count": null,
|
174 |
+
"metadata": {
|
175 |
+
"id": "L_m_ufaua9ri"
|
176 |
+
},
|
177 |
+
"outputs": [],
|
178 |
+
"source": [
|
179 |
+
"#@title: Assemble model\n",
|
180 |
+
"program_name = \"sort_unique\" #@param [\"length\", \"frac_prevs\", \"dyck-2\", \"dyck-3\", \"sort\", \"sort_unique\", \"hist\", \"sort_freq\", \"pair_balance\"]\n",
|
181 |
+
"max_seq_len = 5 #@param {label: \"Test\", type: \"integer\"}\n",
|
182 |
+
"\n",
|
183 |
+
"program, vocab = get_program(program_name=program_name,\n",
|
184 |
+
" max_seq_len=max_seq_len)\n",
|
185 |
+
"\n",
|
186 |
+
"print(f\"Compiling...\")\n",
|
187 |
+
"print(f\" Program: {program_name}\")\n",
|
188 |
+
"print(f\" Input vocabulary: {vocab}\")\n",
|
189 |
+
"print(f\" Context size: {max_seq_len}\")\n",
|
190 |
+
"\n",
|
191 |
+
"assembled_model = compiling.compile_rasp_to_model(\n",
|
192 |
+
" program=program,\n",
|
193 |
+
" vocab=vocab,\n",
|
194 |
+
" max_seq_len=max_seq_len,\n",
|
195 |
+
" causal=False,\n",
|
196 |
+
" compiler_bos=\"bos\",\n",
|
197 |
+
" compiler_pad=\"pad\",\n",
|
198 |
+
" mlp_exactness=100)\n",
|
199 |
+
"\n",
|
200 |
+
"print(\"Done.\")"
|
201 |
+
]
|
202 |
+
},
|
203 |
+
{
|
204 |
+
"cell_type": "code",
|
205 |
+
"execution_count": null,
|
206 |
+
"metadata": {
|
207 |
+
"id": "wtwiE-JiXF3F"
|
208 |
+
},
|
209 |
+
"outputs": [],
|
210 |
+
"source": [
|
211 |
+
"#@title Forward pass\n",
|
212 |
+
"assembled_model.apply([\"bos\", 3, 4, 1]).decoded"
|
213 |
+
]
|
214 |
+
},
|
215 |
+
{
|
216 |
+
"cell_type": "code",
|
217 |
+
"execution_count": null,
|
218 |
+
"metadata": {
|
219 |
+
"id": "RkEkVcEHa2gf"
|
220 |
+
},
|
221 |
+
"outputs": [],
|
222 |
+
"source": [
|
223 |
+
"#@title Plot residual stream\n",
|
224 |
+
"plot_residuals_and_input(\n",
|
225 |
+
" model=assembled_model,\n",
|
226 |
+
" inputs=[\"bos\", 3, 4, 1],\n",
|
227 |
+
" figsize=(10, 9)\n",
|
228 |
+
")"
|
229 |
+
]
|
230 |
+
},
|
231 |
+
{
|
232 |
+
"cell_type": "code",
|
233 |
+
"execution_count": null,
|
234 |
+
"metadata": {
|
235 |
+
"id": "8c4LakWHa4ey"
|
236 |
+
},
|
237 |
+
"outputs": [],
|
238 |
+
"source": [
|
239 |
+
"#@title Plot layer outputs\n",
|
240 |
+
"plot_layer_outputs(\n",
|
241 |
+
" model=assembled_model,\n",
|
242 |
+
" inputs = [\"bos\", 3, 4, 1],\n",
|
243 |
+
" figsize=(8, 9)\n",
|
244 |
+
")"
|
245 |
+
]
|
246 |
+
}
|
247 |
+
],
|
248 |
+
"metadata": {
|
249 |
+
"colab": {
|
250 |
+
"private_outputs": true
|
251 |
+
},
|
252 |
+
"kernelspec": {
|
253 |
+
"display_name": "Python 3",
|
254 |
+
"name": "python3"
|
255 |
+
},
|
256 |
+
"language_info": {
|
257 |
+
"name": "python"
|
258 |
+
}
|
259 |
+
},
|
260 |
+
"nbformat": 4,
|
261 |
+
"nbformat_minor": 0
|
262 |
+
}
|
rasp/causal_eval.py
ADDED
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2022 DeepMind Technologies Limited. All Rights Reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
# ==============================================================================
|
15 |
+
"""RASP Evaluator which applies causal masks to selectors."""
|
16 |
+
|
17 |
+
from typing import Sequence, Union
|
18 |
+
|
19 |
+
import numpy as np
|
20 |
+
from tracr.rasp import rasp
|
21 |
+
|
22 |
+
|
23 |
+
class CausalEvaluator(rasp.DefaultRASPEvaluator):
|
24 |
+
"""Evaluates RASP with causal masking."""
|
25 |
+
|
26 |
+
def evaluate(
|
27 |
+
self, expr: rasp.RASPExpr, xs: Sequence[rasp.Value]
|
28 |
+
) -> Union[Sequence[rasp.Value], rasp.SelectorValue]:
|
29 |
+
out = super().evaluate(expr, xs)
|
30 |
+
|
31 |
+
if not isinstance(expr, rasp.Selector):
|
32 |
+
return out
|
33 |
+
|
34 |
+
out = np.array(out)
|
35 |
+
causal_mask = np.tril(np.full(out.shape, 1))
|
36 |
+
return np.logical_and(causal_mask, out).tolist()
|
37 |
+
|
38 |
+
|
39 |
+
evaluate = CausalEvaluator().evaluate
|
rasp/causal_eval_test.py
ADDED
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2022 DeepMind Technologies Limited. All Rights Reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
# ==============================================================================
|
15 |
+
"""Tests for causal_eval."""
|
16 |
+
|
17 |
+
from absl.testing import absltest
|
18 |
+
from absl.testing import parameterized
|
19 |
+
|
20 |
+
from tracr.rasp import causal_eval
|
21 |
+
from tracr.rasp import rasp
|
22 |
+
|
23 |
+
|
24 |
+
class CausalEvalTest(parameterized.TestCase):
|
25 |
+
|
26 |
+
@parameterized.named_parameters(
|
27 |
+
dict(
|
28 |
+
testcase_name="constant_selector_3x3_1",
|
29 |
+
program=rasp.ConstantSelector([
|
30 |
+
[True, True, True],
|
31 |
+
[True, True, True],
|
32 |
+
[True, True, True],
|
33 |
+
]),
|
34 |
+
input_sequence=[True, True, True],
|
35 |
+
expected_output=[
|
36 |
+
[True, False, False],
|
37 |
+
[True, True, False],
|
38 |
+
[True, True, True],
|
39 |
+
]),
|
40 |
+
dict(
|
41 |
+
testcase_name="constant_selector_3x3_2",
|
42 |
+
program=rasp.ConstantSelector([
|
43 |
+
[True, True, True],
|
44 |
+
[False, True, True],
|
45 |
+
[True, False, True],
|
46 |
+
]),
|
47 |
+
input_sequence=[True, True, True],
|
48 |
+
expected_output=[
|
49 |
+
[True, False, False],
|
50 |
+
[False, True, False],
|
51 |
+
[True, False, True],
|
52 |
+
]))
|
53 |
+
def test_evaluations(self, program, input_sequence, expected_output):
|
54 |
+
self.assertListEqual(
|
55 |
+
causal_eval.evaluate(program, input_sequence),
|
56 |
+
expected_output,
|
57 |
+
)
|
58 |
+
|
59 |
+
|
60 |
+
if __name__ == "__main__":
|
61 |
+
absltest.main()
|
rasp/rasp.py
ADDED
@@ -0,0 +1,932 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2022 DeepMind Technologies Limited. All Rights Reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
# ==============================================================================
|
15 |
+
"""RASP program objects.
|
16 |
+
|
17 |
+
Every object in the RASP language is a function.
|
18 |
+
|
19 |
+
The most important type is S-Op, which is a function list[Value] -> list[Value].
|
20 |
+
|
21 |
+
An S-Op represents a state inside the residual stream of the transformer.
|
22 |
+
Therefore, any RASP program that represents a transformer computation must
|
23 |
+
define a final S-Op that represents the state of the residual stream at the
|
24 |
+
end of the computation. In particular, given an S-Op `x`,
|
25 |
+
`x([1, 2, 3])` represents something like the state of the residual stream
|
26 |
+
at location `x` when the transformer is fed [1, 2, 3] as input.
|
27 |
+
|
28 |
+
A secondary (but still important) type is Selector, which is a function
|
29 |
+
list[Value] -> list[list[bool]]. Given a Selector `sel`, sel([1, 2, 3])
|
30 |
+
represents something like an attention matrix in the transformer.
|
31 |
+
|
32 |
+
For a full reference on RASP, see https://arxiv.org/abs/2106.06981.
|
33 |
+
"""
|
34 |
+
|
35 |
+
import abc
|
36 |
+
import collections.abc
|
37 |
+
import copy
|
38 |
+
import enum
|
39 |
+
import functools
|
40 |
+
import itertools
|
41 |
+
from typing import (Any, Callable, Generic, Mapping, Optional, Protocol,
|
42 |
+
Sequence, TypeVar, Union)
|
43 |
+
from absl import logging
|
44 |
+
|
45 |
+
import numpy as np
|
46 |
+
|
47 |
+
SelectorValue = list[list[bool]]
|
48 |
+
NumericValue = Union[int, float]
|
49 |
+
Value = Union[None, int, float, str, bool]
|
50 |
+
VT = TypeVar("VT", bound=Value)
|
51 |
+
RASPExprT = TypeVar("RASPExprT", bound="RASPExpr")
|
52 |
+
SOpT = TypeVar("SOpT", bound="SOp")
|
53 |
+
T = TypeVar("T")
|
54 |
+
|
55 |
+
_NAME_KEY = "name"
|
56 |
+
_ENCODING_KEY = "encoding"
|
57 |
+
|
58 |
+
# These are run on every expression when it's initialised.
|
59 |
+
# Add your own annotators to this dict to add custom default annotations.
|
60 |
+
#
|
61 |
+
# For example, DEFAULT_ANNOTATORS['foo'] will provide the default value for
|
62 |
+
# expr.annotations['foo]. The annotator will get called lazily the first time
|
63 |
+
# that key is accessed.
|
64 |
+
#
|
65 |
+
# See the `default_name` annotator for a full example.
|
66 |
+
DEFAULT_ANNOTATORS: dict[str, "Annotator"] = {}
|
67 |
+
|
68 |
+
|
69 |
+
class Annotator(Protocol):
|
70 |
+
|
71 |
+
def __call__(self, expr: "RASPExpr") -> Any:
|
72 |
+
"""What annotation to add to `expr`."""
|
73 |
+
|
74 |
+
|
75 |
+
class _Annotations(collections.abc.Mapping):
|
76 |
+
"""Holds the expression's annotations.
|
77 |
+
|
78 |
+
It's immutable to the user, but will attempt to generate default values
|
79 |
+
lazily when missing keys are requested.
|
80 |
+
"""
|
81 |
+
|
82 |
+
def __init__(self, expr, **kwargs: Any):
|
83 |
+
self._expr = expr
|
84 |
+
self._inner_dict: dict[str, Any] = {**kwargs}
|
85 |
+
|
86 |
+
def __getitem__(self, key: str) -> Any:
|
87 |
+
if key not in self._inner_dict:
|
88 |
+
if key not in DEFAULT_ANNOTATORS:
|
89 |
+
raise KeyError(
|
90 |
+
f"No annotation exists for key '{key}'. "
|
91 |
+
f"Available keys: {list(*self.keys(), *DEFAULT_ANNOTATORS.keys())}")
|
92 |
+
self._inner_dict[key] = DEFAULT_ANNOTATORS[key](self._expr)
|
93 |
+
|
94 |
+
return self._inner_dict[key]
|
95 |
+
|
96 |
+
def __iter__(self):
|
97 |
+
return iter(self._inner_dict)
|
98 |
+
|
99 |
+
def __len__(self):
|
100 |
+
return len(self._inner_dict)
|
101 |
+
|
102 |
+
|
103 |
+
class RASPExpr(abc.ABC):
|
104 |
+
"""A class distinguishing RASP expressions from other objects."""
|
105 |
+
_ids = itertools.count(1)
|
106 |
+
|
107 |
+
def __init__(self):
|
108 |
+
self._annotations: Mapping[str, Any] = _Annotations(self)
|
109 |
+
|
110 |
+
@abc.abstractmethod
|
111 |
+
def __call__(self,
|
112 |
+
xs: Sequence[Value]) -> Union[Sequence[Value], SelectorValue]:
|
113 |
+
"""Evaluates the RASPExpr using the standard evaluator."""
|
114 |
+
|
115 |
+
@property
|
116 |
+
def annotations(self) -> Mapping[str, Any]:
|
117 |
+
"""The annotations of this expression instance."""
|
118 |
+
return self._annotations
|
119 |
+
|
120 |
+
@annotations.setter
|
121 |
+
def annotations(self, annotations: Mapping[str, Any]):
|
122 |
+
self._annotations = _Annotations(self, **annotations)
|
123 |
+
|
124 |
+
@property
|
125 |
+
def name(self) -> str:
|
126 |
+
"""The name of this expression."""
|
127 |
+
return self.annotations[_NAME_KEY]
|
128 |
+
|
129 |
+
@property
|
130 |
+
@abc.abstractmethod
|
131 |
+
def children(self) -> Sequence["RASPExpr"]:
|
132 |
+
"""Direct dependencies of this expression."""
|
133 |
+
|
134 |
+
@functools.cached_property
|
135 |
+
def unique_id(self):
|
136 |
+
"""A unique id for every expression instance."""
|
137 |
+
return next(self._ids)
|
138 |
+
|
139 |
+
def copy(self: RASPExprT) -> RASPExprT:
|
140 |
+
"""Returns a shallow copy of this RASPExpr with a new ID."""
|
141 |
+
return copy.copy(self)
|
142 |
+
|
143 |
+
@property
|
144 |
+
def label(self) -> str:
|
145 |
+
return f"{self.name}_{self.unique_id}"
|
146 |
+
|
147 |
+
def named(self: RASPExprT, name: str) -> RASPExprT:
|
148 |
+
"""Convenience method for adding a name."""
|
149 |
+
return annotate(self, name=name)
|
150 |
+
|
151 |
+
def annotated(self: RASPExprT, **annotations) -> RASPExprT:
|
152 |
+
"""Convenience method for adding annotations."""
|
153 |
+
return annotate(self, **annotations)
|
154 |
+
|
155 |
+
|
156 |
+
def annotate(expr: RASPExprT, **annotations) -> RASPExprT:
|
157 |
+
"""Creates a new expr with added annotations."""
|
158 |
+
new = expr.copy()
|
159 |
+
# Note that new annotations will overwrite existing ones with matching keys.
|
160 |
+
new.annotations = {**expr.annotations, **annotations}
|
161 |
+
return new
|
162 |
+
|
163 |
+
|
164 |
+
### S-Ops.
|
165 |
+
|
166 |
+
|
167 |
+
class SOp(RASPExpr):
|
168 |
+
"""A Sequence Operation."""
|
169 |
+
|
170 |
+
def __call__(self, xs: Sequence[Value]) -> Sequence[Value]:
|
171 |
+
return evaluate(self, xs) # pytype: disable=bad-return-type
|
172 |
+
|
173 |
+
# Allow construction of SOps using numeric operators with constant values.
|
174 |
+
# Note: if inheriting SOp by a dataclass, make sure to disable eq and order,
|
175 |
+
# as they will override these.
|
176 |
+
|
177 |
+
def __lt__(self, other: Value) -> "SOp":
|
178 |
+
"""self < other."""
|
179 |
+
return Map(lambda x: x < other, self)
|
180 |
+
|
181 |
+
def __le__(self, other: Value) -> "SOp":
|
182 |
+
"""self <= other."""
|
183 |
+
return Map(lambda x: x <= other, self)
|
184 |
+
|
185 |
+
def __eq__(self, other: Value) -> "SOp":
|
186 |
+
"""self == other."""
|
187 |
+
return Map(lambda x: x == other, self)
|
188 |
+
|
189 |
+
def __ne__(self, other: Value) -> "SOp":
|
190 |
+
"""self != other."""
|
191 |
+
return Map(lambda x: x != other, self)
|
192 |
+
|
193 |
+
def __gt__(self, other: Value) -> "SOp":
|
194 |
+
"""self > other."""
|
195 |
+
return Map(lambda x: x > other, self)
|
196 |
+
|
197 |
+
def __ge__(self, other: Value) -> "SOp":
|
198 |
+
"""self >= other."""
|
199 |
+
return Map(lambda x: x >= other, self)
|
200 |
+
|
201 |
+
def __add__(self, other: Union["SOp", Value]) -> "SOp":
|
202 |
+
"""self + other."""
|
203 |
+
if isinstance(other, SOp):
|
204 |
+
return SequenceMap(lambda x, y: x + y, self, other)
|
205 |
+
return Map(lambda x: x + other, self)
|
206 |
+
|
207 |
+
def __radd__(self, other: Union["SOp", Value]) -> "SOp":
|
208 |
+
"""other + self."""
|
209 |
+
if isinstance(other, SOp):
|
210 |
+
return SequenceMap(lambda x, y: x + y, other, self)
|
211 |
+
return Map(lambda x: other + x, self)
|
212 |
+
|
213 |
+
def __sub__(self, other: Union["SOp", NumericValue]) -> "SOp":
|
214 |
+
"""self - other."""
|
215 |
+
if isinstance(other, SOp):
|
216 |
+
return SequenceMap(lambda x, y: x - y, self, other)
|
217 |
+
return Map(lambda x: x - other, self)
|
218 |
+
|
219 |
+
def __rsub__(self, other: Union["SOp", NumericValue]) -> "SOp":
|
220 |
+
"""other - self."""
|
221 |
+
if isinstance(other, SOp):
|
222 |
+
return SequenceMap(lambda x, y: x - y, other, self)
|
223 |
+
return Map(lambda x: other - x, self)
|
224 |
+
|
225 |
+
def __mul__(self, other: Union["SOp", NumericValue]) -> "SOp":
|
226 |
+
"""self * other."""
|
227 |
+
if isinstance(other, SOp):
|
228 |
+
return SequenceMap(lambda x, y: x * y, self, other)
|
229 |
+
return Map(lambda x: x * other, self)
|
230 |
+
|
231 |
+
def __rmul__(self, other: Union["SOp", NumericValue]) -> "SOp":
|
232 |
+
"""other * self."""
|
233 |
+
if isinstance(other, SOp):
|
234 |
+
return SequenceMap(lambda x, y: x * y, other, self)
|
235 |
+
return Map(lambda x: other * x, self)
|
236 |
+
|
237 |
+
def __truediv__(self, other: Union["SOp", NumericValue]) -> "SOp":
|
238 |
+
"""self / other."""
|
239 |
+
if isinstance(other, SOp):
|
240 |
+
return SequenceMap(lambda x, y: x / y, self, other)
|
241 |
+
return Map(lambda x: x / other, self)
|
242 |
+
|
243 |
+
def __rtruediv__(self, other: Union["SOp", NumericValue]) -> "SOp":
|
244 |
+
"""other / self."""
|
245 |
+
if isinstance(other, SOp):
|
246 |
+
return SequenceMap(lambda x, y: x / y, other, self)
|
247 |
+
return Map(lambda x: other / x, self)
|
248 |
+
|
249 |
+
def __invert__(self) -> "SOp":
|
250 |
+
return Map(lambda x: not x, self)
|
251 |
+
|
252 |
+
def __and__(self, other: Union["SOp", NumericValue]) -> "SOp":
|
253 |
+
"""self & other."""
|
254 |
+
if isinstance(other, SOp):
|
255 |
+
return SequenceMap(lambda x, y: x and y, self, other)
|
256 |
+
return Map(lambda x: x and other, self)
|
257 |
+
|
258 |
+
def __or__(self, other: Union["SOp", NumericValue]) -> "SOp":
|
259 |
+
"""self | other."""
|
260 |
+
if isinstance(other, SOp):
|
261 |
+
return SequenceMap(lambda x, y: x or y, self, other)
|
262 |
+
return Map(lambda x: x or other, self)
|
263 |
+
|
264 |
+
def __rand__(self, other: Union["SOp", NumericValue]) -> "SOp":
|
265 |
+
"""other & self."""
|
266 |
+
if isinstance(other, SOp):
|
267 |
+
return SequenceMap(lambda x, y: x and y, other, self)
|
268 |
+
return Map(lambda x: other and x, self)
|
269 |
+
|
270 |
+
def __ror__(self, other: Union["SOp", NumericValue]) -> "SOp":
|
271 |
+
"""other | self."""
|
272 |
+
if isinstance(other, SOp):
|
273 |
+
return SequenceMap(lambda x, y: x or y, other, self)
|
274 |
+
return Map(lambda x: x or other, self)
|
275 |
+
|
276 |
+
|
277 |
+
class TokensType(SOp):
|
278 |
+
"""Primitive SOp returning the original input tokens."""
|
279 |
+
|
280 |
+
@property
|
281 |
+
def children(self) -> Sequence[RASPExpr]:
|
282 |
+
return []
|
283 |
+
|
284 |
+
@property
|
285 |
+
def label(self) -> str:
|
286 |
+
return "tokens"
|
287 |
+
|
288 |
+
def __repr__(self):
|
289 |
+
return "tokens"
|
290 |
+
|
291 |
+
|
292 |
+
class IndicesType(SOp):
|
293 |
+
"""Primitive SOp returning the position index at each token."""
|
294 |
+
|
295 |
+
@property
|
296 |
+
def children(self) -> Sequence[RASPExpr]:
|
297 |
+
return []
|
298 |
+
|
299 |
+
@property
|
300 |
+
def label(self) -> str:
|
301 |
+
return "indices"
|
302 |
+
|
303 |
+
def __repr__(self):
|
304 |
+
return "indices"
|
305 |
+
|
306 |
+
|
307 |
+
class LengthType(SOp):
|
308 |
+
"""Primitive SOp returning the total length of the input."""
|
309 |
+
|
310 |
+
@property
|
311 |
+
def children(self) -> Sequence[RASPExpr]:
|
312 |
+
return []
|
313 |
+
|
314 |
+
@property
|
315 |
+
def label(self) -> str:
|
316 |
+
return "length"
|
317 |
+
|
318 |
+
def __repr__(self):
|
319 |
+
return "length"
|
320 |
+
|
321 |
+
|
322 |
+
tokens = TokensType()
|
323 |
+
indices = IndicesType()
|
324 |
+
length = LengthType()
|
325 |
+
|
326 |
+
|
327 |
+
class Map(SOp):
|
328 |
+
"""SOp that evaluates the function elementwise on the input SOp.
|
329 |
+
|
330 |
+
Map(lambda x: x + 1, tokens).eval([1, 2, 3]) == [2, 3, 4]
|
331 |
+
"""
|
332 |
+
|
333 |
+
def __init__(self, f: Callable[[Value], Value], inner: SOp):
|
334 |
+
super().__init__()
|
335 |
+
self.f = f
|
336 |
+
self.inner = inner
|
337 |
+
|
338 |
+
assert isinstance(self.inner, SOp)
|
339 |
+
assert callable(self.f) and not isinstance(self.f, RASPExpr)
|
340 |
+
|
341 |
+
if isinstance(self.inner, Map):
|
342 |
+
# Combine the functions into just one.
|
343 |
+
inner_f = self.inner.f
|
344 |
+
self.f = lambda t: f(inner_f(t))
|
345 |
+
self.inner = self.inner.inner
|
346 |
+
|
347 |
+
@property
|
348 |
+
def children(self) -> Sequence[RASPExpr]:
|
349 |
+
return [self.inner]
|
350 |
+
|
351 |
+
|
352 |
+
class SequenceMap(SOp):
|
353 |
+
"""SOp that evaluates the function elementwise on the two given SOp's.
|
354 |
+
|
355 |
+
SequenceMap(lambda x, y: x - y, length, tokens).eval([1, 2, 3]) == [2, 1, 0]
|
356 |
+
"""
|
357 |
+
|
358 |
+
def __init__(self, f: Callable[[Value, Value], Value], fst: SOp, snd: SOp):
|
359 |
+
super().__init__()
|
360 |
+
|
361 |
+
if fst == snd:
|
362 |
+
logging.warning("Creating a SequenceMap with both inputs being the same "
|
363 |
+
"SOp is discouraged. You should use a Map instead.")
|
364 |
+
|
365 |
+
self.f = f
|
366 |
+
self.fst = fst
|
367 |
+
self.snd = snd
|
368 |
+
assert isinstance(self.fst, SOp)
|
369 |
+
assert isinstance(self.snd, SOp)
|
370 |
+
assert callable(self.f) and not isinstance(self.f, RASPExpr)
|
371 |
+
|
372 |
+
@property
|
373 |
+
def children(self) -> Sequence[RASPExpr]:
|
374 |
+
return [self.fst, self.snd]
|
375 |
+
|
376 |
+
|
377 |
+
class LinearSequenceMap(SequenceMap):
|
378 |
+
"""SOp that evaluates a linear function elementwise on the two given SOp's."""
|
379 |
+
|
380 |
+
def __init__(self, fst: SOp, snd: SOp, fst_fac: float, snd_fac: float):
|
381 |
+
super().__init__(fst=fst, snd=snd, f=lambda x, y: fst_fac * x + snd_fac * y)
|
382 |
+
self.fst_fac = fst_fac
|
383 |
+
self.snd_fac = snd_fac
|
384 |
+
|
385 |
+
|
386 |
+
class Full(SOp):
|
387 |
+
"""A SOp evaluating to [fill]*len(input_values)."""
|
388 |
+
|
389 |
+
def __init__(self, fill: Value):
|
390 |
+
super().__init__()
|
391 |
+
self.fill = fill
|
392 |
+
|
393 |
+
@property
|
394 |
+
def children(self) -> Sequence[RASPExpr]:
|
395 |
+
return []
|
396 |
+
|
397 |
+
|
398 |
+
def sop_not(sop: SOp) -> SOp:
|
399 |
+
return Map(lambda t: not t, sop)
|
400 |
+
|
401 |
+
|
402 |
+
class ConstantSOp(SOp, Generic[VT]):
|
403 |
+
"""A constant S-Op for testing purposes."""
|
404 |
+
|
405 |
+
def __init__(self, value: Sequence[VT], check_length: bool = True):
|
406 |
+
super().__init__()
|
407 |
+
self.value = value
|
408 |
+
self.check_length = check_length
|
409 |
+
|
410 |
+
@property
|
411 |
+
def children(self) -> Sequence[RASPExpr]:
|
412 |
+
return []
|
413 |
+
|
414 |
+
|
415 |
+
### Selectors.
|
416 |
+
|
417 |
+
|
418 |
+
class Predicate(Protocol):
|
419 |
+
|
420 |
+
def __call__(self, key: Value, query: Value) -> bool:
|
421 |
+
"""Applies the predicate."""
|
422 |
+
|
423 |
+
|
424 |
+
class Comparison(enum.Enum):
|
425 |
+
"""A two-place boolean comparison predicate for use in Select."""
|
426 |
+
EQ = "=="
|
427 |
+
LT = "<"
|
428 |
+
LEQ = "<="
|
429 |
+
GT = ">"
|
430 |
+
GEQ = ">="
|
431 |
+
NEQ = "!="
|
432 |
+
TRUE = "True"
|
433 |
+
FALSE = "False"
|
434 |
+
|
435 |
+
def __call__(self, key: Value, query: Value) -> bool:
|
436 |
+
if key is None:
|
437 |
+
raise ValueError("key is None!")
|
438 |
+
if query is None:
|
439 |
+
raise ValueError("query is None!")
|
440 |
+
return _comparison_table[self](key, query)
|
441 |
+
|
442 |
+
|
443 |
+
_comparison_table = {
|
444 |
+
Comparison.EQ: lambda key, query: key == query,
|
445 |
+
Comparison.LT: lambda key, query: key < query,
|
446 |
+
Comparison.LEQ: lambda key, query: key <= query,
|
447 |
+
Comparison.GT: lambda key, query: key > query,
|
448 |
+
Comparison.GEQ: lambda key, query: key >= query,
|
449 |
+
Comparison.NEQ: lambda key, query: key != query,
|
450 |
+
Comparison.TRUE: lambda key, query: True,
|
451 |
+
Comparison.FALSE: lambda key, query: False,
|
452 |
+
}
|
453 |
+
|
454 |
+
|
455 |
+
class Selector(RASPExpr):
|
456 |
+
"""RASP Selector. Represents something like an attention head's weights."""
|
457 |
+
|
458 |
+
def __call__(self, xs: Sequence[Value]) -> SelectorValue:
|
459 |
+
return evaluate(self, xs) # pytype: disable=bad-return-type
|
460 |
+
|
461 |
+
# Allow construction of Selector combinations using Python logical operators.
|
462 |
+
def __and__(self, other: "Selector") -> "Selector":
|
463 |
+
"""self & other."""
|
464 |
+
return selector_and(self, other)
|
465 |
+
|
466 |
+
def __rand__(self, other: "Selector") -> "Selector":
|
467 |
+
"""other & self."""
|
468 |
+
return selector_and(other, self)
|
469 |
+
|
470 |
+
def __or__(self, other: "Selector") -> "Selector":
|
471 |
+
"""self | other."""
|
472 |
+
return selector_or(self, other)
|
473 |
+
|
474 |
+
def __ror__(self, other: "Selector") -> "Selector":
|
475 |
+
"""other | self."""
|
476 |
+
return selector_or(other, self)
|
477 |
+
|
478 |
+
def __invert__(self) -> "Selector":
|
479 |
+
"""~self."""
|
480 |
+
return selector_not(self)
|
481 |
+
|
482 |
+
|
483 |
+
class Select(Selector):
|
484 |
+
"""Primitive that creates a Selector."""
|
485 |
+
|
486 |
+
def __init__(self, keys: SOp, queries: SOp, predicate: Predicate):
|
487 |
+
super().__init__()
|
488 |
+
self.keys = keys
|
489 |
+
self.queries = queries
|
490 |
+
self.predicate = predicate
|
491 |
+
assert isinstance(self.keys, SOp)
|
492 |
+
assert isinstance(self.queries, SOp)
|
493 |
+
|
494 |
+
@property
|
495 |
+
def children(self) -> Sequence[RASPExpr]:
|
496 |
+
return [self.keys, self.queries]
|
497 |
+
|
498 |
+
|
499 |
+
class ConstantSelector(Selector):
|
500 |
+
"""A constant selector for testing purposes."""
|
501 |
+
|
502 |
+
def __init__(self, value: SelectorValue, check_length: bool = True):
|
503 |
+
super().__init__()
|
504 |
+
self.value = value
|
505 |
+
self.check_length = check_length
|
506 |
+
|
507 |
+
@property
|
508 |
+
def children(self) -> Sequence[RASPExpr]:
|
509 |
+
return []
|
510 |
+
|
511 |
+
|
512 |
+
class SelectorWidth(SOp):
|
513 |
+
"""SelectorWidth primitive."""
|
514 |
+
|
515 |
+
def __init__(self, selector: Selector):
|
516 |
+
super().__init__()
|
517 |
+
self.selector = selector
|
518 |
+
assert isinstance(self.selector, Selector)
|
519 |
+
|
520 |
+
@property
|
521 |
+
def children(self) -> Sequence[RASPExpr]:
|
522 |
+
return [self.selector]
|
523 |
+
|
524 |
+
|
525 |
+
class SelectorAnd(Selector):
|
526 |
+
"""Implements elementwise `and` between selectors."""
|
527 |
+
|
528 |
+
def __init__(self, fst: Selector, snd: Selector):
|
529 |
+
super().__init__()
|
530 |
+
self.fst = fst
|
531 |
+
self.snd = snd
|
532 |
+
assert isinstance(self.fst, Selector)
|
533 |
+
assert isinstance(self.snd, Selector)
|
534 |
+
|
535 |
+
@property
|
536 |
+
def children(self) -> Sequence[RASPExpr]:
|
537 |
+
return [self.fst, self.snd]
|
538 |
+
|
539 |
+
|
540 |
+
class SelectorOr(Selector):
|
541 |
+
"""Implements elementwise `or` between selectors."""
|
542 |
+
|
543 |
+
def __init__(self, fst: Selector, snd: Selector):
|
544 |
+
super().__init__()
|
545 |
+
self.fst = fst
|
546 |
+
self.snd = snd
|
547 |
+
assert isinstance(self.fst, Selector)
|
548 |
+
assert isinstance(self.snd, Selector)
|
549 |
+
|
550 |
+
@property
|
551 |
+
def children(self) -> Sequence[RASPExpr]:
|
552 |
+
return [self.fst, self.snd]
|
553 |
+
|
554 |
+
|
555 |
+
class SelectorNot(Selector):
|
556 |
+
"""Implements elementwise `not` on a selector."""
|
557 |
+
|
558 |
+
def __init__(self, inner: Selector):
|
559 |
+
self.inner = inner
|
560 |
+
super().__init__()
|
561 |
+
assert isinstance(self.inner, Selector)
|
562 |
+
|
563 |
+
@property
|
564 |
+
def children(self) -> Sequence[RASPExpr]:
|
565 |
+
return [self.inner]
|
566 |
+
|
567 |
+
|
568 |
+
def selector_not(
|
569 |
+
inner: Selector,
|
570 |
+
simplify: bool = True,
|
571 |
+
) -> Selector:
|
572 |
+
"""Returns a SelectorNot, or a Select if simplifying is possible."""
|
573 |
+
if simplify and isinstance(inner, Select):
|
574 |
+
predicate = lambda k, q: not inner.predicate(k, q)
|
575 |
+
return Select(inner.keys, inner.queries, predicate=predicate)
|
576 |
+
|
577 |
+
return SelectorNot(inner)
|
578 |
+
|
579 |
+
|
580 |
+
def selector_and(
|
581 |
+
fst: Selector,
|
582 |
+
snd: Selector,
|
583 |
+
simplify: bool = True,
|
584 |
+
) -> Selector:
|
585 |
+
"""Returns a SelectorAnd, or a Select if simplifying is possible."""
|
586 |
+
if simplify and isinstance(fst, Select) and isinstance(snd, Select):
|
587 |
+
simplified = _attempt_simplify(fst, snd, lambda l, r: l and r)
|
588 |
+
if simplified:
|
589 |
+
return simplified
|
590 |
+
|
591 |
+
return SelectorAnd(fst, snd)
|
592 |
+
|
593 |
+
|
594 |
+
def selector_or(
|
595 |
+
fst: Selector,
|
596 |
+
snd: Selector,
|
597 |
+
simplify: bool = True,
|
598 |
+
) -> Selector:
|
599 |
+
"""Returns a SelectorOr, or a Select if simplifying is possible."""
|
600 |
+
if simplify and isinstance(fst, Select) and isinstance(snd, Select):
|
601 |
+
simplified = _attempt_simplify(fst, snd, lambda l, r: l or r)
|
602 |
+
if simplified:
|
603 |
+
return simplified
|
604 |
+
|
605 |
+
return SelectorOr(fst, snd)
|
606 |
+
|
607 |
+
|
608 |
+
def _attempt_simplify(
|
609 |
+
fst: Select,
|
610 |
+
snd: Select,
|
611 |
+
combine: Callable[[bool, bool], bool],
|
612 |
+
) -> Optional[Select]:
|
613 |
+
"""Simplifies two Selects if possible.
|
614 |
+
|
615 |
+
If two Selects in a compound Selector have matching keys and queries, they can
|
616 |
+
be simplified into one Select with a compound predicate:
|
617 |
+
|
618 |
+
lambda k,q: combine(fst.predicate(k,q), snd.predicate(k,q))
|
619 |
+
|
620 |
+
This function returns a Select with this predicate if possible,
|
621 |
+
and None otherwise.
|
622 |
+
|
623 |
+
A Full SOp in a key or query position is a special case that always matches
|
624 |
+
any SOp in the corresponding position in the other selector. In that case,
|
625 |
+
we bake in the fill value into the corresponding Select's predicate before
|
626 |
+
combining. This allows us to use the other SOp as the input to the simplified
|
627 |
+
Select.
|
628 |
+
|
629 |
+
Args:
|
630 |
+
fst: the first Select.
|
631 |
+
snd: the second Select.
|
632 |
+
combine: how to combine the outputs of the individual predicates.
|
633 |
+
|
634 |
+
Returns:
|
635 |
+
A combined Select, if possible.
|
636 |
+
"""
|
637 |
+
fst_predicate = fst.predicate
|
638 |
+
snd_predicate = snd.predicate
|
639 |
+
common_keys = None
|
640 |
+
common_queries = None
|
641 |
+
|
642 |
+
if isinstance(fst.keys, Full):
|
643 |
+
common_keys = snd.keys
|
644 |
+
# We pass the predicate in as a default arg to avoid unintended recursion.
|
645 |
+
fst_predicate = lambda key, query, p=fst_predicate: p(fst.keys.fill, query)
|
646 |
+
if isinstance(snd.keys, Full):
|
647 |
+
common_keys = fst.keys
|
648 |
+
snd_predicate = lambda key, query, p=snd_predicate: p(snd.keys.fill, query)
|
649 |
+
if isinstance(fst.queries, Full):
|
650 |
+
common_queries = snd.queries
|
651 |
+
fst_predicate = lambda key, query, p=fst_predicate: p(key, fst.queries.fill)
|
652 |
+
if isinstance(snd.queries, Full):
|
653 |
+
common_queries = fst.queries
|
654 |
+
snd_predicate = lambda key, query, p=snd_predicate: p(key, snd.queries.fill)
|
655 |
+
if fst.keys is snd.keys:
|
656 |
+
common_keys = fst.keys
|
657 |
+
if fst.queries is snd.queries:
|
658 |
+
common_queries = fst.queries
|
659 |
+
|
660 |
+
if not common_keys or not common_queries:
|
661 |
+
return None
|
662 |
+
|
663 |
+
def predicate(key, query):
|
664 |
+
return combine(fst_predicate(key, query), snd_predicate(key, query))
|
665 |
+
|
666 |
+
return Select(common_keys, common_queries, predicate=predicate)
|
667 |
+
|
668 |
+
|
669 |
+
class Aggregate(SOp, Generic[VT]):
|
670 |
+
"""Aggregate primitive."""
|
671 |
+
|
672 |
+
def __init__(self,
|
673 |
+
selector: Selector,
|
674 |
+
sop: SOp,
|
675 |
+
default: Optional[VT] = None):
|
676 |
+
"""Initialises. The default is used where nothing is selected."""
|
677 |
+
super().__init__()
|
678 |
+
self.selector = selector
|
679 |
+
self.sop = sop
|
680 |
+
self.default = default
|
681 |
+
assert isinstance(self.selector, Selector)
|
682 |
+
assert isinstance(self.sop, SOp)
|
683 |
+
assert (self.default is None or isinstance(self.default,
|
684 |
+
(str, float, bool, int)))
|
685 |
+
|
686 |
+
@property
|
687 |
+
def children(self) -> Sequence[RASPExpr]:
|
688 |
+
return [self.selector, self.sop]
|
689 |
+
|
690 |
+
|
691 |
+
### SOp encodings.
|
692 |
+
|
693 |
+
|
694 |
+
class Encoding(enum.Enum):
|
695 |
+
"""The encoding used by a SOp. Only number-valued SOps support numerical."""
|
696 |
+
CATEGORICAL = "categorical"
|
697 |
+
NUMERICAL = "numerical"
|
698 |
+
|
699 |
+
|
700 |
+
def numerical(sop: SOpT) -> SOpT:
|
701 |
+
return annotate(sop, encoding=Encoding.NUMERICAL)
|
702 |
+
|
703 |
+
|
704 |
+
def categorical(sop: SOpT) -> SOpT:
|
705 |
+
return annotate(sop, encoding=Encoding.CATEGORICAL)
|
706 |
+
|
707 |
+
|
708 |
+
def get_encoding(sop: SOp) -> Encoding:
|
709 |
+
return sop.annotations["encoding"]
|
710 |
+
|
711 |
+
|
712 |
+
def is_numerical(sop: SOp) -> bool:
|
713 |
+
"""Check if the SOp is numerically encoded."""
|
714 |
+
return get_encoding(sop) == Encoding.NUMERICAL
|
715 |
+
|
716 |
+
|
717 |
+
def is_categorical(sop: SOp) -> bool:
|
718 |
+
"""Check if the SOp is categorically encoded."""
|
719 |
+
return get_encoding(sop) == Encoding.CATEGORICAL
|
720 |
+
|
721 |
+
|
722 |
+
def default_encoding(expr: RASPExpr) -> Optional[Encoding]:
|
723 |
+
"""Adds an 'encoding' annotation, default is Categorical."""
|
724 |
+
if not isinstance(expr, SOp):
|
725 |
+
raise TypeError(f"expr {expr} is not a SOp.")
|
726 |
+
|
727 |
+
return Encoding.CATEGORICAL
|
728 |
+
|
729 |
+
|
730 |
+
DEFAULT_ANNOTATORS[_ENCODING_KEY] = default_encoding
|
731 |
+
|
732 |
+
### naming.
|
733 |
+
|
734 |
+
# Subclasses must appear here before superclasses in order for
|
735 |
+
# the most specific entry to be used.
|
736 |
+
|
737 |
+
_default_name_by_class = {
|
738 |
+
# Primitives
|
739 |
+
TokensType: "tokens",
|
740 |
+
IndicesType: "indices",
|
741 |
+
LengthType: "length",
|
742 |
+
# SOps
|
743 |
+
LinearSequenceMap: "linear_sequence_map",
|
744 |
+
SequenceMap: "sequence_map",
|
745 |
+
Map: "map",
|
746 |
+
Full: "full",
|
747 |
+
ConstantSOp: "constant_sop",
|
748 |
+
SelectorWidth: "selector_width",
|
749 |
+
Aggregate: "aggregate",
|
750 |
+
SOp: "sop",
|
751 |
+
# Selectors
|
752 |
+
Select: "select",
|
753 |
+
SelectorAnd: "selector_and",
|
754 |
+
SelectorOr: "selector_or",
|
755 |
+
SelectorNot: "selector_not",
|
756 |
+
ConstantSelector: "constant_selector",
|
757 |
+
Selector: "selector",
|
758 |
+
}
|
759 |
+
|
760 |
+
|
761 |
+
def default_name(expr: RASPExpr) -> dict[str, str]:
|
762 |
+
for cls, name in _default_name_by_class.items():
|
763 |
+
if isinstance(expr, cls):
|
764 |
+
return name
|
765 |
+
|
766 |
+
raise NotImplementedError(f"{expr} was not given a default name!")
|
767 |
+
|
768 |
+
|
769 |
+
DEFAULT_ANNOTATORS[_NAME_KEY] = default_name
|
770 |
+
|
771 |
+
### evaluation.
|
772 |
+
|
773 |
+
|
774 |
+
class RASPEvaluator(abc.ABC):
|
775 |
+
"""ABC for RASP evaluators."""
|
776 |
+
|
777 |
+
@abc.abstractmethod
|
778 |
+
def evaluate(self, expr: RASPExpr,
|
779 |
+
xs: Sequence[Value]) -> Union[Sequence[Value], SelectorValue]:
|
780 |
+
"""Evaluates the RASP expression on input `xs`."""
|
781 |
+
|
782 |
+
|
783 |
+
class DefaultRASPEvaluator(abc.ABC):
|
784 |
+
"""Default evaluator for RASP."""
|
785 |
+
|
786 |
+
def evaluate(self, expr: RASPExpr,
|
787 |
+
xs: Sequence[Value]) -> Union[Sequence[Value], SelectorValue]:
|
788 |
+
"""Evaluates the RASP expression on input `xs`."""
|
789 |
+
return self._eval_fn_by_expr_type[type(expr)](expr, xs)
|
790 |
+
|
791 |
+
def __init__(self):
|
792 |
+
self._eval_fn_by_expr_type = {
|
793 |
+
# Primitives
|
794 |
+
TokensType: self.eval_tokens,
|
795 |
+
IndicesType: self.eval_indices,
|
796 |
+
LengthType: self.eval_length,
|
797 |
+
# SOps
|
798 |
+
LinearSequenceMap: self.eval_sequence_map,
|
799 |
+
SequenceMap: self.eval_sequence_map,
|
800 |
+
Map: self.eval_map,
|
801 |
+
Full: self.eval_full,
|
802 |
+
ConstantSOp: self.eval_constant_sop,
|
803 |
+
SelectorWidth: self.eval_selector_width,
|
804 |
+
Aggregate: self.eval_aggregate,
|
805 |
+
SOp: _raise_not_implemented,
|
806 |
+
# Selectors
|
807 |
+
Select: self.eval_select,
|
808 |
+
SelectorAnd: self.eval_selector_and,
|
809 |
+
SelectorOr: self.eval_selector_or,
|
810 |
+
SelectorNot: self.eval_selector_not,
|
811 |
+
ConstantSelector: self.eval_constant_selector,
|
812 |
+
Selector: _raise_not_implemented,
|
813 |
+
}
|
814 |
+
|
815 |
+
def eval_tokens(self, sop: TokensType,
|
816 |
+
xs: Sequence[Value]) -> Sequence[Value]:
|
817 |
+
del sop
|
818 |
+
return list(xs)
|
819 |
+
|
820 |
+
def eval_indices(self, sop: IndicesType,
|
821 |
+
xs: Sequence[Value]) -> Sequence[Value]:
|
822 |
+
del sop
|
823 |
+
return list(range(len(xs)))
|
824 |
+
|
825 |
+
def eval_length(self, sop: LengthType, xs: Sequence[Value]) -> Sequence[int]:
|
826 |
+
del sop
|
827 |
+
return [len(xs)] * len(xs)
|
828 |
+
|
829 |
+
def eval_sequence_map(self, sop: SequenceMap,
|
830 |
+
xs: Sequence[Value]) -> Sequence[Value]:
|
831 |
+
fst_values = self.evaluate(sop.fst, xs)
|
832 |
+
snd_values = self.evaluate(sop.snd, xs)
|
833 |
+
return [
|
834 |
+
sop.f(x, y) if None not in [x, y] else None
|
835 |
+
for x, y in zip(fst_values, snd_values)
|
836 |
+
]
|
837 |
+
|
838 |
+
def eval_map(self, sop: Map, xs: Sequence[Value]) -> Sequence[Value]:
|
839 |
+
return [
|
840 |
+
sop.f(x) if x is not None else None
|
841 |
+
for x in self.evaluate(sop.inner, xs)
|
842 |
+
]
|
843 |
+
|
844 |
+
def eval_full(self, sop: Full, xs: Sequence[Value]) -> Sequence[Value]:
|
845 |
+
return [sop.fill] * len(xs)
|
846 |
+
|
847 |
+
def eval_constant_sop(self, sop: ConstantSOp,
|
848 |
+
xs: Sequence[Value]) -> Sequence[Value]:
|
849 |
+
if sop.check_length and (len(xs) != len(sop.value)):
|
850 |
+
raise ValueError(
|
851 |
+
f"Constant len {len(sop.value)} doesn't match input len {len(xs)}.")
|
852 |
+
return sop.value
|
853 |
+
|
854 |
+
def eval_selector_width(self, sop: SelectorWidth,
|
855 |
+
xs: Sequence[Value]) -> Sequence[Value]:
|
856 |
+
selector_values = self.evaluate(sop.selector, xs)
|
857 |
+
return [sum(row) for row in selector_values]
|
858 |
+
|
859 |
+
def eval_aggregate(self, sop: Aggregate,
|
860 |
+
xs: Sequence[Value]) -> Sequence[Value]:
|
861 |
+
selector_value = self.evaluate(sop.selector, xs)
|
862 |
+
values = self.evaluate(sop.sop, xs)
|
863 |
+
default = sop.default
|
864 |
+
|
865 |
+
return [
|
866 |
+
_mean(_get_selected(row, values), default) for row in selector_value
|
867 |
+
]
|
868 |
+
|
869 |
+
def eval_select(self, sel: Select, xs: Sequence[Value]) -> SelectorValue:
|
870 |
+
"""Evaluates a Select on `xs`."""
|
871 |
+
key_values = self.evaluate(sel.keys, xs)
|
872 |
+
query_values = self.evaluate(sel.queries, xs)
|
873 |
+
|
874 |
+
key_len = len(key_values)
|
875 |
+
query_len = len(query_values)
|
876 |
+
out = np.zeros((query_len, key_len), dtype=bool).tolist()
|
877 |
+
for row, query in enumerate(query_values):
|
878 |
+
for col, key in enumerate(key_values):
|
879 |
+
out[row][col] = bool(sel.predicate(key, query))
|
880 |
+
return out
|
881 |
+
|
882 |
+
def eval_constant_selector(self, sel: ConstantSelector,
|
883 |
+
xs: Sequence[Value]) -> SelectorValue:
|
884 |
+
if sel.check_length and (len(xs) != len(sel.value)):
|
885 |
+
raise ValueError(
|
886 |
+
f"Constant len {len(xs)} doesn't match input len {len(sel.value)}.")
|
887 |
+
return sel.value
|
888 |
+
|
889 |
+
def eval_selector_and(self, sel: SelectorAnd,
|
890 |
+
xs: Sequence[Value]) -> SelectorValue:
|
891 |
+
fst_values = self.evaluate(sel.fst, xs)
|
892 |
+
snd_values = self.evaluate(sel.snd, xs)
|
893 |
+
return np.logical_and(np.array(fst_values), np.array(snd_values)).tolist()
|
894 |
+
|
895 |
+
def eval_selector_or(self, sel: SelectorOr,
|
896 |
+
xs: Sequence[Value]) -> SelectorValue:
|
897 |
+
fst_values = self.evaluate(sel.fst, xs)
|
898 |
+
snd_values = self.evaluate(sel.snd, xs)
|
899 |
+
return np.logical_or(np.array(fst_values), np.array(snd_values)).tolist()
|
900 |
+
|
901 |
+
def eval_selector_not(self, sel: SelectorNot,
|
902 |
+
xs: Sequence[Value]) -> SelectorValue:
|
903 |
+
values = self.evaluate(sel.inner, xs)
|
904 |
+
return np.logical_not(np.array(values)).tolist()
|
905 |
+
|
906 |
+
|
907 |
+
def _get_selected(
|
908 |
+
selector_row: list[bool],
|
909 |
+
values: Sequence[VT],
|
910 |
+
) -> Sequence[VT]:
|
911 |
+
"""Helper for aggregate. [T T F], [a b c] -> [a b]."""
|
912 |
+
return [v for s, v in zip(selector_row, values) if s]
|
913 |
+
|
914 |
+
|
915 |
+
def _mean(xs: Sequence[VT], default: VT) -> VT:
|
916 |
+
"""Takes the mean for numbers and concats for strings."""
|
917 |
+
if not xs:
|
918 |
+
return default
|
919 |
+
exemplar = xs[0]
|
920 |
+
if isinstance(exemplar, (int, bool)):
|
921 |
+
return sum(xs) / len(xs)
|
922 |
+
elif len(xs) == 1:
|
923 |
+
return exemplar
|
924 |
+
else:
|
925 |
+
raise ValueError(f"Unsupported type for aggregation: {xs}")
|
926 |
+
|
927 |
+
|
928 |
+
def _raise_not_implemented(expr: RASPExpr, xs: Sequence[Value]):
|
929 |
+
raise NotImplementedError(f"Evaluation of {expr} is not defined.")
|
930 |
+
|
931 |
+
|
932 |
+
evaluate = DefaultRASPEvaluator().evaluate
|
rasp/rasp_test.py
ADDED
@@ -0,0 +1,580 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2022 DeepMind Technologies Limited. All Rights Reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
# ==============================================================================
|
15 |
+
"""Tests for rasp.rasp."""
|
16 |
+
|
17 |
+
import itertools
|
18 |
+
|
19 |
+
from absl.testing import absltest
|
20 |
+
from absl.testing import parameterized
|
21 |
+
import numpy as np
|
22 |
+
from tracr.rasp import rasp
|
23 |
+
|
24 |
+
# Note that the example text labels must match their default names.
|
25 |
+
|
26 |
+
_SOP_PRIMITIVE_EXAMPLES = lambda: [ # pylint: disable=g-long-lambda
|
27 |
+
("tokens", rasp.tokens),
|
28 |
+
("length", rasp.length),
|
29 |
+
("indices", rasp.indices),
|
30 |
+
]
|
31 |
+
|
32 |
+
_NONPRIMITIVE_SOP_EXAMPLES = lambda: [ # pylint: disable=g-long-lambda
|
33 |
+
("map", rasp.Map(lambda x: x, rasp.tokens)),
|
34 |
+
(
|
35 |
+
"sequence_map",
|
36 |
+
rasp.SequenceMap(lambda x, y: x + y, rasp.tokens, rasp.tokens),
|
37 |
+
),
|
38 |
+
(
|
39 |
+
"linear_sequence_map",
|
40 |
+
rasp.LinearSequenceMap(rasp.tokens, rasp.tokens, 0.1, 0.2),
|
41 |
+
),
|
42 |
+
(
|
43 |
+
"aggregate",
|
44 |
+
rasp.Aggregate(
|
45 |
+
rasp.Select(rasp.tokens, rasp.indices, rasp.Comparison.EQ),
|
46 |
+
rasp.tokens,
|
47 |
+
),
|
48 |
+
),
|
49 |
+
(
|
50 |
+
"selector_width",
|
51 |
+
rasp.SelectorWidth(
|
52 |
+
rasp.Select(rasp.tokens, rasp.indices, rasp.Comparison.EQ)),
|
53 |
+
),
|
54 |
+
]
|
55 |
+
|
56 |
+
_SOP_EXAMPLES = lambda: _SOP_PRIMITIVE_EXAMPLES() + _NONPRIMITIVE_SOP_EXAMPLES()
|
57 |
+
|
58 |
+
_SELECTOR_EXAMPLES = lambda: [ # pylint: disable=g-long-lambda
|
59 |
+
("select", rasp.Select(rasp.tokens, rasp.indices, rasp.Comparison.EQ)),
|
60 |
+
("selector_and",
|
61 |
+
rasp.SelectorAnd(
|
62 |
+
rasp.Select(rasp.tokens, rasp.indices, rasp.Comparison.EQ),
|
63 |
+
rasp.Select(rasp.indices, rasp.tokens, rasp.Comparison.LEQ),
|
64 |
+
)),
|
65 |
+
("selector_or",
|
66 |
+
rasp.SelectorOr(
|
67 |
+
rasp.Select(rasp.tokens, rasp.indices, rasp.Comparison.EQ),
|
68 |
+
rasp.Select(rasp.indices, rasp.tokens, rasp.Comparison.LEQ),
|
69 |
+
)),
|
70 |
+
("selector_not",
|
71 |
+
rasp.SelectorNot(
|
72 |
+
rasp.Select(rasp.tokens, rasp.indices, rasp.Comparison.EQ),)),
|
73 |
+
]
|
74 |
+
|
75 |
+
_ALL_EXAMPLES = lambda: _SOP_EXAMPLES() + _SELECTOR_EXAMPLES()
|
76 |
+
|
77 |
+
|
78 |
+
class LabelTest(parameterized.TestCase):
|
79 |
+
|
80 |
+
def test_primitive_labels(self):
|
81 |
+
self.assertEqual(rasp.tokens.label, "tokens")
|
82 |
+
self.assertEqual(rasp.indices.label, "indices")
|
83 |
+
self.assertEqual(rasp.length.label, "length")
|
84 |
+
|
85 |
+
@parameterized.parameters(*_ALL_EXAMPLES())
|
86 |
+
def test_default_names(self, default_name: str, expr: rasp.RASPExpr):
|
87 |
+
self.assertEqual(expr.name, default_name)
|
88 |
+
|
89 |
+
|
90 |
+
class SOpTest(parameterized.TestCase):
|
91 |
+
"""Tests for S-Ops."""
|
92 |
+
|
93 |
+
@parameterized.parameters(
|
94 |
+
("hello", ["h", "e", "l", "l", "o"]),
|
95 |
+
("h", ["h"]),
|
96 |
+
(["h", "e", "l", "l", "o"], ["h", "e", "l", "l", "o"]),
|
97 |
+
(["h"], ["h"]),
|
98 |
+
([1, 2], [1, 2]),
|
99 |
+
([0.1, 0.2], [0.1, 0.2]),
|
100 |
+
)
|
101 |
+
def test_tokens(self, input_sequence, expected):
|
102 |
+
self.assertEqual(rasp.tokens(input_sequence), expected)
|
103 |
+
|
104 |
+
@parameterized.parameters(
|
105 |
+
("hello", [0, 1, 2, 3, 4]),
|
106 |
+
("h", [0]),
|
107 |
+
(["h", "e", "l", "l", "o"], [0, 1, 2, 3, 4]),
|
108 |
+
(["h"], [0]),
|
109 |
+
([1, 2], [0, 1]),
|
110 |
+
([0.1, 0.2], [0, 1]),
|
111 |
+
)
|
112 |
+
def test_indices(self, input_sequence, expected):
|
113 |
+
self.assertEqual(rasp.indices(input_sequence), expected)
|
114 |
+
|
115 |
+
@parameterized.parameters(
|
116 |
+
("hello", [5, 5, 5, 5, 5]),
|
117 |
+
("h", [1]),
|
118 |
+
(["h", "e", "l", "l", "o"], [5, 5, 5, 5, 5]),
|
119 |
+
(["h"], [1]),
|
120 |
+
([1, 2], [2, 2]),
|
121 |
+
([0.1, 0.2], [2, 2]),
|
122 |
+
)
|
123 |
+
def test_length(self, input_sequence, expected):
|
124 |
+
self.assertEqual(rasp.length(input_sequence), expected)
|
125 |
+
|
126 |
+
def test_prims_are_sops(self):
|
127 |
+
self.assertIsInstance(rasp.tokens, rasp.SOp)
|
128 |
+
self.assertIsInstance(rasp.indices, rasp.SOp)
|
129 |
+
self.assertIsInstance(rasp.length, rasp.SOp)
|
130 |
+
|
131 |
+
def test_prims_are_raspexprs(self):
|
132 |
+
self.assertIsInstance(rasp.tokens, rasp.RASPExpr)
|
133 |
+
self.assertIsInstance(rasp.indices, rasp.RASPExpr)
|
134 |
+
self.assertIsInstance(rasp.length, rasp.RASPExpr)
|
135 |
+
|
136 |
+
@parameterized.parameters(
|
137 |
+
(lambda x: x + "a", "hello", ["ha", "ea", "la", "la", "oa"]),
|
138 |
+
(lambda x: x + "t", "h", ["ht"]),
|
139 |
+
(lambda x: x + 1, [1, 2], [2, 3]),
|
140 |
+
(lambda x: x / 2, [0.1, 0.2], [0.05, 0.1]),
|
141 |
+
)
|
142 |
+
def test_map(self, f, input_sequence, expected):
|
143 |
+
self.assertEqual(rasp.Map(f, rasp.tokens)(input_sequence), expected)
|
144 |
+
|
145 |
+
def test_nested_elementwise_ops_results_in_only_one_map_object(self):
|
146 |
+
map_sop = ((rasp.tokens * 2) + 2) / 2
|
147 |
+
self.assertEqual(map_sop.inner, rasp.tokens)
|
148 |
+
self.assertEqual(map_sop([1]), [2])
|
149 |
+
|
150 |
+
@parameterized.parameters(
|
151 |
+
(lambda x, y: x + y, "hello", ["hh", "ee", "ll", "ll", "oo"]),
|
152 |
+
(lambda x, y: x + y, "h", ["hh"]),
|
153 |
+
(lambda x, y: x + y, [1, 2], [2, 4]),
|
154 |
+
(lambda x, y: x * y, [1, 2], [1, 4]),
|
155 |
+
)
|
156 |
+
def test_sequence_map(self, f, input_sequence, expected):
|
157 |
+
self.assertEqual(
|
158 |
+
rasp.SequenceMap(f, rasp.tokens, rasp.tokens)(input_sequence), expected)
|
159 |
+
|
160 |
+
def test_sequence_map_with_same_inputs_logs_warning(self):
|
161 |
+
with self.assertLogs(level="WARNING"):
|
162 |
+
rasp.SequenceMap(lambda x, y: x + y, rasp.tokens, rasp.tokens)
|
163 |
+
|
164 |
+
@parameterized.parameters(
|
165 |
+
(1, 1, [1, 2], [2, 4]),
|
166 |
+
(1, -1, [1, 2], [0, 0]),
|
167 |
+
(1, -2, [1, 2], [-1, -2]),
|
168 |
+
)
|
169 |
+
def test_linear_sequence_map(self, fst_fac, snd_fac, input_sequence,
|
170 |
+
expected):
|
171 |
+
self.assertEqual(
|
172 |
+
rasp.LinearSequenceMap(rasp.tokens, rasp.tokens, fst_fac,
|
173 |
+
snd_fac)(input_sequence), expected)
|
174 |
+
|
175 |
+
@parameterized.parameters(
|
176 |
+
([5, 5, 5, 5, 5], "hello", [5, 5, 5, 5, 5]),
|
177 |
+
(["e"], "h", ["e"]),
|
178 |
+
([1, 2, 3, 4, 5], ["h", "e", "l", "l", "o"], [1, 2, 3, 4, 5]),
|
179 |
+
([2, 2], [1, 2], [2, 2]),
|
180 |
+
)
|
181 |
+
def test_constant(self, const, input_sequence, expected):
|
182 |
+
self.assertEqual(rasp.ConstantSOp(const)(input_sequence), expected)
|
183 |
+
|
184 |
+
def test_constant_complains_if_sizes_dont_match(self):
|
185 |
+
with self.assertRaisesRegex(
|
186 |
+
ValueError,
|
187 |
+
r"^.*Constant len .* doesn't match input len .*$",):
|
188 |
+
rasp.ConstantSOp([1, 2, 3])("longer string")
|
189 |
+
|
190 |
+
def test_can_turn_off_constant_complaints(self):
|
191 |
+
rasp.ConstantSOp([1, 2, 3], check_length=False)("longer string")
|
192 |
+
|
193 |
+
def test_numeric_dunders(self):
|
194 |
+
# We don't check all the cases here -- only a few representative ones.
|
195 |
+
self.assertEqual(
|
196 |
+
(rasp.tokens > 1)([0, 1, 2]),
|
197 |
+
[0, 0, 1],
|
198 |
+
)
|
199 |
+
self.assertEqual(
|
200 |
+
(1 < rasp.tokens)([0, 1, 2]),
|
201 |
+
[0, 0, 1],
|
202 |
+
)
|
203 |
+
self.assertEqual(
|
204 |
+
(rasp.tokens < 1)([0, 1, 2]),
|
205 |
+
[1, 0, 0],
|
206 |
+
)
|
207 |
+
self.assertEqual(
|
208 |
+
(1 > rasp.tokens)([0, 1, 2]),
|
209 |
+
[1, 0, 0],
|
210 |
+
)
|
211 |
+
self.assertEqual(
|
212 |
+
(rasp.tokens == 1)([0, 1, 2]),
|
213 |
+
[0, 1, 0],
|
214 |
+
)
|
215 |
+
self.assertEqual(
|
216 |
+
(rasp.tokens + 1)([0, 1, 2]),
|
217 |
+
[1, 2, 3],
|
218 |
+
)
|
219 |
+
self.assertEqual(
|
220 |
+
(1 + rasp.tokens)([0, 1, 2]),
|
221 |
+
[1, 2, 3],
|
222 |
+
)
|
223 |
+
|
224 |
+
def test_dunders_with_sop(self):
|
225 |
+
self.assertEqual(
|
226 |
+
(rasp.tokens + rasp.indices)([0, 1, 2]),
|
227 |
+
[0, 2, 4],
|
228 |
+
)
|
229 |
+
self.assertEqual(
|
230 |
+
(rasp.length - 1 - rasp.indices)([0, 1, 2]),
|
231 |
+
[2, 1, 0],
|
232 |
+
)
|
233 |
+
self.assertEqual(
|
234 |
+
(rasp.length * rasp.length)([0, 1, 2]),
|
235 |
+
[9, 9, 9],
|
236 |
+
)
|
237 |
+
|
238 |
+
def test_logical_dunders(self):
|
239 |
+
self.assertEqual(
|
240 |
+
(rasp.tokens & True)([True, False]),
|
241 |
+
[True, False],
|
242 |
+
)
|
243 |
+
self.assertEqual(
|
244 |
+
(rasp.tokens & False)([True, False]),
|
245 |
+
[False, False],
|
246 |
+
)
|
247 |
+
self.assertEqual(
|
248 |
+
(rasp.tokens | True)([True, False]),
|
249 |
+
[True, True],
|
250 |
+
)
|
251 |
+
self.assertEqual(
|
252 |
+
(rasp.tokens | False)([True, False]),
|
253 |
+
[True, False],
|
254 |
+
)
|
255 |
+
self.assertEqual(
|
256 |
+
(True & rasp.tokens)([True, False]),
|
257 |
+
[True, False],
|
258 |
+
)
|
259 |
+
self.assertEqual(
|
260 |
+
(False & rasp.tokens)([True, False]),
|
261 |
+
[False, False],
|
262 |
+
)
|
263 |
+
self.assertEqual(
|
264 |
+
(True | rasp.tokens)([True, False]),
|
265 |
+
[True, True],
|
266 |
+
)
|
267 |
+
self.assertEqual(
|
268 |
+
(False | rasp.tokens)([True, False]),
|
269 |
+
[True, False],
|
270 |
+
)
|
271 |
+
|
272 |
+
self.assertEqual(
|
273 |
+
(~rasp.tokens)([True, False]),
|
274 |
+
[False, True],
|
275 |
+
)
|
276 |
+
|
277 |
+
self.assertEqual(
|
278 |
+
(rasp.ConstantSOp([True, True, False, False])
|
279 |
+
& rasp.ConstantSOp([True, False, True, False]))([1, 1, 1, 1]),
|
280 |
+
[True, False, False, False],
|
281 |
+
)
|
282 |
+
|
283 |
+
self.assertEqual(
|
284 |
+
(rasp.ConstantSOp([True, True, False, False])
|
285 |
+
| rasp.ConstantSOp([True, False, True, False]))([1, 1, 1, 1]),
|
286 |
+
[True, True, True, False],
|
287 |
+
)
|
288 |
+
|
289 |
+
|
290 |
+
class EncodingTest(parameterized.TestCase):
|
291 |
+
"""Tests for SOp encodings."""
|
292 |
+
|
293 |
+
@parameterized.named_parameters(*_SOP_EXAMPLES())
|
294 |
+
def test_all_sops_are_categorical_by_default(self, sop: rasp.SOp):
|
295 |
+
self.assertTrue(rasp.is_categorical(sop))
|
296 |
+
|
297 |
+
@parameterized.named_parameters(*_SOP_EXAMPLES())
|
298 |
+
def test_is_numerical(self, sop: rasp.SOp):
|
299 |
+
self.assertTrue(rasp.is_numerical(rasp.numerical(sop)))
|
300 |
+
self.assertFalse(rasp.is_numerical(rasp.categorical(sop)))
|
301 |
+
|
302 |
+
@parameterized.named_parameters(*_SOP_EXAMPLES())
|
303 |
+
def test_is_categorical(self, sop: rasp.SOp):
|
304 |
+
self.assertTrue(rasp.is_categorical(rasp.categorical(sop)))
|
305 |
+
self.assertFalse(rasp.is_categorical(rasp.numerical(sop)))
|
306 |
+
|
307 |
+
@parameterized.named_parameters(*_SOP_EXAMPLES())
|
308 |
+
def test_double_encoding_annotations_overwrites_encoding(self, sop: rasp.SOp):
|
309 |
+
num_sop = rasp.numerical(sop)
|
310 |
+
cat_num_sop = rasp.categorical(num_sop)
|
311 |
+
self.assertTrue(rasp.is_numerical(num_sop))
|
312 |
+
self.assertTrue(rasp.is_categorical(cat_num_sop))
|
313 |
+
|
314 |
+
|
315 |
+
class SelectorTest(parameterized.TestCase):
|
316 |
+
"""Tests for Selectors."""
|
317 |
+
|
318 |
+
def test_select_eq_has_correct_value(self):
|
319 |
+
selector = rasp.Select(rasp.tokens, rasp.tokens, rasp.Comparison.EQ)
|
320 |
+
self.assertEqual(
|
321 |
+
selector("hey"), [
|
322 |
+
[True, False, False],
|
323 |
+
[False, True, False],
|
324 |
+
[False, False, True],
|
325 |
+
])
|
326 |
+
|
327 |
+
def test_select_lt_has_correct_value(self):
|
328 |
+
selector = rasp.Select(rasp.tokens, rasp.tokens, rasp.Comparison.LT)
|
329 |
+
self.assertEqual(selector([0, 1]), [
|
330 |
+
[False, False],
|
331 |
+
[True, False],
|
332 |
+
])
|
333 |
+
|
334 |
+
def test_select_leq_has_correct_value(self):
|
335 |
+
selector = rasp.Select(rasp.tokens, rasp.tokens, rasp.Comparison.LEQ)
|
336 |
+
self.assertEqual(selector([0, 1]), [
|
337 |
+
[True, False],
|
338 |
+
[True, True],
|
339 |
+
])
|
340 |
+
|
341 |
+
def test_select_gt_has_correct_value(self):
|
342 |
+
selector = rasp.Select(rasp.tokens, rasp.tokens, rasp.Comparison.GT)
|
343 |
+
self.assertEqual(selector([0, 1]), [
|
344 |
+
[False, True],
|
345 |
+
[False, False],
|
346 |
+
])
|
347 |
+
|
348 |
+
def test_select_geq_has_correct_value(self):
|
349 |
+
selector = rasp.Select(rasp.tokens, rasp.tokens, rasp.Comparison.GEQ)
|
350 |
+
self.assertEqual(selector([0, 1]), [
|
351 |
+
[True, True],
|
352 |
+
[False, True],
|
353 |
+
])
|
354 |
+
|
355 |
+
def test_select_neq_has_correct_value(self):
|
356 |
+
selector = rasp.Select(rasp.tokens, rasp.tokens, rasp.Comparison.NEQ)
|
357 |
+
self.assertEqual(selector([0, 1]), [
|
358 |
+
[False, True],
|
359 |
+
[True, False],
|
360 |
+
])
|
361 |
+
|
362 |
+
def test_select_true_has_correct_value(self):
|
363 |
+
selector = rasp.Select(rasp.tokens, rasp.tokens, rasp.Comparison.TRUE)
|
364 |
+
self.assertEqual(selector([0, 1]), [
|
365 |
+
[True, True],
|
366 |
+
[True, True],
|
367 |
+
])
|
368 |
+
|
369 |
+
def test_select_false_has_correct_value(self):
|
370 |
+
selector = rasp.Select(rasp.tokens, rasp.tokens, rasp.Comparison.FALSE)
|
371 |
+
self.assertEqual(selector([0, 1]), [
|
372 |
+
[False, False],
|
373 |
+
[False, False],
|
374 |
+
])
|
375 |
+
|
376 |
+
def test_selector_and_gets_simplified_when_keys_and_queries_match(self):
|
377 |
+
selector = rasp.selector_and(
|
378 |
+
rasp.Select(rasp.tokens, rasp.indices, rasp.Comparison.GEQ),
|
379 |
+
rasp.Select(rasp.tokens, rasp.indices, rasp.Comparison.LEQ),
|
380 |
+
)
|
381 |
+
self.assertIsInstance(selector, rasp.Select)
|
382 |
+
self.assertIs(selector.keys, rasp.tokens)
|
383 |
+
self.assertIs(selector.queries, rasp.indices)
|
384 |
+
|
385 |
+
def test_selector_and_doesnt_get_simplified_when_keys_queries_different(self):
|
386 |
+
selector = rasp.selector_and(
|
387 |
+
rasp.Select(rasp.tokens, rasp.indices, rasp.Comparison.GEQ),
|
388 |
+
rasp.Select(rasp.indices, rasp.tokens, rasp.Comparison.LEQ),
|
389 |
+
)
|
390 |
+
self.assertIsInstance(selector, rasp.SelectorAnd)
|
391 |
+
|
392 |
+
def test_selector_and_gets_simplified_when_keys_are_full(self):
|
393 |
+
selector = rasp.selector_and(
|
394 |
+
rasp.Select(rasp.Full(1), rasp.indices, rasp.Comparison.GEQ),
|
395 |
+
rasp.Select(rasp.tokens, rasp.indices, rasp.Comparison.LEQ),
|
396 |
+
)
|
397 |
+
self.assertIsInstance(selector, rasp.Select)
|
398 |
+
self.assertIs(selector.keys, rasp.tokens)
|
399 |
+
self.assertIs(selector.queries, rasp.indices)
|
400 |
+
|
401 |
+
def test_selector_and_gets_simplified_when_queries_are_full(self):
|
402 |
+
selector = rasp.selector_and(
|
403 |
+
rasp.Select(rasp.tokens, rasp.indices, rasp.Comparison.GEQ),
|
404 |
+
rasp.Select(rasp.tokens, rasp.Full(1), rasp.Comparison.LEQ),
|
405 |
+
)
|
406 |
+
self.assertIsInstance(selector, rasp.Select)
|
407 |
+
self.assertIs(selector.keys, rasp.tokens)
|
408 |
+
self.assertIs(selector.queries, rasp.indices)
|
409 |
+
|
410 |
+
@parameterized.parameters(
|
411 |
+
itertools.product(
|
412 |
+
(rasp.tokens, rasp.indices, rasp.Full(1)),
|
413 |
+
(rasp.tokens, rasp.indices, rasp.Full(1)),
|
414 |
+
list(rasp.Comparison),
|
415 |
+
(rasp.tokens, rasp.indices, rasp.Full(1)),
|
416 |
+
(rasp.tokens, rasp.indices, rasp.Full(1)),
|
417 |
+
list(rasp.Comparison),
|
418 |
+
))
|
419 |
+
def test_simplified_selector_and_works_the_same_way_as_not(
|
420 |
+
self, fst_k, fst_q, fst_p, snd_k, snd_q, snd_p):
|
421 |
+
fst = rasp.Select(fst_k, fst_q, fst_p)
|
422 |
+
snd = rasp.Select(snd_k, snd_q, snd_p)
|
423 |
+
|
424 |
+
simplified = rasp.selector_and(fst, snd)([0, 1, 2, 3])
|
425 |
+
not_simplified = rasp.selector_and(fst, snd, simplify=False)([0, 1, 2, 3])
|
426 |
+
|
427 |
+
np.testing.assert_array_equal(
|
428 |
+
np.array(simplified),
|
429 |
+
np.array(not_simplified),
|
430 |
+
)
|
431 |
+
|
432 |
+
def test_select_is_selector(self):
|
433 |
+
self.assertIsInstance(
|
434 |
+
rasp.Select(rasp.tokens, rasp.tokens, rasp.Comparison.EQ),
|
435 |
+
rasp.Selector,
|
436 |
+
)
|
437 |
+
|
438 |
+
def test_select_is_raspexpr(self):
|
439 |
+
self.assertIsInstance(
|
440 |
+
rasp.Select(rasp.tokens, rasp.tokens, rasp.Comparison.EQ),
|
441 |
+
rasp.RASPExpr,
|
442 |
+
)
|
443 |
+
|
444 |
+
def test_constant_selector(self):
|
445 |
+
self.assertEqual(
|
446 |
+
rasp.ConstantSelector([[True, True], [False, False]])([1, 2]),
|
447 |
+
[[True, True], [False, False]],
|
448 |
+
)
|
449 |
+
|
450 |
+
|
451 |
+
class CopyTest(parameterized.TestCase):
|
452 |
+
|
453 |
+
@parameterized.named_parameters(*_ALL_EXAMPLES())
|
454 |
+
def test_copy_preserves_name(self, expr: rasp.RASPExpr):
|
455 |
+
expr = expr.named("foo")
|
456 |
+
self.assertEqual(expr.copy().name, expr.name)
|
457 |
+
|
458 |
+
@parameterized.named_parameters(*_ALL_EXAMPLES())
|
459 |
+
def test_renaming_copy_doesnt_rename_original(self, expr: rasp.RASPExpr):
|
460 |
+
expr = expr.named("foo")
|
461 |
+
expr.copy().named("bar")
|
462 |
+
self.assertEqual(expr.name, "foo")
|
463 |
+
|
464 |
+
@parameterized.named_parameters(*_ALL_EXAMPLES())
|
465 |
+
def test_renaming_original_doesnt_rename_copy(self, expr: rasp.RASPExpr):
|
466 |
+
expr = expr.named("foo")
|
467 |
+
copy = expr.copy()
|
468 |
+
expr.named("bar")
|
469 |
+
self.assertEqual(copy.name, "foo")
|
470 |
+
|
471 |
+
@parameterized.named_parameters(*_ALL_EXAMPLES())
|
472 |
+
def test_copy_changes_id(self, expr: rasp.RASPExpr):
|
473 |
+
self.assertNotEqual(expr.copy().unique_id, expr.unique_id)
|
474 |
+
|
475 |
+
@parameterized.named_parameters(*_ALL_EXAMPLES())
|
476 |
+
def test_copy_preserves_child_ids(self, expr: rasp.RASPExpr):
|
477 |
+
copy_child_ids = [c.unique_id for c in expr.copy().children]
|
478 |
+
child_ids = [c.unique_id for c in expr.children]
|
479 |
+
for child_id, copy_child_id in zip(child_ids, copy_child_ids):
|
480 |
+
self.assertEqual(child_id, copy_child_id)
|
481 |
+
|
482 |
+
|
483 |
+
class AggregateTest(parameterized.TestCase):
|
484 |
+
"""Tests for Aggregate."""
|
485 |
+
|
486 |
+
@parameterized.parameters(
|
487 |
+
dict(
|
488 |
+
selector=rasp.ConstantSelector([
|
489 |
+
[True, False],
|
490 |
+
[False, True],
|
491 |
+
]),
|
492 |
+
sop=rasp.ConstantSOp(["h", "e"]),
|
493 |
+
default=None,
|
494 |
+
expected_value=["h", "e"],
|
495 |
+
),
|
496 |
+
dict(
|
497 |
+
selector=rasp.ConstantSelector([
|
498 |
+
[False, True],
|
499 |
+
[False, False],
|
500 |
+
]),
|
501 |
+
sop=rasp.ConstantSOp(["h", "e"]),
|
502 |
+
default=None,
|
503 |
+
expected_value=["e", None],
|
504 |
+
),
|
505 |
+
dict(
|
506 |
+
selector=rasp.ConstantSelector([
|
507 |
+
[True, False],
|
508 |
+
[False, False],
|
509 |
+
]),
|
510 |
+
sop=rasp.ConstantSOp(["h", "e"]),
|
511 |
+
default=None,
|
512 |
+
expected_value=["h", None],
|
513 |
+
),
|
514 |
+
dict(
|
515 |
+
selector=rasp.ConstantSelector([
|
516 |
+
[True, True],
|
517 |
+
[False, True],
|
518 |
+
]),
|
519 |
+
sop=rasp.ConstantSOp([0, 1]),
|
520 |
+
default=0,
|
521 |
+
expected_value=[0.5, 1],
|
522 |
+
),
|
523 |
+
dict(
|
524 |
+
selector=rasp.ConstantSelector([
|
525 |
+
[False, False],
|
526 |
+
[True, True],
|
527 |
+
]),
|
528 |
+
sop=rasp.ConstantSOp([0, 1]),
|
529 |
+
default=0,
|
530 |
+
expected_value=[0, 0.5],
|
531 |
+
),
|
532 |
+
dict(
|
533 |
+
selector=rasp.ConstantSelector([
|
534 |
+
[False, False],
|
535 |
+
[True, True],
|
536 |
+
]),
|
537 |
+
sop=rasp.ConstantSOp([0, 1]),
|
538 |
+
default=None,
|
539 |
+
expected_value=[None, 0.5],
|
540 |
+
),
|
541 |
+
)
|
542 |
+
def test_aggregate_on_size_2_inputs(self, selector, sop, default,
|
543 |
+
expected_value):
|
544 |
+
# The 0, 0 input is ignored as it's overridden by the constant SOps.
|
545 |
+
self.assertEqual(
|
546 |
+
rasp.Aggregate(selector, sop, default)([0, 0]),
|
547 |
+
expected_value,
|
548 |
+
)
|
549 |
+
|
550 |
+
|
551 |
+
class RaspProgramTest(parameterized.TestCase):
|
552 |
+
"""Each testcase implements and tests a RASP program."""
|
553 |
+
|
554 |
+
def test_has_prev(self):
|
555 |
+
|
556 |
+
def has_prev(seq: rasp.SOp) -> rasp.SOp:
|
557 |
+
prev_copy = rasp.SelectorAnd(
|
558 |
+
rasp.Select(seq, seq, rasp.Comparison.EQ),
|
559 |
+
rasp.Select(rasp.indices, rasp.indices, rasp.Comparison.LT),
|
560 |
+
)
|
561 |
+
return rasp.Aggregate(prev_copy, rasp.Full(1), default=0) > 0
|
562 |
+
|
563 |
+
self.assertEqual(
|
564 |
+
has_prev(rasp.tokens)("hello"),
|
565 |
+
[0, 0, 0, 1, 0],
|
566 |
+
)
|
567 |
+
|
568 |
+
self.assertEqual(
|
569 |
+
has_prev(rasp.tokens)("helllo"),
|
570 |
+
[0, 0, 0, 1, 1, 0],
|
571 |
+
)
|
572 |
+
|
573 |
+
self.assertEqual(
|
574 |
+
has_prev(rasp.tokens)([0, 2, 3, 2, 1, 0, 2]),
|
575 |
+
[0, 0, 0, 1, 0, 1, 1],
|
576 |
+
)
|
577 |
+
|
578 |
+
|
579 |
+
if __name__ == "__main__":
|
580 |
+
absltest.main()
|
transformer/attention.py
ADDED
@@ -0,0 +1,160 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2022 DeepMind Technologies Limited. All Rights Reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
# ==============================================================================
|
15 |
+
"""Instrumented attention layer (forked from the Haiku library implementation).
|
16 |
+
"""
|
17 |
+
|
18 |
+
from typing import Optional
|
19 |
+
import warnings
|
20 |
+
|
21 |
+
import chex
|
22 |
+
import haiku as hk
|
23 |
+
import jax
|
24 |
+
import jax.numpy as jnp
|
25 |
+
import numpy as np
|
26 |
+
|
27 |
+
|
28 |
+
@chex.dataclass
|
29 |
+
class AttentionOutput:
|
30 |
+
out: jax.Array # [..., T', D']
|
31 |
+
logits: jax.Array # [..., H, T', T]
|
32 |
+
|
33 |
+
|
34 |
+
class MultiHeadAttention(hk.Module):
|
35 |
+
"""Multi-headed attention (MHA) module.
|
36 |
+
|
37 |
+
This module is intended for attending over sequences of vectors.
|
38 |
+
|
39 |
+
Rough sketch:
|
40 |
+
- Compute keys (K), queries (Q), and values (V) as projections of inputs.
|
41 |
+
- Attention weights are computed as W = softmax(QK^T / sqrt(key_size)).
|
42 |
+
- Output is another projection of WV^T.
|
43 |
+
|
44 |
+
For more detail, see the original Transformer paper:
|
45 |
+
"Attention is all you need" https://arxiv.org/abs/1706.03762.
|
46 |
+
|
47 |
+
Glossary of shapes:
|
48 |
+
- T: Sequence length.
|
49 |
+
- D: Vector (embedding) size.
|
50 |
+
- H: Number of attention heads.
|
51 |
+
"""
|
52 |
+
|
53 |
+
def __init__(
|
54 |
+
self,
|
55 |
+
num_heads: int,
|
56 |
+
key_size: int,
|
57 |
+
# TODO(b/240019186): Remove `w_init_scale`.
|
58 |
+
w_init_scale: Optional[float] = None,
|
59 |
+
*,
|
60 |
+
w_init: Optional[hk.initializers.Initializer] = None,
|
61 |
+
value_size: Optional[int] = None,
|
62 |
+
model_size: Optional[int] = None,
|
63 |
+
name: Optional[str] = None,
|
64 |
+
):
|
65 |
+
"""Initialises the module.
|
66 |
+
|
67 |
+
Args:
|
68 |
+
num_heads: Number of independent attention heads (H).
|
69 |
+
key_size: The size of keys (K) and queries used for attention.
|
70 |
+
w_init_scale: DEPRECATED. Please use w_init instead.
|
71 |
+
w_init: Initialiser for weights in the linear map.
|
72 |
+
value_size: Optional size of the value projection (V). If None, defaults
|
73 |
+
to the key size (K).
|
74 |
+
model_size: Optional size of the output embedding (D'). If None, defaults
|
75 |
+
to the key size multiplied by the number of heads (K * H).
|
76 |
+
name: Optional name for this module.
|
77 |
+
"""
|
78 |
+
super().__init__(name=name)
|
79 |
+
self.num_heads = num_heads
|
80 |
+
self.key_size = key_size
|
81 |
+
self.value_size = value_size or key_size
|
82 |
+
self.model_size = model_size or key_size * num_heads
|
83 |
+
|
84 |
+
# Backwards-compatibility for w_init_scale.
|
85 |
+
if w_init_scale is not None:
|
86 |
+
warnings.warn(
|
87 |
+
"w_init_scale is deprecated; please pass an explicit weight "
|
88 |
+
"initialiser instead.", DeprecationWarning)
|
89 |
+
if w_init and w_init_scale:
|
90 |
+
raise ValueError("Please provide only `w_init`, not `w_init_scale`.")
|
91 |
+
if w_init is None and w_init_scale is None:
|
92 |
+
raise ValueError("Please provide a weight initializer: `w_init`.")
|
93 |
+
if w_init is None:
|
94 |
+
w_init = hk.initializers.VarianceScaling(w_init_scale)
|
95 |
+
self.w_init = w_init
|
96 |
+
|
97 |
+
def __call__(
|
98 |
+
self,
|
99 |
+
query: jnp.ndarray,
|
100 |
+
key: jnp.ndarray,
|
101 |
+
value: jnp.ndarray,
|
102 |
+
mask: Optional[jnp.ndarray] = None,
|
103 |
+
) -> AttentionOutput:
|
104 |
+
"""Computes (optionally masked) MHA with queries, keys & values.
|
105 |
+
|
106 |
+
This module broadcasts over zero or more 'batch-like' leading dimensions.
|
107 |
+
|
108 |
+
Args:
|
109 |
+
query: Embeddings sequence used to compute queries; shape [..., T', D_q].
|
110 |
+
key: Embeddings sequence used to compute keys; shape [..., T, D_k].
|
111 |
+
value: Embeddings sequence used to compute values; shape [..., T, D_v].
|
112 |
+
mask: Optional mask applied to attention weights; shape [..., H=1, T', T].
|
113 |
+
|
114 |
+
Returns:
|
115 |
+
A new sequence of embeddings, consisting of a projection of the
|
116 |
+
attention-weighted value projections; shape [..., T', D'].
|
117 |
+
"""
|
118 |
+
|
119 |
+
# In shape hints below, we suppress the leading dims [...] for brevity.
|
120 |
+
# Hence e.g. [A, B] should be read in every case as [..., A, B].
|
121 |
+
*leading_dims, sequence_length, _ = query.shape
|
122 |
+
projection = self._linear_projection
|
123 |
+
|
124 |
+
# Compute key/query/values (overload K/Q/V to denote the respective sizes).
|
125 |
+
query_heads = projection(query, self.key_size, "query") # [T', H, Q=K]
|
126 |
+
key_heads = projection(key, self.key_size, "key") # [T, H, K]
|
127 |
+
value_heads = projection(value, self.value_size, "value") # [T, H, V]
|
128 |
+
|
129 |
+
# Compute attention weights.
|
130 |
+
attn_logits = jnp.einsum("...thd,...Thd->...htT", query_heads, key_heads)
|
131 |
+
attn_logits = attn_logits / np.sqrt(self.key_size).astype(key.dtype)
|
132 |
+
if mask is not None:
|
133 |
+
if mask.ndim != attn_logits.ndim:
|
134 |
+
raise ValueError(
|
135 |
+
f"Mask dimensionality {mask.ndim} must match logits dimensionality "
|
136 |
+
f"{attn_logits.ndim}.")
|
137 |
+
attn_logits = jnp.where(mask, attn_logits, -1e30)
|
138 |
+
attn_weights = jax.nn.softmax(attn_logits) # [H, T', T]
|
139 |
+
|
140 |
+
# Weight the values by the attention and flatten the head vectors.
|
141 |
+
attn = jnp.einsum("...htT,...Thd->...thd", attn_weights, value_heads)
|
142 |
+
attn = jnp.reshape(attn, (*leading_dims, sequence_length, -1)) # [T', H*V]
|
143 |
+
|
144 |
+
# Apply another projection to get the final embeddings.
|
145 |
+
final_projection = hk.Linear(self.model_size, w_init=self.w_init)
|
146 |
+
return AttentionOutput(
|
147 |
+
out=final_projection(attn),
|
148 |
+
logits=attn_logits,
|
149 |
+
)
|
150 |
+
|
151 |
+
@hk.transparent
|
152 |
+
def _linear_projection(
|
153 |
+
self,
|
154 |
+
x: jnp.ndarray,
|
155 |
+
head_size: int,
|
156 |
+
name: Optional[str] = None,
|
157 |
+
) -> jnp.ndarray:
|
158 |
+
y = hk.Linear(self.num_heads * head_size, w_init=self.w_init, name=name)(x)
|
159 |
+
*leading_dims, _ = x.shape
|
160 |
+
return y.reshape((*leading_dims, self.num_heads, head_size))
|
transformer/compressed_model.py
ADDED
@@ -0,0 +1,185 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2022 DeepMind Technologies Limited. All Rights Reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
# ==============================================================================
|
15 |
+
"""Modified transformer to learn a linear compression of the residual stream.
|
16 |
+
|
17 |
+
CompressedTransformer adds three arguments compared to Transformer:
|
18 |
+
- embedding_size: the size of the compressed residual stream.
|
19 |
+
- unembed_at_every_layer: whether to apply the unembedding before applying
|
20 |
+
attention and MLP layers
|
21 |
+
- return_activations: whether to return all model activations rather than just
|
22 |
+
the outputs
|
23 |
+
"""
|
24 |
+
|
25 |
+
import collections
|
26 |
+
import dataclasses
|
27 |
+
from typing import Optional
|
28 |
+
|
29 |
+
import haiku as hk
|
30 |
+
import jax
|
31 |
+
import numpy as np
|
32 |
+
|
33 |
+
from tracr.transformer import attention
|
34 |
+
from tracr.transformer import model
|
35 |
+
|
36 |
+
|
37 |
+
@dataclasses.dataclass
|
38 |
+
class CompressedTransformer(hk.Module):
|
39 |
+
"""A transformer stack with linearly compressed residual stream."""
|
40 |
+
|
41 |
+
config: model.TransformerConfig
|
42 |
+
name: Optional[str] = None
|
43 |
+
|
44 |
+
def __call__(
|
45 |
+
self,
|
46 |
+
embeddings: jax.Array, # [B, T, D]
|
47 |
+
mask: jax.Array, # [B, T]
|
48 |
+
*,
|
49 |
+
use_dropout: bool = True,
|
50 |
+
embedding_size: Optional[int] = None,
|
51 |
+
unembed_at_every_layer: bool = False,
|
52 |
+
) -> model.TransformerOutput: # [B, T, D]
|
53 |
+
"""Transforms input embedding sequences to output embedding sequences.
|
54 |
+
|
55 |
+
Args:
|
56 |
+
embeddings: Input embeddings to pass through the model.
|
57 |
+
mask: Boolean mask to restrict the inputs the model uses.
|
58 |
+
use_dropout: Turns dropout on/off.
|
59 |
+
embedding_size: Dimension to compress the residual stream to.
|
60 |
+
unembed_at_every_layer: Whether to unembed the residual stream when
|
61 |
+
reading the input for every layer (keeping the layer input sizes) or to
|
62 |
+
only unembed before the model output (compressing the layer inputs).
|
63 |
+
|
64 |
+
Returns:
|
65 |
+
The outputs of the forward pass through the transformer.
|
66 |
+
"""
|
67 |
+
|
68 |
+
def layer_norm(x: jax.Array) -> jax.Array:
|
69 |
+
"""Applies a unique LayerNorm to x with default settings."""
|
70 |
+
if self.config.layer_norm:
|
71 |
+
return hk.LayerNorm(axis=-1, create_scale=True, create_offset=True)(x)
|
72 |
+
return x
|
73 |
+
|
74 |
+
initializer = hk.initializers.VarianceScaling(2 / self.config.num_layers)
|
75 |
+
dropout_rate = self.config.dropout_rate if use_dropout else 0.
|
76 |
+
_, seq_len, model_size = embeddings.shape
|
77 |
+
|
78 |
+
# To compress the model, we multiply with a matrix W when reading from
|
79 |
+
# the residual stream, and with W^T when writing to the residual stream.
|
80 |
+
if embedding_size is not None:
|
81 |
+
# [to_size, from_size]
|
82 |
+
w_emb = hk.get_parameter(
|
83 |
+
"w_emb", (embedding_size, model_size),
|
84 |
+
init=hk.initializers.RandomNormal())
|
85 |
+
|
86 |
+
write_to_residual = lambda x: x @ w_emb.T
|
87 |
+
read_from_residual = lambda x: x @ w_emb
|
88 |
+
|
89 |
+
if not unembed_at_every_layer:
|
90 |
+
model_size = embedding_size
|
91 |
+
else:
|
92 |
+
write_to_residual = lambda x: x
|
93 |
+
read_from_residual = lambda x: x
|
94 |
+
|
95 |
+
# Compute causal mask for autoregressive sequence modelling.
|
96 |
+
mask = mask[:, None, None, :] # [B, H=1, T'=1, T]
|
97 |
+
mask = mask.repeat(seq_len, axis=2) # [B, H=1, T, T]
|
98 |
+
|
99 |
+
if self.config.causal:
|
100 |
+
causal_mask = np.ones((1, 1, seq_len, seq_len)) # [B=1, H=1, T, T]
|
101 |
+
causal_mask = np.tril(causal_mask)
|
102 |
+
mask = mask * causal_mask # [B, H=1, T, T]
|
103 |
+
|
104 |
+
# Set up activation collection.
|
105 |
+
collected = collections.defaultdict(list)
|
106 |
+
|
107 |
+
def collect(**kwargs):
|
108 |
+
for k, v in kwargs.items():
|
109 |
+
collected[k].append(v)
|
110 |
+
|
111 |
+
residual = write_to_residual(embeddings)
|
112 |
+
|
113 |
+
for layer in range(self.config.num_layers):
|
114 |
+
with hk.experimental.name_scope(f"layer_{layer}"):
|
115 |
+
# First the attention block.
|
116 |
+
attn_block = attention.MultiHeadAttention(
|
117 |
+
num_heads=self.config.num_heads,
|
118 |
+
key_size=self.config.key_size,
|
119 |
+
model_size=model_size,
|
120 |
+
w_init=initializer,
|
121 |
+
name="attn")
|
122 |
+
|
123 |
+
attn_in = residual
|
124 |
+
if unembed_at_every_layer:
|
125 |
+
attn_in = read_from_residual(attn_in)
|
126 |
+
attn_in = layer_norm(attn_in)
|
127 |
+
attn_out = attn_block(attn_in, attn_in, attn_in, mask=mask)
|
128 |
+
attn_out, attn_logits = attn_out.out, attn_out.logits
|
129 |
+
if dropout_rate > 0:
|
130 |
+
attn_out = hk.dropout(hk.next_rng_key(), dropout_rate, attn_out)
|
131 |
+
|
132 |
+
if unembed_at_every_layer:
|
133 |
+
collect(layer_outputs=attn_out, attn_logits=attn_logits)
|
134 |
+
else:
|
135 |
+
collect(
|
136 |
+
layer_outputs=read_from_residual(attn_out),
|
137 |
+
attn_logits=attn_logits,
|
138 |
+
)
|
139 |
+
|
140 |
+
if unembed_at_every_layer:
|
141 |
+
attn_out = write_to_residual(attn_out)
|
142 |
+
residual = residual + attn_out
|
143 |
+
|
144 |
+
collect(residuals=residual)
|
145 |
+
|
146 |
+
# Then the dense block.
|
147 |
+
with hk.experimental.name_scope("mlp"):
|
148 |
+
dense_block = hk.Sequential([
|
149 |
+
hk.Linear(
|
150 |
+
self.config.mlp_hidden_size,
|
151 |
+
w_init=initializer,
|
152 |
+
name="linear_1"),
|
153 |
+
self.config.activation_function,
|
154 |
+
hk.Linear(model_size, w_init=initializer, name="linear_2"),
|
155 |
+
])
|
156 |
+
|
157 |
+
dense_in = residual
|
158 |
+
if unembed_at_every_layer:
|
159 |
+
dense_in = read_from_residual(dense_in)
|
160 |
+
dense_in = layer_norm(dense_in)
|
161 |
+
dense_out = dense_block(dense_in)
|
162 |
+
if dropout_rate > 0:
|
163 |
+
dense_out = hk.dropout(hk.next_rng_key(), dropout_rate, dense_out)
|
164 |
+
|
165 |
+
if unembed_at_every_layer:
|
166 |
+
collect(layer_outputs=dense_out)
|
167 |
+
else:
|
168 |
+
collect(layer_outputs=read_from_residual(dense_out))
|
169 |
+
|
170 |
+
if unembed_at_every_layer:
|
171 |
+
dense_out = write_to_residual(dense_out)
|
172 |
+
residual = residual + dense_out
|
173 |
+
|
174 |
+
collect(residuals=residual)
|
175 |
+
|
176 |
+
output = read_from_residual(residual)
|
177 |
+
output = layer_norm(output)
|
178 |
+
|
179 |
+
return model.TransformerOutput(
|
180 |
+
layer_outputs=collected["layer_outputs"],
|
181 |
+
residuals=collected["residuals"],
|
182 |
+
attn_logits=collected["attn_logits"],
|
183 |
+
output=output,
|
184 |
+
input_embeddings=embeddings,
|
185 |
+
)
|
transformer/compressed_model_test.py
ADDED
@@ -0,0 +1,318 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2022 DeepMind Technologies Limited. All Rights Reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
# ==============================================================================
|
15 |
+
"""Tests for transformer.model."""
|
16 |
+
|
17 |
+
from absl.testing import absltest
|
18 |
+
from absl.testing import parameterized
|
19 |
+
import haiku as hk
|
20 |
+
import jax
|
21 |
+
import jax.numpy as jnp
|
22 |
+
import numpy as np
|
23 |
+
from tracr.transformer import compressed_model
|
24 |
+
from tracr.transformer import model
|
25 |
+
|
26 |
+
|
27 |
+
class CompressedTransformerTest(parameterized.TestCase):
|
28 |
+
|
29 |
+
def _check_layer_naming(self, params):
|
30 |
+
# Modules should be named for example
|
31 |
+
# For MLPs: "compressed_transformer/layer_{i}/mlp/linear_1"
|
32 |
+
# For Attention: "compressed_transformer/layer_{i}/attn/key"
|
33 |
+
# For Layer Norm: "compressed_transformer/layer_{i}/layer_norm"
|
34 |
+
for key in params.keys():
|
35 |
+
levels = key.split("/")
|
36 |
+
self.assertEqual(levels[0], "compressed_transformer")
|
37 |
+
if len(levels) == 1:
|
38 |
+
self.assertEqual(list(params[key].keys()), ["w_emb"])
|
39 |
+
continue
|
40 |
+
if levels[1].startswith("layer_norm"):
|
41 |
+
continue # output layer norm
|
42 |
+
self.assertStartsWith(levels[1], "layer")
|
43 |
+
if levels[2] == "mlp":
|
44 |
+
self.assertIn(levels[3], {"linear_1", "linear_2"})
|
45 |
+
elif levels[2] == "attn":
|
46 |
+
self.assertIn(levels[3], {"key", "query", "value", "linear"})
|
47 |
+
else:
|
48 |
+
self.assertStartsWith(levels[2], "layer_norm")
|
49 |
+
|
50 |
+
def _zero_mlps(self, params):
|
51 |
+
for module in params:
|
52 |
+
if "mlp" in module:
|
53 |
+
for param in params[module]:
|
54 |
+
params[module][param] = jnp.zeros_like(params[module][param])
|
55 |
+
return params
|
56 |
+
|
57 |
+
@parameterized.parameters(dict(layer_norm=True), dict(layer_norm=False))
|
58 |
+
def test_layer_norm(self, layer_norm):
|
59 |
+
# input = [1, 1, 1, 1]
|
60 |
+
# If layer norm is used, this should give all-0 output for a freshly
|
61 |
+
# initialized model because LN will subtract the mean after each layer.
|
62 |
+
# Else we expect non-zero outputs.
|
63 |
+
|
64 |
+
@hk.transform
|
65 |
+
def forward(emb, mask):
|
66 |
+
transformer = compressed_model.CompressedTransformer(
|
67 |
+
model.TransformerConfig(
|
68 |
+
num_heads=2,
|
69 |
+
num_layers=2,
|
70 |
+
key_size=5,
|
71 |
+
mlp_hidden_size=64,
|
72 |
+
dropout_rate=0.,
|
73 |
+
layer_norm=layer_norm))
|
74 |
+
return transformer(emb, mask).output
|
75 |
+
|
76 |
+
seq_len = 4
|
77 |
+
emb = jnp.ones((1, seq_len, 1))
|
78 |
+
mask = jnp.ones((1, seq_len))
|
79 |
+
rng = hk.PRNGSequence(1)
|
80 |
+
params = forward.init(next(rng), emb, mask)
|
81 |
+
out = forward.apply(params, next(rng), emb, mask)
|
82 |
+
|
83 |
+
self._check_layer_naming(params)
|
84 |
+
if layer_norm:
|
85 |
+
np.testing.assert_allclose(out, 0)
|
86 |
+
else:
|
87 |
+
self.assertFalse(np.allclose(out, 0))
|
88 |
+
|
89 |
+
@parameterized.parameters(dict(causal=True), dict(causal=False))
|
90 |
+
def test_causal_attention(self, causal):
|
91 |
+
# input = [0, random, random, random]
|
92 |
+
# mask = [1, 0, 1, 1]
|
93 |
+
# For causal attention the second token can only attend to the first one, so
|
94 |
+
# it should be the same. For non-causal attention all tokens should change.
|
95 |
+
|
96 |
+
@hk.transform
|
97 |
+
def forward(emb, mask):
|
98 |
+
transformer = compressed_model.CompressedTransformer(
|
99 |
+
model.TransformerConfig(
|
100 |
+
num_heads=2,
|
101 |
+
num_layers=2,
|
102 |
+
key_size=5,
|
103 |
+
mlp_hidden_size=64,
|
104 |
+
dropout_rate=0.,
|
105 |
+
layer_norm=False,
|
106 |
+
causal=causal))
|
107 |
+
return transformer(emb, mask).output
|
108 |
+
|
109 |
+
seq_len = 4
|
110 |
+
emb = np.random.random((1, seq_len, 1))
|
111 |
+
emb[:, 0, :] = 0
|
112 |
+
mask = np.array([[1, 0, 1, 1]])
|
113 |
+
emb, mask = jnp.array(emb), jnp.array(mask)
|
114 |
+
|
115 |
+
rng = hk.PRNGSequence(1)
|
116 |
+
params = forward.init(next(rng), emb, mask)
|
117 |
+
params = self._zero_mlps(params)
|
118 |
+
out = forward.apply(params, next(rng), emb, mask)
|
119 |
+
|
120 |
+
self._check_layer_naming(params)
|
121 |
+
if causal:
|
122 |
+
self.assertEqual(0, out[0, 0, 0])
|
123 |
+
self.assertEqual(emb[0, 1, 0], out[0, 1, 0])
|
124 |
+
else:
|
125 |
+
self.assertNotEqual(0, out[0, 0, 0])
|
126 |
+
self.assertNotEqual(emb[0, 1, 0], out[0, 1, 0])
|
127 |
+
self.assertNotEqual(emb[0, 2, 0], out[0, 2, 0])
|
128 |
+
self.assertNotEqual(emb[0, 3, 0], out[0, 3, 0])
|
129 |
+
|
130 |
+
def test_setting_activation_function_to_zero(self):
|
131 |
+
# An activation function that always returns zeros should result in the
|
132 |
+
# same model output as setting all MLP weights to zero.
|
133 |
+
|
134 |
+
@hk.transform
|
135 |
+
def forward_zero(emb, mask):
|
136 |
+
transformer = compressed_model.CompressedTransformer(
|
137 |
+
model.TransformerConfig(
|
138 |
+
num_heads=2,
|
139 |
+
num_layers=2,
|
140 |
+
key_size=5,
|
141 |
+
mlp_hidden_size=64,
|
142 |
+
dropout_rate=0.,
|
143 |
+
causal=False,
|
144 |
+
layer_norm=False,
|
145 |
+
activation_function=jnp.zeros_like))
|
146 |
+
return transformer(emb, mask).output
|
147 |
+
|
148 |
+
@hk.transform
|
149 |
+
def forward(emb, mask):
|
150 |
+
transformer = compressed_model.CompressedTransformer(
|
151 |
+
model.TransformerConfig(
|
152 |
+
num_heads=2,
|
153 |
+
num_layers=2,
|
154 |
+
key_size=5,
|
155 |
+
mlp_hidden_size=64,
|
156 |
+
dropout_rate=0.,
|
157 |
+
causal=False,
|
158 |
+
layer_norm=False,
|
159 |
+
activation_function=jax.nn.gelu))
|
160 |
+
return transformer(emb, mask).output
|
161 |
+
|
162 |
+
seq_len = 4
|
163 |
+
emb = np.random.random((1, seq_len, 1))
|
164 |
+
mask = np.ones((1, seq_len))
|
165 |
+
emb, mask = jnp.array(emb), jnp.array(mask)
|
166 |
+
|
167 |
+
rng = hk.PRNGSequence(1)
|
168 |
+
params = forward.init(next(rng), emb, mask)
|
169 |
+
params_no_mlps = self._zero_mlps(params)
|
170 |
+
|
171 |
+
out_zero_activation = forward_zero.apply(params, next(rng), emb, mask)
|
172 |
+
out_no_mlps = forward.apply(params_no_mlps, next(rng), emb, mask)
|
173 |
+
|
174 |
+
self._check_layer_naming(params)
|
175 |
+
np.testing.assert_allclose(out_zero_activation, out_no_mlps)
|
176 |
+
self.assertFalse(np.allclose(out_zero_activation, 0))
|
177 |
+
|
178 |
+
def test_not_setting_embedding_size_produces_same_output_as_default_model(
|
179 |
+
self):
|
180 |
+
config = model.TransformerConfig(
|
181 |
+
num_heads=2,
|
182 |
+
num_layers=2,
|
183 |
+
key_size=5,
|
184 |
+
mlp_hidden_size=64,
|
185 |
+
dropout_rate=0.,
|
186 |
+
causal=False,
|
187 |
+
layer_norm=False)
|
188 |
+
|
189 |
+
@hk.without_apply_rng
|
190 |
+
@hk.transform
|
191 |
+
def forward_model(emb, mask):
|
192 |
+
return model.Transformer(config)(emb, mask).output
|
193 |
+
|
194 |
+
@hk.without_apply_rng
|
195 |
+
@hk.transform
|
196 |
+
def forward_superposition(emb, mask):
|
197 |
+
return compressed_model.CompressedTransformer(config)(emb, mask).output
|
198 |
+
|
199 |
+
seq_len = 4
|
200 |
+
emb = np.random.random((1, seq_len, 1))
|
201 |
+
mask = np.ones((1, seq_len))
|
202 |
+
emb, mask = jnp.array(emb), jnp.array(mask)
|
203 |
+
|
204 |
+
rng = hk.PRNGSequence(1)
|
205 |
+
params = forward_model.init(next(rng), emb, mask)
|
206 |
+
params_superposition = {
|
207 |
+
k.replace("transformer", "compressed_transformer"): v
|
208 |
+
for k, v in params.items()
|
209 |
+
}
|
210 |
+
|
211 |
+
out_model = forward_model.apply(params, emb, mask)
|
212 |
+
out_superposition = forward_superposition.apply(params_superposition, emb,
|
213 |
+
mask)
|
214 |
+
|
215 |
+
self._check_layer_naming(params_superposition)
|
216 |
+
np.testing.assert_allclose(out_model, out_superposition)
|
217 |
+
|
218 |
+
@parameterized.parameters(
|
219 |
+
dict(embedding_size=2, unembed_at_every_layer=True),
|
220 |
+
dict(embedding_size=2, unembed_at_every_layer=False),
|
221 |
+
dict(embedding_size=6, unembed_at_every_layer=True),
|
222 |
+
dict(embedding_size=6, unembed_at_every_layer=False))
|
223 |
+
def test_embbeding_size_produces_correct_shape_of_residuals_and_layer_outputs(
|
224 |
+
self, embedding_size, unembed_at_every_layer):
|
225 |
+
|
226 |
+
@hk.transform
|
227 |
+
def forward(emb, mask):
|
228 |
+
transformer = compressed_model.CompressedTransformer(
|
229 |
+
model.TransformerConfig(
|
230 |
+
num_heads=2,
|
231 |
+
num_layers=2,
|
232 |
+
key_size=5,
|
233 |
+
mlp_hidden_size=64,
|
234 |
+
dropout_rate=0.,
|
235 |
+
causal=False,
|
236 |
+
layer_norm=False))
|
237 |
+
return transformer(
|
238 |
+
emb,
|
239 |
+
mask,
|
240 |
+
embedding_size=embedding_size,
|
241 |
+
unembed_at_every_layer=unembed_at_every_layer,
|
242 |
+
)
|
243 |
+
|
244 |
+
seq_len = 4
|
245 |
+
model_size = 16
|
246 |
+
|
247 |
+
emb = np.random.random((1, seq_len, model_size))
|
248 |
+
mask = np.ones((1, seq_len))
|
249 |
+
emb, mask = jnp.array(emb), jnp.array(mask)
|
250 |
+
|
251 |
+
rng = hk.PRNGSequence(1)
|
252 |
+
params = forward.init(next(rng), emb, mask)
|
253 |
+
activations = forward.apply(params, next(rng), emb, mask)
|
254 |
+
|
255 |
+
self._check_layer_naming(params)
|
256 |
+
|
257 |
+
for residual in activations.residuals:
|
258 |
+
self.assertEqual(residual.shape, (1, seq_len, embedding_size))
|
259 |
+
|
260 |
+
for layer_output in activations.layer_outputs:
|
261 |
+
self.assertEqual(layer_output.shape, (1, seq_len, model_size))
|
262 |
+
|
263 |
+
@parameterized.parameters(
|
264 |
+
dict(model_size=2, unembed_at_every_layer=True),
|
265 |
+
dict(model_size=2, unembed_at_every_layer=False),
|
266 |
+
dict(model_size=6, unembed_at_every_layer=True),
|
267 |
+
dict(model_size=6, unembed_at_every_layer=False))
|
268 |
+
def test_identity_embedding_produces_same_output_as_standard_model(
|
269 |
+
self, model_size, unembed_at_every_layer):
|
270 |
+
|
271 |
+
config = model.TransformerConfig(
|
272 |
+
num_heads=2,
|
273 |
+
num_layers=2,
|
274 |
+
key_size=5,
|
275 |
+
mlp_hidden_size=64,
|
276 |
+
dropout_rate=0.,
|
277 |
+
causal=False,
|
278 |
+
layer_norm=False)
|
279 |
+
|
280 |
+
@hk.without_apply_rng
|
281 |
+
@hk.transform
|
282 |
+
def forward_model(emb, mask):
|
283 |
+
return model.Transformer(config)(emb, mask).output
|
284 |
+
|
285 |
+
@hk.without_apply_rng
|
286 |
+
@hk.transform
|
287 |
+
def forward_superposition(emb, mask):
|
288 |
+
return compressed_model.CompressedTransformer(config)(
|
289 |
+
emb,
|
290 |
+
mask,
|
291 |
+
embedding_size=model_size,
|
292 |
+
unembed_at_every_layer=unembed_at_every_layer).output
|
293 |
+
|
294 |
+
seq_len = 4
|
295 |
+
emb = np.random.random((1, seq_len, model_size))
|
296 |
+
mask = np.ones((1, seq_len))
|
297 |
+
emb, mask = jnp.array(emb), jnp.array(mask)
|
298 |
+
|
299 |
+
rng = hk.PRNGSequence(1)
|
300 |
+
params = forward_model.init(next(rng), emb, mask)
|
301 |
+
params_superposition = {
|
302 |
+
k.replace("transformer", "compressed_transformer"): v
|
303 |
+
for k, v in params.items()
|
304 |
+
}
|
305 |
+
params_superposition["compressed_transformer"] = {
|
306 |
+
"w_emb": jnp.identity(model_size)
|
307 |
+
}
|
308 |
+
|
309 |
+
out_model = forward_model.apply(params, emb, mask)
|
310 |
+
out_superposition = forward_superposition.apply(params_superposition, emb,
|
311 |
+
mask)
|
312 |
+
|
313 |
+
self._check_layer_naming(params_superposition)
|
314 |
+
np.testing.assert_allclose(out_model, out_superposition)
|
315 |
+
|
316 |
+
|
317 |
+
if __name__ == "__main__":
|
318 |
+
absltest.main()
|
transformer/encoder.py
ADDED
@@ -0,0 +1,135 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2022 DeepMind Technologies Limited. All Rights Reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
# ==============================================================================
|
15 |
+
"""Basic encoder for inputs with a fixed vocabulary."""
|
16 |
+
|
17 |
+
import abc
|
18 |
+
from typing import Any, Sequence, Optional
|
19 |
+
|
20 |
+
from tracr.craft import bases
|
21 |
+
|
22 |
+
|
23 |
+
class Encoder(abc.ABC):
|
24 |
+
"""Encodes a list of tokens into a list of inputs for a transformer model.
|
25 |
+
|
26 |
+
The abstract class does not make assumptions on the input and output types,
|
27 |
+
and we have different encoders for different input types.
|
28 |
+
"""
|
29 |
+
|
30 |
+
@abc.abstractmethod
|
31 |
+
def encode(self, inputs: list[Any]) -> list[Any]:
|
32 |
+
return list()
|
33 |
+
|
34 |
+
@abc.abstractmethod
|
35 |
+
def decode(self, encodings: list[Any]) -> list[Any]:
|
36 |
+
return list()
|
37 |
+
|
38 |
+
@property
|
39 |
+
def pad_token(self) -> Optional[str]:
|
40 |
+
return None
|
41 |
+
|
42 |
+
@property
|
43 |
+
def bos_token(self) -> Optional[str]:
|
44 |
+
return None
|
45 |
+
|
46 |
+
@property
|
47 |
+
def pad_encoding(self) -> Optional[int]:
|
48 |
+
return None
|
49 |
+
|
50 |
+
@property
|
51 |
+
def bos_encoding(self) -> Optional[int]:
|
52 |
+
return None
|
53 |
+
|
54 |
+
|
55 |
+
class NumericalEncoder(Encoder):
|
56 |
+
"""Encodes numerical variables (simply using the identity mapping)."""
|
57 |
+
|
58 |
+
def encode(self, inputs: list[float]) -> list[float]:
|
59 |
+
return inputs
|
60 |
+
|
61 |
+
def decode(self, encodings: list[float]) -> list[float]:
|
62 |
+
return encodings
|
63 |
+
|
64 |
+
|
65 |
+
class CategoricalEncoder(Encoder):
|
66 |
+
"""Encodes categorical variables with a fixed vocabulary."""
|
67 |
+
|
68 |
+
def __init__(
|
69 |
+
self,
|
70 |
+
basis: Sequence[bases.BasisDirection],
|
71 |
+
enforce_bos: bool = False,
|
72 |
+
bos_token: Optional[str] = None,
|
73 |
+
pad_token: Optional[str] = None,
|
74 |
+
max_seq_len: Optional[int] = None,
|
75 |
+
):
|
76 |
+
"""Initialises. If enforce_bos is set, ensures inputs start with it."""
|
77 |
+
if enforce_bos and not bos_token:
|
78 |
+
raise ValueError("BOS token must be specified if enforcing BOS.")
|
79 |
+
|
80 |
+
self.encoding_map = {}
|
81 |
+
for i, direction in enumerate(basis):
|
82 |
+
val = direction.value
|
83 |
+
self.encoding_map[val] = i
|
84 |
+
|
85 |
+
if bos_token and bos_token not in self.encoding_map:
|
86 |
+
raise ValueError("BOS token missing in encoding.")
|
87 |
+
|
88 |
+
if pad_token and pad_token not in self.encoding_map:
|
89 |
+
raise ValueError("PAD token missing in encoding.")
|
90 |
+
|
91 |
+
self.enforce_bos = enforce_bos
|
92 |
+
self._bos_token = bos_token
|
93 |
+
self._pad_token = pad_token
|
94 |
+
self._max_seq_len = max_seq_len
|
95 |
+
|
96 |
+
def encode(self, inputs: list[bases.Value]) -> list[int]:
|
97 |
+
if self.enforce_bos and inputs[0] != self.bos_token:
|
98 |
+
raise ValueError("First input token must be BOS token. "
|
99 |
+
f"Should be '{self.bos_token}', but was '{inputs[0]}'.")
|
100 |
+
if missing := set(inputs) - set(self.encoding_map.keys()):
|
101 |
+
raise ValueError(f"Inputs {missing} not found in encoding ",
|
102 |
+
self.encoding_map.keys())
|
103 |
+
if self._max_seq_len is not None and len(inputs) > self._max_seq_len:
|
104 |
+
raise ValueError(f"{inputs=} are longer than the maximum "
|
105 |
+
f"sequence length {self._max_seq_len}")
|
106 |
+
|
107 |
+
return [self.encoding_map[x] for x in inputs]
|
108 |
+
|
109 |
+
def decode(self, encodings: list[int]) -> list[bases.Value]:
|
110 |
+
"""Recover the tokens that corresponds to `ids`. Inverse of __call__."""
|
111 |
+
decoding_map = {val: key for key, val in self.encoding_map.items()}
|
112 |
+
if missing := set(encodings) - set(decoding_map.keys()):
|
113 |
+
raise ValueError(f"Inputs {missing} not found in decoding map ",
|
114 |
+
decoding_map.keys())
|
115 |
+
return [decoding_map[x] for x in encodings]
|
116 |
+
|
117 |
+
@property
|
118 |
+
def vocab_size(self) -> int:
|
119 |
+
return len(self.encoding_map)
|
120 |
+
|
121 |
+
@property
|
122 |
+
def bos_token(self) -> Optional[str]:
|
123 |
+
return self._bos_token
|
124 |
+
|
125 |
+
@property
|
126 |
+
def pad_token(self) -> Optional[str]:
|
127 |
+
return self._pad_token
|
128 |
+
|
129 |
+
@property
|
130 |
+
def bos_encoding(self) -> Optional[int]:
|
131 |
+
return None if self.bos_token is None else self.encoding_map[self.bos_token]
|
132 |
+
|
133 |
+
@property
|
134 |
+
def pad_encoding(self) -> Optional[int]:
|
135 |
+
return None if self.pad_token is None else self.encoding_map[self.pad_token]
|
transformer/encoder_test.py
ADDED
@@ -0,0 +1,123 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2022 DeepMind Technologies Limited. All Rights Reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
# ==============================================================================
|
15 |
+
"""Tests for transformer.encoder."""
|
16 |
+
|
17 |
+
from absl.testing import absltest
|
18 |
+
from absl.testing import parameterized
|
19 |
+
from tracr.craft import bases
|
20 |
+
from tracr.transformer import encoder
|
21 |
+
|
22 |
+
_BOS_TOKEN = "bos_encoder_test"
|
23 |
+
_PAD_TOKEN = "pad_encoder_test"
|
24 |
+
|
25 |
+
|
26 |
+
class CategoricalEncoderTest(parameterized.TestCase):
|
27 |
+
|
28 |
+
def test_encode_raises_value_error_if_input_doesnt_start_with_bos(self):
|
29 |
+
vs = bases.VectorSpaceWithBasis.from_values("input", {1, 2, 3, _BOS_TOKEN})
|
30 |
+
basic_encoder = encoder.CategoricalEncoder(
|
31 |
+
vs.basis, enforce_bos=True, bos_token=_BOS_TOKEN)
|
32 |
+
with self.assertRaisesRegex(ValueError,
|
33 |
+
r"^.*First input token must be BOS token.*$"):
|
34 |
+
basic_encoder.encode([1, 1, 1])
|
35 |
+
|
36 |
+
def test_encode_raises_value_error_if_input_not_in_vocab(self):
|
37 |
+
vs = bases.VectorSpaceWithBasis.from_values("input", {1, 2, 3, _BOS_TOKEN})
|
38 |
+
basic_encoder = encoder.CategoricalEncoder(
|
39 |
+
vs.basis, enforce_bos=True, bos_token=_BOS_TOKEN)
|
40 |
+
with self.assertRaisesRegex(ValueError,
|
41 |
+
r"^.*Inputs .* not found in encoding.*$"):
|
42 |
+
basic_encoder.encode([_BOS_TOKEN, 1, 2, 3, 4])
|
43 |
+
|
44 |
+
def test_decode_raises_value_error_if_id_outside_of_vocab_size(self):
|
45 |
+
vs = bases.VectorSpaceWithBasis.from_values("input", {1, 2, _BOS_TOKEN})
|
46 |
+
basic_encoder = encoder.CategoricalEncoder(
|
47 |
+
vs.basis, enforce_bos=True, bos_token=_BOS_TOKEN)
|
48 |
+
with self.assertRaisesRegex(ValueError,
|
49 |
+
r"^.*Inputs .* not found in decoding map.*$"):
|
50 |
+
basic_encoder.decode([0, 1, 2, 3])
|
51 |
+
|
52 |
+
def test_encoder_raises_value_error_if_bos_not_in_basis(self):
|
53 |
+
vs = bases.VectorSpaceWithBasis.from_values("input", {1, 2, 3})
|
54 |
+
with self.assertRaisesRegex(ValueError,
|
55 |
+
r"^.*BOS token missing in encoding.*$"):
|
56 |
+
unused_basic_encoder = encoder.CategoricalEncoder(
|
57 |
+
vs.basis, bos_token=_BOS_TOKEN)
|
58 |
+
|
59 |
+
def test_encoder_raises_value_error_if_pad_not_in_basis(self):
|
60 |
+
vs = bases.VectorSpaceWithBasis.from_values("input", {1, 2, 3})
|
61 |
+
with self.assertRaisesRegex(ValueError,
|
62 |
+
r"^.*PAD token missing in encoding.*$"):
|
63 |
+
unused_basic_encoder = encoder.CategoricalEncoder(
|
64 |
+
vs.basis, pad_token=_PAD_TOKEN)
|
65 |
+
|
66 |
+
def test_encoder_encodes_bos_and_pad_tokens_as_expected(self):
|
67 |
+
vs = bases.VectorSpaceWithBasis.from_values(
|
68 |
+
"input", {1, 2, 3, _BOS_TOKEN, _PAD_TOKEN})
|
69 |
+
basic_encoder = encoder.CategoricalEncoder(
|
70 |
+
vs.basis, bos_token=_BOS_TOKEN, pad_token=_PAD_TOKEN)
|
71 |
+
self.assertEqual(
|
72 |
+
basic_encoder.encode([_BOS_TOKEN, _PAD_TOKEN]),
|
73 |
+
[basic_encoder.bos_encoding, basic_encoder.pad_encoding])
|
74 |
+
|
75 |
+
@parameterized.parameters([
|
76 |
+
dict(
|
77 |
+
vocab={1, 2, 3, _BOS_TOKEN}, # lexicographic order
|
78 |
+
inputs=[_BOS_TOKEN, 3, 2, 1],
|
79 |
+
expected=[3, 2, 1, 0]),
|
80 |
+
dict(
|
81 |
+
vocab={"a", "b", _BOS_TOKEN, "c"}, # lexicographic order
|
82 |
+
inputs=[_BOS_TOKEN, "b", "b", "c"],
|
83 |
+
expected=[2, 1, 1, 3]),
|
84 |
+
])
|
85 |
+
def test_tokens_are_encoded_in_lexicographic_order(self, vocab, inputs,
|
86 |
+
expected):
|
87 |
+
# Expect encodings to be assigned to ids according to a lexicographic
|
88 |
+
# ordering of the vocab
|
89 |
+
vs = bases.VectorSpaceWithBasis.from_values("input", vocab)
|
90 |
+
basic_encoder = encoder.CategoricalEncoder(
|
91 |
+
vs.basis, enforce_bos=True, bos_token=_BOS_TOKEN)
|
92 |
+
encodings = basic_encoder.encode(inputs)
|
93 |
+
self.assertEqual(encodings, expected)
|
94 |
+
|
95 |
+
@parameterized.parameters([
|
96 |
+
dict(vocab={_BOS_TOKEN, _PAD_TOKEN, 1, 2, 3}, expected=5),
|
97 |
+
dict(vocab={_BOS_TOKEN, _PAD_TOKEN, "a", "b"}, expected=4),
|
98 |
+
])
|
99 |
+
def test_vocab_size_has_expected_value(self, vocab, expected):
|
100 |
+
vs = bases.VectorSpaceWithBasis.from_values("input", vocab)
|
101 |
+
basic_encoder = encoder.CategoricalEncoder(
|
102 |
+
vs.basis, enforce_bos=True, bos_token=_BOS_TOKEN, pad_token=_PAD_TOKEN)
|
103 |
+
self.assertEqual(basic_encoder.vocab_size, expected)
|
104 |
+
|
105 |
+
@parameterized.parameters([
|
106 |
+
dict(
|
107 |
+
vocab={_BOS_TOKEN, _PAD_TOKEN, 1, 2, 3}, inputs=[_BOS_TOKEN, 3, 2,
|
108 |
+
1]),
|
109 |
+
dict(
|
110 |
+
vocab={_BOS_TOKEN, _PAD_TOKEN, "a", "b", "c"},
|
111 |
+
inputs=[_BOS_TOKEN, "b", "b", "c"]),
|
112 |
+
])
|
113 |
+
def test_decode_inverts_encode(self, vocab, inputs):
|
114 |
+
vs = bases.VectorSpaceWithBasis.from_values("input", vocab)
|
115 |
+
basic_encoder = encoder.CategoricalEncoder(
|
116 |
+
vs.basis, enforce_bos=True, bos_token=_BOS_TOKEN, pad_token=_PAD_TOKEN)
|
117 |
+
encodings = basic_encoder.encode(inputs)
|
118 |
+
recovered = basic_encoder.decode(encodings)
|
119 |
+
self.assertEqual(recovered, inputs)
|
120 |
+
|
121 |
+
|
122 |
+
if __name__ == "__main__":
|
123 |
+
absltest.main()
|
transformer/model.py
ADDED
@@ -0,0 +1,199 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2022 DeepMind Technologies Limited. All Rights Reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
# ==============================================================================
|
15 |
+
"""Didactic example of an autoregressive Transformer-based language model.
|
16 |
+
|
17 |
+
Glossary of shapes:
|
18 |
+
- B: Batch size.
|
19 |
+
- T: Sequence length.
|
20 |
+
- D: Model embedding size.
|
21 |
+
- H: Number of attention heads.
|
22 |
+
- V: Vocabulary size.
|
23 |
+
|
24 |
+
Forked from: haiku.examples.transformer.model
|
25 |
+
"""
|
26 |
+
|
27 |
+
import collections
|
28 |
+
import dataclasses
|
29 |
+
from typing import Callable, Optional
|
30 |
+
|
31 |
+
import chex
|
32 |
+
import haiku as hk
|
33 |
+
import jax
|
34 |
+
import jax.numpy as jnp
|
35 |
+
import numpy as np
|
36 |
+
from tracr.transformer import attention
|
37 |
+
|
38 |
+
# hk.Modules are not always callable: github.com/deepmind/dm-haiku/issues/52
|
39 |
+
# Ideally, we'd want a type:
|
40 |
+
# CallableHaikuModule = Intersection[Callable[..., jax.Array], hk.Module]
|
41 |
+
# But Intersection does not exist (yet): github.com/python/typing/issues/213
|
42 |
+
CallableHaikuModule = Callable[..., jax.Array]
|
43 |
+
|
44 |
+
|
45 |
+
@chex.dataclass
|
46 |
+
class TransformerOutput:
|
47 |
+
layer_outputs: list[jax.Array] # [B, T, D]
|
48 |
+
residuals: list[jax.Array] # [B, T, D]
|
49 |
+
attn_logits: list[jax.Array] # [B, H, T, T]
|
50 |
+
output: jax.Array # [B, T, D]
|
51 |
+
input_embeddings: jax.Array # [B, T, D]
|
52 |
+
|
53 |
+
|
54 |
+
@dataclasses.dataclass
|
55 |
+
class TransformerConfig:
|
56 |
+
num_heads: int
|
57 |
+
num_layers: int
|
58 |
+
key_size: int
|
59 |
+
mlp_hidden_size: int
|
60 |
+
dropout_rate: float
|
61 |
+
activation_function: Callable[[jax.Array], jax.Array] = jax.nn.gelu
|
62 |
+
layer_norm: bool = True
|
63 |
+
causal: bool = False
|
64 |
+
|
65 |
+
|
66 |
+
@dataclasses.dataclass
|
67 |
+
class Transformer(hk.Module):
|
68 |
+
"""A transformer stack."""
|
69 |
+
|
70 |
+
config: TransformerConfig
|
71 |
+
name: Optional[str] = None
|
72 |
+
|
73 |
+
def __call__(
|
74 |
+
self,
|
75 |
+
embeddings: jax.Array, # [B, T, D]
|
76 |
+
mask: jax.Array, # [B, T]
|
77 |
+
*,
|
78 |
+
use_dropout: bool = True,
|
79 |
+
) -> TransformerOutput:
|
80 |
+
"""Transforms input embedding sequences to output embedding sequences."""
|
81 |
+
|
82 |
+
def layer_norm(x: jax.Array) -> jax.Array:
|
83 |
+
"""Applies a unique LayerNorm to x with default settings."""
|
84 |
+
if self.config.layer_norm:
|
85 |
+
return hk.LayerNorm(axis=-1, create_scale=True, create_offset=True)(x)
|
86 |
+
return x
|
87 |
+
|
88 |
+
initializer = hk.initializers.VarianceScaling(2 / self.config.num_layers)
|
89 |
+
dropout_rate = self.config.dropout_rate if use_dropout else 0.
|
90 |
+
_, seq_len, model_size = embeddings.shape
|
91 |
+
|
92 |
+
# Compute causal mask for autoregressive sequence modelling.
|
93 |
+
mask = mask[:, None, None, :] # [B, H=1, T'=1, T]
|
94 |
+
mask = mask.repeat(seq_len, axis=2) # [B, H=1, T, T]
|
95 |
+
|
96 |
+
if self.config.causal:
|
97 |
+
causal_mask = np.ones((1, 1, seq_len, seq_len)) # [B=1, H=1, T, T]
|
98 |
+
causal_mask = np.tril(causal_mask)
|
99 |
+
mask = mask * causal_mask # [B, H=1, T, T]
|
100 |
+
|
101 |
+
# Set up activation collection.
|
102 |
+
collected = collections.defaultdict(list)
|
103 |
+
|
104 |
+
def collect(**kwargs):
|
105 |
+
for k, v in kwargs.items():
|
106 |
+
collected[k].append(v)
|
107 |
+
|
108 |
+
residual = embeddings
|
109 |
+
for layer in range(self.config.num_layers):
|
110 |
+
with hk.experimental.name_scope(f"layer_{layer}"):
|
111 |
+
# First the attention block.
|
112 |
+
attn_block = attention.MultiHeadAttention(
|
113 |
+
num_heads=self.config.num_heads,
|
114 |
+
key_size=self.config.key_size,
|
115 |
+
model_size=model_size,
|
116 |
+
w_init=initializer,
|
117 |
+
name="attn")
|
118 |
+
attn_in = layer_norm(residual)
|
119 |
+
attn_out = attn_block(attn_in, attn_in, attn_in, mask=mask)
|
120 |
+
attn_out, attn_logits = attn_out.out, attn_out.logits
|
121 |
+
if dropout_rate > 0:
|
122 |
+
attn_out = hk.dropout(hk.next_rng_key(), dropout_rate, attn_out)
|
123 |
+
residual = residual + attn_out
|
124 |
+
|
125 |
+
collect(
|
126 |
+
residuals=residual, layer_outputs=attn_out, attn_logits=attn_logits)
|
127 |
+
|
128 |
+
# Then the dense block.
|
129 |
+
with hk.experimental.name_scope("mlp"):
|
130 |
+
dense_block = hk.Sequential([
|
131 |
+
hk.Linear(
|
132 |
+
self.config.mlp_hidden_size,
|
133 |
+
w_init=initializer,
|
134 |
+
name="linear_1"),
|
135 |
+
self.config.activation_function,
|
136 |
+
hk.Linear(model_size, w_init=initializer, name="linear_2"),
|
137 |
+
])
|
138 |
+
dense_in = layer_norm(residual)
|
139 |
+
dense_out = dense_block(dense_in)
|
140 |
+
if dropout_rate > 0:
|
141 |
+
dense_out = hk.dropout(hk.next_rng_key(), dropout_rate, dense_out)
|
142 |
+
residual = residual + dense_out
|
143 |
+
|
144 |
+
collect(residuals=residual, layer_outputs=dense_out)
|
145 |
+
|
146 |
+
return TransformerOutput(
|
147 |
+
residuals=collected["residuals"],
|
148 |
+
layer_outputs=collected["layer_outputs"],
|
149 |
+
attn_logits=collected["attn_logits"],
|
150 |
+
output=layer_norm(residual),
|
151 |
+
input_embeddings=embeddings,
|
152 |
+
)
|
153 |
+
|
154 |
+
|
155 |
+
@chex.dataclass
|
156 |
+
class CompiledTransformerModelOutput:
|
157 |
+
transformer_output: TransformerOutput
|
158 |
+
unembedded_output: jax.Array # [B, T]
|
159 |
+
|
160 |
+
|
161 |
+
@dataclasses.dataclass
|
162 |
+
class CompiledTransformerModel(hk.Module):
|
163 |
+
"""A transformer model with one-hot embeddings."""
|
164 |
+
transformer: Transformer
|
165 |
+
token_embed: CallableHaikuModule
|
166 |
+
position_embed: CallableHaikuModule
|
167 |
+
unembed: CallableHaikuModule
|
168 |
+
use_unembed_argmax: bool
|
169 |
+
pad_token: Optional[int] = None
|
170 |
+
|
171 |
+
def embed(self, tokens: jax.Array) -> jax.Array:
|
172 |
+
token_embeddings = self.token_embed(tokens)
|
173 |
+
positional_embeddings = self.position_embed(jnp.indices(tokens.shape)[-1])
|
174 |
+
return token_embeddings + positional_embeddings # [B, T, D]
|
175 |
+
|
176 |
+
def __call__(
|
177 |
+
self,
|
178 |
+
tokens: jax.Array,
|
179 |
+
use_dropout: bool = True,
|
180 |
+
) -> CompiledTransformerModelOutput:
|
181 |
+
"""Embed tokens, pass through model, and unembed output."""
|
182 |
+
if self.pad_token is None:
|
183 |
+
input_mask = jnp.ones_like(tokens)
|
184 |
+
else:
|
185 |
+
input_mask = (tokens != self.pad_token)
|
186 |
+
input_embeddings = self.embed(tokens)
|
187 |
+
|
188 |
+
transformer_output = self.transformer(
|
189 |
+
input_embeddings,
|
190 |
+
input_mask,
|
191 |
+
use_dropout=use_dropout,
|
192 |
+
)
|
193 |
+
return CompiledTransformerModelOutput(
|
194 |
+
transformer_output=transformer_output,
|
195 |
+
unembedded_output=self.unembed(
|
196 |
+
transformer_output.output,
|
197 |
+
use_unembed_argmax=self.use_unembed_argmax,
|
198 |
+
),
|
199 |
+
)
|
transformer/model_test.py
ADDED
@@ -0,0 +1,275 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2022 DeepMind Technologies Limited. All Rights Reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
# ==============================================================================
|
15 |
+
"""Tests for transformer.model."""
|
16 |
+
|
17 |
+
from absl.testing import absltest
|
18 |
+
from absl.testing import parameterized
|
19 |
+
import haiku as hk
|
20 |
+
import jax
|
21 |
+
import jax.numpy as jnp
|
22 |
+
import numpy as np
|
23 |
+
from tracr.transformer import model
|
24 |
+
|
25 |
+
|
26 |
+
class TransformerTest(parameterized.TestCase):
|
27 |
+
|
28 |
+
def _check_layer_naming(self, params):
|
29 |
+
# Modules should be named for example
|
30 |
+
# For MLPs: "transformer/layer_{i}/mlp/linear_1"
|
31 |
+
# For Attention: "transformer/layer_{i}/attn/key"
|
32 |
+
# For Layer Norm: "transformer/layer_{i}/layer_norm"
|
33 |
+
for key in params.keys():
|
34 |
+
levels = key.split("/")
|
35 |
+
self.assertEqual(levels[0], "transformer")
|
36 |
+
if levels[1].startswith("layer_norm"):
|
37 |
+
continue # output layer norm
|
38 |
+
self.assertStartsWith(levels[1], "layer")
|
39 |
+
if levels[2] == "mlp":
|
40 |
+
self.assertIn(levels[3], {"linear_1", "linear_2"})
|
41 |
+
elif levels[2] == "attn":
|
42 |
+
self.assertIn(levels[3], {"key", "query", "value", "linear"})
|
43 |
+
else:
|
44 |
+
self.assertStartsWith(levels[2], "layer_norm")
|
45 |
+
|
46 |
+
def _zero_mlps(self, params):
|
47 |
+
for module in params:
|
48 |
+
if "mlp" in module:
|
49 |
+
for param in params[module]:
|
50 |
+
params[module][param] = jnp.zeros_like(params[module][param])
|
51 |
+
return params
|
52 |
+
|
53 |
+
@parameterized.parameters(dict(layer_norm=True), dict(layer_norm=False))
|
54 |
+
def test_layer_norm(self, layer_norm):
|
55 |
+
# input = [1, 1, 1, 1]
|
56 |
+
# If layer norm is used, this should give all-0 output for a freshly
|
57 |
+
# initialized model because LN will subtract the mean after each layer.
|
58 |
+
# Else we expect non-zero outputs.
|
59 |
+
|
60 |
+
@hk.transform
|
61 |
+
def forward(emb, mask):
|
62 |
+
transformer = model.Transformer(
|
63 |
+
model.TransformerConfig(
|
64 |
+
num_heads=2,
|
65 |
+
num_layers=2,
|
66 |
+
key_size=5,
|
67 |
+
mlp_hidden_size=64,
|
68 |
+
dropout_rate=0.,
|
69 |
+
layer_norm=layer_norm))
|
70 |
+
return transformer(emb, mask).output
|
71 |
+
|
72 |
+
seq_len = 4
|
73 |
+
emb = jnp.ones((1, seq_len, 1))
|
74 |
+
mask = jnp.ones((1, seq_len))
|
75 |
+
rng = hk.PRNGSequence(1)
|
76 |
+
params = forward.init(next(rng), emb, mask)
|
77 |
+
out = forward.apply(params, next(rng), emb, mask)
|
78 |
+
|
79 |
+
self._check_layer_naming(params)
|
80 |
+
if layer_norm:
|
81 |
+
np.testing.assert_allclose(out, 0)
|
82 |
+
else:
|
83 |
+
self.assertFalse(np.allclose(out, 0))
|
84 |
+
|
85 |
+
@parameterized.parameters(dict(causal=True), dict(causal=False))
|
86 |
+
def test_causal_attention(self, causal):
|
87 |
+
# input = [0, random, random, random]
|
88 |
+
# mask = [1, 0, 1, 1]
|
89 |
+
# For causal attention the second token can only attend to the first one, so
|
90 |
+
# it should be the same. For non-causal attention all tokens should change.
|
91 |
+
|
92 |
+
@hk.transform
|
93 |
+
def forward(emb, mask):
|
94 |
+
transformer = model.Transformer(
|
95 |
+
model.TransformerConfig(
|
96 |
+
num_heads=2,
|
97 |
+
num_layers=2,
|
98 |
+
key_size=5,
|
99 |
+
mlp_hidden_size=64,
|
100 |
+
dropout_rate=0.,
|
101 |
+
layer_norm=False,
|
102 |
+
causal=causal))
|
103 |
+
return transformer(emb, mask).output
|
104 |
+
|
105 |
+
seq_len = 4
|
106 |
+
emb = np.random.random((1, seq_len, 1))
|
107 |
+
emb[:, 0, :] = 0
|
108 |
+
mask = np.array([[1, 0, 1, 1]])
|
109 |
+
emb, mask = jnp.array(emb), jnp.array(mask)
|
110 |
+
|
111 |
+
rng = hk.PRNGSequence(1)
|
112 |
+
params = forward.init(next(rng), emb, mask)
|
113 |
+
params = self._zero_mlps(params)
|
114 |
+
out = forward.apply(params, next(rng), emb, mask)
|
115 |
+
|
116 |
+
self._check_layer_naming(params)
|
117 |
+
if causal:
|
118 |
+
self.assertEqual(0, out[0, 0, 0])
|
119 |
+
self.assertEqual(emb[0, 1, 0], out[0, 1, 0])
|
120 |
+
else:
|
121 |
+
self.assertNotEqual(0, out[0, 0, 0])
|
122 |
+
self.assertNotEqual(emb[0, 1, 0], out[0, 1, 0])
|
123 |
+
self.assertNotEqual(emb[0, 2, 0], out[0, 2, 0])
|
124 |
+
self.assertNotEqual(emb[0, 3, 0], out[0, 3, 0])
|
125 |
+
|
126 |
+
def test_setting_activation_function_to_zero(self):
|
127 |
+
# An activation function that always returns zeros should result in the
|
128 |
+
# same model output as setting all MLP weights to zero.
|
129 |
+
|
130 |
+
@hk.transform
|
131 |
+
def forward_zero(emb, mask):
|
132 |
+
transformer = model.Transformer(
|
133 |
+
model.TransformerConfig(
|
134 |
+
num_heads=2,
|
135 |
+
num_layers=2,
|
136 |
+
key_size=5,
|
137 |
+
mlp_hidden_size=64,
|
138 |
+
dropout_rate=0.,
|
139 |
+
causal=False,
|
140 |
+
layer_norm=False,
|
141 |
+
activation_function=jnp.zeros_like))
|
142 |
+
return transformer(emb, mask).output
|
143 |
+
|
144 |
+
@hk.transform
|
145 |
+
def forward(emb, mask):
|
146 |
+
transformer = model.Transformer(
|
147 |
+
model.TransformerConfig(
|
148 |
+
num_heads=2,
|
149 |
+
num_layers=2,
|
150 |
+
key_size=5,
|
151 |
+
mlp_hidden_size=64,
|
152 |
+
dropout_rate=0.,
|
153 |
+
causal=False,
|
154 |
+
layer_norm=False,
|
155 |
+
activation_function=jax.nn.gelu))
|
156 |
+
return transformer(emb, mask).output
|
157 |
+
|
158 |
+
seq_len = 4
|
159 |
+
emb = np.random.random((1, seq_len, 1))
|
160 |
+
mask = np.ones((1, seq_len))
|
161 |
+
emb, mask = jnp.array(emb), jnp.array(mask)
|
162 |
+
|
163 |
+
rng = hk.PRNGSequence(1)
|
164 |
+
params = forward.init(next(rng), emb, mask)
|
165 |
+
params_no_mlps = self._zero_mlps(params)
|
166 |
+
|
167 |
+
out_zero_activation = forward_zero.apply(params, next(rng), emb, mask)
|
168 |
+
out_no_mlps = forward.apply(params_no_mlps, next(rng), emb, mask)
|
169 |
+
|
170 |
+
self._check_layer_naming(params)
|
171 |
+
np.testing.assert_allclose(out_zero_activation, out_no_mlps)
|
172 |
+
self.assertFalse(np.allclose(out_zero_activation, 0))
|
173 |
+
|
174 |
+
|
175 |
+
class CompiledTransformerModelTest(parameterized.TestCase):
|
176 |
+
|
177 |
+
def _get_one_hot_embed_unembed(self, vocab_size, max_seq_len):
|
178 |
+
# Embeds tokens as one-hot into the first `vocab_size` dimensions
|
179 |
+
token_embed = hk.Embed(
|
180 |
+
embedding_matrix=jnp.block(
|
181 |
+
[jnp.eye(vocab_size),
|
182 |
+
jnp.zeros((vocab_size, max_seq_len))]))
|
183 |
+
|
184 |
+
# Embeds positions as one-hot into the last `max_seq_len` dimensions
|
185 |
+
position_embed = hk.Embed(
|
186 |
+
embedding_matrix=jnp.block(
|
187 |
+
[jnp.zeros((max_seq_len, vocab_size)),
|
188 |
+
jnp.eye(max_seq_len)]))
|
189 |
+
|
190 |
+
class Unembed(hk.Module):
|
191 |
+
|
192 |
+
def __call__(self, embeddings):
|
193 |
+
return jnp.argmax(embeddings[:, :, :vocab_size], axis=-1)
|
194 |
+
|
195 |
+
return token_embed, position_embed, Unembed()
|
196 |
+
|
197 |
+
def test_embedding_gives_desired_result(self):
|
198 |
+
tokens = jnp.array([[1, 2, 3]])
|
199 |
+
vocab_size, max_seq_len, pad_token = 5, 5, 0
|
200 |
+
|
201 |
+
expected_embeddings = jnp.array([[[0, 1, 0, 0, 0, 1, 0, 0, 0, 0],
|
202 |
+
[0, 0, 1, 0, 0, 0, 1, 0, 0, 0],
|
203 |
+
[0, 0, 0, 1, 0, 0, 0, 1, 0, 0]]])
|
204 |
+
|
205 |
+
@hk.transform
|
206 |
+
def embed(tokens):
|
207 |
+
transformer = model.Transformer(
|
208 |
+
model.TransformerConfig(
|
209 |
+
num_heads=2,
|
210 |
+
num_layers=2,
|
211 |
+
key_size=5,
|
212 |
+
mlp_hidden_size=64,
|
213 |
+
dropout_rate=0.,
|
214 |
+
causal=False,
|
215 |
+
layer_norm=False,
|
216 |
+
activation_function=jax.nn.gelu))
|
217 |
+
token_embed, position_embed, unembed = self._get_one_hot_embed_unembed(
|
218 |
+
vocab_size, max_seq_len)
|
219 |
+
compiled_model = model.CompiledTransformerModel(
|
220 |
+
transformer=transformer,
|
221 |
+
token_embed=token_embed,
|
222 |
+
position_embed=position_embed,
|
223 |
+
unembed=unembed,
|
224 |
+
use_unembed_argmax=True,
|
225 |
+
pad_token=pad_token)
|
226 |
+
return compiled_model.embed(tokens)
|
227 |
+
|
228 |
+
rng = hk.PRNGSequence(1)
|
229 |
+
params = embed.init(next(rng), tokens)
|
230 |
+
embeddings = embed.apply(params, next(rng), tokens)
|
231 |
+
|
232 |
+
np.testing.assert_allclose(embeddings, expected_embeddings)
|
233 |
+
|
234 |
+
def test_embedding_then_unembedding_gives_same_tokens(self):
|
235 |
+
tokens = jnp.array([[1, 2, 3], [4, 5, 6], [3, 2, 4]])
|
236 |
+
vocab_size, max_seq_len, pad_token = 10, 5, 0
|
237 |
+
|
238 |
+
@hk.transform
|
239 |
+
def embed_unembed(tokens):
|
240 |
+
transformer = model.Transformer(
|
241 |
+
model.TransformerConfig(
|
242 |
+
num_heads=2,
|
243 |
+
num_layers=2,
|
244 |
+
key_size=5,
|
245 |
+
mlp_hidden_size=64,
|
246 |
+
dropout_rate=0.,
|
247 |
+
causal=False,
|
248 |
+
layer_norm=False,
|
249 |
+
activation_function=jax.nn.gelu))
|
250 |
+
token_embed, position_embed, unembed = self._get_one_hot_embed_unembed(
|
251 |
+
vocab_size, max_seq_len)
|
252 |
+
compiled_model = model.CompiledTransformerModel(
|
253 |
+
transformer=transformer,
|
254 |
+
token_embed=token_embed,
|
255 |
+
position_embed=position_embed,
|
256 |
+
unembed=unembed,
|
257 |
+
use_unembed_argmax=True,
|
258 |
+
pad_token=pad_token)
|
259 |
+
embeddings = compiled_model.embed(tokens)
|
260 |
+
unembeddings = compiled_model.unembed(embeddings)
|
261 |
+
return embeddings, unembeddings
|
262 |
+
|
263 |
+
rng = hk.PRNGSequence(1)
|
264 |
+
params = embed_unembed.init(next(rng), tokens)
|
265 |
+
embeddings, unembeddings = embed_unembed.apply(params, next(rng), tokens)
|
266 |
+
|
267 |
+
self.assertEqual(
|
268 |
+
embeddings.shape,
|
269 |
+
(tokens.shape[0], tokens.shape[1], vocab_size + max_seq_len))
|
270 |
+
|
271 |
+
np.testing.assert_allclose(unembeddings, tokens)
|
272 |
+
|
273 |
+
|
274 |
+
if __name__ == "__main__":
|
275 |
+
absltest.main()
|
utils/debugging.py
ADDED
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2022 DeepMind Technologies Limited. All Rights Reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
# ==============================================================================
|
15 |
+
"""Useful helpers for model debugging."""
|
16 |
+
|
17 |
+
|
18 |
+
def print_arrays(arrays, labels=None, colwidth=12):
|
19 |
+
"""Pretty-prints a list of [1, T, D] arrays."""
|
20 |
+
if labels is not None:
|
21 |
+
print(" |".join(labels))
|
22 |
+
widths = [len(l) for l in labels]
|
23 |
+
else:
|
24 |
+
widths = [colwidth] * len(arrays[0].shape[-1])
|
25 |
+
for layer in arrays:
|
26 |
+
print("=" * (colwidth + 1) * layer.shape[1])
|
27 |
+
for row in layer[0]:
|
28 |
+
print(" |".join([f"{x:<{width}.2f}" for x, width in zip(row, widths)]))
|