esunn0412
tei support
9085c0d
raw
history blame
2.39 kB
from safetensors import safe_open
from safetensors.torch import save_file
import torch
def rename_key(key):
parts = key.split('.')
if 'roberta' in parts:
parts.remove('roberta')
if 'parametrizations' in parts:
parts.remove('parametrizations')
if 'weight' in parts and 'original' in parts:
parts.remove('original')
if 'encoder.layers' in key:
parts[parts.index('layers')] = 'layer'
if 'mixer' in parts:
parts[parts.index('mixer')] = 'attention'
if 'Wqkv' in parts:
parts[parts.index('Wqkv')] = 'qkv_proj'
if 'out_proj' in parts:
parts[parts.index('out_proj')] = 'o_proj'
if 'mlp.fc1' in key:
parts[parts.index('fc1')] = 'up_proj'
if 'mlp.fc2' in key:
parts[parts.index('fc2')] = 'down_proj'
if 'emb_ln' in parts:
parts[parts.index('emb_ln')] = 'LayerNorm'
parts.insert(0, 'embeddings')
if 'norm1' in parts:
parts[parts.index('norm1')] = 'attn_ln'
if 'norm2' in parts:
parts[parts.index('norm2')] = 'mlp_ln'
if 'weight' in parts:
if parts[-2] in ['attn_ln', 'mlp_ln', 'LayerNorm']:
parts[-1] = 'gamma'
if 'bias' in parts:
if parts[-2] in ['attn_ln', 'mlp_ln', 'LayerNorm']:
parts[-1] = 'beta'
return '.'.join(parts)
input_file = "original_model.safetensors"
output_file = "model.safetensors"
new_tensors = {}
with safe_open(input_file, framework="pt", device="cpu") as f:
for key in f.keys():
if 'lora' not in key:
new_key = rename_key(key)
tensor = f.get_tensor(key)
if 'mlp.up_proj' in new_key:
# Create up_proj and up_gate_proj
new_tensors[new_key] = tensor
gate_key = new_key.replace('up_proj', 'up_gate_proj')
# Expand the tensor to match the expected shape
expanded_tensor = torch.cat([tensor] * 2, dim=0)
new_tensors[gate_key] = expanded_tensor
else:
new_tensors[new_key] = tensor
save_file(new_tensors, output_file)
print(f"Renamed tensors saved to {output_file}")
# Inspect the renamed tensors
with safe_open(output_file, framework="pt", device="cpu") as f:
print("\nRenamed tensors:")
for key in f.keys():
print(f"{key}: {f.get_tensor(key).shape}")