flash attention 2
GPU A100 40GB (COLAB)
!python -c "import torch; assert torch.cuda.get_device_capability()[0] >= 8, 'Hardware not supported for Flash Attention'" # pass
!export CUDA_HOME=/usr/local/cuda-11.8
# !MAX_JOBS=4 pip install flash-attn --no-build-isolation
!MAX_JOBS=4 pip install flash-attn --no-build-isolation -qqq
!pip install git+"https://github.com/HazyResearch/flash-attention.git#subdirectory=csrc/rotary" -qqq
!python -m pip install optimum -qqq
import torch, transformers,torchvision
torch.__version__,transformers.__version__, torchvision.__version__ # ('2.0.1+cu118', '4.34.0.dev0', '0.15.2+cu118')
model_id = 'mistralai/Mistral-7B-Instruct-v0.1'
model = AutoModelForCausalLM.from_pretrained(model_id,
device_map="auto",
trust_remote_code=True,
torch_dtype=torch.bfloat16,
load_in_4bit=True,
quantization_config=quantization_config,
use_flash_attention_2=True,
low_cpu_mem_usage= True,
)
ERROR
```Python
ValueError Traceback (most recent call last)
in <cell line: 32>()
30 # from optimum.bettertransformer import BetterTransformer #flash attention 2
31
---> 32 model = AutoModelForCausalLM.from_pretrained(model_id,
33 device_map="auto",
34 trust_remote_code=True,
2 frames
/usr/local/lib/python3.10/dist-packages/transformers/modeling_utils.py in _check_and_enable_flash_attn_2(cls, config, torch_dtype, device_map)
1263 """
1264 if not cls._supports_flash_attn_2:
-> 1265 raise ValueError(
1266 "The current architecture does not support Flash Attention 2.0. Please open an issue on GitHub to "
1267 "request support for this architecture: https://github.com/huggingface/transformers/issues/new"
ValueError: The current architecture does not support Flash Attention 2.0. Please open an issue on GitHub to request support for this architecture: https://github.com/huggingface/transformers/issues/new
```
Other ERROR BetterTransformer -->> flash attention 2
!python -c "import torch; assert torch.cuda.get_device_capability()[0] >= 8, 'Hardware not supported for Flash Attention'"
!export CUDA_HOME=/usr/local/cuda-11.8
# !MAX_JOBS=4 pip install flash-attn --no-build-isolation
!MAX_JOBS=4 pip install flash-attn --no-build-isolation -qqq
!pip install git+"https://github.com/HazyResearch/flash-attention.git#subdirectory=csrc/rotary" -qqq
!python -m pip install optimum -qqq
from optimum.bettertransformer import BetterTransformer #flash attention 2
model = AutoModelForCausalLM.from_pretrained(model_id,
device_map="auto",
trust_remote_code=True,
torch_dtype=torch.bfloat16,
load_in_4bit=True,
quantization_config=quantization_config,
# use_flash_attention_2=True,
low_cpu_mem_usage= True,
)
# model = BetterTransformer.transform(model, keep_original_model=False) #flash attention 2
ERROR 2
```Python
NotImplementedError Traceback (most recent call last)
in <cell line: 43>()
41 )
42
---> 43 model = BetterTransformer.transform(model, keep_original_model=False) #flash attention 2
44
45
1 frames
/usr/local/lib/python3.10/dist-packages/optimum/bettertransformer/transformation.py in transform(model, keep_original_model, max_memory, offload_dir, **kwargs)
226 )
227 if not BetterTransformerManager.supports(model.config.model_type):
--> 228 raise NotImplementedError(
229 f"The model type {model.config.model_type} is not yet supported to be used with BetterTransformer. Feel free"
230 f" to open an issue at https://github.com/huggingface/optimum/issues if you would like this model type to be supported."
NotImplementedError: The model type mistral is not yet supported to be used with BetterTransformer. Feel free to open an issue at https://github.com/huggingface/optimum/issues
if you would like this model type to be supported. Currently supported models are:
dict_keys(['albert', 'bark', 'bart', 'bert', 'bert-generation',
'blenderbot', 'bloom', 'camembert', 'blip-2', 'clip', 'codegen', 'data2vec-text', 'deit',
'distilbert', 'electra', 'ernie', 'fsmt', 'falcon', 'gpt2', 'gpt_bigcode', 'gptj', 'gpt_neo',
'gpt_neox', 'hubert', 'layoutlm', 'llama', 'm2m_100', 'marian', 'markuplm', 'mbart',
'opt', 'pegasus', 'rembert', 'prophetnet', 'roberta', 'roc_bert', 'roformer', 'splinter',
'tapas', 't5', 'vilt', 'vit', 'vit_mae', 'vit_msn', 'wav2vec2', 'whisper', 'xlm-roberta', 'yolos']).
```
Getting the same error. In the supported models they say that Mistral is included but don't know why is it giving this error.
Hi @NickyNicky and @sgauravm
If you install the latest version of transformers
pip install -U transformers
Flash Attention-2 should be supported
Check out this specific section of the docs: https://huggingface.co/docs/transformers/model_doc/mistral#combining-mistral-and-flash-attention-2 for more details
Thank you very much,
I have the latest version of transformers.
@NickyNicky
Your first script
!python -c "import torch; assert torch.cuda.get_device_capability()[0] >= 8, 'Hardware not supported for Flash Attention'" # pass
!export CUDA_HOME=/usr/local/cuda-11.8
# !MAX_JOBS=4 pip install flash-attn --no-build-isolation
!MAX_JOBS=4 pip install flash-attn --no-build-isolation -qqq
!pip install git+"https://github.com/HazyResearch/flash-attention.git#subdirectory=csrc/rotary" -qqq
!python -m pip install optimum -qqq
import torch, transformers,torchvision
torch.__version__,transformers.__version__, torchvision.__version__ # ('2.0.1+cu118', '4.34.0.dev0', '0.15.2+cu118')
model_id = 'mistralai/Mistral-7B-Instruct-v0.1'
model = AutoModelForCausalLM.from_pretrained(model_id,
device_map="auto",
trust_remote_code=True,
torch_dtype=torch.bfloat16,
load_in_4bit=True,
quantization_config=quantization_config,
use_flash_attention_2=True,
low_cpu_mem_usage= True,
)
Should work if you have latest transformers installed, however Mistral is not in BetterTransformer
yet, we will add the support of F.SDPA natively in transformers core soon
With these versions it works.
import torch, transformers
torch.__version__,transformers.__version__
('2.0.1+cu118', '4.34.0')