Spaces:
Sleeping
Sleeping
File size: 12,062 Bytes
9bdaa77 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 |
# Copyright 2022 DeepMind Technologies Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""RASP programs only using the subset of RASP supported by the compiler."""
from typing import Sequence
from tracr.rasp import rasp
### Programs that work only under non-causal evaluation.
def make_length() -> rasp.SOp:
"""Creates the `length` SOp using selector width primitive.
Example usage:
length = make_length()
length("abcdefg")
>> [7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0]
Returns:
length: SOp mapping an input to a sequence, where every element
is the length of that sequence.
"""
all_true_selector = rasp.Select(
rasp.tokens, rasp.tokens, rasp.Comparison.TRUE).named("all_true_selector")
return rasp.SelectorWidth(all_true_selector).named("length")
length = make_length()
def make_reverse(sop: rasp.SOp) -> rasp.SOp:
"""Create an SOp that reverses a sequence, using length primitive.
Example usage:
reverse = make_reverse(rasp.tokens)
reverse("Hello")
>> ['o', 'l', 'l', 'e', 'H']
Args:
sop: an SOp
Returns:
reverse : SOp that reverses the input sequence.
"""
opp_idx = (length - rasp.indices).named("opp_idx")
opp_idx = (opp_idx - 1).named("opp_idx-1")
reverse_selector = rasp.Select(rasp.indices, opp_idx,
rasp.Comparison.EQ).named("reverse_selector")
return rasp.Aggregate(reverse_selector, sop).named("reverse")
def make_pair_balance(sop: rasp.SOp, open_token: str,
close_token: str) -> rasp.SOp:
"""Return fraction of previous open tokens minus the fraction of close tokens.
(As implemented in the RASP paper.)
If the outputs are always non-negative and end in 0, that implies the input
has balanced parentheses.
Example usage:
num_l = make_pair_balance(rasp.tokens, "(", ")")
num_l("a()b(c))")
>> [0, 1/2, 0, 0, 1/5, 1/6, 0, -1/8]
Args:
sop: Input SOp.
open_token: Token that counts positive.
close_token: Token that counts negative.
Returns:
pair_balance: SOp mapping an input to a sequence, where every element
is the fraction of previous open tokens minus previous close tokens.
"""
bools_open = rasp.numerical(sop == open_token).named("bools_open")
opens = rasp.numerical(make_frac_prevs(bools_open)).named("opens")
bools_close = rasp.numerical(sop == close_token).named("bools_close")
closes = rasp.numerical(make_frac_prevs(bools_close)).named("closes")
pair_balance = rasp.numerical(rasp.LinearSequenceMap(opens, closes, 1, -1))
return pair_balance.named("pair_balance")
def make_shuffle_dyck(pairs: list[str]) -> rasp.SOp:
"""Returns 1 if a set of parentheses are balanced, 0 else.
(As implemented in the RASP paper.)
Example usage:
shuffle_dyck2 = make_shuffle_dyck(pairs=["()", "{}"])
shuffle_dyck2("({)}")
>> [1, 1, 1, 1]
shuffle_dyck2("(){)}")
>> [0, 0, 0, 0, 0]
Args:
pairs: List of pairs of open and close tokens that each should be balanced.
"""
assert len(pairs) >= 1
# Compute running balance of each type of parenthesis
balances = []
for pair in pairs:
assert len(pair) == 2
open_token, close_token = pair
balance = make_pair_balance(
rasp.tokens, open_token=open_token,
close_token=close_token).named(f"balance_{pair}")
balances.append(balance)
# Check if balances where negative anywhere -> parentheses not balanced
any_negative = balances[0] < 0
for balance in balances[1:]:
any_negative = any_negative | (balance < 0)
# Convert to numerical SOp
any_negative = rasp.numerical(rasp.Map(lambda x: x,
any_negative)).named("any_negative")
select_all = rasp.Select(rasp.indices, rasp.indices,
rasp.Comparison.TRUE).named("select_all")
has_neg = rasp.numerical(rasp.Aggregate(select_all, any_negative,
default=0)).named("has_neg")
# Check if all balances are 0 at the end -> closed all parentheses
all_zero = balances[0] == 0
for balance in balances[1:]:
all_zero = all_zero & (balance == 0)
select_last = rasp.Select(rasp.indices, length - 1,
rasp.Comparison.EQ).named("select_last")
last_zero = rasp.Aggregate(select_last, all_zero).named("last_zero")
not_has_neg = (~has_neg).named("not_has_neg")
return (last_zero & not_has_neg).named("shuffle_dyck")
def make_shuffle_dyck2() -> rasp.SOp:
return make_shuffle_dyck(pairs=["()", "{}"]).named("shuffle_dyck2")
def make_hist() -> rasp.SOp:
"""Returns the number of times each token occurs in the input.
(As implemented in the RASP paper.)
Example usage:
hist = make_hist()
hist("abac")
>> [2, 1, 2, 1]
"""
same_tok = rasp.Select(rasp.tokens, rasp.tokens,
rasp.Comparison.EQ).named("same_tok")
return rasp.SelectorWidth(same_tok).named("hist")
def make_sort_unique(vals: rasp.SOp, keys: rasp.SOp) -> rasp.SOp:
"""Returns vals sorted by < relation on keys.
Only supports unique keys.
Example usage:
sort = make_sort(rasp.tokens, rasp.tokens)
sort([2, 4, 3, 1])
>> [1, 2, 3, 4]
Args:
vals: Values to sort.
keys: Keys for sorting.
"""
smaller = rasp.Select(keys, keys, rasp.Comparison.LT).named("smaller")
target_pos = rasp.SelectorWidth(smaller).named("target_pos")
sel_new = rasp.Select(target_pos, rasp.indices, rasp.Comparison.EQ)
return rasp.Aggregate(sel_new, vals).named("sort")
def make_sort(vals: rasp.SOp, keys: rasp.SOp, *, max_seq_len: int,
min_key: float) -> rasp.SOp:
"""Returns vals sorted by < relation on keys, which don't need to be unique.
The implementation differs from the RASP paper, as it avoids using
compositions of selectors to break ties. Instead, it uses the arguments
max_seq_len and min_key to ensure the keys are unique.
Note that this approach only works for numerical keys.
Example usage:
sort = make_sort(rasp.tokens, rasp.tokens, 5, 1)
sort([2, 4, 3, 1])
>> [1, 2, 3, 4]
sort([2, 4, 1, 2])
>> [1, 2, 2, 4]
Args:
vals: Values to sort.
keys: Keys for sorting.
max_seq_len: Maximum sequence length (used to ensure keys are unique)
min_key: Minimum key value (used to ensure keys are unique)
Returns:
Output SOp of sort program.
"""
keys = rasp.SequenceMap(lambda x, i: x + min_key * i / max_seq_len, keys,
rasp.indices)
return make_sort_unique(vals, keys)
def make_sort_freq(max_seq_len: int) -> rasp.SOp:
"""Returns tokens sorted by the frequency they appear in the input.
Tokens the appear the same amount of times are output in the same order as in
the input.
Example usage:
sort = make_sort_freq(rasp.tokens, rasp.tokens, 5)
sort([2, 4, 2, 1])
>> [2, 2, 4, 1]
Args:
max_seq_len: Maximum sequence length (used to ensure keys are unique)
"""
hist = -1 * make_hist().named("hist")
return make_sort(
rasp.tokens, hist, max_seq_len=max_seq_len, min_key=1).named("sort_freq")
### Programs that work under both causal and regular evaluation.
def make_frac_prevs(bools: rasp.SOp) -> rasp.SOp:
"""Count the fraction of previous tokens where a specific condition was True.
(As implemented in the RASP paper.)
Example usage:
num_l = make_frac_prevs(rasp.tokens=="l")
num_l("hello")
>> [0, 0, 1/3, 1/2, 2/5]
Args:
bools: SOp mapping a sequence to a sequence of booleans.
Returns:
frac_prevs: SOp mapping an input to a sequence, where every element
is the fraction of previous "True" tokens.
"""
bools = rasp.numerical(bools)
prevs = rasp.Select(rasp.indices, rasp.indices, rasp.Comparison.LEQ)
return rasp.numerical(rasp.Aggregate(prevs, bools,
default=0)).named("frac_prevs")
def shift_by(offset: int, /, sop: rasp.SOp) -> rasp.SOp:
"""Returns the sop, shifted by `offset`, None-padded."""
select_off_by_offset = rasp.Select(rasp.indices, rasp.indices,
lambda k, q: q == k + offset)
out = rasp.Aggregate(select_off_by_offset, sop, default=None)
return out.named(f"shift_by({offset})")
def detect_pattern(sop: rasp.SOp, pattern: Sequence[rasp.Value]) -> rasp.SOp:
"""Returns an SOp which is True at the final element of the pattern.
The first len(pattern) - 1 elements of the output SOp are None-padded.
detect_pattern(tokens, "abc")("abcabc") == [None, None, T, F, F, T]
Args:
sop: the SOp in which to look for patterns.
pattern: a sequence of values to look for.
Returns:
a sop which detects the pattern.
"""
if len(pattern) < 1:
raise ValueError(f"Length of `pattern` must be at least 1. Got {pattern}")
# detectors[i] will be a boolean-valued SOp which is true at position j iff
# the i'th (from the end) element of the pattern was detected at position j-i.
detectors = []
for i, element in enumerate(reversed(pattern)):
detector = sop == element
if i != 0:
detector = shift_by(i, detector)
detectors.append(detector)
# All that's left is to take the AND over all detectors.
pattern_detected = detectors.pop()
while detectors:
pattern_detected = pattern_detected & detectors.pop()
return pattern_detected.named(f"detect_pattern({pattern})")
def make_count_less_freq(n: int) -> rasp.SOp:
"""Returns how many tokens appear fewer than n times in the input.
The output sequence contains this count in each position.
Example usage:
count_less_freq = make_count_less_freq(2)
count_less_freq(["a", "a", "a", "b", "b", "c"])
>> [3, 3, 3, 3, 3, 3]
count_less_freq(["a", "a", "c", "b", "b", "c"])
>> [6, 6, 6, 6, 6, 6]
Args:
n: Integer to compare token frequences to.
"""
hist = make_hist().named("hist")
select_less = rasp.Select(hist, hist,
lambda x, y: x <= n).named("select_less")
return rasp.SelectorWidth(select_less).named("count_less_freq")
def make_count(sop, token):
"""Returns the count of `token` in `sop`.
The output sequence contains this count in each position.
Example usage:
count = make_count(tokens, "a")
count(["a", "a", "a", "b", "b", "c"])
>> [3, 3, 3, 3, 3, 3]
count(["c", "a", "b", "c"])
>> [1, 1, 1, 1]
Args:
sop: Sop to count tokens in.
token: Token to count.
"""
return rasp.SelectorWidth(rasp.Select(
sop, sop, lambda k, q: k == token)).named(f"count_{token}")
def make_nary_sequencemap(f, *sops):
"""Returns an SOp that simulates an n-ary SequenceMap.
Uses multiple binary SequenceMaps to convert n SOps x_1, x_2, ..., x_n
into a single SOp arguments that takes n-tuples as value. The n-ary sequence
map implementing f is then a Map on this resulting SOp.
Note that the intermediate variables representing tuples of varying length
will be encoded categorically, and can become very high-dimensional. So,
using this function might lead to very large compiled models.
Args:
f: Function with n arguments.
*sops: Sequence of SOps, one for each argument of f.
"""
values, *sops = sops
for sop in sops:
# x is a single entry in the first iteration but a tuple in later iterations
values = rasp.SequenceMap(
lambda x, y: (*x, y) if isinstance(x, tuple) else (x, y), values, sop)
return rasp.Map(lambda args: f(*args), values)
|