# Base configuration for the Hierarchical Transformer. include "trainer_configuration.gin" # Imports from transformer import attention from transformer import decoder_stack from transformer import models from transformer import nn_components from transformer import transformer_base from transformer import transformer_layer NUM_LAYERS = 12 NUM_HEADS = 8 HEAD_DIM = 128 EMBED_DIM = 512 # Size of embedding vector for each token MLP_DIM = 2048 # Number of hidden units in transformer FFN NUM_EMBEDDINGS = 256 # Number of tokens in vocabulary. DROPOUT_RATE = 0.05 ATTN_DROPOUT_RATE = 0.05 # For training on TPU. DTYPE="bfloat16" # Configure the input task. decoder_stack.TransformerTaskConfig: dataset_name = "synthetic" train_split = "train" test_split = "test" sequence_length = 512 batch_size = 8 vocab_size = %NUM_EMBEDDINGS transformer_layer.TransformerLayer: num_heads = %NUM_HEADS head_size = %HEAD_DIM window_length = 512 use_long_xl_architecture = True max_unrolled_windows = -1 # Always unroll. relative_position_type = "t5" # Can be "fourier", "t5", or None. use_causal_mask = True attn_dropout_rate = %ATTN_DROPOUT_RATE # Attention matrix dropout. memory_num_neighbors = 0 compute_importance = False dtype = %DTYPE transformer_base.TransformerBase: attn_mlp_factory = @transformer_attn/nn_components.MLP ffn_factory = @transformer_ffn/nn_components.MLP normalize_keys = True # More stable with Transformer XL. dropout_rate = %DROPOUT_RATE pre_attn_dropout = True post_attn_dropout = False pre_ffn_dropout = False post_ffn_dropout = True transformer_attn/nn_components.MLP: num_layers = 1 # Just a single dense matmul. num_hidden_units = 0 hidden_activation = None use_bias = False transformer_ffn/nn_components.MLP: num_layers = 2 num_hidden_units = %MLP_DIM hidden_activation = "relu" use_bias = False decoder_stack.DecoderStack: # task_config will be passed in from DecoderOnlyLanguageModel. num_layers = %NUM_LAYERS embedding_size = %EMBED_DIM embedding_stddev = 1.0 layer_factory = @transformer_layer.TransformerLayer dstack_window_length = 0 use_absolute_positions = False use_final_layernorm = True # Final layernorm before token lookup. final_dropout_rate = %DROPOUT_RATE # Dropout before token lookup. final_mlp_factory = None # Final MLP to predict target tokens. recurrent_layer_indices = () memory_factory = None # e.g. @memory_factory.memory_on_tpu_factory memory_layer_indices = () dtype = %DTYPE models.DecoderOnlyLanguageModel: task_config = @decoder_stack.TransformerTaskConfig() decoder_factory = @decoder_stack.DecoderStack nn_components.LayerNorm: use_scale = True use_bias = False use_mean = False # Calculate and adjust for the mean as well as the scale. dtype = %DTYPE