File size: 1,222 Bytes
9851643 4556966 9851643 4556966 9851643 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 |
import os
import re
import torch
from modeling_jamba import JambaForCausalLM
output_dir = "/home/user/jamba-small"
model = JambaForCausalLM.from_pretrained("ai21labs/Jamba-v0.1", device_map="cpu", torch_dtype=torch.bfloat16)
def prune_and_copy_additional_layers(original_state_dict):
layer_mapping = {
0: 0,
1: 1,
2: 2,
3: 2,
4: 4,
5: 5,
6: 30,
7: 31
}
new_state_dict = {}
# Copy specified layers from the original state dict
for new_idx, orig_idx in layer_mapping.items():
prefix = f"model.layers.{orig_idx}"
for key, value in original_state_dict.items():
if key.startswith(prefix):
new_key = key.replace(f"layers.{orig_idx}", f"layers.{new_idx}")
new_state_dict[new_key] = value
global_keys = ['model.embed_tokens.weight', 'model.final_layernorm.weight', 'lm_head.weight']
for key in global_keys:
new_state_dict[key] = original_state_dict[key]
return new_state_dict
pruned_state_dict = prune_and_copy_additional_layers(model.state_dict())
print("Saving weights...")
torch.save(pruned_state_dict, output_dir)
print("Done!") |