File size: 17,426 Bytes
15bcbe6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
# Copyright 2022 Google.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Hierarchical transformer."""

import functools
from typing import Any, Callable, Optional, Sequence, Tuple

from absl import logging

from flax import linen as nn
from flax import struct
import gin
import jax.numpy as jnp
from transformer import attention
from transformer import metric_utils
from transformer import nn_components
from transformer import position
from transformer import transformer_layer


Array = Any


# Basic task options are shared among multiple classes.
@gin.configurable
@struct.dataclass
class TransformerTaskConfig:
  """Configuration hyperparameters for sequence-to-sequence tasks."""

  dataset_name: str = "synthetic"
  train_split: str = "train"
  test_split: str = "test"
  sequential_chunks: bool = True  # Process chunks of text in sequential order.

  sequence_length: int = 4096
  batch_size: int = 1  # per device batch size
  vocab_size: int = 256


DStackDecoderState = Tuple[transformer_layer.DecoderState, ...]
DStackWindowState = Tuple[transformer_layer.WindowState, ...]


@gin.configurable
class DecoderStack(nn.Module):
  """Stack of transformer decoder layers."""

  mode: str
  task_config: TransformerTaskConfig = gin.REQUIRED

  # Configurable hyperparameters.
  num_layers: int = gin.REQUIRED
  embedding_size: int = gin.REQUIRED
  embedding_stddev: float = 1.0

  # The class to use for an individual transformer layer.
  layer_factory: Any = gin.REQUIRED

  # Window length to use for the decoder stack.
  # If nonzero, use this instead of TransformerLayer.window_length.
  dstack_window_length: int = 0
  use_absolute_positions: bool = False
  use_final_layernorm: bool = True
  final_dropout_rate: float = 0.0
  final_mlp_factory: Optional[Callable[[int], nn.Module]] = None

  # Enable recurrence on particular layers.
  recurrent_layer_indices: Sequence[int] = ()
  feedback_recurrence: bool = True

  # The factory function which creates a MemoryManager, or None.
  memory_factory: Any = None
  # Layers to equip with external memory.
  memory_layer_indices: Sequence[int] = ()

  dtype: Any = jnp.float32

  def is_training(self):
    return self.mode == "train"

  def supports_generate(self) -> bool:
    return all([lyr.supports_generate() for lyr in self.transformer_layers])

  def setup(self):
    task_config = self.task_config

    embed_init = nn.initializers.normal(stddev=self.embedding_stddev,
                                        dtype=jnp.float32)
    self.embed = nn.Embed(num_embeddings=task_config.vocab_size,
                          features=self.embedding_size,
                          embedding_init=embed_init)

    # Create a memory_factory.MemoryManager object, which is shared among
    # all transformer layers.  Each layer will use the MemoryManager object
    # to instantiate a block of memory for that layer.
    memory = None
    if self.memory_factory is not None:
      if self.memory_layer_indices:
        memory = self.memory_factory(batch_size=task_config.batch_size,
                                     mode=self.mode)
      else:
        logging.warning(
            "Memory factory specified, but memory_layer_indices is empty.")

    # Allow negative numbers in memory_layer_indices.
    # Negative numbers refer to layers at the top of the stack.
    for k in self.memory_layer_indices:
      if k < -self.num_layers or k >= self.num_layers:
        raise ValueError(f"Invalid memory layer index {k}")
    # The % operator will convert negative k to self.num_layers + k.
    mem_layer_indices = [
        idx % self.num_layers for idx in self.memory_layer_indices
    ]

    # Allow negative numbers in recurrent_layer_indices.
    for k in self.recurrent_layer_indices:
      if k < -self.num_layers or k >= self.num_layers:
        raise ValueError(f"Invalid recurrent layer index {k}")
    recurrent_layer_indices = [
        idx % self.num_layers for idx in self.recurrent_layer_indices
    ]
    # Turn on cross attention if there are recurrent layers with feedback.
    enable_cross_attn = (self.feedback_recurrence and
                         self.recurrent_layer_indices and
                         self.dstack_window_length > 0)

    layers = []
    for i in range(0, self.num_layers):
      mem = memory if (i in mem_layer_indices) else None
      rec_i = i in recurrent_layer_indices
      layer_fn = functools.partial(
          self.layer_factory,
          mode=self.mode,
          batch_size=self.task_config.batch_size,
          embedding_size=self.embedding_size,
          name=f"transformer{i}",
          recurrent_attention=rec_i,
          cross_attention=enable_cross_attn and not rec_i)
      if mem:
        logging.info("Using external memory with transformer layer %d.", i)
        layer_fn = functools.partial(
            layer_fn,
            memory=mem,
            # We use partial function applications here only to avoid
            # overwriting the head size unless memory is involved.
            head_size=mem.key_size,
            num_heads=mem.num_heads)
      layers.append(layer_fn())
    self.transformer_layers = layers

    if self.use_final_layernorm:
      self.final_layernorm = nn_components.LayerNorm()

    if self.final_mlp_factory is not None:
      self.final_mlp = self.final_mlp_factory(self.embedding_size)

  def init_decoder_state(self, sequence_length: int,
                         start_of_sequence: Array) -> DStackDecoderState:
    """Return initial state for autoregressive generation."""
    return tuple([
        layer.init_decoder_state(sequence_length, start_of_sequence)
        for layer in self.transformer_layers
    ])

  def load_window_state(self, start_of_sequence: Array) -> DStackWindowState:
    """Load cached state that is passed from one window to the next."""
    return tuple([
        layer.load_window_state(start_of_sequence)
        for layer in self.transformer_layers
    ])

  def store_window_state(self, window_state: DStackWindowState):
    """Write window state to the cache."""
    for (layer, wstate) in zip(self.transformer_layers, window_state):
      layer.store_window_state(wstate)

  def _eval_layer_stack(self, xs: Array, start_of_sequence: Array,
                        window_state: Optional[DStackWindowState],
                        decoder_state: Optional[DStackDecoderState]) -> (
                            Tuple[Array, Optional[DStackWindowState],
                                  Optional[DStackDecoderState], Any]):
    """Evaluate a stack of transformer layers on an input."""

    ys = xs  # (batch_size, seq_len, num_hidden)
    importance = None  # (batch_size, sequence_length)
    next_window_states = []
    next_decoder_states = []
    attn_viz_dicts = []

    # If we have a recurrent layer, grab the keys and values from it.
    # All other layers can then cross-attend to the recurrent keys and values.
    recurrent_kv = None
    enable_cross_attn = (self.feedback_recurrence and
                         self.recurrent_layer_indices and
                         self.dstack_window_length > 0)
    if enable_cross_attn and window_state is not None:
      # TODO(delesley): fix this so it works with the autoregressive decoder.
      assert decoder_state is None
      logging.info("dstack: using recurrent cross attention on all layers.")
      for (layer, wstate_i) in zip(self.transformer_layers, window_state):
        rkv = layer.get_recurrent_kv(wstate_i)
        if rkv is not None:
          recurrent_kv = rkv

    # Apply transformer layers.
    for (i, layer) in enumerate(self.transformer_layers):
      if layer.recurrent_attention:
        cross_kv = None  # The recurrent layer handles rkv internally.
      else:
        cross_kv = recurrent_kv  # Other layers cross-attend to recurrent one.

      logging.info("dstack: ---- Layer %d ----", i)
      wstate_i = None if window_state is None else window_state[i]
      dstate_i = None if decoder_state is None else decoder_state[i]
      (ys, importance, n_wstate_i, n_dstate_i, viz_dict) = layer(
          ys, start_of_sequence,
          importance=importance,
          cross_attention_kv=cross_kv,   # cross-attend to recurrent_kv.
          window_state=wstate_i,
          decoder_state=dstate_i)
      next_window_states.append(n_wstate_i)
      next_decoder_states.append(n_dstate_i)
      attn_viz_dicts.append(viz_dict)

    window_state = tuple(next_window_states)
    decoder_state = tuple(next_decoder_states)
    return (ys, window_state, decoder_state, attn_viz_dicts)

  def __call__(self,
               input_tokens: Array,
               target_tokens: Array,
               start_of_sequence: Array,
               decoder_state: Optional[DStackDecoderState] = None) -> (
                   Tuple[Array, Optional[DStackDecoderState], Any]):
    """Call the decoder stack.

    This function will embed tokens, run the embeddings through a stack of
    decoder layers, and then compute logits for the target tokens using the
    transpose of the embeddings.  It returns un-normalized (pre-softmax)
    logits.

    Args:
      input_tokens: Integer array of shape [batch_size, sequence_length]
      target_tokens: For compatibility.  Ignored by this class.
      start_of_sequence: Boolean array of shape [batch_size],
          which indicates whether a sequence is at the start of sequence.
      decoder_state: State object for autoregressive decoding,
          created from init_decoder_state.

    Returns:
       (logits, of shape [batch_size, sequence_length, vocab_size],
        next_decoder_state: for autoregressive decoding,
        viz_dict: dictionary of visualizations,
       )
    """
    del target_tokens
    task_config = self.task_config

    # Embed tokens.
    embeddings = self.embed(input_tokens)  # (batch_size, seq_len, num_hidden)
    embeddings = embeddings.astype(self.dtype)
    sequence_length = embeddings.shape[1]
    logging.info("dstack: embeddings = %r", embeddings)

    # Add absolute position encodings if necessary.
    if self.use_absolute_positions:
      # Use a large max_wavelength so that only part of the input vector
      # is used for positions.
      positions = position.position_encoding(
          num_positions=task_config.sequence_length,
          input_dim=self.embedding_size,
          max_wavelength=10_000)
      positions = jnp.asarray(positions, dtype=self.dtype)
      positions = jnp.expand_dims(positions, 0)  # Add batch dimension.
      logging.info("dstack: absolute positions = %r", positions)
      embeddings = embeddings + positions

    # Function to run the whole transformer stack on a single window.
    # ---------------------------------------------------------------
    def single_window_stack(carry, inputs_w):
      (window_state_w, start_of_seq_w) = carry
      (outputs_w, window_state_w, _, _) = self._eval_layer_stack(
          inputs_w, start_of_seq_w,
          window_state=window_state_w, decoder_state=None)

      # start_of_sequence is false after the first window.
      bsize = self.task_config.batch_size
      next_start_of_seq = jnp.asarray([False] * bsize, dtype=jnp.bool_)
      return ((window_state_w, next_start_of_seq), outputs_w)

    # Find the number of windows.  A sequence may be split into multiple
    # windows here, or alternatively, it may be split (or further split) within
    # TransformerLayer, depending on configuration.
    if (self.dstack_window_length == 0 or
        self.dstack_window_length >= sequence_length):
      num_windows = 1
    else:
      num_windows = sequence_length // self.dstack_window_length
      assert (num_windows * self.dstack_window_length) == sequence_length

    # Evaluate the stack of layers, scanning over windows if configured.
    # ------------------------------------------------------------------
    if decoder_state is None:
      logging.info("dstack: scanning over %d windows.", num_windows)
      # Load cached state from the previous training step, for truncated BPTT.
      window_state = self.load_window_state(start_of_sequence)

      # Scan single_window_stack over the sequence.
      cstate = (window_state, start_of_sequence)
      (cstate, ys) = attention.split_and_scan(single_window_stack,
                                              cstate,
                                              embeddings,
                                              sections=num_windows,
                                              axis=1)
      (window_state, _) = cstate

      # Cache state for the next training step, for truncated BPTT.
      self.store_window_state(window_state)
      attn_viz_dicts = {}  # Temporarily disabled.
    else:
      logging.info("dstack: autoregressive generator.")
      # Run as an autoregressive decoder: evaluate the whole stack on a token.
      # Do not load or store window_state; decoder_state is used instead.
      (ys, _, decoder_state, _) = self._eval_layer_stack(
          embeddings, start_of_sequence,
          window_state=None, decoder_state=decoder_state)
      attn_viz_dicts = {}

    # Apply layernorm to the final output, before calculating logits.
    # With a pre-layernorm architecture, this has to be done here.
    if self.use_final_layernorm:
      logging.info("dstack: Final layernorm.")
      ys = self.final_layernorm(ys)

    # Final dropout before token prediction.
    drop_tile_shape = (1, 128, self.embedding_size)
    get_dropout_rng = lambda: self.make_rng("dropout")
    ys = nn_components.tiled_dropout(ys, drop_tile_shape,
                                     self.final_dropout_rate,
                                     rng_function=get_dropout_rng,
                                     deterministic=not self.is_training())

    # Apply an MLP at the very end to convert the output of the transformer
    # into a vector to look up target tokens in the embedding table.
    # This final layer allows the NN to distinguish between the "input context",
    # which is returned by the transformer resnet, and the "predicted target".
    if self.final_mlp_factory is not None:
      logging.info("dstack: Final MLP layer.")
      ys = self.final_mlp(ys)

    # Reverse embedding to generate logits which predict the output tokens.
    logits = self.embed.attend(ys)  # (..., seq_len, vocab_size)
    logging.info("dstack: logits = %r", logits)

    # Normalize so that the range of logits is reasonable.
    logits = logits / jnp.sqrt(logits.shape[-1]).astype(self.dtype)

    # Produce various visualizations in generate mode.
    # TODO(delesley): Too many visualizations crashes the summary writer.
    if self.mode == "generate":
      img_dict = self._make_images(attn_viz_dicts, [])
      hist_dict = {}  # metric_utils.make_histograms(attn_viz_dicts)
      info_dict = {**img_dict, **hist_dict}
    else:
      info_dict = {}  # Don't output any visualizations.

    return (logits, decoder_state, info_dict)

  def _make_importance_image(self, importance_list, scaled=True) -> Array:
    rows = []
    for imp in importance_list:
      rows += [imp] * 8  # Rows are 8 pixels high for better visability.
    image = jnp.stack(rows)
    if scaled:
      image = jnp.exp(image)
    image = metric_utils.normalize_image(image, True)
    return metric_utils.reshape_image(image)

  def _make_images(self, viz_dicts, importance_list):
    image_dict = {}
    for (i, viz_dict) in enumerate(viz_dicts):
      if "attn_importance_gate" in viz_dict:
        imp_gate = viz_dict["attn_importance_gate"][0]  # First item in batch.
        imp_strip = metric_utils.normalize_image(imp_gate[:, 0:8, :], True)
      else:
        imp_strip = None

      for (k, attn_images) in viz_dict.items():
        if k not in {"attn_content",
                     "attn_pre_softmax",
                     "attn_log",
                     "attn",
                     "attn_position_bias",
                     "attn_importance_bias",
                     "attn_importance_gate"}:
          continue

        attn_img = attn_images[0]  # Grab the first item in the batch.
        attn_img = metric_utils.normalize_image(attn_img,
                                                as_group=(k != "attn"))
        if imp_strip is not None and k in {"attn_log", "attn"}:
          # Show importance bias in a strip at the bottom of the image.
          attn_img = metric_utils.overlay_images(attn_img, imp_strip)
        attn_img = metric_utils.reshape_image(attn_img)  # Returns None on fail.
        if attn_img is not None:
          image_dict[k + "_" + str(i)] = attn_img

    if importance_list:
      # Create an image out of the importance for each layer.
      image_dict["importance_gate"] = self._make_importance_image(
          importance_list, scaled=True)
      image_dict["importance_raw"] = self._make_importance_image(
          importance_list, scaled=False)
    return image_dict