GeoGenSolve / aglib /meliad /transformer /configs /trainer_configuration.gin
HugoVoxx's picture
Upload 3 files
1cdcee4 verified
raw
history blame
1.55 kB
import optimizer_config
import training_loop
from transformer import models
from transformer import text_dataset
# Training setup.
training_loop.Trainer:
model_definition = @models.DecoderOnlyLanguageModel
num_steps = 250_000
status_every_steps = 10
log_every_steps = 1000
test_every_steps = 1000
num_test_steps = 400
generate_every_steps = 5000
print_input_every_steps = 5000
checkpoint_every_steps = 5000
save_checkpoints = True
restore_checkpoints = True
use_separate_metric_directories = False
optimizer_factory = @optimizer_config.FlaxAdafactorConfig()
learning_rate_schedule = @optimizer_config.lr_cosine_decay
max_scheduled_steps = 0 # Use num_steps as max_scheduled_steps.
warmup_steps = 1000
learning_rate_multiplier = 1.0
rng_key_names = ("dropout", "sample")
text_dataset.load_text_dataset:
verbose = False # if true, prints the start of every book/repo read from disk
# Use cosine decay to max_scheduled_steps, as described in Chinchilla:
# https://arxiv.org/abs/2203.15556
optimizer_config.lr_cosine_decay:
max_lr = 0.01
min_lr = 0.001
decay_after = True
spike_steps = 0
spike_lr = 0.0
# Adam optimizer configuration.
# optimizer_config.AdamConfig:
# learning_rate = 0.05 # Will be multiplied by the LR schedule.
# beta1 = 0.9
# beta2 = 0.98
# weight_decay_rate = 0.0
# Adafactor optimizer configuration.
optimizer_config.FlaxAdafactorConfig:
learning_rate = 1.0 # Will be multiplied by the LR schedule.
beta1 = 0.9 # Can be "None".