Crystalcareai
commited on
Update modeling_quiet.py
Browse files- modeling_quiet.py +14 -56
modeling_quiet.py
CHANGED
@@ -607,6 +607,19 @@ class QuietFlashAttention2(QuietAttention):
|
|
607 |
else:
|
608 |
# TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in LlamaFlashAttention2 __init__.
|
609 |
causal = self.is_causal and query_length != 1
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
610 |
|
611 |
# Contains at least one padding token in the sequence
|
612 |
if attention_mask is not None:
|
@@ -667,63 +680,8 @@ class QuietFlashAttention2(QuietAttention):
|
|
667 |
causal=causal,
|
668 |
window_size=(self.config.sliding_window, self.config.sliding_window),
|
669 |
)
|
670 |
-
try:
|
671 |
-
attn_output_unpad = flash_attn_varlen_func(
|
672 |
-
query_states,
|
673 |
-
key_states,
|
674 |
-
value_states,
|
675 |
-
cu_seqlens_q=cu_seqlens_q,
|
676 |
-
cu_seqlens_k=cu_seqlens_k,
|
677 |
-
max_seqlen_q=max_seqlen_in_batch_q,
|
678 |
-
max_seqlen_k=max_seqlen_in_batch_k,
|
679 |
-
dropout_p=dropout,
|
680 |
-
softmax_scale=softmax_scale,
|
681 |
-
causal=causal,
|
682 |
-
)
|
683 |
-
except RuntimeError as e:
|
684 |
-
if "cu_seqlens_q must have shape (batch_size + 1)" in str(e):
|
685 |
-
# Handle the case when cu_seqlens_q has an invalid shape
|
686 |
-
if attention_mask is not None:
|
687 |
-
# Ensure attention_mask has the correct shape
|
688 |
-
if attention_mask.dim() == 2:
|
689 |
-
# Convert 2D attention mask to 4D
|
690 |
-
attention_mask = _prepare_4d_causal_attention_mask(
|
691 |
-
attention_mask,
|
692 |
-
(query_states.size(0), query_states.size(1)),
|
693 |
-
query_states,
|
694 |
-
past_key_values_length=0,
|
695 |
-
sliding_window=0,
|
696 |
-
)
|
697 |
-
elif attention_mask.dim() != 4:
|
698 |
-
raise ValueError(
|
699 |
-
f"Invalid attention mask dimension: {attention_mask.dim()}. Expected 2D or 4D mask."
|
700 |
-
)
|
701 |
-
|
702 |
-
# Update cu_seqlens_q based on the attention mask
|
703 |
-
cu_seqlens_q = attention_mask.sum(dim=-1).flatten().cumsum(dim=0).to(torch.int32)
|
704 |
-
max_seqlen_in_batch_q = cu_seqlens_q[-1].item()
|
705 |
-
|
706 |
-
# Retry flash_attn_varlen_func with updated cu_seqlens_q
|
707 |
-
attn_output_unpad = flash_attn_varlen_func(
|
708 |
-
query_states,
|
709 |
-
key_states,
|
710 |
-
value_states,
|
711 |
-
cu_seqlens_q=cu_seqlens_q,
|
712 |
-
cu_seqlens_k=cu_seqlens_k,
|
713 |
-
max_seqlen_q=max_seqlen_in_batch_q,
|
714 |
-
max_seqlen_k=max_seqlen_in_batch_k,
|
715 |
-
dropout_p=dropout,
|
716 |
-
softmax_scale=softmax_scale,
|
717 |
-
causal=causal,
|
718 |
-
)
|
719 |
-
else:
|
720 |
-
raise ValueError(
|
721 |
-
"Attention mask is required for flash-attn when cu_seqlens_q has an invalid shape."
|
722 |
-
)
|
723 |
-
else:
|
724 |
-
raise e
|
725 |
|
726 |
-
|
727 |
|
728 |
def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length):
|
729 |
batch_size, kv_seq_len, num_heads, head_dim = key_layer.shape
|
|
|
607 |
else:
|
608 |
# TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in LlamaFlashAttention2 __init__.
|
609 |
causal = self.is_causal and query_length != 1
|
610 |
+
|
611 |
+
# Ensure attention_mask has the correct shape and values
|
612 |
+
if attention_mask is not None:
|
613 |
+
if attention_mask.dim() == 4:
|
614 |
+
# Convert 4D attention mask to 2D
|
615 |
+
attention_mask = attention_mask.squeeze(1).squeeze(1)
|
616 |
+
elif attention_mask.dim() != 2:
|
617 |
+
raise ValueError(
|
618 |
+
f"Invalid attention mask dimension: {attention_mask.dim()}. Expected 2D or 4D mask."
|
619 |
+
)
|
620 |
+
|
621 |
+
# Ensure attention_mask has values of 0 and 1
|
622 |
+
attention_mask = attention_mask.to(torch.bool).to(torch.int32)
|
623 |
|
624 |
# Contains at least one padding token in the sequence
|
625 |
if attention_mask is not None:
|
|
|
680 |
causal=causal,
|
681 |
window_size=(self.config.sliding_window, self.config.sliding_window),
|
682 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
683 |
|
684 |
+
return attn_output
|
685 |
|
686 |
def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length):
|
687 |
batch_size, kv_seq_len, num_heads, head_dim = key_layer.shape
|