Vladimir Mikulik
add typing_extensions to list of deps.
d4d39d0
raw
history blame
12.1 kB
# Copyright 2022 DeepMind Technologies Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""RASP programs only using the subset of RASP supported by the compiler."""
from typing import Sequence
from tracr.rasp import rasp
### Programs that work only under non-causal evaluation.
def make_length() -> rasp.SOp:
"""Creates the `length` SOp using selector width primitive.
Example usage:
length = make_length()
length("abcdefg")
>> [7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0]
Returns:
length: SOp mapping an input to a sequence, where every element
is the length of that sequence.
"""
all_true_selector = rasp.Select(
rasp.tokens, rasp.tokens, rasp.Comparison.TRUE).named("all_true_selector")
return rasp.SelectorWidth(all_true_selector).named("length")
length = make_length()
def make_reverse(sop: rasp.SOp) -> rasp.SOp:
"""Create an SOp that reverses a sequence, using length primitive.
Example usage:
reverse = make_reverse(rasp.tokens)
reverse("Hello")
>> ['o', 'l', 'l', 'e', 'H']
Args:
sop: an SOp
Returns:
reverse : SOp that reverses the input sequence.
"""
opp_idx = (length - rasp.indices).named("opp_idx")
opp_idx = (opp_idx - 1).named("opp_idx-1")
reverse_selector = rasp.Select(rasp.indices, opp_idx,
rasp.Comparison.EQ).named("reverse_selector")
return rasp.Aggregate(reverse_selector, sop).named("reverse")
def make_pair_balance(sop: rasp.SOp, open_token: str,
close_token: str) -> rasp.SOp:
"""Return fraction of previous open tokens minus the fraction of close tokens.
(As implemented in the RASP paper.)
If the outputs are always non-negative and end in 0, that implies the input
has balanced parentheses.
Example usage:
num_l = make_pair_balance(rasp.tokens, "(", ")")
num_l("a()b(c))")
>> [0, 1/2, 0, 0, 1/5, 1/6, 0, -1/8]
Args:
sop: Input SOp.
open_token: Token that counts positive.
close_token: Token that counts negative.
Returns:
pair_balance: SOp mapping an input to a sequence, where every element
is the fraction of previous open tokens minus previous close tokens.
"""
bools_open = rasp.numerical(sop == open_token).named("bools_open")
opens = rasp.numerical(make_frac_prevs(bools_open)).named("opens")
bools_close = rasp.numerical(sop == close_token).named("bools_close")
closes = rasp.numerical(make_frac_prevs(bools_close)).named("closes")
pair_balance = rasp.numerical(rasp.LinearSequenceMap(opens, closes, 1, -1))
return pair_balance.named("pair_balance")
def make_shuffle_dyck(pairs: list[str]) -> rasp.SOp:
"""Returns 1 if a set of parentheses are balanced, 0 else.
(As implemented in the RASP paper.)
Example usage:
shuffle_dyck2 = make_shuffle_dyck(pairs=["()", "{}"])
shuffle_dyck2("({)}")
>> [1, 1, 1, 1]
shuffle_dyck2("(){)}")
>> [0, 0, 0, 0, 0]
Args:
pairs: List of pairs of open and close tokens that each should be balanced.
"""
assert len(pairs) >= 1
# Compute running balance of each type of parenthesis
balances = []
for pair in pairs:
assert len(pair) == 2
open_token, close_token = pair
balance = make_pair_balance(
rasp.tokens, open_token=open_token,
close_token=close_token).named(f"balance_{pair}")
balances.append(balance)
# Check if balances where negative anywhere -> parentheses not balanced
any_negative = balances[0] < 0
for balance in balances[1:]:
any_negative = any_negative | (balance < 0)
# Convert to numerical SOp
any_negative = rasp.numerical(rasp.Map(lambda x: x,
any_negative)).named("any_negative")
select_all = rasp.Select(rasp.indices, rasp.indices,
rasp.Comparison.TRUE).named("select_all")
has_neg = rasp.numerical(rasp.Aggregate(select_all, any_negative,
default=0)).named("has_neg")
# Check if all balances are 0 at the end -> closed all parentheses
all_zero = balances[0] == 0
for balance in balances[1:]:
all_zero = all_zero & (balance == 0)
select_last = rasp.Select(rasp.indices, length - 1,
rasp.Comparison.EQ).named("select_last")
last_zero = rasp.Aggregate(select_last, all_zero).named("last_zero")
not_has_neg = (~has_neg).named("not_has_neg")
return (last_zero & not_has_neg).named("shuffle_dyck")
def make_shuffle_dyck2() -> rasp.SOp:
return make_shuffle_dyck(pairs=["()", "{}"]).named("shuffle_dyck2")
def make_hist() -> rasp.SOp:
"""Returns the number of times each token occurs in the input.
(As implemented in the RASP paper.)
Example usage:
hist = make_hist()
hist("abac")
>> [2, 1, 2, 1]
"""
same_tok = rasp.Select(rasp.tokens, rasp.tokens,
rasp.Comparison.EQ).named("same_tok")
return rasp.SelectorWidth(same_tok).named("hist")
def make_sort_unique(vals: rasp.SOp, keys: rasp.SOp) -> rasp.SOp:
"""Returns vals sorted by < relation on keys.
Only supports unique keys.
Example usage:
sort = make_sort(rasp.tokens, rasp.tokens)
sort([2, 4, 3, 1])
>> [1, 2, 3, 4]
Args:
vals: Values to sort.
keys: Keys for sorting.
"""
smaller = rasp.Select(keys, keys, rasp.Comparison.LT).named("smaller")
target_pos = rasp.SelectorWidth(smaller).named("target_pos")
sel_new = rasp.Select(target_pos, rasp.indices, rasp.Comparison.EQ)
return rasp.Aggregate(sel_new, vals).named("sort")
def make_sort(vals: rasp.SOp, keys: rasp.SOp, *, max_seq_len: int,
min_key: float) -> rasp.SOp:
"""Returns vals sorted by < relation on keys, which don't need to be unique.
The implementation differs from the RASP paper, as it avoids using
compositions of selectors to break ties. Instead, it uses the arguments
max_seq_len and min_key to ensure the keys are unique.
Note that this approach only works for numerical keys.
Example usage:
sort = make_sort(rasp.tokens, rasp.tokens, 5, 1)
sort([2, 4, 3, 1])
>> [1, 2, 3, 4]
sort([2, 4, 1, 2])
>> [1, 2, 2, 4]
Args:
vals: Values to sort.
keys: Keys for sorting.
max_seq_len: Maximum sequence length (used to ensure keys are unique)
min_key: Minimum key value (used to ensure keys are unique)
Returns:
Output SOp of sort program.
"""
keys = rasp.SequenceMap(lambda x, i: x + min_key * i / max_seq_len, keys,
rasp.indices)
return make_sort_unique(vals, keys)
def make_sort_freq(max_seq_len: int) -> rasp.SOp:
"""Returns tokens sorted by the frequency they appear in the input.
Tokens the appear the same amount of times are output in the same order as in
the input.
Example usage:
sort = make_sort_freq(rasp.tokens, rasp.tokens, 5)
sort([2, 4, 2, 1])
>> [2, 2, 4, 1]
Args:
max_seq_len: Maximum sequence length (used to ensure keys are unique)
"""
hist = -1 * make_hist().named("hist")
return make_sort(
rasp.tokens, hist, max_seq_len=max_seq_len, min_key=1).named("sort_freq")
### Programs that work under both causal and regular evaluation.
def make_frac_prevs(bools: rasp.SOp) -> rasp.SOp:
"""Count the fraction of previous tokens where a specific condition was True.
(As implemented in the RASP paper.)
Example usage:
num_l = make_frac_prevs(rasp.tokens=="l")
num_l("hello")
>> [0, 0, 1/3, 1/2, 2/5]
Args:
bools: SOp mapping a sequence to a sequence of booleans.
Returns:
frac_prevs: SOp mapping an input to a sequence, where every element
is the fraction of previous "True" tokens.
"""
bools = rasp.numerical(bools)
prevs = rasp.Select(rasp.indices, rasp.indices, rasp.Comparison.LEQ)
return rasp.numerical(rasp.Aggregate(prevs, bools,
default=0)).named("frac_prevs")
def shift_by(offset: int, /, sop: rasp.SOp) -> rasp.SOp:
"""Returns the sop, shifted by `offset`, None-padded."""
select_off_by_offset = rasp.Select(rasp.indices, rasp.indices,
lambda k, q: q == k + offset)
out = rasp.Aggregate(select_off_by_offset, sop, default=None)
return out.named(f"shift_by({offset})")
def detect_pattern(sop: rasp.SOp, pattern: Sequence[rasp.Value]) -> rasp.SOp:
"""Returns an SOp which is True at the final element of the pattern.
The first len(pattern) - 1 elements of the output SOp are None-padded.
detect_pattern(tokens, "abc")("abcabc") == [None, None, T, F, F, T]
Args:
sop: the SOp in which to look for patterns.
pattern: a sequence of values to look for.
Returns:
a sop which detects the pattern.
"""
if len(pattern) < 1:
raise ValueError(f"Length of `pattern` must be at least 1. Got {pattern}")
# detectors[i] will be a boolean-valued SOp which is true at position j iff
# the i'th (from the end) element of the pattern was detected at position j-i.
detectors = []
for i, element in enumerate(reversed(pattern)):
detector = sop == element
if i != 0:
detector = shift_by(i, detector)
detectors.append(detector)
# All that's left is to take the AND over all detectors.
pattern_detected = detectors.pop()
while detectors:
pattern_detected = pattern_detected & detectors.pop()
return pattern_detected.named(f"detect_pattern({pattern})")
def make_count_less_freq(n: int) -> rasp.SOp:
"""Returns how many tokens appear fewer than n times in the input.
The output sequence contains this count in each position.
Example usage:
count_less_freq = make_count_less_freq(2)
count_less_freq(["a", "a", "a", "b", "b", "c"])
>> [3, 3, 3, 3, 3, 3]
count_less_freq(["a", "a", "c", "b", "b", "c"])
>> [6, 6, 6, 6, 6, 6]
Args:
n: Integer to compare token frequences to.
"""
hist = make_hist().named("hist")
select_less = rasp.Select(hist, hist,
lambda x, y: x <= n).named("select_less")
return rasp.SelectorWidth(select_less).named("count_less_freq")
def make_count(sop, token):
"""Returns the count of `token` in `sop`.
The output sequence contains this count in each position.
Example usage:
count = make_count(tokens, "a")
count(["a", "a", "a", "b", "b", "c"])
>> [3, 3, 3, 3, 3, 3]
count(["c", "a", "b", "c"])
>> [1, 1, 1, 1]
Args:
sop: Sop to count tokens in.
token: Token to count.
"""
return rasp.SelectorWidth(rasp.Select(
sop, sop, lambda k, q: k == token)).named(f"count_{token}")
def make_nary_sequencemap(f, *sops):
"""Returns an SOp that simulates an n-ary SequenceMap.
Uses multiple binary SequenceMaps to convert n SOps x_1, x_2, ..., x_n
into a single SOp arguments that takes n-tuples as value. The n-ary sequence
map implementing f is then a Map on this resulting SOp.
Note that the intermediate variables representing tuples of varying length
will be encoded categorically, and can become very high-dimensional. So,
using this function might lead to very large compiled models.
Args:
f: Function with n arguments.
*sops: Sequence of SOps, one for each argument of f.
"""
values, *sops = sops
for sop in sops:
# x is a single entry in the first iteration but a tuple in later iterations
values = rasp.SequenceMap(
lambda x, y: (*x, y) if isinstance(x, tuple) else (x, y), values, sop)
return rasp.Map(lambda args: f(*args), values)