File size: 6,895 Bytes
9bdaa77
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c7eb2f2
9bdaa77
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c46567d
9bdaa77
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c46567d
9bdaa77
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
# Copyright 2022 DeepMind Technologies Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Pieces for making transformers."""

import abc
import dataclasses
from typing import Iterable, List, Optional, Sequence, Union

import numpy as np

from tracr.craft import bases
from tracr.craft import vectorspace_fns

project = vectorspace_fns.project


def _np_softmax(x, axis=-1):
  x_max = np.max(x, axis=axis, keepdims=True)
  return np.exp(x - x_max) / np.sum(np.exp(x - x_max), axis=axis, keepdims=True)


def _np_relu(x):
  return np.where(x > 0, x, 0)


def relu(x: bases.VectorInBasis) -> bases.VectorInBasis:
  return bases.VectorInBasis(x.basis_directions, _np_relu(x.magnitudes))


class Block(abc.ABC):
  """Transformer block, acting on a sequence of vector space elements.

  Attributes:
    residual_space: Vector space that contains all subspaces the Block interacts
      with. This can be either the full residual space of a model or a subspace.
  """
  residual_space: bases.VectorSpaceWithBasis

  @abc.abstractmethod
  def apply(self, x: bases.VectorInBasis) -> bases.VectorInBasis:
    """Applies self to an input."""


@dataclasses.dataclass
class AttentionHead(Block):
  """A transformer attention head."""
  w_qk: vectorspace_fns.ScalarBilinear
  w_ov: vectorspace_fns.Linear
  residual_space: Optional[bases.VectorSpaceWithBasis] = None
  causal: bool = False

  def __post_init__(self):
    """Infer residual stream and typecheck subspaces."""
    if self.residual_space is None:
      self.residual_space = bases.join_vector_spaces(self.w_qk.left_space,
                                                     self.w_qk.right_space,
                                                     self.w_ov.input_space,
                                                     self.w_ov.output_space)

    assert self.w_qk.left_space.issubspace(self.residual_space)
    assert self.w_qk.right_space.issubspace(self.residual_space)
    assert self.w_ov.input_space.issubspace(self.residual_space)
    assert self.w_ov.output_space.issubspace(self.residual_space)

  def apply(self, x: bases.VectorInBasis) -> bases.VectorInBasis:
    assert x in self.residual_space
    # seq_len x query_space
    queries = x.project(self.w_qk.left_space)
    # seq_len x key_space
    keys = x.project(self.w_qk.right_space)

    attn_matrix = queries.magnitudes @ self.w_qk.matrix @ keys.magnitudes.T

    if self.causal:
      # The 1 gives us the matrix above the diagonal.
      mask = np.triu(np.full_like(attn_matrix, -np.inf), 1)
      attn_matrix = attn_matrix + mask

    attn_weights = _np_softmax(attn_matrix)  # seq_len_from, seq_len_to
    values = self.w_ov_residual(x).magnitudes  # seq_len_to, d_model

    magnitudes = attn_weights @ values  # seq_len_from, d_model
    return bases.VectorInBasis(sorted(self.residual_space.basis), magnitudes)

  def w_ov_residual(self, x: bases.VectorInBasis) -> bases.VectorInBasis:
    """Wov but acting on the residual space."""
    x = project(self.residual_space, self.w_ov.input_space)(x)
    out = self.w_ov(x)
    return project(self.w_ov.output_space, self.residual_space)(out)

  @property
  def num_heads(self) -> int:
    return 1

  def as_multi(self) -> "MultiAttentionHead":
    return MultiAttentionHead([self])


@dataclasses.dataclass
class MultiAttentionHead(Block):
  """Applies attention heads in parallel."""
  sub_blocks: List[Union[AttentionHead, "MultiAttentionHead"]]

  def __post_init__(self):
    spaces = [block.residual_space for block in self.sub_blocks]
    self.residual_space, *others = spaces
    assert all(s == self.residual_space for s in others)

  def apply(self, x: bases.VectorInBasis) -> bases.VectorInBasis:
    # each element is seq_len x embedding
    outs = [block.apply(x) for block in self.sub_blocks]
    return bases.VectorInBasis.sum(outs)  # seq_len x embedding

  @property
  def num_heads(self) -> int:
    return sum(sub_block.num_heads for sub_block in self.sub_blocks)

  def heads(self) -> Iterable[AttentionHead]:
    for sub_block in self.sub_blocks:
      if isinstance(sub_block, AttentionHead):
        yield sub_block
      elif isinstance(sub_block, MultiAttentionHead):
        yield from sub_block.heads()
      else:
        raise NotImplementedError()

  def as_multi(self) -> "MultiAttentionHead":
    return self


@dataclasses.dataclass
class MLP(Block):
  """A transformer MLP block."""
  fst: vectorspace_fns.Linear
  snd: vectorspace_fns.Linear
  residual_space: Optional[bases.VectorSpaceWithBasis] = None

  def __post_init__(self):
    """Typecheck subspaces."""
    if self.residual_space is None:
      self.residual_space = bases.join_vector_spaces(self.fst.input_space,
                                                     self.snd.output_space)

    assert self.fst.output_space == self.snd.input_space
    assert self.fst.input_space.issubspace(self.residual_space)
    assert self.snd.output_space.issubspace(self.residual_space)

  def apply(self, x: bases.VectorInBasis) -> bases.VectorInBasis:
    assert x in self.residual_space

    x = project(self.residual_space, self.fst.input_space)(x)
    hidden = self.fst(x)
    hidden = relu(hidden)
    out = self.snd(hidden)
    return project(self.snd.output_space, self.residual_space)(out)

  @classmethod
  def combine_in_parallel(cls, mlps: Sequence["MLP"]) -> "MLP":
    fst = vectorspace_fns.Linear.combine_in_parallel(
        [block.fst for block in mlps])
    snd = vectorspace_fns.Linear.combine_in_parallel(
        [block.snd for block in mlps])
    return cls(fst=fst, snd=snd, residual_space=None)


# Block that fits into a half-layer, without residual connections.
HalfLayerBlock = Union[MLP, AttentionHead, MultiAttentionHead]


@dataclasses.dataclass
class SeriesWithResiduals(Block):
  """A series of blocks with residual connections."""
  blocks: List[HalfLayerBlock]

  def __post_init__(self):
    spaces = [block.residual_space for block in self.blocks]
    self.residual_space = bases.join_vector_spaces(*spaces)

  def apply(self, x: bases.VectorInBasis) -> bases.VectorInBasis:
    x = x.project(self.residual_space)
    for block in self.blocks:
      x_in = x.project(block.residual_space)
      x_out = block.apply(x_in).project(self.residual_space)
      x = x + x_out
    return x