Crystalcareai commited on
Commit
37a8486
·
verified ·
1 Parent(s): 7530909

Update modeling_quiet.py

Browse files
Files changed (1) hide show
  1. modeling_quiet.py +26 -38
modeling_quiet.py CHANGED
@@ -656,46 +656,34 @@ class QuietFlashAttention2(QuietAttention):
656
 
657
  return attn_output
658
 
659
- def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length):
660
- batch_size, kv_seq_len, num_heads, head_dim = key_layer.shape
661
-
662
- # On the first iteration we need to properly re-create the padding mask
663
- # by slicing it on the proper place
664
- if kv_seq_len != attention_mask.shape[-1]:
665
- attention_mask_num_tokens = attention_mask.shape[-1]
666
- attention_mask = attention_mask[:, attention_mask_num_tokens - kv_seq_len :]
667
-
668
- indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask)
669
-
670
- key_layer = index_first_axis(key_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k)
671
- value_layer= index_first_axis(value_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k)
672
- if query_length == kv_seq_len:
673
- query_layer = index_first_axis(
674
- query_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k
675
- )
676
- cu_seqlens_q = cu_seqlens_k
677
- max_seqlen_in_batch_q = max_seqlen_in_batch_k
678
- indices_q = indices_k
679
- elif query_length == 1:
680
- max_seqlen_in_batch_q = 1
681
- cu_seqlens_q = torch.arange(
682
- batch_size + 1, dtype=torch.int32, device=query_layer.device
683
- ) # There is a memcpy here, that is very bad.
684
- indices_q = cu_seqlens_q[:-1]
685
- query_layer = query_layer.squeeze(1)
686
  else:
687
- # The -q_len: slice assumes left padding.
688
- attention_mask = attention_mask[:, -query_length:]
689
- query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask)
690
 
691
- return (
692
- query_layer,
693
- key_layer,
694
- value_layer,
695
- indices_q,
696
- (cu_seqlens_q, cu_seqlens_k),
697
- (max_seqlen_in_batch_q, max_seqlen_in_batch_k),
698
- )
699
  # Copied from transformers.models.llama.modeling_llama.LlamaSdpaAttention with Llama->Quiet
700
  class QuietSdpaAttention(QuietAttention):
701
  """
 
656
 
657
  return attn_output
658
 
659
+ def unpad_input(hidden_states, attention_mask):
660
+ """
661
+ Arguments:
662
+ hidden_states: (batch, seqlen, dim)
663
+ attention_mask: (batch, seqlen), bool / int, 1 means valid and 0 means not valid.
664
+ Return:
665
+ hidden_states: (total_nnz, dim), where total_nnz = number of tokens in selected in attention_mask.
666
+ indices: (total_nnz,)
667
+ cu_seqlens: (batch + 1,), use 0 as delimiter.
668
+ max_seqlen: int
669
+ """
670
+ batch_size, seqlen = hidden_states.shape[:2]
671
+ seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
672
+
673
+ # Handle the case when seqlens_in_batch is empty
674
+ if seqlens_in_batch.numel() == 0:
675
+ max_seqlen_in_batch = 0
 
 
 
 
 
 
 
 
 
 
676
  else:
677
+ max_seqlen_in_batch = seqlens_in_batch.max().item()
 
 
678
 
679
+ indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
680
+ if indices.numel() == 0:
681
+ indices = torch.zeros(0, dtype=torch.int64, device=hidden_states.device)
682
+ hidden_states = torch.zeros(0, hidden_states.shape[-1], dtype=hidden_states.dtype, device=hidden_states.device)
683
+ else:
684
+ hidden_states = hidden_states.flatten(0, 1)[indices]
685
+ cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0))
686
+ return hidden_states, indices, cu_seqlens, max_seqlen_in_batch
687
  # Copied from transformers.models.llama.modeling_llama.LlamaSdpaAttention with Llama->Quiet
688
  class QuietSdpaAttention(QuietAttention):
689
  """