Crystalcareai
commited on
Update modeling_quiet.py
Browse files- modeling_quiet.py +38 -25
modeling_quiet.py
CHANGED
@@ -656,34 +656,47 @@ class QuietFlashAttention2(QuietAttention):
|
|
656 |
|
657 |
return attn_output
|
658 |
|
659 |
-
def upad_input(
|
660 |
-
|
661 |
-
Arguments:
|
662 |
-
hidden_states: (batch, seqlen, dim)
|
663 |
-
attention_mask: (batch, seqlen), bool / int, 1 means valid and 0 means not valid.
|
664 |
-
Return:
|
665 |
-
hidden_states: (total_nnz, dim), where total_nnz = number of tokens in selected in attention_mask.
|
666 |
-
indices: (total_nnz,)
|
667 |
-
cu_seqlens: (batch + 1,), use 0 as delimiter.
|
668 |
-
max_seqlen: int
|
669 |
-
"""
|
670 |
-
batch_size, seqlen = hidden_states.shape[:2]
|
671 |
-
seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
|
672 |
|
673 |
-
#
|
674 |
-
|
675 |
-
|
676 |
-
|
677 |
-
|
|
|
|
|
|
|
|
|
|
|
678 |
|
679 |
-
|
680 |
-
|
681 |
-
|
682 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
683 |
else:
|
684 |
-
|
685 |
-
|
686 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
687 |
# Copied from transformers.models.llama.modeling_llama.LlamaSdpaAttention with Llama->Quiet
|
688 |
class QuietSdpaAttention(QuietAttention):
|
689 |
"""
|
|
|
656 |
|
657 |
return attn_output
|
658 |
|
659 |
+
def upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length):
|
660 |
+
batch_size, kv_seq_len, num_heads, head_dim = key_layer.shape
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
661 |
|
662 |
+
# On the first iteration we need to properly re-create the padding mask
|
663 |
+
# by slicing it on the proper place
|
664 |
+
if kv_seq_len != attention_mask.shape[-1]:
|
665 |
+
attention_mask_num_tokens = attention_mask.shape[-1]
|
666 |
+
attention_mask = attention_mask[:, attention_mask_num_tokens - kv_seq_len :]
|
667 |
+
|
668 |
+
indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask)
|
669 |
+
|
670 |
+
key_layer = index_first_axis(key_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k)
|
671 |
+
value_layer = index_first_axis(value_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k)
|
672 |
|
673 |
+
if query_length == kv_seq_len:
|
674 |
+
query_layer = index_first_axis(
|
675 |
+
query_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k
|
676 |
+
)
|
677 |
+
cu_seqlens_q = cu_seqlens_k
|
678 |
+
max_seqlen_in_batch_q = max_seqlen_in_batch_k
|
679 |
+
indices_q = indices_k
|
680 |
+
elif query_length == 1:
|
681 |
+
max_seqlen_in_batch_q = 1
|
682 |
+
cu_seqlens_q = torch.arange(
|
683 |
+
batch_size + 1, dtype=torch.int32, device=query_layer.device
|
684 |
+
) # There is a memcpy here, that is very bad.
|
685 |
+
indices_q = cu_seqlens_q[:-1]
|
686 |
+
query_layer = query_layer.squeeze(1)
|
687 |
else:
|
688 |
+
# The -q_len: slice assumes left padding.
|
689 |
+
attention_mask = attention_mask[:, -query_length:]
|
690 |
+
query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask)
|
691 |
+
|
692 |
+
return (
|
693 |
+
query_layer,
|
694 |
+
key_layer,
|
695 |
+
value_layer,
|
696 |
+
indices_q,
|
697 |
+
(cu_seqlens_q, cu_seqlens_k),
|
698 |
+
(max_seqlen_in_batch_q, max_seqlen_in_batch_k),
|
699 |
+
)
|
700 |
# Copied from transformers.models.llama.modeling_llama.LlamaSdpaAttention with Llama->Quiet
|
701 |
class QuietSdpaAttention(QuietAttention):
|
702 |
"""
|