Crystalcareai commited on
Commit
2c3ad55
·
verified ·
1 Parent(s): f197dce

Update modeling_quiet.py

Browse files
Files changed (1) hide show
  1. modeling_quiet.py +38 -25
modeling_quiet.py CHANGED
@@ -656,34 +656,47 @@ class QuietFlashAttention2(QuietAttention):
656
 
657
  return attn_output
658
 
659
- def upad_input(hidden_states, attention_mask):
660
- """
661
- Arguments:
662
- hidden_states: (batch, seqlen, dim)
663
- attention_mask: (batch, seqlen), bool / int, 1 means valid and 0 means not valid.
664
- Return:
665
- hidden_states: (total_nnz, dim), where total_nnz = number of tokens in selected in attention_mask.
666
- indices: (total_nnz,)
667
- cu_seqlens: (batch + 1,), use 0 as delimiter.
668
- max_seqlen: int
669
- """
670
- batch_size, seqlen = hidden_states.shape[:2]
671
- seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
672
 
673
- # Handle the case when seqlens_in_batch is empty
674
- if seqlens_in_batch.numel() == 0:
675
- max_seqlen_in_batch = 0
676
- else:
677
- max_seqlen_in_batch = seqlens_in_batch.max().item()
 
 
 
 
 
678
 
679
- indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
680
- if indices.numel() == 0:
681
- indices = torch.zeros(0, dtype=torch.int64, device=hidden_states.device)
682
- hidden_states = torch.zeros(0, hidden_states.shape[-1], dtype=hidden_states.dtype, device=hidden_states.device)
 
 
 
 
 
 
 
 
 
 
683
  else:
684
- hidden_states = hidden_states.flatten(0, 1)[indices]
685
- cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0))
686
- return hidden_states, indices, cu_seqlens, max_seqlen_in_batch
 
 
 
 
 
 
 
 
 
687
  # Copied from transformers.models.llama.modeling_llama.LlamaSdpaAttention with Llama->Quiet
688
  class QuietSdpaAttention(QuietAttention):
689
  """
 
656
 
657
  return attn_output
658
 
659
+ def upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length):
660
+ batch_size, kv_seq_len, num_heads, head_dim = key_layer.shape
 
 
 
 
 
 
 
 
 
 
 
661
 
662
+ # On the first iteration we need to properly re-create the padding mask
663
+ # by slicing it on the proper place
664
+ if kv_seq_len != attention_mask.shape[-1]:
665
+ attention_mask_num_tokens = attention_mask.shape[-1]
666
+ attention_mask = attention_mask[:, attention_mask_num_tokens - kv_seq_len :]
667
+
668
+ indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask)
669
+
670
+ key_layer = index_first_axis(key_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k)
671
+ value_layer = index_first_axis(value_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k)
672
 
673
+ if query_length == kv_seq_len:
674
+ query_layer = index_first_axis(
675
+ query_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k
676
+ )
677
+ cu_seqlens_q = cu_seqlens_k
678
+ max_seqlen_in_batch_q = max_seqlen_in_batch_k
679
+ indices_q = indices_k
680
+ elif query_length == 1:
681
+ max_seqlen_in_batch_q = 1
682
+ cu_seqlens_q = torch.arange(
683
+ batch_size + 1, dtype=torch.int32, device=query_layer.device
684
+ ) # There is a memcpy here, that is very bad.
685
+ indices_q = cu_seqlens_q[:-1]
686
+ query_layer = query_layer.squeeze(1)
687
  else:
688
+ # The -q_len: slice assumes left padding.
689
+ attention_mask = attention_mask[:, -query_length:]
690
+ query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask)
691
+
692
+ return (
693
+ query_layer,
694
+ key_layer,
695
+ value_layer,
696
+ indices_q,
697
+ (cu_seqlens_q, cu_seqlens_k),
698
+ (max_seqlen_in_batch_q, max_seqlen_in_batch_k),
699
+ )
700
  # Copied from transformers.models.llama.modeling_llama.LlamaSdpaAttention with Llama->Quiet
701
  class QuietSdpaAttention(QuietAttention):
702
  """