Crystalcareai commited on
Commit
8d1df50
·
verified ·
1 Parent(s): f1bc6fa

Update modeling_quiet.py

Browse files
Files changed (1) hide show
  1. modeling_quiet.py +13 -31
modeling_quiet.py CHANGED
@@ -136,22 +136,17 @@ def save_tokens_with_rewards_to_pdf(input_ids, token_rewards, tokenizer, output_
136
  c.save()
137
 
138
 
 
139
  def _get_unpad_data(attention_mask):
140
  seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
141
-
142
  indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
143
-
144
- # Handle the case when seqlens_in_batch is empty
145
- if seqlens_in_batch.numel() == 0:
146
- max_seqlen_in_batch = 0
147
- else:
148
- max_seqlen_in_batch = seqlens_in_batch.max().item()
149
-
150
- # Ensure seqlens_in_batch has the correct shape before cumulative sum
151
- seqlens_in_batch = seqlens_in_batch.view(-1)
152
- cu_seqlens = torch.cat([torch.zeros(1, dtype=torch.int32, device=attention_mask.device), seqlens_in_batch.cumsum(dim=0)])
153
-
154
- return indices, cu_seqlens, max_seqlen_in_batch
155
 
156
 
157
  # Copied from transformers.models.llama.modeling_llama.LlamaRMSNorm with Llama->Quiet
@@ -545,7 +540,7 @@ class QuietFlashAttention2(QuietAttention):
545
  value_states = value_states.to(target_dtype)
546
 
547
  # Compute the causal mask
548
- causal = self.config.is_causal
549
  if causal:
550
  if self._flash_attn_uses_top_left_mask:
551
  # Compute the causal mask
@@ -583,7 +578,7 @@ class QuietFlashAttention2(QuietAttention):
583
  indices_q,
584
  cu_seq_lens,
585
  max_seq_lens,
586
- ) = self.upad_input(query_states, key_states, value_states, attention_mask, q_len)
587
 
588
  cu_seqlens_q, cu_seqlens_k = cu_seq_lens
589
  max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
@@ -656,7 +651,7 @@ class QuietFlashAttention2(QuietAttention):
656
 
657
  return attn_output
658
 
659
- def upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length):
660
  batch_size, kv_seq_len, num_heads, head_dim = key_layer.shape
661
 
662
  # On the first iteration we need to properly re-create the padding mask
@@ -665,23 +660,10 @@ class QuietFlashAttention2(QuietAttention):
665
  attention_mask_num_tokens = attention_mask.shape[-1]
666
  attention_mask = attention_mask[:, attention_mask_num_tokens - kv_seq_len :]
667
 
668
- # Check if attention_mask is empty or all zeros
669
- if attention_mask.numel() == 0 or attention_mask.sum() == 0:
670
- # Return the original query_layer, key_layer, and value_layer without modifications
671
- return (
672
- query_layer,
673
- key_layer,
674
- value_layer,
675
- torch.arange(batch_size, device=query_layer.device),
676
- (torch.arange(batch_size + 1, device=query_layer.device), torch.arange(batch_size + 1, device=key_layer.device)),
677
- (query_layer.shape[1], key_layer.shape[1]),
678
- )
679
-
680
  indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask)
681
 
682
  key_layer = index_first_axis(key_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k)
683
- value_layer = index_first_axis(value_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k)
684
-
685
  if query_length == kv_seq_len:
686
  query_layer = index_first_axis(
687
  query_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k
@@ -786,7 +768,7 @@ class QuietSdpaAttention(QuietAttention):
786
  attn_mask=attention_mask.to(query_states.device) if attention_mask is not None else None,
787
  dropout_p=self.attention_dropout if self.training else 0.0,
788
  # The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1.
789
- is_causal=self.is_causal and attention_mask is None and q_len > 1,
790
  )
791
 
792
  attn_output = attn_output.transpose(1, 2).contiguous()
 
136
  c.save()
137
 
138
 
139
+ # Copied from transformers.models.llama.modeling_llama._get_unpad_data
140
  def _get_unpad_data(attention_mask):
141
  seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
 
142
  indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
143
+ max_seqlen_in_batch = seqlens_in_batch.max().item()
144
+ cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0))
145
+ return (
146
+ indices,
147
+ cu_seqlens,
148
+ max_seqlen_in_batch,
149
+ )
 
 
 
 
 
150
 
151
 
152
  # Copied from transformers.models.llama.modeling_llama.LlamaRMSNorm with Llama->Quiet
 
540
  value_states = value_states.to(target_dtype)
541
 
542
  # Compute the causal mask
543
+ causal = self.config.causal
544
  if causal:
545
  if self._flash_attn_uses_top_left_mask:
546
  # Compute the causal mask
 
578
  indices_q,
579
  cu_seq_lens,
580
  max_seq_lens,
581
+ ) = self._upad_input(query_states, key_states, value_states, attention_mask, q_len)
582
 
583
  cu_seqlens_q, cu_seqlens_k = cu_seq_lens
584
  max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
 
651
 
652
  return attn_output
653
 
654
+ def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length):
655
  batch_size, kv_seq_len, num_heads, head_dim = key_layer.shape
656
 
657
  # On the first iteration we need to properly re-create the padding mask
 
660
  attention_mask_num_tokens = attention_mask.shape[-1]
661
  attention_mask = attention_mask[:, attention_mask_num_tokens - kv_seq_len :]
662
 
 
 
 
 
 
 
 
 
 
 
 
 
663
  indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask)
664
 
665
  key_layer = index_first_axis(key_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k)
666
+ value_layer= index_first_axis(value_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k)
 
667
  if query_length == kv_seq_len:
668
  query_layer = index_first_axis(
669
  query_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k
 
768
  attn_mask=attention_mask.to(query_states.device) if attention_mask is not None else None,
769
  dropout_p=self.attention_dropout if self.training else 0.0,
770
  # The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1.
771
+ causal=self.is_causal and attention_mask is None and q_len > 1,
772
  )
773
 
774
  attn_output = attn_output.transpose(1, 2).contiguous()