# Copyright 2022 Google. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Class for T5 relative position biases. Adapted from flaxformer.components.relative_position_biases.py """ from typing import Any, Callable, Optional from flax import linen as nn import gin from jax import lax import jax.numpy as jnp from transformer import position import numpy as np Array = Any @gin.configurable class T5RelativePositionBiases(nn.Module): """Adds T5-style relative positional embeddings to the attention logits. Attributes: num_buckets: Number of buckets to bucket distances between key and query positions into. max_distance: Maximum distance before everything is lumped into the last distance bucket. num_heads: Number of heads in the attention layer. Each head will get a different relative position weighting. dtype: Type of arrays through this module. embedding_init: initializer for relative embedding table. """ num_buckets: int max_distance: int num_heads: int dtype: Any embedding_init: Callable[..., Array] = nn.linear.default_embed_init @staticmethod def _relative_position_bucket(relative_position, bidirectional=True, num_buckets=32, max_distance=128): """Translate relative position to a bucket number for relative attention. The relative position is defined as memory_position - query_position, i.e. the distance in tokens from the attending position to the attended-to position. If bidirectional=False, then positive relative positions are invalid. We use smaller buckets for small absolute relative_position and larger buckets for larger absolute relative_positions. All relative positions >=max_distance map to the same bucket. All relative positions <=-max_distance map to the same bucket. This should allow for more graceful generalization to longer sequences than the model has been trained on. Args: relative_position: an int32 array bidirectional: a boolean - whether the attention is bidirectional num_buckets: an integer max_distance: an integer Returns: a Tensor with the same shape as relative_position, containing int32 values in the range [0, num_buckets) """ ret = 0 n = -relative_position if bidirectional: num_buckets //= 2 ret += (n < 0).astype(np.int32) * num_buckets n = np.abs(n) else: n = np.maximum(n, 0) # now n is in the range [0, inf) max_exact = num_buckets // 2 is_small = (n < max_exact) val_if_large = max_exact + ( np.log(n.astype(np.float32) / max_exact + np.finfo(np.float32).eps) / np.log(max_distance / max_exact) * (num_buckets - max_exact)).astype(np.int32) val_if_large = np.minimum(val_if_large, num_buckets - 1) ret += np.where(is_small, n, val_if_large) return ret @nn.compact def __call__(self, num_queries, num_keys, offset: Optional[int]=None, bidirectional=True): """Produce relative position embedding attention biases. Args: num_queries: Number of queries. num_keys: Number of keys. offset: Offset of the first query with respect to the first key. (See position.relative_positions() for more info.) bidirectional: whether to allow positive memory-query relative position embeddings. Returns: output: `(1, num_heads, num_queries, num_keys)` attention bias """ # Find the distance between each query and each key. # This is where this implementation differs from the T5 implementation; # this version lines the /last/ N queries up with the /last/ N keys, # (which is appropriate for XL/sliding window) while the T5 version lines # up the /first/ N queries with the first N keys, in cases where the # number of keys and queries differ. relative_position = position.relative_positions_np( num_queries=num_queries, num_keys=num_keys, offset=offset) rp_bucket = self._relative_position_bucket( relative_position, bidirectional=bidirectional, num_buckets=self.num_buckets, max_distance=self.max_distance) relative_attention_bias = self.param('rel_embedding', self.embedding_init, (self.num_heads, self.num_buckets), jnp.float32) relative_attention_bias = jnp.asarray(relative_attention_bias, self.dtype) # Instead of using a slow gather, we create a leading-dimension one-hot # array from rp_bucket and use it to perform the gather-equivalent via a # contraction, i.e.: # (num_head, num_buckets) x (num_buckets one-hot, num_queries, num_keys). # This is equivalent to relative_attention_bias[:, rp_bucket] bcast_iota = lax.broadcasted_iota(jnp.int32, (self.num_buckets, 1, 1), 0) rp_bucket_one_hot = jnp.array( rp_bucket[jnp.newaxis, ...] == bcast_iota, dtype=self.dtype) # --> shape (num_queries, num_keys, num_heads) values = lax.dot_general( relative_attention_bias, rp_bucket_one_hot, ( ((1,), (0,)), # rhs, lhs contracting dims ((), ()))) # no batched dims # Add a singleton batch dimension. # --> shape (1, num_heads, num_queries, num_keys) out = values[jnp.newaxis, ...] return out