|
from safetensors import safe_open |
|
from safetensors.torch import save_file |
|
import torch |
|
|
|
def rename_key(key): |
|
parts = key.split('.') |
|
if 'roberta' in parts: |
|
parts.remove('roberta') |
|
if 'parametrizations' in parts: |
|
parts.remove('parametrizations') |
|
if 'weight' in parts and 'original' in parts: |
|
parts.remove('original') |
|
|
|
if 'encoder.layers' in key: |
|
parts[parts.index('layers')] = 'layer' |
|
|
|
if 'mixer' in parts: |
|
parts[parts.index('mixer')] = 'attention' |
|
if 'Wqkv' in parts: |
|
parts[parts.index('Wqkv')] = 'qkv_proj' |
|
if 'out_proj' in parts: |
|
parts[parts.index('out_proj')] = 'o_proj' |
|
if 'mlp.fc1' in key: |
|
parts[parts.index('fc1')] = 'up_proj' |
|
if 'mlp.fc2' in key: |
|
parts[parts.index('fc2')] = 'down_proj' |
|
if 'emb_ln' in parts: |
|
parts[parts.index('emb_ln')] = 'LayerNorm' |
|
parts.insert(0, 'embeddings') |
|
if 'norm1' in parts: |
|
parts[parts.index('norm1')] = 'attn_ln' |
|
if 'norm2' in parts: |
|
parts[parts.index('norm2')] = 'mlp_ln' |
|
if 'weight' in parts: |
|
if parts[-2] in ['attn_ln', 'mlp_ln', 'LayerNorm']: |
|
parts[-1] = 'gamma' |
|
if 'bias' in parts: |
|
if parts[-2] in ['attn_ln', 'mlp_ln', 'LayerNorm']: |
|
parts[-1] = 'beta' |
|
|
|
return '.'.join(parts) |
|
|
|
input_file = "original_model.safetensors" |
|
output_file = "model.safetensors" |
|
|
|
new_tensors = {} |
|
|
|
with safe_open(input_file, framework="pt", device="cpu") as f: |
|
for key in f.keys(): |
|
if 'lora' not in key: |
|
new_key = rename_key(key) |
|
tensor = f.get_tensor(key) |
|
|
|
if 'mlp.up_proj' in new_key: |
|
|
|
new_tensors[new_key] = tensor |
|
gate_key = new_key.replace('up_proj', 'up_gate_proj') |
|
|
|
expanded_tensor = torch.cat([tensor] * 2, dim=0) |
|
new_tensors[gate_key] = expanded_tensor |
|
else: |
|
new_tensors[new_key] = tensor |
|
|
|
save_file(new_tensors, output_file) |
|
print(f"Renamed tensors saved to {output_file}") |
|
|
|
|
|
with safe_open(output_file, framework="pt", device="cpu") as f: |
|
print("\nRenamed tensors:") |
|
for key in f.keys(): |
|
print(f"{key}: {f.get_tensor(key).shape}") |
|
|