# Model architecture | |
model: | |
input_dim: 3 # RGB images | |
hidden_dim: 512 | |
num_blocks: 8 | |
num_heads: 8 | |
patch_size: 8 | |
patch_stride: 4 | |
time_freq_dim: 256 | |
time_max_period: 1024 | |
mlp_ratio: 4 | |
use_bias: false | |
padding: "SAME" | |
pos_embed_cls_token: false | |
pos_embed_extra_tokens: 0 | |
# Training parameters | |
training: | |
learning_rate: 1.0e-4 | |
batch_size: 128 | |
num_steps: 1_000_000 | |
warmup_pct: 0.01 | |
weight_decay: 0.0 | |
grad_clip_norm: 100.0 | |
# Checkpointing and logging | |
checkpointing: | |
log_every: 1_000 | |
plot_every: 10_000 | |
save_every: 10_000 | |
resume_from_checkpoint: null | |
# Data | |
data: | |
train_split: 0.9 # 90% for training, 10% for testing | |
random_seed: 42 | |