Spaces:
Sleeping
Sleeping
# Copyright 2022 DeepMind Technologies Limited. All Rights Reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
# ============================================================================== | |
"""Vectors and bases.""" | |
import dataclasses | |
from typing import Sequence, Union, Optional, Iterable | |
import numpy as np | |
Name = Union[int, str] | |
Value = Union[int, float, bool, str, tuple] | |
class BasisDirection: | |
"""Represents a basis direction (no magnitude) in a vector space. | |
Attributes: | |
name: a unique name for this direction. | |
value: used to hold a value one-hot-encoded by this direction. e.g., | |
[BasisDirection("vs_1", True), BasisDirection("vs_1", False)] would be | |
basis directions of a subspace called "vs_1" which one-hot-encodes the | |
values True and False. If provided, considered part of the name for the | |
purpose of disambiguating directions. | |
""" | |
name: Name | |
value: Optional[Value] = None | |
def __str__(self): | |
if self.value is None: | |
return str(self.name) | |
return f"{self.name}:{self.value}" | |
def __lt__(self, other: "BasisDirection") -> bool: | |
try: | |
return (self.name, self.value) < (other.name, other.value) | |
except TypeError: | |
return str(self) < str(other) | |
class VectorInBasis: | |
"""A vector (or array of vectors) in a given basis. | |
When magnitudes are 1-d, this is a vector. | |
When magnitudes are (n+1)-d, this is an array of vectors, | |
where the -1th dimension is the basis dimension. | |
""" | |
basis_directions: Sequence[BasisDirection] | |
magnitudes: np.ndarray | |
def __post_init__(self): | |
"""Sort basis directions.""" | |
if len(self.basis_directions) != self.magnitudes.shape[-1]: | |
raise ValueError( | |
"Last dimension of magnitudes must be the same as number " | |
f"of basis directions. Was {len(self.basis_directions)} " | |
f"and {self.magnitudes.shape[-1]}.") | |
sort_idx = np.argsort(self.basis_directions) | |
self.basis_directions = [self.basis_directions[i] for i in sort_idx] | |
self.magnitudes = np.take(self.magnitudes, sort_idx, -1) | |
def __add__(self, other: "VectorInBasis") -> "VectorInBasis": | |
if self.basis_directions != other.basis_directions: | |
raise TypeError(f"Adding incompatible bases: {self} + {other}") | |
magnitudes = self.magnitudes + other.magnitudes | |
return VectorInBasis(self.basis_directions, magnitudes) | |
def __radd__(self, other: "VectorInBasis") -> "VectorInBasis": | |
if self.basis_directions != other.basis_directions: | |
raise TypeError(f"Adding incompatible bases: {other} + {self}") | |
return self + other | |
def __sub__(self, other: "VectorInBasis") -> "VectorInBasis": | |
if self.basis_directions != other.basis_directions: | |
raise TypeError(f"Subtracting incompatible bases: {self} - {other}") | |
magnitudes = self.magnitudes - other.magnitudes | |
return VectorInBasis(self.basis_directions, magnitudes) | |
def __rsub__(self, other: "VectorInBasis") -> "VectorInBasis": | |
if self.basis_directions != other.basis_directions: | |
raise TypeError(f"Subtracting incompatible bases: {other} - {self}") | |
magnitudes = other.magnitudes - self.magnitudes | |
return VectorInBasis(self.basis_directions, magnitudes) | |
def __mul__(self, scalar: float) -> "VectorInBasis": | |
return VectorInBasis(self.basis_directions, self.magnitudes * scalar) | |
def __rmul__(self, scalar: float) -> "VectorInBasis": | |
return self * scalar | |
def __truediv__(self, scalar: float) -> "VectorInBasis": | |
return VectorInBasis(self.basis_directions, self.magnitudes / scalar) | |
def __neg__(self) -> "VectorInBasis": | |
return (-1) * self | |
def __eq__(self, other: "VectorInBasis") -> bool: | |
return ((self.basis_directions == other.basis_directions) and | |
(self.magnitudes.shape == other.magnitudes.shape) and | |
(np.all(self.magnitudes == other.magnitudes))) | |
def sum(cls, vectors: Sequence["VectorInBasis"]) -> "VectorInBasis": | |
return cls(vectors[0].basis_directions, | |
np.sum([x.magnitudes for x in vectors], axis=0)) | |
def stack(cls, | |
vectors: Sequence["VectorInBasis"], | |
axis: int = 0) -> "VectorInBasis": | |
for v in vectors[1:]: | |
if v.basis_directions != vectors[0].basis_directions: | |
raise TypeError(f"Stacking incompatible bases: {vectors[0]} + {v}") | |
return cls(vectors[0].basis_directions, | |
np.stack([v.magnitudes for v in vectors], axis=axis)) | |
def project( | |
self, basis: Union["VectorSpaceWithBasis", Sequence[BasisDirection]] | |
) -> "VectorInBasis": | |
"""Projects to the basis.""" | |
if isinstance(basis, VectorSpaceWithBasis): | |
basis = basis.basis | |
components = [] | |
for direction in basis: | |
if direction in self.basis_directions: | |
components.append( | |
self.magnitudes[..., self.basis_directions.index(direction)]) | |
else: | |
components.append(np.zeros_like(self.magnitudes[..., 0])) | |
return VectorInBasis(list(basis), np.stack(components, axis=-1)) | |
class VectorSpaceWithBasis: | |
"""A vector subspace in a given basis.""" | |
basis: Sequence[BasisDirection] | |
def __post_init__(self): | |
"""Keep basis directions sorted.""" | |
self.basis = sorted(self.basis) | |
def num_dims(self) -> int: | |
return len(self.basis) | |
def __contains__(self, item: Union[VectorInBasis, BasisDirection]) -> bool: | |
if isinstance(item, BasisDirection): | |
return item in self.basis | |
return set(self.basis) == set(item.basis_directions) | |
def issubspace(self, other: "VectorSpaceWithBasis") -> bool: | |
return set(self.basis).issubset(set(other.basis)) | |
def basis_vectors(self) -> Sequence[VectorInBasis]: | |
basis_vector_magnitudes = list(np.eye(self.num_dims)) | |
return [VectorInBasis(self.basis, m) for m in basis_vector_magnitudes] | |
def vector_from_basis_direction( | |
self, basis_direction: BasisDirection) -> VectorInBasis: | |
i = self.basis.index(basis_direction) | |
return VectorInBasis(self.basis, np.eye(self.num_dims)[i]) | |
def null_vector(self) -> VectorInBasis: | |
return VectorInBasis(self.basis, np.zeros(self.num_dims)) | |
def from_names(cls, names: Sequence[Name]) -> "VectorSpaceWithBasis": | |
"""Creates a VectorSpace from a list of names for its basis directions.""" | |
return cls([BasisDirection(n) for n in names]) | |
def from_values( | |
cls, | |
name: Name, | |
values: Iterable[Value], | |
) -> "VectorSpaceWithBasis": | |
"""Creates a VectorSpace from a list of values for its basis directions.""" | |
return cls([BasisDirection(name, v) for v in values]) | |
def direct_sum(*vs: VectorSpaceWithBasis) -> VectorSpaceWithBasis: | |
"""Create a direct sum of the vector spaces. | |
Assumes the basis elements of all input vector spaces are | |
orthogonal to each other. Maintains the order of the bases. | |
Args: | |
*vs: the vector spaces to sum. | |
Returns: | |
the combined vector space. | |
Raises: | |
Value error in case of overlapping bases. | |
""" | |
# Take the union of all the bases: | |
total_basis = sum([v.basis for v in vs], []) | |
if len(total_basis) != len(set(total_basis)): | |
raise ValueError("Overlapping bases!") | |
return VectorSpaceWithBasis(total_basis) | |
def join_vector_spaces(*vs: VectorSpaceWithBasis) -> VectorSpaceWithBasis: | |
"""Joins a set of vector spaces allowing them to overlap. | |
Assumes the basis elements of all input vector spaces are | |
orthogonal to each other. Does not maintain the order of the bases but | |
sorts them. | |
Args: | |
*vs: the vector spaces to sum. | |
Returns: | |
the combined vector space. | |
""" | |
# Take the union of all the bases: | |
total_basis = list(set().union(*[set(v.basis) for v in vs])) | |
total_basis = sorted(total_basis) | |
return VectorSpaceWithBasis(total_basis) | |
def ensure_dims( | |
vs: VectorSpaceWithBasis, | |
num_dims: int, | |
name: str = "vector space", | |
) -> None: | |
"""Raises ValueError if vs has the wrong number of dimensions.""" | |
if vs.num_dims != num_dims: | |
raise ValueError(f"{name} must have num_dims={num_dims}, " | |
f"but got {vs.num_dims}: {vs.basis}") | |