HugoVoxx commited on
Commit
1cdcee4
·
verified ·
1 Parent(s): 64ef4b4

Upload 3 files

Browse files
aglib/meliad/transformer/configs/base_htrans.gin ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Base configuration for the Hierarchical Transformer.
2
+
3
+ include "trainer_configuration.gin"
4
+
5
+ # Imports
6
+ from transformer import attention
7
+ from transformer import decoder_stack
8
+ from transformer import models
9
+ from transformer import nn_components
10
+ from transformer import transformer_base
11
+ from transformer import transformer_layer
12
+
13
+
14
+ NUM_LAYERS = 12
15
+ NUM_HEADS = 8
16
+ HEAD_DIM = 128
17
+ EMBED_DIM = 512 # Size of embedding vector for each token
18
+ MLP_DIM = 2048 # Number of hidden units in transformer FFN
19
+ NUM_EMBEDDINGS = 256 # Number of tokens in vocabulary.
20
+ DROPOUT_RATE = 0.05
21
+ ATTN_DROPOUT_RATE = 0.05
22
+
23
+ # For training on TPU.
24
+ DTYPE="bfloat16"
25
+
26
+ # Configure the input task.
27
+ decoder_stack.TransformerTaskConfig:
28
+ dataset_name = "synthetic"
29
+ train_split = "train"
30
+ test_split = "test"
31
+ sequence_length = 512
32
+ batch_size = 8
33
+ vocab_size = %NUM_EMBEDDINGS
34
+
35
+ transformer_layer.TransformerLayer:
36
+ num_heads = %NUM_HEADS
37
+ head_size = %HEAD_DIM
38
+ window_length = 512
39
+ use_long_xl_architecture = True
40
+ max_unrolled_windows = -1 # Always unroll.
41
+ relative_position_type = "t5" # Can be "fourier", "t5", or None.
42
+ use_causal_mask = True
43
+ attn_dropout_rate = %ATTN_DROPOUT_RATE # Attention matrix dropout.
44
+
45
+ memory_num_neighbors = 0
46
+ compute_importance = False
47
+ dtype = %DTYPE
48
+
49
+ transformer_base.TransformerBase:
50
+ attn_mlp_factory = @transformer_attn/nn_components.MLP
51
+ ffn_factory = @transformer_ffn/nn_components.MLP
52
+ normalize_keys = True # More stable with Transformer XL.
53
+ dropout_rate = %DROPOUT_RATE
54
+ pre_attn_dropout = True
55
+ post_attn_dropout = False
56
+ pre_ffn_dropout = False
57
+ post_ffn_dropout = True
58
+
59
+ transformer_attn/nn_components.MLP:
60
+ num_layers = 1 # Just a single dense matmul.
61
+ num_hidden_units = 0
62
+ hidden_activation = None
63
+ use_bias = False
64
+
65
+ transformer_ffn/nn_components.MLP:
66
+ num_layers = 2
67
+ num_hidden_units = %MLP_DIM
68
+ hidden_activation = "relu"
69
+ use_bias = False
70
+
71
+ decoder_stack.DecoderStack:
72
+ # task_config will be passed in from DecoderOnlyLanguageModel.
73
+ num_layers = %NUM_LAYERS
74
+ embedding_size = %EMBED_DIM
75
+ embedding_stddev = 1.0
76
+ layer_factory = @transformer_layer.TransformerLayer
77
+ dstack_window_length = 0
78
+ use_absolute_positions = False
79
+ use_final_layernorm = True # Final layernorm before token lookup.
80
+ final_dropout_rate = %DROPOUT_RATE # Dropout before token lookup.
81
+ final_mlp_factory = None # Final MLP to predict target tokens.
82
+ recurrent_layer_indices = ()
83
+ memory_factory = None # e.g. @memory_factory.memory_on_tpu_factory
84
+ memory_layer_indices = ()
85
+ dtype = %DTYPE
86
+
87
+ models.DecoderOnlyLanguageModel:
88
+ task_config = @decoder_stack.TransformerTaskConfig()
89
+ decoder_factory = @decoder_stack.DecoderStack
90
+
91
+ nn_components.LayerNorm:
92
+ use_scale = True
93
+ use_bias = False
94
+ use_mean = False # Calculate and adjust for the mean as well as the scale.
95
+ dtype = %DTYPE
96
+
aglib/meliad/transformer/configs/memory_configuration.gin ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Configure external memory.
2
+ # This file should be included after base_htrans.gin.
3
+
4
+ import training_loop
5
+ from transformer import memory_factory
6
+
7
+ MEMORY_HEAD_DIM = %HEAD_DIM
8
+ NUM_MEMORY_HEADS = %gin.REQUIRED
9
+
10
+ memory_factory.memory_on_tpu_factory:
11
+ num_heads = %NUM_MEMORY_HEADS
12
+ key_size = %MEMORY_HEAD_DIM
13
+ value_size = %MEMORY_HEAD_DIM
14
+ database_size = 8192
15
+ dtype = %DTYPE # defined in base_htrans.gin
16
+
17
+
18
+ training_loop.Trainer:
19
+ log_every_steps = 100 # memory can slow down training, need responsive stats
20
+ checkpoint_every_steps = 1000
21
+ generate_every_steps = 0 # disable generate mode when using external memory.
22
+
aglib/meliad/transformer/configs/trainer_configuration.gin ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import optimizer_config
3
+ import training_loop
4
+ from transformer import models
5
+ from transformer import text_dataset
6
+
7
+ # Training setup.
8
+ training_loop.Trainer:
9
+ model_definition = @models.DecoderOnlyLanguageModel
10
+
11
+ num_steps = 250_000
12
+ status_every_steps = 10
13
+ log_every_steps = 1000
14
+ test_every_steps = 1000
15
+ num_test_steps = 400
16
+ generate_every_steps = 5000
17
+ print_input_every_steps = 5000
18
+ checkpoint_every_steps = 5000
19
+ save_checkpoints = True
20
+ restore_checkpoints = True
21
+ use_separate_metric_directories = False
22
+
23
+ optimizer_factory = @optimizer_config.FlaxAdafactorConfig()
24
+ learning_rate_schedule = @optimizer_config.lr_cosine_decay
25
+ max_scheduled_steps = 0 # Use num_steps as max_scheduled_steps.
26
+ warmup_steps = 1000
27
+ learning_rate_multiplier = 1.0
28
+
29
+ rng_key_names = ("dropout", "sample")
30
+
31
+ text_dataset.load_text_dataset:
32
+ verbose = False # if true, prints the start of every book/repo read from disk
33
+
34
+ # Use cosine decay to max_scheduled_steps, as described in Chinchilla:
35
+ # https://arxiv.org/abs/2203.15556
36
+ optimizer_config.lr_cosine_decay:
37
+ max_lr = 0.01
38
+ min_lr = 0.001
39
+ decay_after = True
40
+ spike_steps = 0
41
+ spike_lr = 0.0
42
+
43
+
44
+ # Adam optimizer configuration.
45
+ # optimizer_config.AdamConfig:
46
+ # learning_rate = 0.05 # Will be multiplied by the LR schedule.
47
+ # beta1 = 0.9
48
+ # beta2 = 0.98
49
+ # weight_decay_rate = 0.0
50
+
51
+ # Adafactor optimizer configuration.
52
+ optimizer_config.FlaxAdafactorConfig:
53
+ learning_rate = 1.0 # Will be multiplied by the LR schedule.
54
+ beta1 = 0.9 # Can be "None".
55
+
56
+