Crystalcareai
commited on
Update modeling_quiet.py
Browse files- modeling_quiet.py +26 -38
modeling_quiet.py
CHANGED
@@ -656,46 +656,34 @@ class QuietFlashAttention2(QuietAttention):
|
|
656 |
|
657 |
return attn_output
|
658 |
|
659 |
-
def
|
660 |
-
|
661 |
-
|
662 |
-
|
663 |
-
|
664 |
-
|
665 |
-
|
666 |
-
|
667 |
-
|
668 |
-
|
669 |
-
|
670 |
-
|
671 |
-
|
672 |
-
|
673 |
-
|
674 |
-
|
675 |
-
|
676 |
-
cu_seqlens_q = cu_seqlens_k
|
677 |
-
max_seqlen_in_batch_q = max_seqlen_in_batch_k
|
678 |
-
indices_q = indices_k
|
679 |
-
elif query_length == 1:
|
680 |
-
max_seqlen_in_batch_q = 1
|
681 |
-
cu_seqlens_q = torch.arange(
|
682 |
-
batch_size + 1, dtype=torch.int32, device=query_layer.device
|
683 |
-
) # There is a memcpy here, that is very bad.
|
684 |
-
indices_q = cu_seqlens_q[:-1]
|
685 |
-
query_layer = query_layer.squeeze(1)
|
686 |
else:
|
687 |
-
|
688 |
-
attention_mask = attention_mask[:, -query_length:]
|
689 |
-
query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask)
|
690 |
|
691 |
-
|
692 |
-
|
693 |
-
|
694 |
-
|
695 |
-
|
696 |
-
(
|
697 |
-
|
698 |
-
|
699 |
# Copied from transformers.models.llama.modeling_llama.LlamaSdpaAttention with Llama->Quiet
|
700 |
class QuietSdpaAttention(QuietAttention):
|
701 |
"""
|
|
|
656 |
|
657 |
return attn_output
|
658 |
|
659 |
+
def unpad_input(hidden_states, attention_mask):
|
660 |
+
"""
|
661 |
+
Arguments:
|
662 |
+
hidden_states: (batch, seqlen, dim)
|
663 |
+
attention_mask: (batch, seqlen), bool / int, 1 means valid and 0 means not valid.
|
664 |
+
Return:
|
665 |
+
hidden_states: (total_nnz, dim), where total_nnz = number of tokens in selected in attention_mask.
|
666 |
+
indices: (total_nnz,)
|
667 |
+
cu_seqlens: (batch + 1,), use 0 as delimiter.
|
668 |
+
max_seqlen: int
|
669 |
+
"""
|
670 |
+
batch_size, seqlen = hidden_states.shape[:2]
|
671 |
+
seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
|
672 |
+
|
673 |
+
# Handle the case when seqlens_in_batch is empty
|
674 |
+
if seqlens_in_batch.numel() == 0:
|
675 |
+
max_seqlen_in_batch = 0
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
676 |
else:
|
677 |
+
max_seqlen_in_batch = seqlens_in_batch.max().item()
|
|
|
|
|
678 |
|
679 |
+
indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
|
680 |
+
if indices.numel() == 0:
|
681 |
+
indices = torch.zeros(0, dtype=torch.int64, device=hidden_states.device)
|
682 |
+
hidden_states = torch.zeros(0, hidden_states.shape[-1], dtype=hidden_states.dtype, device=hidden_states.device)
|
683 |
+
else:
|
684 |
+
hidden_states = hidden_states.flatten(0, 1)[indices]
|
685 |
+
cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0))
|
686 |
+
return hidden_states, indices, cu_seqlens, max_seqlen_in_batch
|
687 |
# Copied from transformers.models.llama.modeling_llama.LlamaSdpaAttention with Llama->Quiet
|
688 |
class QuietSdpaAttention(QuietAttention):
|
689 |
"""
|