File size: 1,222 Bytes
9851643
 
 
 
 
 
4556966
 
9851643
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4556966
9851643
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
import os
import re

import torch
from modeling_jamba import JambaForCausalLM

output_dir = "/home/user/jamba-small"

model = JambaForCausalLM.from_pretrained("ai21labs/Jamba-v0.1", device_map="cpu", torch_dtype=torch.bfloat16)

def prune_and_copy_additional_layers(original_state_dict):
    layer_mapping = {
        0: 0,
        1: 1,
        2: 2,
        3: 2,
        4: 4,
        5: 5,
        6: 30,
        7: 31
    }
    
    new_state_dict = {}
    
    # Copy specified layers from the original state dict
    for new_idx, orig_idx in layer_mapping.items():
        prefix = f"model.layers.{orig_idx}"
        for key, value in original_state_dict.items():
            if key.startswith(prefix):
                new_key = key.replace(f"layers.{orig_idx}", f"layers.{new_idx}")
                new_state_dict[new_key] = value
                
    global_keys = ['model.embed_tokens.weight', 'model.final_layernorm.weight', 'lm_head.weight']
    for key in global_keys:
        new_state_dict[key] = original_state_dict[key]

    return new_state_dict

pruned_state_dict = prune_and_copy_additional_layers(model.state_dict())

print("Saving weights...")
torch.save(pruned_state_dict, output_dir)
print("Done!")