OxxoCodes commited on
Commit
9851643
·
verified ·
1 Parent(s): f95bd00

Upload prune.py

Browse files

Upload pruning script used to create Jamba-Small-v1

Files changed (1) hide show
  1. prune.py +41 -0
prune.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import re
3
+
4
+ import torch
5
+ from modeling_jamba import JambaForCausalLM
6
+
7
+ model = JambaForCausalLM.from_pretrained("ai21labs/Jamba-v0.1", device_map="cpu", torch_dtype=torch.bfloat16)
8
+
9
+ def prune_and_copy_additional_layers(original_state_dict):
10
+ layer_mapping = {
11
+ 0: 0,
12
+ 1: 1,
13
+ 2: 2,
14
+ 3: 2,
15
+ 4: 4,
16
+ 5: 5,
17
+ 6: 30,
18
+ 7: 31
19
+ }
20
+
21
+ new_state_dict = {}
22
+
23
+ # Copy specified layers from the original state dict
24
+ for new_idx, orig_idx in layer_mapping.items():
25
+ prefix = f"model.layers.{orig_idx}"
26
+ for key, value in original_state_dict.items():
27
+ if key.startswith(prefix):
28
+ new_key = key.replace(f"layers.{orig_idx}", f"layers.{new_idx}")
29
+ new_state_dict[new_key] = value
30
+
31
+ global_keys = ['model.embed_tokens.weight', 'model.final_layernorm.weight', 'lm_head.weight']
32
+ for key in global_keys:
33
+ new_state_dict[key] = original_state_dict[key]
34
+
35
+ return new_state_dict
36
+
37
+ pruned_state_dict = prune_and_copy_additional_layers(model.state_dict())
38
+
39
+ print("Saving weights...")
40
+ torch.save(pruned_state_dict, '/scratch/nbrown9/jamba-small-v2.bin')
41
+ print("Done!")