R1ckShi's picture
init push
5b4c852 verified
raw
history blame
6.77 kB
import sys
import torch
def convert_llm(state_dict):
# 调整了lm的结构,把codec_lm.encoder作为llm,codec_lm.decoder作为decoder
keys = list(state_dict.keys())
for k in keys:
if k.startswith('codec_lm.encoder.'):
v = state_dict.pop(k)
k = k.replace('codec_lm.encoder.', 'llm.')
state_dict[k] = v
if k.startswith('codec_lm.decoder.'):
v = state_dict.pop(k)
k = k.replace('codec_lm.decoder.', 'llm_decoder.')
state_dict[k] = v
# espnet和wenet具体实现上的差异
keys = list(state_dict.keys())
for k in keys:
if k.startswith('text_encoder.embed.'):
v = state_dict.pop(k)
k = k.replace('text_encoder.embed.', 'text_encoder.embed.out.')
state_dict[k] = v
if k.startswith('llm.embed.'):
v = state_dict.pop(k)
k = k.replace('llm.embed.', 'llm.embed.out.')
state_dict[k] = v
keys = list(state_dict.keys())
for k in keys:
if k.startswith('text_enc_out_layer.'):
v = state_dict.pop(k)
k = k.replace('text_enc_out_layer.', 'text_encoder_affine_layer.')
state_dict[k] = v
if k.startswith('token_embedding.'):
v = state_dict.pop(k)
k = k.replace('token_embedding.', 'text_embedding.')
state_dict[k] = v
if k.startswith('xvec_proj.'):
v = state_dict.pop(k)
k = k.replace('xvec_proj.', 'spk_embed_affine_layer.')
state_dict[k] = v
if k.startswith('lm_embedding.'):
v = state_dict.pop(k)
k = k.replace('lm_embedding.', 'llm_embedding.')
state_dict[k] = v
if k.startswith('codec_embedder.'):
v = state_dict.pop(k)
k = k.replace('codec_embedder.', 'speech_embedding.')
state_dict[k] = v
# instruct少了spk embedding参数,加个全0上去
keys = list(state_dict.keys())
if 'spk_embed_affine_layer.weight' not in keys:
print('no spk_embed_affine_layer.weight, should be instruct model')
state_dict['spk_embed_affine_layer.weight'] = torch.zeros(1024, 192)
if 'spk_embed_affine_layer.bias' not in keys:
print('no spk_embed_affine_layer.bias, should be instruct model')
state_dict['spk_embed_affine_layer.bias'] = torch.zeros(1024)
return state_dict
def convert_hift(state_dict):
# 调整了cosyvoice中hifigan的结构,把f0_predictor放到generator里
keys = list(state_dict.keys())
for k in keys:
if k.startswith('decoder.'):
v = state_dict.pop(k)
k = k.replace('decoder.', '')
state_dict[k] = v
if k.startswith('generator.'):
v = state_dict.pop(k)
k = k.replace('generator.', '')
state_dict[k] = v
return state_dict
def convert_flow(state_dict):
keys = list(state_dict.keys())
for k in keys:
if k.startswith('encoder.embed.'):
v = state_dict.pop(k)
k = k.replace('encoder.embed.', 'encoder.embed.out.')
state_dict[k] = v
for k in keys:
if k.startswith('xvec_proj.'):
v = state_dict.pop(k)
k = k.replace('xvec_proj.', 'spk_embed_affine_layer.')
state_dict[k] = v
return state_dict
def convert_llm2(state_dict):
# 调整了lm的结构,把codec_lm.encoder作为llm,codec_lm.decoder作为decoder
keys = list(state_dict.keys())
for k in keys:
if k.startswith('codec_lm.encoder.'):
v = state_dict.pop(k)
k = k.replace('codec_lm.encoder.', 'llm.')
state_dict[k] = v
if k.startswith('codec_lm.decoder.'):
v = state_dict.pop(k)
k = k.replace('codec_lm.decoder.', 'llm_decoder.')
state_dict[k] = v
if k.startswith('lm_embedding.'):
v = state_dict.pop(k)
k = k.replace('lm_embedding.', 'llm_embedding.')
state_dict[k] = v
if k.startswith('codec_embedder.'):
v = state_dict.pop(k)
k = k.replace('codec_embedder.', 'speech_embedding.')
state_dict[k] = v
if k.startswith('text_enc_out_layer.'):
state_dict.pop(k)
if k.startswith('token_embedding.weight'):
state_dict.pop(k)
return state_dict
def convert_flow2(state_dict):
keys = list(state_dict.keys())
for k in keys:
if k.startswith('encoder.embed.'):
v = state_dict.pop(k)
k = k.replace('encoder.embed.', 'encoder.embed.out.')
state_dict[k] = v
for k in keys:
if k.startswith('xvec_proj.'):
v = state_dict.pop(k)
k = k.replace('xvec_proj.', 'spk_embed_affine_layer.')
state_dict[k] = v
for k in keys:
if k.startswith('mel_extractor.'):
state_dict.pop(k)
for k in keys:
if k.startswith('encoder.upsample_blocks.0.0.'):
v = state_dict.pop(k)
k = k.replace('encoder.upsample_blocks.0.0.', 'encoder.up_layer.')
state_dict[k] = v
if k.startswith('encoder.upsample_blocks.0.1.'):
v = state_dict.pop(k)
k = k.replace('encoder.upsample_blocks.0.1.', 'encoder.up_embed.out.')
state_dict[k] = v
if k.startswith('encoder.upsample_blocks.0.2.'):
v = state_dict.pop(k)
k = k.replace('encoder.upsample_blocks.0.2.', 'encoder.up_encoders.')
state_dict[k] = v
# CausalBlock1D中sequantial 1->2
if k.startswith('decoder.estimator.') and k.endswith('block.1.weight'):
v = state_dict.pop(k)
k = k.replace('block.1.weight', 'block.2.weight')
state_dict[k] = v
if k.startswith('decoder.estimator.') and k.endswith('block.1.bias'):
v = state_dict.pop(k)
k = k.replace('block.1.bias', 'block.2.bias')
state_dict[k] = v
return state_dict
if __name__ == '__main__':
# 使用方法 python3 convert.py 原格式llm.pt llm normalize 新格式llm.pt
# 或者 python3 convert.py 新格式llm.pt llm inverse_normalize 原格式llm.pt
state_dict = torch.load(sys.argv[1], map_location='cpu')
if sys.argv[2] == 'llm':
state_dict = convert_llm(state_dict)
elif sys.argv[2] == 'flow':
state_dict = convert_flow(state_dict)
elif sys.argv[2] == 'hift':
state_dict = convert_hift(state_dict)
elif sys.argv[2] == 'llm2':
state_dict = convert_llm2(state_dict)
elif sys.argv[2] == 'flow2':
state_dict = convert_flow2(state_dict)
else:
raise ValueError
torch.save(state_dict, sys.argv[4])