Spaces:
Sleeping
Sleeping
# Copyright 2022 DeepMind Technologies Limited. All Rights Reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
# ============================================================================== | |
"""RASP program objects. | |
Every object in the RASP language is a function. | |
The most important type is S-Op, which is a function list[Value] -> list[Value]. | |
An S-Op represents a state inside the residual stream of the transformer. | |
Therefore, any RASP program that represents a transformer computation must | |
define a final S-Op that represents the state of the residual stream at the | |
end of the computation. In particular, given an S-Op `x`, | |
`x([1, 2, 3])` represents something like the state of the residual stream | |
at location `x` when the transformer is fed [1, 2, 3] as input. | |
A secondary (but still important) type is Selector, which is a function | |
list[Value] -> list[list[bool]]. Given a Selector `sel`, sel([1, 2, 3]) | |
represents something like an attention matrix in the transformer. | |
For a full reference on RASP, see https://arxiv.org/abs/2106.06981. | |
""" | |
import abc | |
import collections.abc | |
import copy | |
import enum | |
import functools | |
import itertools | |
from typing import (Any, Callable, Generic, Mapping, Optional, Protocol, | |
Sequence, TypeVar, Union) | |
from absl import logging | |
import numpy as np | |
SelectorValue = list[list[bool]] | |
NumericValue = Union[int, float] | |
Value = Union[None, int, float, str, bool] | |
VT = TypeVar("VT", bound=Value) | |
RASPExprT = TypeVar("RASPExprT", bound="RASPExpr") | |
SOpT = TypeVar("SOpT", bound="SOp") | |
T = TypeVar("T") | |
_NAME_KEY = "name" | |
_ENCODING_KEY = "encoding" | |
# These are run on every expression when it's initialised. | |
# Add your own annotators to this dict to add custom default annotations. | |
# | |
# For example, DEFAULT_ANNOTATORS['foo'] will provide the default value for | |
# expr.annotations['foo]. The annotator will get called lazily the first time | |
# that key is accessed. | |
# | |
# See the `default_name` annotator for a full example. | |
DEFAULT_ANNOTATORS: dict[str, "Annotator"] = {} | |
class Annotator(Protocol): | |
def __call__(self, expr: "RASPExpr") -> Any: | |
"""What annotation to add to `expr`.""" | |
class _Annotations(collections.abc.Mapping): | |
"""Holds the expression's annotations. | |
It's immutable to the user, but will attempt to generate default values | |
lazily when missing keys are requested. | |
""" | |
def __init__(self, expr, **kwargs: Any): | |
self._expr = expr | |
self._inner_dict: dict[str, Any] = {**kwargs} | |
def __getitem__(self, key: str) -> Any: | |
if key not in self._inner_dict: | |
if key not in DEFAULT_ANNOTATORS: | |
raise KeyError( | |
f"No annotation exists for key '{key}'. " | |
f"Available keys: {list(*self.keys(), *DEFAULT_ANNOTATORS.keys())}") | |
self._inner_dict[key] = DEFAULT_ANNOTATORS[key](self._expr) | |
return self._inner_dict[key] | |
def __iter__(self): | |
return iter(self._inner_dict) | |
def __len__(self): | |
return len(self._inner_dict) | |
class RASPExpr(abc.ABC): | |
"""A class distinguishing RASP expressions from other objects.""" | |
_ids = itertools.count(1) | |
def __init__(self): | |
self._annotations: Mapping[str, Any] = _Annotations(self) | |
def __call__(self, | |
xs: Sequence[Value]) -> Union[Sequence[Value], SelectorValue]: | |
"""Evaluates the RASPExpr using the standard evaluator.""" | |
def annotations(self) -> Mapping[str, Any]: | |
"""The annotations of this expression instance.""" | |
return self._annotations | |
def annotations(self, annotations: Mapping[str, Any]): | |
self._annotations = _Annotations(self, **annotations) | |
def name(self) -> str: | |
"""The name of this expression.""" | |
return self.annotations[_NAME_KEY] | |
def children(self) -> Sequence["RASPExpr"]: | |
"""Direct dependencies of this expression.""" | |
def unique_id(self): | |
"""A unique id for every expression instance.""" | |
return next(self._ids) | |
def copy(self: RASPExprT) -> RASPExprT: | |
"""Returns a shallow copy of this RASPExpr with a new ID.""" | |
return copy.copy(self) | |
def label(self) -> str: | |
return f"{self.name}_{self.unique_id}" | |
def named(self: RASPExprT, name: str) -> RASPExprT: | |
"""Convenience method for adding a name.""" | |
return annotate(self, name=name) | |
def annotated(self: RASPExprT, **annotations) -> RASPExprT: | |
"""Convenience method for adding annotations.""" | |
return annotate(self, **annotations) | |
def annotate(expr: RASPExprT, **annotations) -> RASPExprT: | |
"""Creates a new expr with added annotations.""" | |
new = expr.copy() | |
# Note that new annotations will overwrite existing ones with matching keys. | |
new.annotations = {**expr.annotations, **annotations} | |
return new | |
### S-Ops. | |
class SOp(RASPExpr): | |
"""A Sequence Operation.""" | |
def __call__(self, xs: Sequence[Value]) -> Sequence[Value]: | |
return evaluate(self, xs) # pytype: disable=bad-return-type | |
# Allow construction of SOps using numeric operators with constant values. | |
# Note: if inheriting SOp by a dataclass, make sure to disable eq and order, | |
# as they will override these. | |
def __lt__(self, other: Value) -> "SOp": | |
"""self < other.""" | |
return Map(lambda x: x < other, self) | |
def __le__(self, other: Value) -> "SOp": | |
"""self <= other.""" | |
return Map(lambda x: x <= other, self) | |
def __eq__(self, other: Value) -> "SOp": | |
"""self == other.""" | |
return Map(lambda x: x == other, self) | |
def __ne__(self, other: Value) -> "SOp": | |
"""self != other.""" | |
return Map(lambda x: x != other, self) | |
def __gt__(self, other: Value) -> "SOp": | |
"""self > other.""" | |
return Map(lambda x: x > other, self) | |
def __ge__(self, other: Value) -> "SOp": | |
"""self >= other.""" | |
return Map(lambda x: x >= other, self) | |
def __add__(self, other: Union["SOp", Value]) -> "SOp": | |
"""self + other.""" | |
if isinstance(other, SOp): | |
return SequenceMap(lambda x, y: x + y, self, other) | |
return Map(lambda x: x + other, self) | |
def __radd__(self, other: Union["SOp", Value]) -> "SOp": | |
"""other + self.""" | |
if isinstance(other, SOp): | |
return SequenceMap(lambda x, y: x + y, other, self) | |
return Map(lambda x: other + x, self) | |
def __sub__(self, other: Union["SOp", NumericValue]) -> "SOp": | |
"""self - other.""" | |
if isinstance(other, SOp): | |
return SequenceMap(lambda x, y: x - y, self, other) | |
return Map(lambda x: x - other, self) | |
def __rsub__(self, other: Union["SOp", NumericValue]) -> "SOp": | |
"""other - self.""" | |
if isinstance(other, SOp): | |
return SequenceMap(lambda x, y: x - y, other, self) | |
return Map(lambda x: other - x, self) | |
def __mul__(self, other: Union["SOp", NumericValue]) -> "SOp": | |
"""self * other.""" | |
if isinstance(other, SOp): | |
return SequenceMap(lambda x, y: x * y, self, other) | |
return Map(lambda x: x * other, self) | |
def __rmul__(self, other: Union["SOp", NumericValue]) -> "SOp": | |
"""other * self.""" | |
if isinstance(other, SOp): | |
return SequenceMap(lambda x, y: x * y, other, self) | |
return Map(lambda x: other * x, self) | |
def __truediv__(self, other: Union["SOp", NumericValue]) -> "SOp": | |
"""self / other.""" | |
if isinstance(other, SOp): | |
return SequenceMap(lambda x, y: x / y, self, other) | |
return Map(lambda x: x / other, self) | |
def __rtruediv__(self, other: Union["SOp", NumericValue]) -> "SOp": | |
"""other / self.""" | |
if isinstance(other, SOp): | |
return SequenceMap(lambda x, y: x / y, other, self) | |
return Map(lambda x: other / x, self) | |
def __invert__(self) -> "SOp": | |
return Map(lambda x: not x, self) | |
def __and__(self, other: Union["SOp", NumericValue]) -> "SOp": | |
"""self & other.""" | |
if isinstance(other, SOp): | |
return SequenceMap(lambda x, y: x and y, self, other) | |
return Map(lambda x: x and other, self) | |
def __or__(self, other: Union["SOp", NumericValue]) -> "SOp": | |
"""self | other.""" | |
if isinstance(other, SOp): | |
return SequenceMap(lambda x, y: x or y, self, other) | |
return Map(lambda x: x or other, self) | |
def __rand__(self, other: Union["SOp", NumericValue]) -> "SOp": | |
"""other & self.""" | |
if isinstance(other, SOp): | |
return SequenceMap(lambda x, y: x and y, other, self) | |
return Map(lambda x: other and x, self) | |
def __ror__(self, other: Union["SOp", NumericValue]) -> "SOp": | |
"""other | self.""" | |
if isinstance(other, SOp): | |
return SequenceMap(lambda x, y: x or y, other, self) | |
return Map(lambda x: x or other, self) | |
class TokensType(SOp): | |
"""Primitive SOp returning the original input tokens.""" | |
def children(self) -> Sequence[RASPExpr]: | |
return [] | |
def label(self) -> str: | |
return "tokens" | |
def __repr__(self): | |
return "tokens" | |
class IndicesType(SOp): | |
"""Primitive SOp returning the position index at each token.""" | |
def children(self) -> Sequence[RASPExpr]: | |
return [] | |
def label(self) -> str: | |
return "indices" | |
def __repr__(self): | |
return "indices" | |
class LengthType(SOp): | |
"""Primitive SOp returning the total length of the input.""" | |
def children(self) -> Sequence[RASPExpr]: | |
return [] | |
def label(self) -> str: | |
return "length" | |
def __repr__(self): | |
return "length" | |
tokens = TokensType() | |
indices = IndicesType() | |
length = LengthType() | |
class Map(SOp): | |
"""SOp that evaluates the function elementwise on the input SOp. | |
Map(lambda x: x + 1, tokens).eval([1, 2, 3]) == [2, 3, 4] | |
""" | |
def __init__(self, f: Callable[[Value], Value], inner: SOp): | |
super().__init__() | |
self.f = f | |
self.inner = inner | |
assert isinstance(self.inner, SOp) | |
assert callable(self.f) and not isinstance(self.f, RASPExpr) | |
if isinstance(self.inner, Map): | |
# Combine the functions into just one. | |
inner_f = self.inner.f | |
self.f = lambda t: f(inner_f(t)) | |
self.inner = self.inner.inner | |
def children(self) -> Sequence[RASPExpr]: | |
return [self.inner] | |
class SequenceMap(SOp): | |
"""SOp that evaluates the function elementwise on the two given SOp's. | |
SequenceMap(lambda x, y: x - y, length, tokens).eval([1, 2, 3]) == [2, 1, 0] | |
""" | |
def __init__(self, f: Callable[[Value, Value], Value], fst: SOp, snd: SOp): | |
super().__init__() | |
if fst == snd: | |
logging.warning("Creating a SequenceMap with both inputs being the same " | |
"SOp is discouraged. You should use a Map instead.") | |
self.f = f | |
self.fst = fst | |
self.snd = snd | |
assert isinstance(self.fst, SOp) | |
assert isinstance(self.snd, SOp) | |
assert callable(self.f) and not isinstance(self.f, RASPExpr) | |
def children(self) -> Sequence[RASPExpr]: | |
return [self.fst, self.snd] | |
class LinearSequenceMap(SequenceMap): | |
"""SOp that evaluates a linear function elementwise on the two given SOp's.""" | |
def __init__(self, fst: SOp, snd: SOp, fst_fac: float, snd_fac: float): | |
super().__init__(fst=fst, snd=snd, f=lambda x, y: fst_fac * x + snd_fac * y) | |
self.fst_fac = fst_fac | |
self.snd_fac = snd_fac | |
class Full(SOp): | |
"""A SOp evaluating to [fill]*len(input_values).""" | |
def __init__(self, fill: Value): | |
super().__init__() | |
self.fill = fill | |
def children(self) -> Sequence[RASPExpr]: | |
return [] | |
def sop_not(sop: SOp) -> SOp: | |
return Map(lambda t: not t, sop) | |
class ConstantSOp(SOp, Generic[VT]): | |
"""A constant S-Op for testing purposes.""" | |
def __init__(self, value: Sequence[VT], check_length: bool = True): | |
super().__init__() | |
self.value = value | |
self.check_length = check_length | |
def children(self) -> Sequence[RASPExpr]: | |
return [] | |
### Selectors. | |
class Predicate(Protocol): | |
def __call__(self, key: Value, query: Value) -> bool: | |
"""Applies the predicate.""" | |
class Comparison(enum.Enum): | |
"""A two-place boolean comparison predicate for use in Select.""" | |
EQ = "==" | |
LT = "<" | |
LEQ = "<=" | |
GT = ">" | |
GEQ = ">=" | |
NEQ = "!=" | |
TRUE = "True" | |
FALSE = "False" | |
def __call__(self, key: Value, query: Value) -> bool: | |
if key is None: | |
raise ValueError("key is None!") | |
if query is None: | |
raise ValueError("query is None!") | |
return _comparison_table[self](key, query) | |
_comparison_table = { | |
Comparison.EQ: lambda key, query: key == query, | |
Comparison.LT: lambda key, query: key < query, | |
Comparison.LEQ: lambda key, query: key <= query, | |
Comparison.GT: lambda key, query: key > query, | |
Comparison.GEQ: lambda key, query: key >= query, | |
Comparison.NEQ: lambda key, query: key != query, | |
Comparison.TRUE: lambda key, query: True, | |
Comparison.FALSE: lambda key, query: False, | |
} | |
class Selector(RASPExpr): | |
"""RASP Selector. Represents something like an attention head's weights.""" | |
def __call__(self, xs: Sequence[Value]) -> SelectorValue: | |
return evaluate(self, xs) # pytype: disable=bad-return-type | |
# Allow construction of Selector combinations using Python logical operators. | |
def __and__(self, other: "Selector") -> "Selector": | |
"""self & other.""" | |
return selector_and(self, other) | |
def __rand__(self, other: "Selector") -> "Selector": | |
"""other & self.""" | |
return selector_and(other, self) | |
def __or__(self, other: "Selector") -> "Selector": | |
"""self | other.""" | |
return selector_or(self, other) | |
def __ror__(self, other: "Selector") -> "Selector": | |
"""other | self.""" | |
return selector_or(other, self) | |
def __invert__(self) -> "Selector": | |
"""~self.""" | |
return selector_not(self) | |
class Select(Selector): | |
"""Primitive that creates a Selector.""" | |
def __init__(self, keys: SOp, queries: SOp, predicate: Predicate): | |
super().__init__() | |
self.keys = keys | |
self.queries = queries | |
self.predicate = predicate | |
assert isinstance(self.keys, SOp) | |
assert isinstance(self.queries, SOp) | |
def children(self) -> Sequence[RASPExpr]: | |
return [self.keys, self.queries] | |
class ConstantSelector(Selector): | |
"""A constant selector for testing purposes.""" | |
def __init__(self, value: SelectorValue, check_length: bool = True): | |
super().__init__() | |
self.value = value | |
self.check_length = check_length | |
def children(self) -> Sequence[RASPExpr]: | |
return [] | |
class SelectorWidth(SOp): | |
"""SelectorWidth primitive.""" | |
def __init__(self, selector: Selector): | |
super().__init__() | |
self.selector = selector | |
assert isinstance(self.selector, Selector) | |
def children(self) -> Sequence[RASPExpr]: | |
return [self.selector] | |
class SelectorAnd(Selector): | |
"""Implements elementwise `and` between selectors.""" | |
def __init__(self, fst: Selector, snd: Selector): | |
super().__init__() | |
self.fst = fst | |
self.snd = snd | |
assert isinstance(self.fst, Selector) | |
assert isinstance(self.snd, Selector) | |
def children(self) -> Sequence[RASPExpr]: | |
return [self.fst, self.snd] | |
class SelectorOr(Selector): | |
"""Implements elementwise `or` between selectors.""" | |
def __init__(self, fst: Selector, snd: Selector): | |
super().__init__() | |
self.fst = fst | |
self.snd = snd | |
assert isinstance(self.fst, Selector) | |
assert isinstance(self.snd, Selector) | |
def children(self) -> Sequence[RASPExpr]: | |
return [self.fst, self.snd] | |
class SelectorNot(Selector): | |
"""Implements elementwise `not` on a selector.""" | |
def __init__(self, inner: Selector): | |
self.inner = inner | |
super().__init__() | |
assert isinstance(self.inner, Selector) | |
def children(self) -> Sequence[RASPExpr]: | |
return [self.inner] | |
def selector_not( | |
inner: Selector, | |
simplify: bool = True, | |
) -> Selector: | |
"""Returns a SelectorNot, or a Select if simplifying is possible.""" | |
if simplify and isinstance(inner, Select): | |
predicate = lambda k, q: not inner.predicate(k, q) | |
return Select(inner.keys, inner.queries, predicate=predicate) | |
return SelectorNot(inner) | |
def selector_and( | |
fst: Selector, | |
snd: Selector, | |
simplify: bool = True, | |
) -> Selector: | |
"""Returns a SelectorAnd, or a Select if simplifying is possible.""" | |
if simplify and isinstance(fst, Select) and isinstance(snd, Select): | |
simplified = _attempt_simplify(fst, snd, lambda l, r: l and r) | |
if simplified: | |
return simplified | |
return SelectorAnd(fst, snd) | |
def selector_or( | |
fst: Selector, | |
snd: Selector, | |
simplify: bool = True, | |
) -> Selector: | |
"""Returns a SelectorOr, or a Select if simplifying is possible.""" | |
if simplify and isinstance(fst, Select) and isinstance(snd, Select): | |
simplified = _attempt_simplify(fst, snd, lambda l, r: l or r) | |
if simplified: | |
return simplified | |
return SelectorOr(fst, snd) | |
def _attempt_simplify( | |
fst: Select, | |
snd: Select, | |
combine: Callable[[bool, bool], bool], | |
) -> Optional[Select]: | |
"""Simplifies two Selects if possible. | |
If two Selects in a compound Selector have matching keys and queries, they can | |
be simplified into one Select with a compound predicate: | |
lambda k,q: combine(fst.predicate(k,q), snd.predicate(k,q)) | |
This function returns a Select with this predicate if possible, | |
and None otherwise. | |
A Full SOp in a key or query position is a special case that always matches | |
any SOp in the corresponding position in the other selector. In that case, | |
we bake in the fill value into the corresponding Select's predicate before | |
combining. This allows us to use the other SOp as the input to the simplified | |
Select. | |
Args: | |
fst: the first Select. | |
snd: the second Select. | |
combine: how to combine the outputs of the individual predicates. | |
Returns: | |
A combined Select, if possible. | |
""" | |
fst_predicate = fst.predicate | |
snd_predicate = snd.predicate | |
common_keys = None | |
common_queries = None | |
if isinstance(fst.keys, Full): | |
common_keys = snd.keys | |
# We pass the predicate in as a default arg to avoid unintended recursion. | |
fst_predicate = lambda key, query, p=fst_predicate: p(fst.keys.fill, query) | |
if isinstance(snd.keys, Full): | |
common_keys = fst.keys | |
snd_predicate = lambda key, query, p=snd_predicate: p(snd.keys.fill, query) | |
if isinstance(fst.queries, Full): | |
common_queries = snd.queries | |
fst_predicate = lambda key, query, p=fst_predicate: p(key, fst.queries.fill) | |
if isinstance(snd.queries, Full): | |
common_queries = fst.queries | |
snd_predicate = lambda key, query, p=snd_predicate: p(key, snd.queries.fill) | |
if fst.keys is snd.keys: | |
common_keys = fst.keys | |
if fst.queries is snd.queries: | |
common_queries = fst.queries | |
if not common_keys or not common_queries: | |
return None | |
def predicate(key, query): | |
return combine(fst_predicate(key, query), snd_predicate(key, query)) | |
return Select(common_keys, common_queries, predicate=predicate) | |
class Aggregate(SOp, Generic[VT]): | |
"""Aggregate primitive.""" | |
def __init__(self, | |
selector: Selector, | |
sop: SOp, | |
default: Optional[VT] = None): | |
"""Initialises. The default is used where nothing is selected.""" | |
super().__init__() | |
self.selector = selector | |
self.sop = sop | |
self.default = default | |
assert isinstance(self.selector, Selector) | |
assert isinstance(self.sop, SOp) | |
assert (self.default is None or isinstance(self.default, | |
(str, float, bool, int))) | |
def children(self) -> Sequence[RASPExpr]: | |
return [self.selector, self.sop] | |
### SOp encodings. | |
class Encoding(enum.Enum): | |
"""The encoding used by a SOp. Only number-valued SOps support numerical.""" | |
CATEGORICAL = "categorical" | |
NUMERICAL = "numerical" | |
def numerical(sop: SOpT) -> SOpT: | |
return annotate(sop, encoding=Encoding.NUMERICAL) | |
def categorical(sop: SOpT) -> SOpT: | |
return annotate(sop, encoding=Encoding.CATEGORICAL) | |
def get_encoding(sop: SOp) -> Encoding: | |
return sop.annotations["encoding"] | |
def is_numerical(sop: SOp) -> bool: | |
"""Check if the SOp is numerically encoded.""" | |
return get_encoding(sop) == Encoding.NUMERICAL | |
def is_categorical(sop: SOp) -> bool: | |
"""Check if the SOp is categorically encoded.""" | |
return get_encoding(sop) == Encoding.CATEGORICAL | |
def default_encoding(expr: RASPExpr) -> Optional[Encoding]: | |
"""Adds an 'encoding' annotation, default is Categorical.""" | |
if not isinstance(expr, SOp): | |
raise TypeError(f"expr {expr} is not a SOp.") | |
return Encoding.CATEGORICAL | |
DEFAULT_ANNOTATORS[_ENCODING_KEY] = default_encoding | |
### naming. | |
# Subclasses must appear here before superclasses in order for | |
# the most specific entry to be used. | |
_default_name_by_class = { | |
# Primitives | |
TokensType: "tokens", | |
IndicesType: "indices", | |
LengthType: "length", | |
# SOps | |
LinearSequenceMap: "linear_sequence_map", | |
SequenceMap: "sequence_map", | |
Map: "map", | |
Full: "full", | |
ConstantSOp: "constant_sop", | |
SelectorWidth: "selector_width", | |
Aggregate: "aggregate", | |
SOp: "sop", | |
# Selectors | |
Select: "select", | |
SelectorAnd: "selector_and", | |
SelectorOr: "selector_or", | |
SelectorNot: "selector_not", | |
ConstantSelector: "constant_selector", | |
Selector: "selector", | |
} | |
def default_name(expr: RASPExpr) -> dict[str, str]: | |
for cls, name in _default_name_by_class.items(): | |
if isinstance(expr, cls): | |
return name | |
raise NotImplementedError(f"{expr} was not given a default name!") | |
DEFAULT_ANNOTATORS[_NAME_KEY] = default_name | |
### evaluation. | |
class RASPEvaluator(abc.ABC): | |
"""ABC for RASP evaluators.""" | |
def evaluate(self, expr: RASPExpr, | |
xs: Sequence[Value]) -> Union[Sequence[Value], SelectorValue]: | |
"""Evaluates the RASP expression on input `xs`.""" | |
class DefaultRASPEvaluator(abc.ABC): | |
"""Default evaluator for RASP.""" | |
def evaluate(self, expr: RASPExpr, | |
xs: Sequence[Value]) -> Union[Sequence[Value], SelectorValue]: | |
"""Evaluates the RASP expression on input `xs`.""" | |
return self._eval_fn_by_expr_type[type(expr)](expr, xs) | |
def __init__(self): | |
self._eval_fn_by_expr_type = { | |
# Primitives | |
TokensType: self.eval_tokens, | |
IndicesType: self.eval_indices, | |
LengthType: self.eval_length, | |
# SOps | |
LinearSequenceMap: self.eval_sequence_map, | |
SequenceMap: self.eval_sequence_map, | |
Map: self.eval_map, | |
Full: self.eval_full, | |
ConstantSOp: self.eval_constant_sop, | |
SelectorWidth: self.eval_selector_width, | |
Aggregate: self.eval_aggregate, | |
SOp: _raise_not_implemented, | |
# Selectors | |
Select: self.eval_select, | |
SelectorAnd: self.eval_selector_and, | |
SelectorOr: self.eval_selector_or, | |
SelectorNot: self.eval_selector_not, | |
ConstantSelector: self.eval_constant_selector, | |
Selector: _raise_not_implemented, | |
} | |
def eval_tokens(self, sop: TokensType, | |
xs: Sequence[Value]) -> Sequence[Value]: | |
del sop | |
return list(xs) | |
def eval_indices(self, sop: IndicesType, | |
xs: Sequence[Value]) -> Sequence[Value]: | |
del sop | |
return list(range(len(xs))) | |
def eval_length(self, sop: LengthType, xs: Sequence[Value]) -> Sequence[int]: | |
del sop | |
return [len(xs)] * len(xs) | |
def eval_sequence_map(self, sop: SequenceMap, | |
xs: Sequence[Value]) -> Sequence[Value]: | |
fst_values = self.evaluate(sop.fst, xs) | |
snd_values = self.evaluate(sop.snd, xs) | |
return [ | |
sop.f(x, y) if None not in [x, y] else None | |
for x, y in zip(fst_values, snd_values) | |
] | |
def eval_map(self, sop: Map, xs: Sequence[Value]) -> Sequence[Value]: | |
return [ | |
sop.f(x) if x is not None else None | |
for x in self.evaluate(sop.inner, xs) | |
] | |
def eval_full(self, sop: Full, xs: Sequence[Value]) -> Sequence[Value]: | |
return [sop.fill] * len(xs) | |
def eval_constant_sop(self, sop: ConstantSOp, | |
xs: Sequence[Value]) -> Sequence[Value]: | |
if sop.check_length and (len(xs) != len(sop.value)): | |
raise ValueError( | |
f"Constant len {len(sop.value)} doesn't match input len {len(xs)}.") | |
return sop.value | |
def eval_selector_width(self, sop: SelectorWidth, | |
xs: Sequence[Value]) -> Sequence[Value]: | |
selector_values = self.evaluate(sop.selector, xs) | |
return [sum(row) for row in selector_values] | |
def eval_aggregate(self, sop: Aggregate, | |
xs: Sequence[Value]) -> Sequence[Value]: | |
selector_value = self.evaluate(sop.selector, xs) | |
values = self.evaluate(sop.sop, xs) | |
default = sop.default | |
return [ | |
_mean(_get_selected(row, values), default) for row in selector_value | |
] | |
def eval_select(self, sel: Select, xs: Sequence[Value]) -> SelectorValue: | |
"""Evaluates a Select on `xs`.""" | |
key_values = self.evaluate(sel.keys, xs) | |
query_values = self.evaluate(sel.queries, xs) | |
key_len = len(key_values) | |
query_len = len(query_values) | |
out = np.zeros((query_len, key_len), dtype=bool).tolist() | |
for row, query in enumerate(query_values): | |
for col, key in enumerate(key_values): | |
out[row][col] = bool(sel.predicate(key, query)) | |
return out | |
def eval_constant_selector(self, sel: ConstantSelector, | |
xs: Sequence[Value]) -> SelectorValue: | |
if sel.check_length and (len(xs) != len(sel.value)): | |
raise ValueError( | |
f"Constant len {len(xs)} doesn't match input len {len(sel.value)}.") | |
return sel.value | |
def eval_selector_and(self, sel: SelectorAnd, | |
xs: Sequence[Value]) -> SelectorValue: | |
fst_values = self.evaluate(sel.fst, xs) | |
snd_values = self.evaluate(sel.snd, xs) | |
return np.logical_and(np.array(fst_values), np.array(snd_values)).tolist() | |
def eval_selector_or(self, sel: SelectorOr, | |
xs: Sequence[Value]) -> SelectorValue: | |
fst_values = self.evaluate(sel.fst, xs) | |
snd_values = self.evaluate(sel.snd, xs) | |
return np.logical_or(np.array(fst_values), np.array(snd_values)).tolist() | |
def eval_selector_not(self, sel: SelectorNot, | |
xs: Sequence[Value]) -> SelectorValue: | |
values = self.evaluate(sel.inner, xs) | |
return np.logical_not(np.array(values)).tolist() | |
def _get_selected( | |
selector_row: list[bool], | |
values: Sequence[VT], | |
) -> Sequence[VT]: | |
"""Helper for aggregate. [T T F], [a b c] -> [a b].""" | |
return [v for s, v in zip(selector_row, values) if s] | |
def _mean(xs: Sequence[VT], default: VT) -> VT: | |
"""Takes the mean for numbers and concats for strings.""" | |
if not xs: | |
return default | |
exemplar = xs[0] | |
if isinstance(exemplar, (int, bool)): | |
return sum(xs) / len(xs) | |
elif len(xs) == 1: | |
return exemplar | |
else: | |
raise ValueError(f"Unsupported type for aggregation: {xs}") | |
def _raise_not_implemented(expr: RASPExpr, xs: Sequence[Value]): | |
raise NotImplementedError(f"Evaluation of {expr} is not defined.") | |
evaluate = DefaultRASPEvaluator().evaluate | |