jamba-small-v1 / prune.py
OxxoCodes's picture
Nit
4556966 verified
raw
history blame
1.22 kB
import os
import re
import torch
from modeling_jamba import JambaForCausalLM
output_dir = "/home/user/jamba-small"
model = JambaForCausalLM.from_pretrained("ai21labs/Jamba-v0.1", device_map="cpu", torch_dtype=torch.bfloat16)
def prune_and_copy_additional_layers(original_state_dict):
layer_mapping = {
0: 0,
1: 1,
2: 2,
3: 2,
4: 4,
5: 5,
6: 30,
7: 31
}
new_state_dict = {}
# Copy specified layers from the original state dict
for new_idx, orig_idx in layer_mapping.items():
prefix = f"model.layers.{orig_idx}"
for key, value in original_state_dict.items():
if key.startswith(prefix):
new_key = key.replace(f"layers.{orig_idx}", f"layers.{new_idx}")
new_state_dict[new_key] = value
global_keys = ['model.embed_tokens.weight', 'model.final_layernorm.weight', 'lm_head.weight']
for key in global_keys:
new_state_dict[key] = original_state_dict[key]
return new_state_dict
pruned_state_dict = prune_and_copy_additional_layers(model.state_dict())
print("Saving weights...")
torch.save(pruned_state_dict, output_dir)
print("Done!")