Update flash_attention.py
Browse files- flash_attention.py +1 -1
flash_attention.py
CHANGED
@@ -3,7 +3,7 @@ import torch.nn as nn
|
|
3 |
from einops import rearrange
|
4 |
|
5 |
|
6 |
-
from
|
7 |
from triton_bert_padding import pad_input, unpad_input
|
8 |
|
9 |
|
|
|
3 |
from einops import rearrange
|
4 |
|
5 |
|
6 |
+
from triton_flash_atn import _attention
|
7 |
from triton_bert_padding import pad_input, unpad_input
|
8 |
|
9 |
|