Spaces:
Sleeping
Sleeping
File size: 2,432 Bytes
72cfe15 f0b559a 72cfe15 f0b559a 72cfe15 f0b559a 72cfe15 f0b559a 72cfe15 f0b559a 275482b f0b559a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 |
import streamlit as st
import numpy as np
import argparse
import itertools
import time
import ast
import re
from tracr.compiler import compiling
from typing import get_args
import inspect
import pickle
import base64
from abstract_syntax_tree import *
from python_embedded_rasp import *
from rasp_synthesizer import *
# HELPER FUNCTIONS
def download_model(model):
output_model = pickle.dumps(model)
b64 = base64.b64encode(output_model).decode()
href = f'<a href="data:file/output_model;base64,{b64}" download="model_params.pkl">Download Haiku Model Parameters in a .pkl File</a>'
st.markdown(href, unsafe_allow_html=True)
# APP DRIVER CODE
st.title("Bottom Up Synthesis for RASP")
max_weight = st.slider("Choose the maximum program weight to search for (~ size of transformer)", 2, 20, 15)
default_example = "[[['h', 'e', 'l', 'l', 'o'], [1,1,2,2,1]]]"
example_text = st.text_input(label = "Provide Input and Output Examples", value = default_example)
inputs, outs = analyze_examples(example_text)
examples = list(zip(inputs, outs))
st.write("Received the following input and output examples:")
st.write(examples)
max_seq_len = 0
for i in inputs:
max_seq_len = max(len(i), max_seq_len)
vocab = get_vocabulary(examples)
st.subheader("Synthesis Configuration")
st.write("Running synthesizer with")
st.write("Vocab: {}".format(vocab))
st.write("Max sequence length: {}".format(max_seq_len))
st.write("Max weight: {}".format(max_weight))
program, approx_programs = run_synthesizer(examples, max_weight)
st.subheader("Synthesis Results:")
st.caption("May take a while.")
if program:
algorithm = program.to_python()
bos = "BOS"
model = compiling.compile_rasp_to_model(
algorithm,
vocab=vocab,
max_seq_len=max_seq_len,
compiler_bos=bos,
)
def extract_layer_number(s):
match = re.search(r'layer_(\d+)', s)
if match:
return int(match.group(1)) + 1
else:
return None
layer_num = extract_layer_number(list(model.params.keys())[-1])
st.write(f"The following program has been compiled to a transformer with {layer_num} layer(s):")
st.write(program.str())
st.write("Here is a model download link: ")
hk_model = model.params
download_model(hk_model)
else:
st.write("No exact program found but here is some brainstorming fodder for you: {}".format(approx_programs)) |