Spaces:
Sleeping
Sleeping
# Copyright 2022 DeepMind Technologies Limited. All Rights Reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
# ============================================================================== | |
"""Instrumented attention layer (forked from the Haiku library implementation). | |
""" | |
from typing import Optional | |
import warnings | |
import chex | |
import haiku as hk | |
import jax | |
import jax.numpy as jnp | |
import numpy as np | |
class AttentionOutput: | |
out: jax.Array # [..., T', D'] | |
logits: jax.Array # [..., H, T', T] | |
class MultiHeadAttention(hk.Module): | |
"""Multi-headed attention (MHA) module. | |
This module is intended for attending over sequences of vectors. | |
Rough sketch: | |
- Compute keys (K), queries (Q), and values (V) as projections of inputs. | |
- Attention weights are computed as W = softmax(QK^T / sqrt(key_size)). | |
- Output is another projection of WV^T. | |
For more detail, see the original Transformer paper: | |
"Attention is all you need" https://arxiv.org/abs/1706.03762. | |
Glossary of shapes: | |
- T: Sequence length. | |
- D: Vector (embedding) size. | |
- H: Number of attention heads. | |
""" | |
def __init__( | |
self, | |
num_heads: int, | |
key_size: int, | |
# TODO(b/240019186): Remove `w_init_scale`. | |
w_init_scale: Optional[float] = None, | |
*, | |
w_init: Optional[hk.initializers.Initializer] = None, | |
value_size: Optional[int] = None, | |
model_size: Optional[int] = None, | |
name: Optional[str] = None, | |
): | |
"""Initialises the module. | |
Args: | |
num_heads: Number of independent attention heads (H). | |
key_size: The size of keys (K) and queries used for attention. | |
w_init_scale: DEPRECATED. Please use w_init instead. | |
w_init: Initialiser for weights in the linear map. | |
value_size: Optional size of the value projection (V). If None, defaults | |
to the key size (K). | |
model_size: Optional size of the output embedding (D'). If None, defaults | |
to the key size multiplied by the number of heads (K * H). | |
name: Optional name for this module. | |
""" | |
super().__init__(name=name) | |
self.num_heads = num_heads | |
self.key_size = key_size | |
self.value_size = value_size or key_size | |
self.model_size = model_size or key_size * num_heads | |
# Backwards-compatibility for w_init_scale. | |
if w_init_scale is not None: | |
warnings.warn( | |
"w_init_scale is deprecated; please pass an explicit weight " | |
"initialiser instead.", DeprecationWarning) | |
if w_init and w_init_scale: | |
raise ValueError("Please provide only `w_init`, not `w_init_scale`.") | |
if w_init is None and w_init_scale is None: | |
raise ValueError("Please provide a weight initializer: `w_init`.") | |
if w_init is None: | |
w_init = hk.initializers.VarianceScaling(w_init_scale) | |
self.w_init = w_init | |
def __call__( | |
self, | |
query: jnp.ndarray, | |
key: jnp.ndarray, | |
value: jnp.ndarray, | |
mask: Optional[jnp.ndarray] = None, | |
) -> AttentionOutput: | |
"""Computes (optionally masked) MHA with queries, keys & values. | |
This module broadcasts over zero or more 'batch-like' leading dimensions. | |
Args: | |
query: Embeddings sequence used to compute queries; shape [..., T', D_q]. | |
key: Embeddings sequence used to compute keys; shape [..., T, D_k]. | |
value: Embeddings sequence used to compute values; shape [..., T, D_v]. | |
mask: Optional mask applied to attention weights; shape [..., H=1, T', T]. | |
Returns: | |
A new sequence of embeddings, consisting of a projection of the | |
attention-weighted value projections; shape [..., T', D']. | |
""" | |
# In shape hints below, we suppress the leading dims [...] for brevity. | |
# Hence e.g. [A, B] should be read in every case as [..., A, B]. | |
*leading_dims, sequence_length, _ = query.shape | |
projection = self._linear_projection | |
# Compute key/query/values (overload K/Q/V to denote the respective sizes). | |
query_heads = projection(query, self.key_size, "query") # [T', H, Q=K] | |
key_heads = projection(key, self.key_size, "key") # [T, H, K] | |
value_heads = projection(value, self.value_size, "value") # [T, H, V] | |
# Compute attention weights. | |
attn_logits = jnp.einsum("...thd,...Thd->...htT", query_heads, key_heads) | |
attn_logits = attn_logits / np.sqrt(self.key_size).astype(key.dtype) | |
if mask is not None: | |
if mask.ndim != attn_logits.ndim: | |
raise ValueError( | |
f"Mask dimensionality {mask.ndim} must match logits dimensionality " | |
f"{attn_logits.ndim}.") | |
attn_logits = jnp.where(mask, attn_logits, -1e30) | |
attn_weights = jax.nn.softmax(attn_logits) # [H, T', T] | |
# Weight the values by the attention and flatten the head vectors. | |
attn = jnp.einsum("...htT,...Thd->...thd", attn_weights, value_heads) | |
attn = jnp.reshape(attn, (*leading_dims, sequence_length, -1)) # [T', H*V] | |
# Apply another projection to get the final embeddings. | |
final_projection = hk.Linear(self.model_size, w_init=self.w_init) | |
return AttentionOutput( | |
out=final_projection(attn), | |
logits=attn_logits, | |
) | |
def _linear_projection( | |
self, | |
x: jnp.ndarray, | |
head_size: int, | |
name: Optional[str] = None, | |
) -> jnp.ndarray: | |
y = hk.Linear(self.num_heads * head_size, w_init=self.w_init, name=name)(x) | |
*leading_dims, _ = x.shape | |
return y.reshape((*leading_dims, self.num_heads, head_size)) | |