different results w/ and w/o flash-attn
Hi there,
I trained your model with two settings:
1. default configs
2. I used config._attn_implementation = "flash_attention_2"
to enable flash_attn_2.
Under setting 2, the training speed is doubled compared to setting 1, but the loss goes high and is extremely unstable.
What is the correct way of activating flash-attn? Thanks in advance!
Hi
@g-h-chen
@zhumj34
, could you please update to transformers main and try to retrain the model again?
There was a recent commit to phi
model on the library which fixed an issue of fp16 logits becoming NaNs.
Please update to transformers main before re running the training script -
git clone https://github.com/huggingface/transformers.git
cd transformers
pip install .
Please let me know if this update fixes the issue that you are facing.
Hi, just a heads up, because of the recent changes made to main branch of transformers
you can't load the model properly because of the weights mismatch(using this repo). I will most likely fix this today.
Hi @zhumj34 , could you please share the training script?
I will have a look at it but it is much easier when I have the training script available :) .
Hi @susnato , sorry for the late reply. My code is based on llava. I just simply replace llama with phi-2, and load llava-phi with flash_attention_2 (similar to this code, https://huggingface.co/docs/transformers/v4.36.1/en/model_doc/phi#combining-phi-and-flash-attention-2). The training script is identically to llava-v1.5.
Hi @g-h-chen @zhumj34 , could you please update to transformers main and try to retrain the model again?
There was a recent commit tophi
model on the library which fixed an issue of fp16 logits becoming NaNs.Please update to transformers main before re running the training script -
git clone https://github.com/huggingface/transformers.git cd transformers pip install .
Please let me know if this update fixes the issue that you are facing.
OK. About two weeks ago, I finetuned llava-phi with llava-v1.5 fine-tuning script. After that , at the inference stage, I load the model with fp16 and the output texts are '!!!!! ... !!!!!', which are caused by NaNs. At that time, I thought I'd made a mistake somewhere.
I'll try this update to verify if it solves the NaNs problem. Sorry for the later reply.
I notice that I pretrain and finetune llava-phi with bfp16. Is there anything wrong with this setup? I'll retrain the model with this version of transformers.
Hi
@zhumj34
, I apologise for the huge delay.
I have updated the checkpoint and it should work now. Please install the latest transformers version 4.38.0.dev0 by running this command -
pip uninstall -y transformers && pip install git+https://github.com/huggingface/transformers
After updating the library, you should be able to properly load the weights.
Please let me know if you are facing any issues with it.
Hi @g-h-chen , regarding the problem you are facing
Under setting 2, the training speed is doubled compared to setting 1, but the loss goes high and is extremely unstable.
There is an ongoing issue at transformers where people are reporting the same thing as you said above. I guess Gugarosa is working on fixing it.
In the meantime please try to fine-tune it without FA2 (as I have updated the checkpoint you can properly load the weights), sorry for the inconvenience.