CSquid333 commited on
Commit
72cfe15
·
1 Parent(s): c036e04

added a synthesizer on top

Browse files
Files changed (40) hide show
  1. __pycache__/abstract_syntax_tree.cpython-39.pyc +0 -0
  2. __pycache__/python_embedded_rasp.cpython-39.pyc +0 -0
  3. __pycache__/rasp_synthesizer.cpython-39.pyc +0 -0
  4. __pycache__/utils.cpython-39.pyc +0 -0
  5. abstract_syntax_tree.py +72 -0
  6. app.py +18 -0
  7. comp_flows/( tokens_int . 1 )(.1. 2.).pdf +0 -0
  8. outtest.txt +36 -0
  9. python_embedded_rasp.py +308 -0
  10. rasp_synthesizer.py +257 -0
  11. reverse-viz.ipynb +0 -0
  12. testouts.txt +55 -0
  13. tracr/__pycache__/__init__.cpython-39.pyc +0 -0
  14. tracr/compiler/__pycache__/__init__.cpython-39.pyc +0 -0
  15. tracr/compiler/__pycache__/assemble.cpython-39.pyc +0 -0
  16. tracr/compiler/__pycache__/basis_inference.cpython-39.pyc +0 -0
  17. tracr/compiler/__pycache__/compiling.cpython-39.pyc +0 -0
  18. tracr/compiler/__pycache__/craft_graph_to_model.cpython-39.pyc +0 -0
  19. tracr/compiler/__pycache__/craft_model_to_transformer.cpython-39.pyc +0 -0
  20. tracr/compiler/__pycache__/expr_to_craft_graph.cpython-39.pyc +0 -0
  21. tracr/compiler/__pycache__/nodes.cpython-39.pyc +0 -0
  22. tracr/compiler/__pycache__/rasp_to_graph.cpython-39.pyc +0 -0
  23. tracr/craft/__pycache__/__init__.cpython-39.pyc +0 -0
  24. tracr/craft/__pycache__/bases.cpython-39.pyc +0 -0
  25. tracr/craft/__pycache__/transformers.cpython-39.pyc +0 -0
  26. tracr/craft/__pycache__/vectorspace_fns.cpython-39.pyc +0 -0
  27. tracr/craft/chamber/__pycache__/__init__.cpython-39.pyc +0 -0
  28. tracr/craft/chamber/__pycache__/categorical_attn.cpython-39.pyc +0 -0
  29. tracr/craft/chamber/__pycache__/categorical_mlp.cpython-39.pyc +0 -0
  30. tracr/craft/chamber/__pycache__/numerical_mlp.cpython-39.pyc +0 -0
  31. tracr/craft/chamber/__pycache__/selector_width.cpython-39.pyc +0 -0
  32. tracr/rasp/__pycache__/__init__.cpython-39.pyc +0 -0
  33. tracr/rasp/__pycache__/rasp.cpython-39.pyc +0 -0
  34. tracr/transformer/__pycache__/__init__.cpython-39.pyc +0 -0
  35. tracr/transformer/__pycache__/attention.cpython-39.pyc +0 -0
  36. tracr/transformer/__pycache__/encoder.cpython-39.pyc +0 -0
  37. tracr/transformer/__pycache__/model.cpython-39.pyc +0 -0
  38. tracr/utils/__pycache__/__init__.cpython-39.pyc +0 -0
  39. tracr/utils/__pycache__/errors.cpython-39.pyc +0 -0
  40. utils.py +80 -0
__pycache__/abstract_syntax_tree.cpython-39.pyc ADDED
Binary file (2.94 kB). View file
 
__pycache__/python_embedded_rasp.cpython-39.pyc ADDED
Binary file (9.03 kB). View file
 
__pycache__/rasp_synthesizer.cpython-39.pyc ADDED
Binary file (9.09 kB). View file
 
__pycache__/utils.cpython-39.pyc ADDED
Binary file (1.69 kB). View file
 
abstract_syntax_tree.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
+ ABSTRACT SYNTAX TREE
3
+ This file contains the Python class that represents programs created by our rasp synthesizer.
4
+ '''
5
+ from utils import *
6
+
7
+ class OperatorNode:
8
+ '''
9
+ Class to represent operator nodes (i.e., an operator and its operands) as an AST.
10
+
11
+ Args:
12
+ operator (object): operator object (e.g., Select, Aggregate, etc.)
13
+ children (list): list of children nodes (operands)
14
+
15
+ Example:
16
+ select_node: OperatorNode(Select(), [Tokens(), Tokens(), Equal()])
17
+ select_node.str() = "select(tokens, tokens, ==)"
18
+ select_node.evaluate("hi") = [[1, 0], [0, 1]]
19
+ select_node.to_python() = rasp.Select(rasp.tokens, rasp.tokens, rasp.Comparison.EQ)
20
+ '''
21
+ def __init__(self, operator, children):
22
+ self.operator = operator
23
+ self.children = children
24
+ self.weight = operator.weight + sum([child.weight for child in children])
25
+ self.return_type = operator.return_type
26
+
27
+ def str(self):
28
+ if len(self.children) != self.operator.n_args:
29
+ raise ValueError("Improper number of arguments for operator.")
30
+ operand_strings = [child.str() for child in self.children]
31
+ return f"({self.operator.str(*operand_strings)})"
32
+
33
+ def evaluate(self, input=None):
34
+ '''
35
+ Directly evaluate the python translation.
36
+ '''
37
+ exe = self.to_python()
38
+ return exe(input)
39
+
40
+ # DEPRECATED VERSION: uses the actual rasp repl
41
+ # exe = f"({self.str()})" + f"({repr(input)});".replace("'", "\"")
42
+ # return run_repl(exe)
43
+
44
+ def to_python(self):
45
+ if len(self.children) != self.operator.n_args:
46
+ raise ValueError("Improper number of arguments for operator.")
47
+ operands = [child.to_python() for child in self.children]
48
+ return self.operator.to_python(*operands)
49
+
50
+ '''
51
+ TESTING
52
+ '''
53
+ if __name__ == "__main__":
54
+ from python_embedded_rasp import *
55
+ from tracr.rasp import rasp
56
+
57
+ select_op = OperatorNode(Select(), [Tokens(), Tokens(), Equal()]) # wait should children be operators or operator nodes? maybe can be either?
58
+ assert (select_op.weight == 4)
59
+
60
+ select_op_str = select_op.str()
61
+ actual_so_str = "(select(tokens, tokens, ==))"
62
+ assert select_op_str == actual_so_str
63
+
64
+ select_op_res = select_op.evaluate("hi")
65
+ actual_so_res = [[1, 0],[0, 1]]
66
+ assert select_op_res == actual_so_res
67
+
68
+ select_op_python = select_op.to_python()
69
+ actual_so_python = rasp.Select(rasp.tokens, rasp.tokens, rasp.Comparison.EQ)
70
+ assert type(select_op_python) == type(actual_so_python)
71
+
72
+ print("all tests passed hooray!")
app.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
+ For future reference with downloading model files:
3
+
4
+ import streamlit as st
5
+ import pickle
6
+ import base64
7
+
8
+ x = {"my": "data"}
9
+
10
+ def download_model(model):
11
+ output_model = pickle.dumps(model)
12
+ b64 = base64.b64encode(output_model).decode()
13
+ href = f'<a href="data:file/output_model;base64,{b64}" download="myfile.pkl">Download Trained Model .pkl File</a>'
14
+ st.markdown(href, unsafe_allow_html=True)
15
+
16
+
17
+ download_model(x)
18
+ '''
comp_flows/( tokens_int . 1 )(.1. 2.).pdf ADDED
Binary file (19.2 kB). View file
 
outtest.txt ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Received the following input and output examples:
2
+ [(['h', 'e', 'l', 'l', 'o'], [1, 1, 2, 2, 1])]
3
+ Running synthesizer with
4
+ Vocab: {'o', 'e', 'h', 'l'}
5
+ Max sequence length: 5
6
+ Max weight: 25
7
+ (indices - indices)
8
+ [[0, 0, 0, 0, 0]]
9
+ (indices - 0)
10
+ [[0, 1, 2, 3, 4]]
11
+ (indices - 1)
12
+ [[-1, 0, 1, 2, 3]]
13
+ (0 - indices)
14
+ [[0, -1, -2, -3, -4]]
15
+ (1 - indices)
16
+ [[1, 0, -1, -2, -3]]
17
+ (select(tokens, tokens, ==))
18
+ [[[True, False, False, False, False], [False, True, False, False, False], [False, False, True, True, False], [False, False, True, True, False], [False, False, False, False, True]]]
19
+ (select(tokens, tokens, true))
20
+ [[[True, True, True, True, True], [True, True, True, True, True], [True, True, True, True, True], [True, True, True, True, True], [True, True, True, True, True]]]
21
+ (select(tokens, indices, ==))
22
+ [[[False, False, False, False, False], [False, False, False, False, False], [False, False, False, False, False], [False, False, False, False, False], [False, False, False, False, False]]]
23
+ (select(tokens, indices, true))
24
+ [[[True, True, True, True, True], [True, True, True, True, True], [True, True, True, True, True], [True, True, True, True, True], [True, True, True, True, True]]]
25
+ (select(indices, tokens, ==))
26
+ [[[False, False, False, False, False], [False, False, False, False, False], [False, False, False, False, False], [False, False, False, False, False], [False, False, False, False, False]]]
27
+ (select(indices, tokens, true))
28
+ [[[True, True, True, True, True], [True, True, True, True, True], [True, True, True, True, True], [True, True, True, True, True], [True, True, True, True, True]]]
29
+ (select(indices, indices, ==))
30
+ [[[True, False, False, False, False], [False, True, False, False, False], [False, False, True, False, False], [False, False, False, True, False], [False, False, False, False, True]]]
31
+ (select(indices, indices, true))
32
+ [[[True, True, True, True, True], [True, True, True, True, True], [True, True, True, True, True], [True, True, True, True, True], [True, True, True, True, True]]]
33
+ (select_width((select(tokens, tokens, ==))))
34
+ [[1, 1, 2, 2, 1]]
35
+ The following program has been compiled to a transformer with 1 layer(s):
36
+ (select_width((select(tokens, tokens, ==))))
python_embedded_rasp.py ADDED
@@ -0,0 +1,308 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
+ RASP OPERATORS THAT ARE SUPPORTED BY TRACR'S PYTHON EMBEDDING
3
+ This file contains Python classes that define the rasp operators supported by TRACR's python embedding of the langauge.
4
+ This is subset of everything that TRACR supports in python, due to project time constraints.
5
+ '''
6
+ import random
7
+ from typing import (Any, Callable, Dict, Generic, List, Mapping, Optional,
8
+ Sequence, TypeVar, Union)
9
+ from tracr.rasp import rasp
10
+ import subprocess
11
+ import time
12
+
13
+ '''
14
+ CLASS DEFINITIONS
15
+ '''
16
+ class Tokens:
17
+ '''
18
+ Tokens constant.
19
+ '''
20
+ def __init__(self):
21
+ self.n_args = 0
22
+ self.arg_types = []
23
+ self.return_type = rasp.SOp
24
+ self.weight = 1
25
+
26
+ def to_python(self):
27
+ # return an object that can be compiled into a TRACR transformer
28
+ # arguments should be python objects
29
+ return rasp.tokens
30
+
31
+ def str(self):
32
+ # represent rasp operator in string form
33
+ # expects arguments to be strings
34
+ return "tokens"
35
+
36
+ class Indices:
37
+ def __init__(self):
38
+ self.n_args = 0
39
+ self.arg_types = []
40
+ self.return_type = rasp.SOp
41
+ self.weight = 1
42
+
43
+ def to_python(self):
44
+ # return an object that can be compiled into a TRACR transformer
45
+ # arguments should be python objects
46
+ return rasp.indices
47
+
48
+ def str(self):
49
+ # represent rasp operator in string form
50
+ # expects arguments to be strings
51
+ return "indices"
52
+
53
+ class Zero:
54
+ def __init__(self):
55
+ self.n_args = 0
56
+ self.arg_types = []
57
+ self.return_type = int
58
+ self.weight = 1
59
+
60
+ def to_python(self):
61
+ # return an object that can be compiled into a TRACR transformer
62
+ # arguments should be python objects
63
+ return 0
64
+
65
+ def str(self):
66
+ # represent rasp operator in string form
67
+ # expects arguments to be strings
68
+ return "0"
69
+
70
+ class One:
71
+ def __init__(self):
72
+ self.n_args = 0
73
+ self.arg_types = []
74
+ self.return_type = int
75
+ self.weight = 1
76
+
77
+ def to_python(self):
78
+ # return an object that can be compiled into a TRACR transformer
79
+ # arguments should be python objects
80
+ return 1
81
+
82
+ def str(self):
83
+ # represent rasp operator in string form
84
+ # expects arguments to be strings
85
+ return "1"
86
+
87
+ class Equal:
88
+ '''
89
+ Comparison Equal constant.
90
+ '''
91
+ def __init__(self):
92
+ self.n_args = 0
93
+ self.arg_types = []
94
+ self.return_type = rasp.Predicate
95
+ self.weight = 1
96
+
97
+ def to_python(self):
98
+ # return an object that can be compiled into a TRACR transformer
99
+ # arguments should be python objects
100
+ return rasp.Comparison.EQ
101
+
102
+ def str(self):
103
+ # represent rasp operator in string form
104
+ # expects arguments to be strings
105
+ return "=="
106
+
107
+ class GT:
108
+ '''
109
+ Greater Than comparison operator.
110
+ '''
111
+ pass
112
+
113
+ class LT:
114
+ '''
115
+ Less Than comparison operator
116
+ '''
117
+ pass
118
+
119
+ class LEQ:
120
+ pass
121
+
122
+ class GEQ:
123
+ pass
124
+
125
+ class TRUE:
126
+ '''
127
+ Comparison True constant.
128
+ '''
129
+ def __init__(self):
130
+ self.n_args = 0
131
+ self.arg_types = []
132
+ self.return_type = rasp.Predicate
133
+ self.weight = 1
134
+
135
+ def to_python(self):
136
+ # return an object that can be compiled into a TRACR transformer
137
+ # arguments should be python objects
138
+ return rasp.Comparison.TRUE
139
+
140
+ def str(self):
141
+ # represent rasp operator in string form
142
+ # expects arguments to be strings
143
+ return "true"
144
+
145
+ class FALSE:
146
+ pass
147
+
148
+ class Add:
149
+ '''
150
+ Element-wise.
151
+ Input can be either int, float or s-op.
152
+ '''
153
+ pass
154
+
155
+ class Subtract:
156
+ '''
157
+ Element-wise.
158
+ Input can be either int, float or s-op.
159
+ '''
160
+ def __init__(self):
161
+ self.n_args = 2
162
+ self.arg_types = [Union[rasp.SOp, float, int], Union[rasp.SOp, float, int]]
163
+ self.return_type = Union[rasp.SOp, int, float]
164
+ self.weight = 1
165
+
166
+ def to_python(self, x, y):
167
+ # return an object that can be compiled into a TRACR transformer
168
+ # arguments should be python objects
169
+ if type(x) == type(rasp.tokens):
170
+ return None
171
+ if type(y) == type(rasp.tokens):
172
+ return None
173
+ return x - y
174
+
175
+ def str(self, x, y):
176
+ # represent rasp operator in string form
177
+ # expects arguments to be strings
178
+ return f"{x} - {y}"
179
+
180
+ class Mult:
181
+ '''
182
+ Element-wise.
183
+ Input can be either int, float or s-op.
184
+ '''
185
+ pass
186
+
187
+ class Divide:
188
+ '''
189
+ Element-wise.
190
+ Input can be either int, float or s-op.
191
+ '''
192
+ pass
193
+
194
+ class Fill:
195
+ '''
196
+ Given fill value and length, returns Sop of that length with that fill value.
197
+ Fill value can be int, float, or char.
198
+ Length must be a positive integer.
199
+ '''
200
+ pass
201
+
202
+ class SelectorAnd:
203
+ '''
204
+ Input can be bool or s-op.
205
+ '''
206
+ pass
207
+
208
+ class SelectorOr:
209
+ '''
210
+ Input can be bool or s-op.
211
+ '''
212
+ pass
213
+
214
+ class SelectorNot:
215
+ '''
216
+ Input is an s-op of bools. (Or bool-convertible values.)
217
+ '''
218
+ pass
219
+
220
+ class Select:
221
+ '''
222
+ Select operator.
223
+ '''
224
+ def __init__(self):
225
+ self.n_args = 3
226
+ self.arg_types = [rasp.SOp, rasp.SOp, rasp.Predicate]
227
+ self.return_type = rasp.Selector
228
+ self.weight = 1
229
+
230
+ def to_python(self, sop1, sop2, comp):
231
+ # return an object that can be compiled into a TRACR transformer
232
+ # arguments should be python objects
233
+ return rasp.Select(sop1, sop2, comp)
234
+
235
+ def str(self, sop1, sop2, comp):
236
+ # represent rasp operator in string form
237
+ # expects arguments to be strings
238
+ return f"select({sop1}, {sop2}, {comp})"
239
+
240
+ class Aggregate:
241
+ '''
242
+ The Aggregate operator.
243
+ '''
244
+ def __init__(self):
245
+ self.n_args = 2
246
+ self.arg_types = [rasp.Selector, rasp.SOp]
247
+ self.return_type = rasp.SOp
248
+ self.weight = 1
249
+
250
+ def to_python(self, sel, sop):
251
+ # return an object that can be compiled into a TRACR transformer
252
+ # arguments should be python objects
253
+ return rasp.Aggregate(sel, sop)
254
+
255
+ def str(self, sel, sop):
256
+ # represent rasp operator in string form
257
+ # expects arguments to be strings
258
+ return f"aggregate({sel}, {sop})"
259
+
260
+ class SelectorWidth:
261
+ '''
262
+ The selector_width operator.
263
+ '''
264
+ def __init__(self):
265
+ self.n_args = 1
266
+ self.arg_types = [rasp.Selector]
267
+ self.return_type = rasp.SOp
268
+ self.weight = 1
269
+
270
+ def to_python(self, sel):
271
+ # return an object that can be compiled into a TRACR transformer
272
+ # arguments should be python objects
273
+ return rasp.SelectorWidth(sel)
274
+
275
+ def str(self, sel):
276
+ # represent rasp operator in string form
277
+ # expects arguments to be strings
278
+ return f"select_width({sel})"
279
+
280
+ '''
281
+ GLOBAL CONSTANTS
282
+ '''
283
+
284
+ # define operators
285
+ rasp_operators = [Select(), SelectorWidth(), Aggregate(), Subtract()]
286
+ rasp_consts = [Tokens(), Tokens(), Equal(), TRUE(), Indices(), Indices(), Zero(), One()]
287
+ '''
288
+ TESTING
289
+ '''
290
+ if __name__ == "__main__":
291
+ test_select = Select()
292
+
293
+ test_select_python = test_select.to_python(Tokens().to_python(), Tokens().to_python(), Equal().to_python())
294
+ actual_ts_python = rasp.Select(rasp.tokens, rasp.tokens, rasp.Comparison.EQ)
295
+ assert type(Tokens().to_python()) == type(rasp.tokens)
296
+ assert type(Equal().to_python() == type(rasp.Comparison.EQ))
297
+ assert type(test_select_python) == type(actual_ts_python)
298
+
299
+ test_select_string = test_select.str(Tokens().str(), Tokens().str(), Equal().str())
300
+ actual_ts_string = "select(tokens, tokens, ==)"
301
+ assert(test_select_string == actual_ts_string)
302
+
303
+
304
+ test_aggregate = Aggregate()
305
+ print(rasp.Aggregate(rasp.Select(rasp.tokens, rasp.tokens, rasp.Comparison.EQ), rasp.tokens)("hi"))
306
+
307
+ print("all tests passed hooray!")
308
+
rasp_synthesizer.py ADDED
@@ -0,0 +1,257 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
+ BOTTOM-UP ENUMERATIVE SYTHESIS FOR RASP
3
+
4
+ Usage:
5
+ python rasp_synthesis.py --examples
6
+ '''
7
+ import numpy as np
8
+ import argparse
9
+ import itertools
10
+ import time
11
+ import ast
12
+ import re
13
+ from tracr.compiler import compiling
14
+ from typing import get_args
15
+ import inspect
16
+
17
+ from abstract_syntax_tree import *
18
+ from python_embedded_rasp import *
19
+
20
+ # PARSE ARGUMENTS
21
+ def parse_args():
22
+ '''
23
+ Parse command line arguments.
24
+ '''
25
+ parser = argparse.ArgumentParser(description="Bottom-up enumerative synthesis for RASP.")
26
+ parser.add_argument('--examples', required=True, help="input/output sequence examples for synthesis")
27
+ parser.add_argument('--max_weight', type=int, required=False, default=10, help="Maximum weight of programs to consider before terminating search.")
28
+ args = parser.parse_args()
29
+ return args
30
+
31
+ # ANALYZE EXAMPLES
32
+ def analyze_examples(inputs):
33
+ '''
34
+ Returns a list of unique (input_sequence, output_sequence) tuples of proper python types.
35
+ Ensures each example is only numeric values or only char values.
36
+ Returns useful constants given the input examples.
37
+ '''
38
+ example_ins = []
39
+ example_outs = []
40
+ try:
41
+ # Safely evaluate the string to a Python object
42
+ examples_lst = ast.literal_eval(inputs)
43
+ except (SyntaxError, ValueError) as e:
44
+ raise argparse.ArgumentTypeError(f"Invalid examples format: {e}")
45
+
46
+ if not isinstance(examples_lst, list):
47
+ raise ValueError("Input should be a list.")
48
+ for ex in examples_lst:
49
+ try:
50
+ ins, outs = ex[0], ex[1]
51
+ except:
52
+ raise argparse.ArgumentTypeError(f"Invalid examples format.")
53
+
54
+ def same_legal_type(lst):
55
+ return (all(isinstance(x, int) for x in lst) or
56
+ all(isinstance(x, float) for x in lst) or
57
+ all(isinstance(x, bool) for x in lst) or
58
+ all(isinstance(x, str) for x in lst))
59
+
60
+ if same_legal_type(ins) and same_legal_type(outs):
61
+ example_ins.append(ins)
62
+ example_outs.append(outs)
63
+ continue
64
+ raise argparse.ArgumentTypeError(f"Each example must have consistent types. Expected inputs to have type {first_in_type} and outputs to have {first_out_type} but instead inputs have types {[type(x) for x in ins]} and outputs have types {[type(x) for x in outs]}")
65
+
66
+ return example_ins, example_outs
67
+
68
+ # GET VOCABULARY
69
+ def get_vocabulary(examples):
70
+ '''
71
+ Returns vocabulary for later compiling the RASP model.
72
+ '''
73
+ vocab = []
74
+ for ex in examples:
75
+ ins, outs = ex[0], ex[1]
76
+ vocab.extend([obj for obj in ins])
77
+ return set(vocab)
78
+
79
+ # CHECK OBSERVATIONAL EQUIVALENCE
80
+ def check_obs_equivalence(examples, program_a, program_b):
81
+ try:
82
+ inputs = [example[0] for example in examples]
83
+ a_output = None
84
+ b_output = None
85
+ if program_a not in rasp_consts:
86
+ a_output = [program_a.evaluate(input) for input in inputs]
87
+ if program_b not in rasp_consts:
88
+ b_output = [program_b.evaluate(input) for input in inputs]
89
+ except:
90
+ return True # force the synthesizer to not consider this program
91
+
92
+ return a_output == b_output
93
+
94
+ # CHECK CORRECTNESS
95
+ def check_correctness(examples, program):
96
+ '''
97
+ Checks if the programs output matches expected output on all examples.
98
+ '''
99
+ try:
100
+ inputs = [example[0] for example in examples]
101
+ outputs = [example[1] for example in examples]
102
+ program_output = [program.evaluate(input) for input in inputs]
103
+ except:
104
+ return False
105
+
106
+ print(program.str())
107
+ print(program_output)
108
+
109
+ # TODO return number that match and return this
110
+
111
+ return program_output == outputs
112
+
113
+ # COMPARE TYPE SIGNATURES
114
+ def compare_types(list1, list2):
115
+ for idx, type1 in enumerate(list1):
116
+ if idx >= len(list2):
117
+ return False # The first list is longer than the second list
118
+
119
+ type2 = list2[idx]
120
+
121
+ # Check if type2 is a Union
122
+ if hasattr(type2, '__origin__') and type2.__origin__ is Union:
123
+ # Extract types from Union
124
+ types_in_union2 = get_args(type2)
125
+ # Check if type1 is a Union
126
+ if hasattr(type1, '__origin__') and type1.__origin__ is Union:
127
+ types_in_union1 = get_args(type1)
128
+ # Check if all types in type1's Union are in type2's Union
129
+ if not all(any(t1 == t2 for t2 in types_in_union2) for t1 in types_in_union1):
130
+ return False
131
+ else:
132
+ # Check if type1 is in type2's Union
133
+ if not any(type1 == t2 for t2 in types_in_union2):
134
+ return False
135
+ else:
136
+ # Direct type comparison
137
+ if type1 != type2:
138
+ return False
139
+
140
+ return True
141
+
142
+ # RUN SYNTHESIZER
143
+ def run_synthesizer(examples, max_weight):
144
+ '''
145
+ Run bottom-up enumerative synthesis.
146
+ '''
147
+ program_bank = rasp_consts
148
+ program_bank_str = [p.str() for p in program_bank]
149
+
150
+ # TODO: store approximate programs, measured by number of output examples that match
151
+
152
+ # iterate over each level
153
+ for weight in range(2, max_weight):
154
+
155
+ for op in rasp_operators:
156
+ combinations = itertools.permutations(program_bank, op.n_args)
157
+
158
+ for combination in combinations:
159
+
160
+ type_signature = [p.return_type for p in combination]
161
+
162
+ if not compare_types(type_signature, op.arg_types):
163
+ continue
164
+
165
+ if sum([p.weight for p in combination]) > weight:
166
+ continue
167
+
168
+ program = OperatorNode(op, combination)
169
+
170
+ if program.str() in program_bank_str:
171
+ continue
172
+
173
+ if any([check_obs_equivalence(examples, program, p) for p in program_bank]):
174
+ continue
175
+
176
+ program_bank.append(program)
177
+ program_bank_str.append(program.str())
178
+
179
+ if check_correctness(examples, program):
180
+ return(program)
181
+
182
+ return None
183
+
184
+ # COMPILE RASP MODEL
185
+ if __name__ == "__main__":
186
+
187
+ '''
188
+ Some examples:
189
+ Identify anagrams:
190
+ [[['V','I','W',',','W','I','V'], [True, True, True, True, True, True, True]],[['a','b',',','b','a'], [True, True, True, True, True]],[['e','l',',','s','t'], [False, False, False, False, False]]]
191
+ Output: times out
192
+ Calculate the median of a list of numbers:
193
+ [[[1,2,3,4,5], [3,3,3,3,3]], [[2,8,10,11], [9,9,9,9]], [[1,2,3],[2,2,2]]]
194
+ Output: times out
195
+ Identity function:
196
+ [[['h','i'], ['h','i']]]
197
+ Output: (aggregate((select(tokens, tokens, ==)), tokens))
198
+ Histogram:
199
+ [[['h', 'e', 'l', 'l', 'o'], [1,1,2,2,1]]]
200
+ Output: (select_width((select(tokens, tokens, ==))))
201
+ Length:
202
+ [[[7,2,5],[3,3,3]],[[1],[1]],[[2,0,1,7,3,6,8,20],[8,8,8,8,8,8,8,8]]]
203
+ Output: (select_width((select(tokens, tokens, true))))
204
+ Calculate mean of list of numbers:
205
+ [[[5,10,3,2,43], [12.6, 12.6, 12.6, 12.6, 12.6]],[[1,2], [1.5, 1.5]],[[3,3,3],[3,3,3]]]
206
+ Output: (aggregate((select(tokens, tokens, true)), tokens))
207
+ Reverse a string:
208
+ [[['h', 'i'], ['i', 'h']]]
209
+ Output: times out
210
+ Expected: aggregate(select(indices, (select_width((select(tokens, tokens, true)))) - indices - 1, ==), tokens);
211
+ PERSONAL TODOS:
212
+ - output several similar programs
213
+ -
214
+
215
+ '''
216
+
217
+ args = parse_args()
218
+ inputs, outs = analyze_examples(args.examples)
219
+ examples = list(zip(inputs, outs))
220
+ print("Received the following input and output examples:")
221
+ print(examples)
222
+ max_seq_len = 0
223
+ for i in inputs:
224
+ max_seq_len = max(len(i), max_seq_len)
225
+ vocab = get_vocabulary(examples)
226
+
227
+ print("Running synthesizer with")
228
+ print("Vocab: {}".format(vocab))
229
+ print("Max sequence length: {}".format(max_seq_len))
230
+ print("Max weight: {}".format(args.max_weight))
231
+
232
+ program = run_synthesizer(examples, args.max_weight)
233
+
234
+ if program:
235
+ algorithm = program.to_python()
236
+
237
+ bos = "BOS"
238
+ model = compiling.compile_rasp_to_model(
239
+ algorithm,
240
+ vocab=vocab,
241
+ max_seq_len=max_seq_len,
242
+ compiler_bos=bos,
243
+ )
244
+
245
+
246
+ def extract_layer_number(s):
247
+ match = re.search(r'layer_(\d+)', s)
248
+ if match:
249
+ return int(match.group(1)) + 1
250
+ else:
251
+ return None
252
+
253
+ layer_num = extract_layer_number(list(model.params.keys())[-1])
254
+ print(f"The following program has been compiled to a transformer with {layer_num} layer(s):")
255
+ print(program.str())
256
+ else:
257
+ print("No program found.")
reverse-viz.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
testouts.txt ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Received the following input and output examples:
2
+ [(['h', 'i'], ['h', 'h'])]
3
+ Running synthesizer with
4
+ Vocab: {'h', 'i'}
5
+ Max sequence length: 2
6
+ Max weight: 15
7
+ - Searching level 2 with 4 primitives.
8
+ - Searching level 3 with 4 primitives.
9
+ (select(tokens, tokens, ==))
10
+ [[[True, False], [False, True]]]
11
+ (select(tokens, tokens, true))
12
+ [[[True, True], [True, True]]]
13
+ - Searching level 4 with 6 primitives.
14
+ (select_width((select(tokens, tokens, ==))))
15
+ [[1, 1]]
16
+ (select_width((select(tokens, tokens, true))))
17
+ [[2, 2]]
18
+ - Searching level 5 with 8 primitives.
19
+ - Searching level 6 with 8 primitives.
20
+ - Searching level 7 with 8 primitives.
21
+ - Searching level 8 with 8 primitives.
22
+ - Searching level 9 with 8 primitives.
23
+ (aggregate((select(tokens, tokens, ==)), (select_width((select(tokens, tokens, ==))))))
24
+ [[1.0, 1.0]]
25
+ (aggregate((select(tokens, tokens, ==)), (select_width((select(tokens, tokens, true))))))
26
+ [[2.0, 2.0]]
27
+ (aggregate((select(tokens, tokens, true)), (select_width((select(tokens, tokens, ==))))))
28
+ [[1.0, 1.0]]
29
+ (aggregate((select(tokens, tokens, true)), (select_width((select(tokens, tokens, true))))))
30
+ [[2.0, 2.0]]
31
+ - Searching level 10 with 12 primitives.
32
+ - Searching level 11 with 12 primitives.
33
+ - Searching level 12 with 12 primitives.
34
+ - Searching level 13 with 12 primitives.
35
+ - Searching level 14 with 12 primitives.
36
+ (aggregate((select(tokens, tokens, ==)), (aggregate((select(tokens, tokens, ==)), (select_width((select(tokens, tokens, ==))))))))
37
+ [[1.0, 1.0]]
38
+ (aggregate((select(tokens, tokens, ==)), (aggregate((select(tokens, tokens, ==)), (select_width((select(tokens, tokens, true))))))))
39
+ [[2.0, 2.0]]
40
+ (aggregate((select(tokens, tokens, ==)), (aggregate((select(tokens, tokens, true)), (select_width((select(tokens, tokens, ==))))))))
41
+ [[1.0, 1.0]]
42
+ (aggregate((select(tokens, tokens, ==)), (aggregate((select(tokens, tokens, true)), (select_width((select(tokens, tokens, true))))))))
43
+ [[2.0, 2.0]]
44
+ > c:\users\18084\desktop\cs252r\final_project\tracr-synthesis\rasp_synthesizer.py(94)check_obs_equivalence()
45
+ -> return a_output == b_output
46
+ (Pdb) --KeyboardInterrupt--
47
+ (Pdb) --KeyboardInterrupt--
48
+ (Pdb) --KeyboardInterrupt--
49
+ (Pdb) *** SyntaxError: invalid syntax
50
+ (Pdb) --KeyboardInterrupt--
51
+ (Pdb) *** SyntaxError: invalid syntax
52
+ (Pdb) --KeyboardInterrupt--
53
+ (Pdb) --KeyboardInterrupt--
54
+ (Pdb) *** SyntaxError: invalid syntax
55
+ (Pdb)
tracr/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (168 Bytes). View file
 
tracr/compiler/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (341 Bytes). View file
 
tracr/compiler/__pycache__/assemble.cpython-39.pyc ADDED
Binary file (9.98 kB). View file
 
tracr/compiler/__pycache__/basis_inference.cpython-39.pyc ADDED
Binary file (2.97 kB). View file
 
tracr/compiler/__pycache__/compiling.cpython-39.pyc ADDED
Binary file (2.48 kB). View file
 
tracr/compiler/__pycache__/craft_graph_to_model.cpython-39.pyc ADDED
Binary file (6.71 kB). View file
 
tracr/compiler/__pycache__/craft_model_to_transformer.cpython-39.pyc ADDED
Binary file (1.69 kB). View file
 
tracr/compiler/__pycache__/expr_to_craft_graph.cpython-39.pyc ADDED
Binary file (7.4 kB). View file
 
tracr/compiler/__pycache__/nodes.cpython-39.pyc ADDED
Binary file (442 Bytes). View file
 
tracr/compiler/__pycache__/rasp_to_graph.cpython-39.pyc ADDED
Binary file (1.78 kB). View file
 
tracr/craft/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (174 Bytes). View file
 
tracr/craft/__pycache__/bases.cpython-39.pyc ADDED
Binary file (10 kB). View file
 
tracr/craft/__pycache__/transformers.cpython-39.pyc ADDED
Binary file (7.64 kB). View file
 
tracr/craft/__pycache__/vectorspace_fns.cpython-39.pyc ADDED
Binary file (5.32 kB). View file
 
tracr/craft/chamber/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (182 Bytes). View file
 
tracr/craft/chamber/__pycache__/categorical_attn.cpython-39.pyc ADDED
Binary file (4.25 kB). View file
 
tracr/craft/chamber/__pycache__/categorical_mlp.cpython-39.pyc ADDED
Binary file (5.04 kB). View file
 
tracr/craft/chamber/__pycache__/numerical_mlp.cpython-39.pyc ADDED
Binary file (10.6 kB). View file
 
tracr/craft/chamber/__pycache__/selector_width.cpython-39.pyc ADDED
Binary file (4.54 kB). View file
 
tracr/rasp/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (173 Bytes). View file
 
tracr/rasp/__pycache__/rasp.cpython-39.pyc ADDED
Binary file (36.6 kB). View file
 
tracr/transformer/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (180 Bytes). View file
 
tracr/transformer/__pycache__/attention.cpython-39.pyc ADDED
Binary file (4.83 kB). View file
 
tracr/transformer/__pycache__/encoder.cpython-39.pyc ADDED
Binary file (5.39 kB). View file
 
tracr/transformer/__pycache__/model.cpython-39.pyc ADDED
Binary file (5.25 kB). View file
 
tracr/utils/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (174 Bytes). View file
 
tracr/utils/__pycache__/errors.cpython-39.pyc ADDED
Binary file (928 Bytes). View file
 
utils.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import subprocess
2
+ import time
3
+ import re
4
+ import ast
5
+
6
+ # Start the REPL subprocess
7
+ python_exe = '/Users/18084/Desktop/CS252R/final_project/rasp-env-py3.9/Scripts/python.exe' #SETUP THING: replace with path to your python environment
8
+
9
+ '''
10
+ THE FOLLOWING FUNCTIONS ARE DEPRECATED
11
+ '''
12
+ def clean_carrots(text):
13
+ pattern = r">>(.*?)>>"
14
+
15
+ match = re.search(pattern, text)
16
+ if match:
17
+ result = match.group(1).strip() # .strip() is used to remove any leading/trailing whitespace
18
+ return result
19
+
20
+ def parse_output(out):
21
+ out = clean_carrots(out)
22
+ out = ast.literal_eval(out)
23
+ # can arrive as tuple, list, or dictionary
24
+ # ultimately want to convert everything to list form
25
+ if isinstance(out, dict):
26
+ return list(out.values())
27
+ if isinstance(out, tuple):
28
+ return list(out)
29
+ if isinstance(out, list):
30
+ return list
31
+ raise Exception("Error executing rasp program.")
32
+
33
+ def run_repl(command):
34
+ '''
35
+ Runs the RASP repl in a separate subprocess.
36
+ '''
37
+ process = subprocess.Popen([python_exe, 'RASP/RASP_support/REPL.py'],
38
+ stdin=subprocess.PIPE,
39
+ stdout=subprocess.PIPE,
40
+ stderr=subprocess.PIPE,
41
+ text=True)
42
+
43
+ # Send commands to the REPL
44
+ process.stdin.write(f'{command}\nexit()\n')
45
+ process.stdin.flush()
46
+
47
+ # Check periodically if the subprocess has terminated
48
+ while True:
49
+ if process.poll() is not None:
50
+ # The subprocess has terminated
51
+ break
52
+ time.sleep(0.1) # Wait for a short period (e.g., 0.1 seconds) before checking again
53
+
54
+ # Close the subprocess if still running
55
+ if process.poll() is None:
56
+ process.terminate()
57
+
58
+ # Read output and error
59
+ output = process.stdout.readlines()
60
+ error = process.stderr.readlines()
61
+
62
+ # Print output and error
63
+ str_output = ""
64
+ str_error = ""
65
+ for line in output:
66
+ str_output += line.strip() + " "
67
+ for line in error:
68
+ str_error += line.strip() + " "
69
+
70
+ str_output = parse_output(str_output)
71
+ return str_output, str_error
72
+
73
+ if __name__ == "__main__":
74
+ command = "select(tokens, tokens, ==)(\"hi\");"
75
+ res, _res_err = run_repl(command)
76
+ print(res)
77
+
78
+ command = "selector_width(select(tokens, tokens, ==))(\"hi\");"
79
+ res, _res_err = run_repl(command)
80
+ print(res)