distilmodernbert / README.md
andersonbcdefg's picture
Update README.md
32807cd verified

This is a version of ModernBERT-base distilled down to 16 layers out of 22. This reduces the number of parameters from 149M to 119M; however, practically speaking, since the embedding params do not contribute greatly to latency, the effect is reducing the "trunk" of the model from 110M params to 80M params. I would expect this to reduce latency by roughly 25% (increasing throughput by roughly 33%). The last 6 local attention layers were removed:

  1. Global
  2. Local
  3. Local
  4. Global
  5. Local
  6. Local
  7. Global
  8. Local
  9. Local
  10. Global
  11. Local
  12. Local
  13. Global
  14. Local (REMOVED)
  15. Local (REMOVED)
  16. Global
  17. Local (REMOVED)
  18. Local (REMOVED)
  19. Global
  20. Local (REMOVED)
  21. Local (REMOVED)
  22. Global

Unfortunately the HuggingFace modeling code for ModernBERT relies on global-local attention patterns being uniform throughout the model, so loading this bad boy properly takes a bit of model surgery. I hope in the future that the HuggingFace team will update this model configuration to allow custom striping of global+local layers. For now, here's how to do it:

  1. Download the checkpoint (model.pt) from this repository.
  2. Initialize ModernBERT-base:
import torch.nn as nn
from transformers import AutoTokenizer, AutoModelForMaskedLM

model_id = "answerdotai/ModernBERT-base"
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForMaskedLM.from_pretrained(model_id)
  1. Remove the layers:
layers_to_remove = [13, 14, 16, 17, 19, 20]
model.model.layers = nn.ModuleList([
  layer for idx, layer in enumerate(model.model.layers)
  if idx not in layers_to_remove
])
  1. Load the checkpoint state dict:
state_dict = torch.load("model.pt")
model.model.load_state_dict(state_dict)
  1. Use the model! Yay!

Training Information

This model was distilled from ModernBERT-base on the MiniPile dataset, which includes English and code data. Distillation used all 1M samples in this dataset for 1 epoch, MSE loss on the logits, batch size of 16, AdamW optimizer, and constant learning rate of 1.0e-5. The embeddings/LM head were frozen and shared between the teacher and student; only the transformer blocks were trained. I have not yet evaluated this model. However, after the initial model surgery, it failed to correctly complete "The capital of France is [MASK]", and after training, it correctly says "Paris", so something good happened!