RuntimeError: FlashAttention is not installed.
Hi, can you tell me how to disable flash_attn?
model = SentenceTransformer("jinaai/jina-embeddings-v3",
device = device, trust_remote_code=True, model_kwargs={'default_task': 'text-matching' })
................
trainer = SentenceTransformerTrainer(
model=model,
args=args,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
loss=train_loss,
evaluator=dev_evaluator,
)
trainer.train()
RuntimeError: FlashAttention is not installed. To proceed with training, please install FlashAttention. For inference, you have two options: either install FlashAttention or disable it by setting use_flash_attn=False when loading the model.
Sentence Transformers v3.2
Hi @seregadgl , you need to have flash attention installed if you want to train the model, you can only disable it during inference
Thanks for the answer, maybe you can tell me what version of flash attention to install so that I can fine-tune the model in Google Colab on the T4 video card. Thanks!
Seems like you also need to install other dependencies (i.e. triton).
If you see rotary.py file, you could find that the RuntimeError: FlashAttention is not installed
exception is raised if you failed to run from flash_attn.ops.triton.rotary import apply_rotary
.
This line requires both flash attention and triton.
So, I guess you should also install the triton by running pip install triton
@seregadgl you can install any recent version, the last one (2.6.3) should work fine
@BlackBeenie you're right, it requires triton as well, however triton should be automatically installed as you install torch if cuda is enabled
@jupyterjazz
Seems like triton is not installed automatically in Google Colab. Cos, I also faced similar error, and running the pip install triton
actually fixes the issue.
@BlackBeenie , makes sense. This happens because Colab comes with pre-installed torch. If you uninstall it and reinstall it while connected to a GPU runtime, triton should be installed as well
I am using windows 11 and successfully installed flash-attn
show in the following pic. But still get this RuntimeError: FlashAttention is not installed
error. So it does not support Windows if I want to use flash-attention?
Seems like you also need to install other dependencies (i.e. triton).
If you see rotary.py file, you could find that theRuntimeError: FlashAttention is not installed
exception is raised if you failed to runfrom flash_attn.ops.triton.rotary import apply_rotary
.
This line requires both flash attention and triton.
So, I guess you should also install the triton by runningpip install triton
"name": "RuntimeError",
"message": "FlashAttention is not installed. To proceed with training, please install FlashAttention. For inference, you have two options: either install FlashAttention or disable it by setting use_flash_attn=False when loading the model."
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
Cell In[5], line 2
1 print(len(chunks))
----> 2 chunks_embeddings = embedder.encode(chunks, convert_to_tensor=True, batch_size=1)
3 # Find the closest 5 sentences of the corpus for each query sentence based on cosine similarity
4 top_k = min(3, len(chunks))
File d:\\conda\\win\\envs\\prod\\Lib\\site-packages\\sentence_transformers\\SentenceTransformer.py:623, in SentenceTransformer.encode(self, sentences, prompt_name, prompt, batch_size, show_progress_bar, output_value, precision, convert_to_numpy, convert_to_tensor, device, normalize_embeddings, **kwargs)
620 features.update(extra_features)
622 with torch.no_grad():
--> 623 out_features = self.forward(features, **kwargs)
624 if self.device.type == \"hpu\":
625 out_features = copy.deepcopy(out_features)
File d:\\conda\\win\\envs\\prod\\Lib\\site-packages\\sentence_transformers\\SentenceTransformer.py:690, in SentenceTransformer.forward(self, input, **kwargs)
688 module_kwarg_keys = self.module_kwargs.get(module_name, [])
689 module_kwargs = {key: value for key, value in kwargs.items() if key in module_kwarg_keys}
--> 690 input = module(input, **module_kwargs)
691 return input
File d:\\conda\\win\\envs\\prod\\Lib\\site-packages\\torch\
n\\modules\\module.py:1736, in Module._wrapped_call_impl(self, *args, **kwargs)
1734 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1735 else:
-> 1736 return self._call_impl(*args, **kwargs)
File d:\\conda\\win\\envs\\prod\\Lib\\site-packages\\torch\
n\\modules\\module.py:1747, in Module._call_impl(self, *args, **kwargs)
1742 # If we don't have any hooks, we want to skip the rest of the logic in
1743 # this function, and just call forward.
1744 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1745 or _global_backward_pre_hooks or _global_backward_hooks
1746 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1747 return forward_call(*args, **kwargs)
1749 result = None
1750 called_always_called_hooks = set()
File ~\\.cache\\huggingface\\modules\\transformers_modules\\jinaai\\jina-embeddings-v3\\30996fea06f69ecd8382ee4f11e29acaf6b5405e\\custom_st.py:143, in Transformer.forward(self, features, task)
139 lora_arguments = (
140 {\"adapter_mask\": adapter_mask} if adapter_mask is not None else {}
141 )
142 features.pop('prompt_length', None)
--> 143 output_states = self.auto_model.forward(**features, **lora_arguments, return_dict=False)
144 output_tokens = output_states[0]
145 features.update({\"token_embeddings\": output_tokens, \"attention_mask\": features[\"attention_mask\"]})
File ~\\.cache\\huggingface\\modules\\transformers_modules\\jinaai\\xlm-roberta-flash-implementation\\9dc60336f6b2df56c4f094dd287ca49fb7b93342\\modeling_lora.py:370, in XLMRobertaLoRA.forward(self, *args, **kwargs)
369 def forward(self, *args, **kwargs):
--> 370 return self.roberta(*args, **kwargs)
File d:\\conda\\win\\envs\\prod\\Lib\\site-packages\\torch\
n\\modules\\module.py:1736, in Module._wrapped_call_impl(self, *args, **kwargs)
1734 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1735 else:
-> 1736 return self._call_impl(*args, **kwargs)
File d:\\conda\\win\\envs\\prod\\Lib\\site-packages\\torch\
n\\modules\\module.py:1747, in Module._call_impl(self, *args, **kwargs)
1742 # If we don't have any hooks, we want to skip the rest of the logic in
1743 # this function, and just call forward.
1744 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1745 or _global_backward_pre_hooks or _global_backward_hooks
1746 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1747 return forward_call(*args, **kwargs)
1749 result = None
1750 called_always_called_hooks = set()
File ~\\.cache\\huggingface\\modules\\transformers_modules\\jinaai\\xlm-roberta-flash-implementation\\9dc60336f6b2df56c4f094dd287ca49fb7b93342\\modeling_xlm_roberta.py:709, in XLMRobertaModel.forward(self, input_ids, position_ids, token_type_ids, attention_mask, masked_tokens_mask, return_dict, **kwargs)
706 else:
707 subset_mask = None
--> 709 sequence_output = self.encoder(
710 hidden_states,
711 key_padding_mask=attention_mask,
712 subset_mask=subset_mask,
713 adapter_mask=adapter_mask,
714 )
716 if masked_tokens_mask is None:
717 pooled_output = (
718 self.pooler(sequence_output, adapter_mask=adapter_mask)
719 if self.pooler is not None
720 else None
721 )
File d:\\conda\\win\\envs\\prod\\Lib\\site-packages\\torch\
n\\modules\\module.py:1736, in Module._wrapped_call_impl(self, *args, **kwargs)
1734 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1735 else:
-> 1736 return self._call_impl(*args, **kwargs)
File d:\\conda\\win\\envs\\prod\\Lib\\site-packages\\torch\
n\\modules\\module.py:1747, in Module._call_impl(self, *args, **kwargs)
1742 # If we don't have any hooks, we want to skip the rest of the logic in
1743 # this function, and just call forward.
1744 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1745 or _global_backward_pre_hooks or _global_backward_hooks
1746 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1747 return forward_call(*args, **kwargs)
1749 result = None
1750 called_always_called_hooks = set()
File ~\\.cache\\huggingface\\modules\\transformers_modules\\jinaai\\xlm-roberta-flash-implementation\\9dc60336f6b2df56c4f094dd287ca49fb7b93342\\modeling_xlm_roberta.py:241, in XLMRobertaEncoder.forward(self, hidden_states, key_padding_mask, subset_mask, adapter_mask)
234 hidden_states = torch.utils.checkpoint.checkpoint(
235 layer,
236 hidden_states,
237 use_reentrant=self.use_reentrant,
238 mixer_kwargs=mixer_kwargs,
239 )
240 else:
--> 241 hidden_states = layer(hidden_states, mixer_kwargs=mixer_kwargs)
242 hidden_states = pad_input(hidden_states, indices, batch, seqlen)
243 else:
File d:\\conda\\win\\envs\\prod\\Lib\\site-packages\\torch\
n\\modules\\module.py:1736, in Module._wrapped_call_impl(self, *args, **kwargs)
1734 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1735 else:
-> 1736 return self._call_impl(*args, **kwargs)
File d:\\conda\\win\\envs\\prod\\Lib\\site-packages\\torch\
n\\modules\\module.py:1747, in Module._call_impl(self, *args, **kwargs)
1742 # If we don't have any hooks, we want to skip the rest of the logic in
1743 # this function, and just call forward.
1744 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1745 or _global_backward_pre_hooks or _global_backward_hooks
1746 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1747 return forward_call(*args, **kwargs)
1749 result = None
1750 called_always_called_hooks = set()
File ~\\.cache\\huggingface\\modules\\transformers_modules\\jinaai\\xlm-roberta-flash-implementation\\9dc60336f6b2df56c4f094dd287ca49fb7b93342\\block.py:201, in Block.forward(self, hidden_states, residual, mixer_subset, mixer_kwargs)
199 else:
200 assert residual is None
--> 201 mixer_out = self.mixer(
202 hidden_states, **(mixer_kwargs if mixer_kwargs is not None else {})
203 )
204 if self.return_residual: # mixer out is actually a pair here
205 mixer_out, hidden_states = mixer_out
File d:\\conda\\win\\envs\\prod\\Lib\\site-packages\\torch\
n\\modules\\module.py:1736, in Module._wrapped_call_impl(self, *args, **kwargs)
1734 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1735 else:
-> 1736 return self._call_impl(*args, **kwargs)
File d:\\conda\\win\\envs\\prod\\Lib\\site-packages\\torch\
n\\modules\\module.py:1747, in Module._call_impl(self, *args, **kwargs)
1742 # If we don't have any hooks, we want to skip the rest of the logic in
1743 # this function, and just call forward.
1744 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1745 or _global_backward_pre_hooks or _global_backward_hooks
1746 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1747 return forward_call(*args, **kwargs)
1749 result = None
1750 called_always_called_hooks = set()
File ~\\.cache\\huggingface\\modules\\transformers_modules\\jinaai\\xlm-roberta-flash-implementation\\9dc60336f6b2df56c4f094dd287ca49fb7b93342\\mha.py:732, in MHA.forward(self, x, x_kv, key_padding_mask, cu_seqlens, max_seqlen, mixer_subset, inference_params, adapter_mask, **kwargs)
725 if (
726 inference_params is None
727 or inference_params.seqlen_offset == 0
728 or (self.rotary_emb_dim == 0 or self.rotary_emb_dim % 16 != 0)
729 or not self.use_flash_attn
730 ):
731 if self.rotary_emb_dim > 0:
--> 732 qkv = self.rotary_emb(
733 qkv,
734 seqlen_offset=seqlen_offset,
735 cu_seqlens=cu_seqlens,
736 max_seqlen=rotary_max_seqlen,
737 )
738 if inference_params is None:
739 if not self.checkpointing:
File d:\\conda\\win\\envs\\prod\\Lib\\site-packages\\torch\
n\\modules\\module.py:1736, in Module._wrapped_call_impl(self, *args, **kwargs)
1734 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1735 else:
-> 1736 return self._call_impl(*args, **kwargs)
File d:\\conda\\win\\envs\\prod\\Lib\\site-packages\\torch\
n\\modules\\module.py:1747, in Module._call_impl(self, *args, **kwargs)
1742 # If we don't have any hooks, we want to skip the rest of the logic in
1743 # this function, and just call forward.
1744 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1745 or _global_backward_pre_hooks or _global_backward_hooks
1746 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1747 return forward_call(*args, **kwargs)
1749 result = None
1750 called_always_called_hooks = set()
File ~\\.cache\\huggingface\\modules\\transformers_modules\\jinaai\\xlm-roberta-flash-implementation\\9dc60336f6b2df56c4f094dd287ca49fb7b93342\\rotary.py:604, in RotaryEmbedding.forward(self, qkv, kv, seqlen_offset, cu_seqlens, max_seqlen)
602 if kv is None:
603 if self.scale is None:
--> 604 return apply_rotary_emb_qkv_(
605 qkv,
606 self._cos_cached,
607 self._sin_cached,
608 interleaved=self.interleaved,
609 seqlen_offsets=seqlen_offset,
610 cu_seqlens=cu_seqlens,
611 max_seqlen=max_seqlen,
612 use_flash_attn=self.use_flash_attn,
613 )
614 else:
615 return apply_rotary_emb_qkv_(
616 qkv,
617 self._cos_cached,
(...)
625 use_flash_attn=self.use_flash_attn,
626 )
File ~\\.cache\\huggingface\\modules\\transformers_modules\\jinaai\\xlm-roberta-flash-implementation\\9dc60336f6b2df56c4f094dd287ca49fb7b93342\\rotary.py:327, in apply_rotary_emb_qkv_(qkv, cos, sin, cos_k, sin_k, interleaved, seqlen_offsets, cu_seqlens, max_seqlen, use_flash_attn)
297 def apply_rotary_emb_qkv_(
298 qkv,
299 cos,
(...)
307 use_flash_attn=True,
308 ):
309 \"\"\"
310 Arguments:
311 qkv: (batch_size, seqlen, 3, nheads, headdim) if cu_seqlens is None
(...)
325 Apply rotary embedding *inplace* to the first rotary_dim of Q and K.
326 \"\"\"
--> 327 return ApplyRotaryEmbQKV_.apply(
328 qkv, cos, sin, cos_k, sin_k, interleaved, seqlen_offsets, cu_seqlens, max_seqlen, use_flash_attn,
329 )
File d:\\conda\\win\\envs\\prod\\Lib\\site-packages\\torch\\autograd\\function.py:575, in Function.apply(cls, *args, **kwargs)
572 if not torch._C._are_functorch_transforms_active():
573 # See NOTE: [functorch vjp and autograd interaction]
574 args = _functorch.utils.unwrap_dead_wrappers(args)
--> 575 return super().apply(*args, **kwargs) # type: ignore[misc]
577 if not is_setup_ctx_defined:
578 raise RuntimeError(
579 \"In order to use an autograd.Function with functorch transforms \"
580 \"(vmap, grad, jvp, jacrev, ...), it must override the setup_context \"
581 \"staticmethod. For more details, please see \"
582 \"https://pytorch.org/docs/main/notes/extending.func.html\"
583 )
File ~\\.cache\\huggingface\\modules\\transformers_modules\\jinaai\\xlm-roberta-flash-implementation\\9dc60336f6b2df56c4f094dd287ca49fb7b93342\\rotary.py:186, in ApplyRotaryEmbQKV_.forward(ctx, qkv, cos, sin, cos_k, sin_k, interleaved, seqlen_offsets, cu_seqlens, max_seqlen, use_flash_attn)
184 qk = rearrange(qkv[..., :2, :, :], \"... t h d -> ... (t h) d\")
185 # qk = qkv[:, :, :2].reshape(batch, seqlen, -1, headdim)
--> 186 apply_rotary(
187 qk,
188 cos,
189 sin,
190 seqlen_offsets=seqlen_offsets,
191 interleaved=interleaved,
192 inplace=True,
193 cu_seqlens=cu_seqlens,
194 max_seqlen=max_seqlen,
195 )
196 else:
197 q_rot = apply_rotary_emb_torch(
198 qkv[:, :, 0],
199 cos,
200 sin,
201 interleaved=interleaved,
202 )
File ~\\.cache\\huggingface\\modules\\transformers_modules\\jinaai\\xlm-roberta-flash-implementation\\9dc60336f6b2df56c4f094dd287ca49fb7b93342\\rotary.py:18, in apply_rotary(*args, **kwargs)
17 def apply_rotary(*args, **kwargs):
---> 18 raise RuntimeError(
19 \"FlashAttention is not installed. To proceed with training, please install FlashAttention. \"
20 \"For inference, you have two options: either install FlashAttention or disable it by setting use_flash_attn=False when loading the model.\"
21 )
RuntimeError: FlashAttention is not installed. To proceed with training, please install FlashAttention. For inference, you have two options: either install FlashAttention or disable it by setting use_flash_attn=False when loading the model.
I also had this problem. Thanks to
@BlackBeenie
I found out that I was missing the triton
package. It doesn't support Python 3.13 yet.
https://pypi.org/project/triton/
https://huggingface.co/jinaai/jina-embeddings-v3/discussions/47#6714c101d0aceb08357afc2a
So I switched to Python 3.12, but ended up getting another error:
ValueError: Pointer argument (at 0) cannot be accessed from Triton (cpu tensor?)