# Copyright 2022 DeepMind Technologies Limited. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """Pieces for making transformers.""" import abc import dataclasses from typing import Iterable, List, Optional, Sequence, Union import numpy as np from tracr.craft import bases from tracr.craft import vectorspace_fns project = vectorspace_fns.project def _np_softmax(x, axis=-1): x_max = np.max(x, axis=axis, keepdims=True) return np.exp(x - x_max) / np.sum(np.exp(x - x_max), axis=axis, keepdims=True) def _np_relu(x): return np.where(x > 0, x, 0) def relu(x: bases.VectorInBasis) -> bases.VectorInBasis: return bases.VectorInBasis(x.basis_directions, _np_relu(x.magnitudes)) class Block(abc.ABC): """Transformer block, acting on a sequence of vector space elements. Attributes: residual_space: Vector space that contains all subspaces the Block interacts with. This can be either the full residual space of a model or a subspace. """ residual_space: bases.VectorSpaceWithBasis @abc.abstractmethod def apply(self, x: bases.VectorInBasis) -> bases.VectorInBasis: """Applies self to an input.""" @dataclasses.dataclass class AttentionHead(Block): """A transformer attention head.""" w_qk: vectorspace_fns.ScalarBilinear w_ov: vectorspace_fns.Linear residual_space: Optional[bases.VectorSpaceWithBasis] = None causal: bool = False def __post_init__(self): """Infer residual stream and typecheck subspaces.""" if self.residual_space is None: self.residual_space = bases.join_vector_spaces(self.w_qk.left_space, self.w_qk.right_space, self.w_ov.input_space, self.w_ov.output_space) assert self.w_qk.left_space.issubspace(self.residual_space) assert self.w_qk.right_space.issubspace(self.residual_space) assert self.w_ov.input_space.issubspace(self.residual_space) assert self.w_ov.output_space.issubspace(self.residual_space) def apply(self, x: bases.VectorInBasis) -> bases.VectorInBasis: assert x in self.residual_space # seq_len x query_space queries = x.project(self.w_qk.left_space) # seq_len x key_space keys = x.project(self.w_qk.right_space) attn_matrix = queries.magnitudes @ self.w_qk.matrix @ keys.magnitudes.T if self.causal: # The 1 gives us the matrix above the diagonal. mask = np.triu(np.full_like(attn_matrix, -np.inf), 1) attn_matrix = attn_matrix + mask attn_weights = _np_softmax(attn_matrix) # seq_len_from, seq_len_to values = self.w_ov_residual(x).magnitudes # seq_len_to, d_model magnitudes = attn_weights @ values # seq_len_from, d_model return bases.VectorInBasis(sorted(self.residual_space.basis), magnitudes) def w_ov_residual(self, x: bases.VectorInBasis) -> bases.VectorInBasis: """Wov but acting on the residual space.""" x = project(self.residual_space, self.w_ov.input_space)(x) out = self.w_ov(x) return project(self.w_ov.output_space, self.residual_space)(out) @property def num_heads(self) -> int: return 1 def as_multi(self) -> "MultiAttentionHead": return MultiAttentionHead([self]) @dataclasses.dataclass class MultiAttentionHead(Block): """Applies attention heads in parallel.""" sub_blocks: List[Union[AttentionHead, "MultiAttentionHead"]] def __post_init__(self): spaces = [block.residual_space for block in self.sub_blocks] self.residual_space, *others = spaces assert all(s == self.residual_space for s in others) def apply(self, x: bases.VectorInBasis) -> bases.VectorInBasis: # each element is seq_len x embedding outs = [block.apply(x) for block in self.sub_blocks] return bases.VectorInBasis.sum(outs) # seq_len x embedding @property def num_heads(self) -> int: return sum(sub_block.num_heads for sub_block in self.sub_blocks) def heads(self) -> Iterable[AttentionHead]: for sub_block in self.sub_blocks: if isinstance(sub_block, AttentionHead): yield sub_block elif isinstance(sub_block, MultiAttentionHead): yield from sub_block.heads() else: raise NotImplementedError() def as_multi(self) -> "MultiAttentionHead": return self @dataclasses.dataclass class MLP(Block): """A transformer MLP block.""" fst: vectorspace_fns.Linear snd: vectorspace_fns.Linear residual_space: Optional[bases.VectorSpaceWithBasis] = None def __post_init__(self): """Typecheck subspaces.""" if self.residual_space is None: self.residual_space = bases.join_vector_spaces(self.fst.input_space, self.snd.output_space) assert self.fst.output_space == self.snd.input_space assert self.fst.input_space.issubspace(self.residual_space) assert self.snd.output_space.issubspace(self.residual_space) def apply(self, x: bases.VectorInBasis) -> bases.VectorInBasis: assert x in self.residual_space x = project(self.residual_space, self.fst.input_space)(x) hidden = self.fst(x) hidden = relu(hidden) out = self.snd(hidden) return project(self.snd.output_space, self.residual_space)(out) @classmethod def combine_in_parallel(cls, mlps: Sequence["MLP"]) -> "MLP": fst = vectorspace_fns.Linear.combine_in_parallel( [block.fst for block in mlps]) snd = vectorspace_fns.Linear.combine_in_parallel( [block.snd for block in mlps]) return cls(fst=fst, snd=snd, residual_space=None) # Block that fits into a half-layer, without residual connections. HalfLayerBlock = Union[MLP, AttentionHead, MultiAttentionHead] @dataclasses.dataclass class SeriesWithResiduals(Block): """A series of blocks with residual connections.""" blocks: List[HalfLayerBlock] def __post_init__(self): spaces = [block.residual_space for block in self.blocks] self.residual_space = bases.join_vector_spaces(*spaces) def apply(self, x: bases.VectorInBasis) -> bases.VectorInBasis: x = x.project(self.residual_space) for block in self.blocks: x_in = x.project(block.residual_space) x_out = block.apply(x_in).project(self.residual_space) x = x + x_out return x