Spaces:
Sleeping
Sleeping
# Copyright 2022 DeepMind Technologies Limited. All Rights Reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
# ============================================================================== | |
"""Basic encoder for inputs with a fixed vocabulary.""" | |
import abc | |
from typing import Any, Sequence, Optional | |
from tracr.craft import bases | |
class Encoder(abc.ABC): | |
"""Encodes a list of tokens into a list of inputs for a transformer model. | |
The abstract class does not make assumptions on the input and output types, | |
and we have different encoders for different input types. | |
""" | |
def encode(self, inputs: list[Any]) -> list[Any]: | |
return list() | |
def decode(self, encodings: list[Any]) -> list[Any]: | |
return list() | |
def pad_token(self) -> Optional[str]: | |
return None | |
def bos_token(self) -> Optional[str]: | |
return None | |
def pad_encoding(self) -> Optional[int]: | |
return None | |
def bos_encoding(self) -> Optional[int]: | |
return None | |
class NumericalEncoder(Encoder): | |
"""Encodes numerical variables (simply using the identity mapping).""" | |
def encode(self, inputs: list[float]) -> list[float]: | |
return inputs | |
def decode(self, encodings: list[float]) -> list[float]: | |
return encodings | |
class CategoricalEncoder(Encoder): | |
"""Encodes categorical variables with a fixed vocabulary.""" | |
def __init__( | |
self, | |
basis: Sequence[bases.BasisDirection], | |
enforce_bos: bool = False, | |
bos_token: Optional[str] = None, | |
pad_token: Optional[str] = None, | |
max_seq_len: Optional[int] = None, | |
): | |
"""Initialises. If enforce_bos is set, ensures inputs start with it.""" | |
if enforce_bos and not bos_token: | |
raise ValueError("BOS token must be specified if enforcing BOS.") | |
self.encoding_map = {} | |
for i, direction in enumerate(basis): | |
val = direction.value | |
self.encoding_map[val] = i | |
if bos_token and bos_token not in self.encoding_map: | |
raise ValueError("BOS token missing in encoding.") | |
if pad_token and pad_token not in self.encoding_map: | |
raise ValueError("PAD token missing in encoding.") | |
self.enforce_bos = enforce_bos | |
self._bos_token = bos_token | |
self._pad_token = pad_token | |
self._max_seq_len = max_seq_len | |
def encode(self, inputs: list[bases.Value]) -> list[int]: | |
if self.enforce_bos and inputs[0] != self.bos_token: | |
raise ValueError("First input token must be BOS token. " | |
f"Should be '{self.bos_token}', but was '{inputs[0]}'.") | |
if missing := set(inputs) - set(self.encoding_map.keys()): | |
raise ValueError(f"Inputs {missing} not found in encoding ", | |
self.encoding_map.keys()) | |
if self._max_seq_len is not None and len(inputs) > self._max_seq_len: | |
raise ValueError(f"{inputs=} are longer than the maximum " | |
f"sequence length {self._max_seq_len}") | |
return [self.encoding_map[x] for x in inputs] | |
def decode(self, encodings: list[int]) -> list[bases.Value]: | |
"""Recover the tokens that corresponds to `ids`. Inverse of __call__.""" | |
decoding_map = {val: key for key, val in self.encoding_map.items()} | |
if missing := set(encodings) - set(decoding_map.keys()): | |
raise ValueError(f"Inputs {missing} not found in decoding map ", | |
decoding_map.keys()) | |
return [decoding_map[x] for x in encodings] | |
def vocab_size(self) -> int: | |
return len(self.encoding_map) | |
def bos_token(self) -> Optional[str]: | |
return self._bos_token | |
def pad_token(self) -> Optional[str]: | |
return self._pad_token | |
def bos_encoding(self) -> Optional[int]: | |
return None if self.bos_token is None else self.encoding_map[self.bos_token] | |
def pad_encoding(self) -> Optional[int]: | |
return None if self.pad_token is None else self.encoding_map[self.pad_token] | |