import os import re import torch from modeling_jamba import JambaForCausalLM output_dir = "/home/user/jamba-small" model = JambaForCausalLM.from_pretrained("ai21labs/Jamba-v0.1", device_map="cpu", torch_dtype=torch.bfloat16) def prune_and_copy_additional_layers(original_state_dict): layer_mapping = { 0: 0, 1: 1, 2: 2, 3: 2, 4: 4, 5: 5, 6: 30, 7: 31 } new_state_dict = {} # Copy specified layers from the original state dict for new_idx, orig_idx in layer_mapping.items(): prefix = f"model.layers.{orig_idx}" for key, value in original_state_dict.items(): if key.startswith(prefix): new_key = key.replace(f"layers.{orig_idx}", f"layers.{new_idx}") new_state_dict[new_key] = value global_keys = ['model.embed_tokens.weight', 'model.final_layernorm.weight', 'lm_head.weight'] for key in global_keys: new_state_dict[key] = original_state_dict[key] return new_state_dict pruned_state_dict = prune_and_copy_additional_layers(model.state_dict()) print("Saving weights...") torch.save(pruned_state_dict, output_dir) print("Done!")