How to run this model

#1
by AgeOfAlgorithms - opened

I tried running it the same way I would run xcodec2:

from xcodec2.modeling_xcodec2 import XCodec2Model
codec_model_path = "annuvin/xcodec2-bf16"
Codec_model = XCodec2Model.from_pretrained(codec_model_path)

but this gives the following error:

[rank0]: Traceback (most recent call last):
[rank0]:   File "/home/x/Llasa/ref_vllm.py", line 18, in <module>
[rank0]:     Codec_model = XCodec2Model.from_pretrained(codec_model_path)
[rank0]:   File "/home/x/miniconda3/envs/xcodec2-2/lib/python3.9/site-packages/transformers/modeling_utils.py", line 3792, in from_pretrained
[rank0]:     if metadata.get("format") == "pt":
[rank0]: AttributeError: 'NoneType' object has no attribute 'get'

It seems metatdata is missing. How do you run the model?

I'm able to load it this exact way without any errors. Try updating your xcodec2 package? Alternatively, you can also cast the fp32 model on the fly with .to(torch.bfloat16). The result should be the same.

I updated xcodec2, but none of these methods work for me. I tried the above code on both this model and srinivasbilla/xcodec2, but it seems both of these repos are missing "metadata", which results in an error. To be clear, my code works well with HKUSTAudio/xcodec2.

As for casting to bfloat16, I tried it on the codec model

Codec_model = Codec_model.to(torch.bfloat16)

but the program keeps encountering errors where it tries to multiply a scalar type float matrix with bfloat16. It happens when you step into this line:

vq_code_prompt = Codec_model.encode_code(input_waveform=prompt_wav)

I tried casting everything I can see with bfloat16, but could not get it to work.

My bad, seems like it's an issue with older transformers, not xcodec. According to this it was fixed in 4.48.0. I'll try and upload a model with metadata later though, just in case.

As for encoding with bf16 you might need to cast the tensor to bf16 as well or use autocast, something like this:

with torch.autocast("cuda", torch.bfloat16), torch.inference_mode():
    vq_code_prompt = Codec_model.encode_code(prompt_wav)

Uploaded a new one, should work now.

Annuvin changed discussion status to closed

Sign up or log in to comment