fix sdp attention to use the flash/mem-efficient context manaager
Browse files
src/axolotl/monkeypatch/llama_attn_hijack_xformers.py
CHANGED
@@ -184,14 +184,15 @@ def sdp_attention_forward(
|
|
184 |
|
185 |
# We only apply sdp attention if we don't need to output the whole attention matrix
|
186 |
if not output_attentions:
|
187 |
-
|
188 |
-
|
189 |
-
|
190 |
-
|
191 |
-
|
192 |
-
|
193 |
-
|
194 |
-
|
|
|
195 |
else:
|
196 |
attn_weights = torch.matmul(
|
197 |
query_states, key_states.transpose(2, 3)
|
|
|
184 |
|
185 |
# We only apply sdp attention if we don't need to output the whole attention matrix
|
186 |
if not output_attentions:
|
187 |
+
with torch.backends.cuda.sdp_kernel():
|
188 |
+
attn_output = torch.nn.functional.scaled_dot_product_attention(
|
189 |
+
query_states,
|
190 |
+
key_states,
|
191 |
+
value_states,
|
192 |
+
attn_mask=attention_mask,
|
193 |
+
is_causal=False,
|
194 |
+
)
|
195 |
+
attn_weights = None
|
196 |
else:
|
197 |
attn_weights = torch.matmul(
|
198 |
query_states, key_states.transpose(2, 3)
|