Spaces:
Sleeping
Sleeping
# Copyright 2022 Google. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
"""Hierarchical transformer.""" | |
import functools | |
from typing import Any, Callable, Optional, Sequence, Tuple | |
from absl import logging | |
from flax import linen as nn | |
from flax import struct | |
import gin | |
import jax.numpy as jnp | |
from transformer import attention | |
from transformer import metric_utils | |
from transformer import nn_components | |
from transformer import position | |
from transformer import transformer_layer | |
Array = Any | |
# Basic task options are shared among multiple classes. | |
class TransformerTaskConfig: | |
"""Configuration hyperparameters for sequence-to-sequence tasks.""" | |
dataset_name: str = "synthetic" | |
train_split: str = "train" | |
test_split: str = "test" | |
sequential_chunks: bool = True # Process chunks of text in sequential order. | |
sequence_length: int = 4096 | |
batch_size: int = 1 # per device batch size | |
vocab_size: int = 256 | |
DStackDecoderState = Tuple[transformer_layer.DecoderState, ...] | |
DStackWindowState = Tuple[transformer_layer.WindowState, ...] | |
class DecoderStack(nn.Module): | |
"""Stack of transformer decoder layers.""" | |
mode: str | |
task_config: TransformerTaskConfig = gin.REQUIRED | |
# Configurable hyperparameters. | |
num_layers: int = gin.REQUIRED | |
embedding_size: int = gin.REQUIRED | |
embedding_stddev: float = 1.0 | |
# The class to use for an individual transformer layer. | |
layer_factory: Any = gin.REQUIRED | |
# Window length to use for the decoder stack. | |
# If nonzero, use this instead of TransformerLayer.window_length. | |
dstack_window_length: int = 0 | |
use_absolute_positions: bool = False | |
use_final_layernorm: bool = True | |
final_dropout_rate: float = 0.0 | |
final_mlp_factory: Optional[Callable[[int], nn.Module]] = None | |
# Enable recurrence on particular layers. | |
recurrent_layer_indices: Sequence[int] = () | |
feedback_recurrence: bool = True | |
# The factory function which creates a MemoryManager, or None. | |
memory_factory: Any = None | |
# Layers to equip with external memory. | |
memory_layer_indices: Sequence[int] = () | |
dtype: Any = jnp.float32 | |
def is_training(self): | |
return self.mode == "train" | |
def supports_generate(self) -> bool: | |
return all([lyr.supports_generate() for lyr in self.transformer_layers]) | |
def setup(self): | |
task_config = self.task_config | |
embed_init = nn.initializers.normal(stddev=self.embedding_stddev, | |
dtype=jnp.float32) | |
self.embed = nn.Embed(num_embeddings=task_config.vocab_size, | |
features=self.embedding_size, | |
embedding_init=embed_init) | |
# Create a memory_factory.MemoryManager object, which is shared among | |
# all transformer layers. Each layer will use the MemoryManager object | |
# to instantiate a block of memory for that layer. | |
memory = None | |
if self.memory_factory is not None: | |
if self.memory_layer_indices: | |
memory = self.memory_factory(batch_size=task_config.batch_size, | |
mode=self.mode) | |
else: | |
logging.warning( | |
"Memory factory specified, but memory_layer_indices is empty.") | |
# Allow negative numbers in memory_layer_indices. | |
# Negative numbers refer to layers at the top of the stack. | |
for k in self.memory_layer_indices: | |
if k < -self.num_layers or k >= self.num_layers: | |
raise ValueError(f"Invalid memory layer index {k}") | |
# The % operator will convert negative k to self.num_layers + k. | |
mem_layer_indices = [ | |
idx % self.num_layers for idx in self.memory_layer_indices | |
] | |
# Allow negative numbers in recurrent_layer_indices. | |
for k in self.recurrent_layer_indices: | |
if k < -self.num_layers or k >= self.num_layers: | |
raise ValueError(f"Invalid recurrent layer index {k}") | |
recurrent_layer_indices = [ | |
idx % self.num_layers for idx in self.recurrent_layer_indices | |
] | |
# Turn on cross attention if there are recurrent layers with feedback. | |
enable_cross_attn = (self.feedback_recurrence and | |
self.recurrent_layer_indices and | |
self.dstack_window_length > 0) | |
layers = [] | |
for i in range(0, self.num_layers): | |
mem = memory if (i in mem_layer_indices) else None | |
rec_i = i in recurrent_layer_indices | |
layer_fn = functools.partial( | |
self.layer_factory, | |
mode=self.mode, | |
batch_size=self.task_config.batch_size, | |
embedding_size=self.embedding_size, | |
name=f"transformer{i}", | |
recurrent_attention=rec_i, | |
cross_attention=enable_cross_attn and not rec_i) | |
if mem: | |
logging.info("Using external memory with transformer layer %d.", i) | |
layer_fn = functools.partial( | |
layer_fn, | |
memory=mem, | |
# We use partial function applications here only to avoid | |
# overwriting the head size unless memory is involved. | |
head_size=mem.key_size, | |
num_heads=mem.num_heads) | |
layers.append(layer_fn()) | |
self.transformer_layers = layers | |
if self.use_final_layernorm: | |
self.final_layernorm = nn_components.LayerNorm() | |
if self.final_mlp_factory is not None: | |
self.final_mlp = self.final_mlp_factory(self.embedding_size) | |
def init_decoder_state(self, sequence_length: int, | |
start_of_sequence: Array) -> DStackDecoderState: | |
"""Return initial state for autoregressive generation.""" | |
return tuple([ | |
layer.init_decoder_state(sequence_length, start_of_sequence) | |
for layer in self.transformer_layers | |
]) | |
def load_window_state(self, start_of_sequence: Array) -> DStackWindowState: | |
"""Load cached state that is passed from one window to the next.""" | |
return tuple([ | |
layer.load_window_state(start_of_sequence) | |
for layer in self.transformer_layers | |
]) | |
def store_window_state(self, window_state: DStackWindowState): | |
"""Write window state to the cache.""" | |
for (layer, wstate) in zip(self.transformer_layers, window_state): | |
layer.store_window_state(wstate) | |
def _eval_layer_stack(self, xs: Array, start_of_sequence: Array, | |
window_state: Optional[DStackWindowState], | |
decoder_state: Optional[DStackDecoderState]) -> ( | |
Tuple[Array, Optional[DStackWindowState], | |
Optional[DStackDecoderState], Any]): | |
"""Evaluate a stack of transformer layers on an input.""" | |
ys = xs # (batch_size, seq_len, num_hidden) | |
importance = None # (batch_size, sequence_length) | |
next_window_states = [] | |
next_decoder_states = [] | |
attn_viz_dicts = [] | |
# If we have a recurrent layer, grab the keys and values from it. | |
# All other layers can then cross-attend to the recurrent keys and values. | |
recurrent_kv = None | |
enable_cross_attn = (self.feedback_recurrence and | |
self.recurrent_layer_indices and | |
self.dstack_window_length > 0) | |
if enable_cross_attn and window_state is not None: | |
# TODO(delesley): fix this so it works with the autoregressive decoder. | |
assert decoder_state is None | |
logging.info("dstack: using recurrent cross attention on all layers.") | |
for (layer, wstate_i) in zip(self.transformer_layers, window_state): | |
rkv = layer.get_recurrent_kv(wstate_i) | |
if rkv is not None: | |
recurrent_kv = rkv | |
# Apply transformer layers. | |
for (i, layer) in enumerate(self.transformer_layers): | |
if layer.recurrent_attention: | |
cross_kv = None # The recurrent layer handles rkv internally. | |
else: | |
cross_kv = recurrent_kv # Other layers cross-attend to recurrent one. | |
logging.info("dstack: ---- Layer %d ----", i) | |
wstate_i = None if window_state is None else window_state[i] | |
dstate_i = None if decoder_state is None else decoder_state[i] | |
(ys, importance, n_wstate_i, n_dstate_i, viz_dict) = layer( | |
ys, start_of_sequence, | |
importance=importance, | |
cross_attention_kv=cross_kv, # cross-attend to recurrent_kv. | |
window_state=wstate_i, | |
decoder_state=dstate_i) | |
next_window_states.append(n_wstate_i) | |
next_decoder_states.append(n_dstate_i) | |
attn_viz_dicts.append(viz_dict) | |
window_state = tuple(next_window_states) | |
decoder_state = tuple(next_decoder_states) | |
return (ys, window_state, decoder_state, attn_viz_dicts) | |
def __call__(self, | |
input_tokens: Array, | |
target_tokens: Array, | |
start_of_sequence: Array, | |
decoder_state: Optional[DStackDecoderState] = None) -> ( | |
Tuple[Array, Optional[DStackDecoderState], Any]): | |
"""Call the decoder stack. | |
This function will embed tokens, run the embeddings through a stack of | |
decoder layers, and then compute logits for the target tokens using the | |
transpose of the embeddings. It returns un-normalized (pre-softmax) | |
logits. | |
Args: | |
input_tokens: Integer array of shape [batch_size, sequence_length] | |
target_tokens: For compatibility. Ignored by this class. | |
start_of_sequence: Boolean array of shape [batch_size], | |
which indicates whether a sequence is at the start of sequence. | |
decoder_state: State object for autoregressive decoding, | |
created from init_decoder_state. | |
Returns: | |
(logits, of shape [batch_size, sequence_length, vocab_size], | |
next_decoder_state: for autoregressive decoding, | |
viz_dict: dictionary of visualizations, | |
) | |
""" | |
del target_tokens | |
task_config = self.task_config | |
# Embed tokens. | |
embeddings = self.embed(input_tokens) # (batch_size, seq_len, num_hidden) | |
embeddings = embeddings.astype(self.dtype) | |
sequence_length = embeddings.shape[1] | |
logging.info("dstack: embeddings = %r", embeddings) | |
# Add absolute position encodings if necessary. | |
if self.use_absolute_positions: | |
# Use a large max_wavelength so that only part of the input vector | |
# is used for positions. | |
positions = position.position_encoding( | |
num_positions=task_config.sequence_length, | |
input_dim=self.embedding_size, | |
max_wavelength=10_000) | |
positions = jnp.asarray(positions, dtype=self.dtype) | |
positions = jnp.expand_dims(positions, 0) # Add batch dimension. | |
logging.info("dstack: absolute positions = %r", positions) | |
embeddings = embeddings + positions | |
# Function to run the whole transformer stack on a single window. | |
# --------------------------------------------------------------- | |
def single_window_stack(carry, inputs_w): | |
(window_state_w, start_of_seq_w) = carry | |
(outputs_w, window_state_w, _, _) = self._eval_layer_stack( | |
inputs_w, start_of_seq_w, | |
window_state=window_state_w, decoder_state=None) | |
# start_of_sequence is false after the first window. | |
bsize = self.task_config.batch_size | |
next_start_of_seq = jnp.asarray([False] * bsize, dtype=jnp.bool_) | |
return ((window_state_w, next_start_of_seq), outputs_w) | |
# Find the number of windows. A sequence may be split into multiple | |
# windows here, or alternatively, it may be split (or further split) within | |
# TransformerLayer, depending on configuration. | |
if (self.dstack_window_length == 0 or | |
self.dstack_window_length >= sequence_length): | |
num_windows = 1 | |
else: | |
num_windows = sequence_length // self.dstack_window_length | |
assert (num_windows * self.dstack_window_length) == sequence_length | |
# Evaluate the stack of layers, scanning over windows if configured. | |
# ------------------------------------------------------------------ | |
if decoder_state is None: | |
logging.info("dstack: scanning over %d windows.", num_windows) | |
# Load cached state from the previous training step, for truncated BPTT. | |
window_state = self.load_window_state(start_of_sequence) | |
# Scan single_window_stack over the sequence. | |
cstate = (window_state, start_of_sequence) | |
(cstate, ys) = attention.split_and_scan(single_window_stack, | |
cstate, | |
embeddings, | |
sections=num_windows, | |
axis=1) | |
(window_state, _) = cstate | |
# Cache state for the next training step, for truncated BPTT. | |
self.store_window_state(window_state) | |
attn_viz_dicts = {} # Temporarily disabled. | |
else: | |
logging.info("dstack: autoregressive generator.") | |
# Run as an autoregressive decoder: evaluate the whole stack on a token. | |
# Do not load or store window_state; decoder_state is used instead. | |
(ys, _, decoder_state, _) = self._eval_layer_stack( | |
embeddings, start_of_sequence, | |
window_state=None, decoder_state=decoder_state) | |
attn_viz_dicts = {} | |
# Apply layernorm to the final output, before calculating logits. | |
# With a pre-layernorm architecture, this has to be done here. | |
if self.use_final_layernorm: | |
logging.info("dstack: Final layernorm.") | |
ys = self.final_layernorm(ys) | |
# Final dropout before token prediction. | |
drop_tile_shape = (1, 128, self.embedding_size) | |
get_dropout_rng = lambda: self.make_rng("dropout") | |
ys = nn_components.tiled_dropout(ys, drop_tile_shape, | |
self.final_dropout_rate, | |
rng_function=get_dropout_rng, | |
deterministic=not self.is_training()) | |
# Apply an MLP at the very end to convert the output of the transformer | |
# into a vector to look up target tokens in the embedding table. | |
# This final layer allows the NN to distinguish between the "input context", | |
# which is returned by the transformer resnet, and the "predicted target". | |
if self.final_mlp_factory is not None: | |
logging.info("dstack: Final MLP layer.") | |
ys = self.final_mlp(ys) | |
# Reverse embedding to generate logits which predict the output tokens. | |
logits = self.embed.attend(ys) # (..., seq_len, vocab_size) | |
logging.info("dstack: logits = %r", logits) | |
# Normalize so that the range of logits is reasonable. | |
logits = logits / jnp.sqrt(logits.shape[-1]).astype(self.dtype) | |
# Produce various visualizations in generate mode. | |
# TODO(delesley): Too many visualizations crashes the summary writer. | |
if self.mode == "generate": | |
img_dict = self._make_images(attn_viz_dicts, []) | |
hist_dict = {} # metric_utils.make_histograms(attn_viz_dicts) | |
info_dict = {**img_dict, **hist_dict} | |
else: | |
info_dict = {} # Don't output any visualizations. | |
return (logits, decoder_state, info_dict) | |
def _make_importance_image(self, importance_list, scaled=True) -> Array: | |
rows = [] | |
for imp in importance_list: | |
rows += [imp] * 8 # Rows are 8 pixels high for better visability. | |
image = jnp.stack(rows) | |
if scaled: | |
image = jnp.exp(image) | |
image = metric_utils.normalize_image(image, True) | |
return metric_utils.reshape_image(image) | |
def _make_images(self, viz_dicts, importance_list): | |
image_dict = {} | |
for (i, viz_dict) in enumerate(viz_dicts): | |
if "attn_importance_gate" in viz_dict: | |
imp_gate = viz_dict["attn_importance_gate"][0] # First item in batch. | |
imp_strip = metric_utils.normalize_image(imp_gate[:, 0:8, :], True) | |
else: | |
imp_strip = None | |
for (k, attn_images) in viz_dict.items(): | |
if k not in {"attn_content", | |
"attn_pre_softmax", | |
"attn_log", | |
"attn", | |
"attn_position_bias", | |
"attn_importance_bias", | |
"attn_importance_gate"}: | |
continue | |
attn_img = attn_images[0] # Grab the first item in the batch. | |
attn_img = metric_utils.normalize_image(attn_img, | |
as_group=(k != "attn")) | |
if imp_strip is not None and k in {"attn_log", "attn"}: | |
# Show importance bias in a strip at the bottom of the image. | |
attn_img = metric_utils.overlay_images(attn_img, imp_strip) | |
attn_img = metric_utils.reshape_image(attn_img) # Returns None on fail. | |
if attn_img is not None: | |
image_dict[k + "_" + str(i)] = attn_img | |
if importance_list: | |
# Create an image out of the importance for each layer. | |
image_dict["importance_gate"] = self._make_importance_image( | |
importance_list, scaled=True) | |
image_dict["importance_raw"] = self._make_importance_image( | |
importance_list, scaled=False) | |
return image_dict | |