Crystalcareai commited on
Commit
6f9c805
·
verified ·
1 Parent(s): 3bbca75

Update modeling_quiet.py

Browse files
Files changed (1) hide show
  1. modeling_quiet.py +122 -109
modeling_quiet.py CHANGED
@@ -571,117 +571,133 @@ class QuietFlashAttention2(QuietAttention):
571
 
572
  return attn_output, attn_weights, past_key_value
573
 
574
- def _flash_attention_forward(
575
- self,
576
- query_states,
577
- key_states,
578
- value_states,
579
- attention_mask,
580
- query_length,
581
- dropout=0.0,
582
- softmax_scale=None,
583
- use_sliding_windows=False,
584
- ):
585
- """
586
- Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
587
- first unpad the input, then computes the attention scores and pad the final attention scores.
588
- Args:
589
- query_states (`torch.Tensor`):
590
- Input query states to be passed to Flash Attention API
591
- key_states (`torch.Tensor`):
592
- Input key states to be passed to Flash Attention API
593
- value_states (`torch.Tensor`):
594
- Input value states to be passed to Flash Attention API
595
- attention_mask (`torch.Tensor`):
596
- The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the
597
- position of padding tokens and 1 for the position of non-padding tokens.
598
- dropout (`int`, *optional*):
599
- Attention dropout
600
- softmax_scale (`float`, *optional*):
601
- The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
602
- use_sliding_windows (`bool`, *optional*):
603
- Whether to activate sliding window attention.
604
- """
605
- if not self._flash_attn_uses_top_left_mask:
606
- causal = self.is_causal
607
- else:
608
- # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in LlamaFlashAttention2 __init__.
609
- causal = self.is_causal and query_length != 1
610
-
611
- # Ensure attention_mask has the correct shape and values
612
- if attention_mask is not None:
613
- if attention_mask.dim() == 4:
614
- # Convert 4D attention mask to 2D
615
- attention_mask = attention_mask.squeeze(1).squeeze(1)
616
- elif attention_mask.dim() != 2:
617
- raise ValueError(
618
- f"Invalid attention mask dimension: {attention_mask.dim()}. Expected 2D or 4D mask."
619
- )
620
-
621
- # Ensure attention_mask has values of 0 and 1
622
- attention_mask = attention_mask.to(torch.bool).to(torch.int32)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
623
 
624
- # Contains at least one padding token in the sequence
 
 
 
625
  if attention_mask is not None:
626
- batch_size = query_states.shape[0]
627
- query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input(
628
- query_states, key_states, value_states, attention_mask, query_length
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
629
  )
630
 
631
- cu_seqlens_q, cu_seqlens_k = cu_seq_lens
632
- max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
633
-
634
- if not use_sliding_windows:
635
- attn_output_unpad = flash_attn_varlen_func(
636
- query_states,
637
- key_states,
638
- value_states,
639
- cu_seqlens_q=cu_seqlens_q,
640
- cu_seqlens_k=cu_seqlens_k,
641
- max_seqlen_q=max_seqlen_in_batch_q,
642
- max_seqlen_k=max_seqlen_in_batch_k,
643
- dropout_p=dropout,
644
- softmax_scale=softmax_scale,
645
- causal=causal,
646
- )
647
- else:
648
- attn_output_unpad = flash_attn_varlen_func(
649
- query_states,
650
- key_states,
651
- value_states,
652
- cu_seqlens_q=cu_seqlens_q,
653
- cu_seqlens_k=cu_seqlens_k,
654
- max_seqlen_q=max_seqlen_in_batch_q,
655
- max_seqlen_k=max_seqlen_in_batch_k,
656
- dropout_p=dropout,
657
- softmax_scale=softmax_scale,
658
- causal=causal,
659
- window_size=(self.config.sliding_window, self.config.sliding_window),
660
- )
661
-
662
- attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)
663
  else:
664
- if not use_sliding_windows:
665
- attn_output = flash_attn_func(
666
- query_states,
667
- key_states,
668
- value_states,
669
- dropout,
670
- softmax_scale=softmax_scale,
671
- causal=causal,
672
- )
673
- else:
674
- attn_output = flash_attn_func(
675
- query_states,
676
- key_states,
677
- value_states,
678
- dropout,
679
- softmax_scale=softmax_scale,
680
- causal=causal,
681
- window_size=(self.config.sliding_window, self.config.sliding_window),
682
- )
683
 
684
- return attn_output
685
 
686
  def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length):
687
  batch_size, kv_seq_len, num_heads, head_dim = key_layer.shape
@@ -1848,10 +1864,7 @@ class QuietForCausalLM(QuietPreTrainedModel):
1848
  [shift_labels, padding],
1849
  dim=-1
1850
  )
1851
-
1852
- # Adjust the labels to account for the additional thinking tokens
1853
- new_rm_tokens = torch.cat([torch.full_like(new_rm_tokens[..., :self.n_ahead], self.tokenizer.pad_token_id, dtype=torch.long, device=new_rm_tokens.device), new_rm_tokens], dim=-1)
1854
-
1855
 
1856
  # print((new_rm_tokens > self.vocab_size - 1).any().item())
1857
  new_rm_tokens = torch.clamp(new_rm_tokens, 0, self.vocab_size - 1)
 
571
 
572
  return attn_output, attn_weights, past_key_value
573
 
574
+ def _flash_attention_forward(
575
+ self,
576
+ query_states,
577
+ key_states,
578
+ value_states,
579
+ attention_mask,
580
+ query_length,
581
+ dropout=0.0,
582
+ softmax_scale=None,
583
+ use_sliding_windows=False,
584
+ ):
585
+ """
586
+ Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
587
+ first unpad the input, then computes the attention scores and pad the final attention scores.
588
+ Args:
589
+ query_states (`torch.Tensor`):
590
+ Input query states to be passed to Flash Attention API
591
+ key_states (`torch.Tensor`):
592
+ Input key states to be passed to Flash Attention API
593
+ value_states (`torch.Tensor`):
594
+ Input value states to be passed to Flash Attention API
595
+ attention_mask (`torch.Tensor`):
596
+ The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the
597
+ position of padding tokens and 1 for the position of non-padding tokens.
598
+ dropout (`int`, *optional*):
599
+ Attention dropout
600
+ softmax_scale (`float`, *optional*):
601
+ The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
602
+ use_sliding_windows (`bool`, *optional*):
603
+ Whether to activate sliding window attention.
604
+ """
605
+ if not self._flash_attn_uses_top_left_mask:
606
+ causal = self.is_causal
607
+ else:
608
+ # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in LlamaFlashAttention2 __init__.
609
+ causal = self.is_causal and query_length != 1
610
+
611
+ # Ensure attention_mask has the correct shape and values
612
+ if attention_mask is not None:
613
+ if attention_mask.dim() == 4:
614
+ # Convert 4D attention mask to 2D
615
+ attention_mask = attention_mask.squeeze(1).squeeze(1)
616
+ elif attention_mask.dim() != 2:
617
+ raise ValueError(
618
+ f"Invalid attention mask dimension: {attention_mask.dim()}. Expected 2D or 4D mask."
619
+ )
620
+
621
+ # Ensure attention_mask has values of 0 and 1
622
+ attention_mask = attention_mask.to(torch.bool).to(torch.int32)
623
+
624
+ # Contains at least one padding token in the sequence
625
+ if attention_mask is not None:
626
+ batch_size = query_states.shape[0]
627
+ query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input(
628
+ query_states, key_states, value_states, attention_mask, query_length
629
+ )
630
+
631
+ cu_seqlens_q, cu_seqlens_k = cu_seq_lens
632
+ max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
633
+
634
+ # Create the cu_seqlens_q and cu_seqlens_k tensors
635
+ q_max_s, k_max_s = query_states.shape[1], key_states.shape[1]
636
+ qkv_max_s = max(q_max_s, k_max_s)
637
 
638
+ q_seqlens = torch.full((batch_size,), q_max_s, dtype=torch.int32, device=query_states.device)
639
+ k_seqlens = torch.full((batch_size,), k_max_s, dtype=torch.int32, device=key_states.device)
640
+
641
+ # Adjust the attention mask to match the sequence lengths
642
  if attention_mask is not None:
643
+ q_seqlens = attention_mask.sum(dim=1).int()
644
+ k_seqlens = attention_mask.sum(dim=1).int()
645
+
646
+ # Convert seqlens to cumulative sequence lengths
647
+ cu_seqlens_q = torch.cat([torch.zeros(1, dtype=torch.int32, device=q_seqlens.device), q_seqlens.cumsum(dim=0)])
648
+ cu_seqlens_k = torch.cat([torch.zeros(1, dtype=torch.int32, device=k_seqlens.device), k_seqlens.cumsum(dim=0)])
649
+
650
+ if not use_sliding_windows:
651
+ attn_output_unpad = flash_attn_varlen_func(
652
+ query_states,
653
+ key_states,
654
+ value_states,
655
+ cu_seqlens_q=cu_seqlens_q,
656
+ cu_seqlens_k=cu_seqlens_k,
657
+ max_seqlen_q=qkv_max_s,
658
+ max_seqlen_k=qkv_max_s,
659
+ dropout_p=dropout,
660
+ softmax_scale=softmax_scale,
661
+ causal=causal,
662
+ )
663
+ else:
664
+ attn_output_unpad = flash_attn_varlen_func(
665
+ query_states,
666
+ key_states,
667
+ value_states,
668
+ cu_seqlens_q=cu_seqlens_q,
669
+ cu_seqlens_k=cu_seqlens_k,
670
+ max_seqlen_q=qkv_max_s,
671
+ max_seqlen_k=qkv_max_s,
672
+ dropout_p=dropout,
673
+ softmax_scale=softmax_scale,
674
+ causal=causal,
675
+ window_size=(self.config.sliding_window, self.config.sliding_window),
676
  )
677
 
678
+ attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)
679
+ else:
680
+ if not use_sliding_windows:
681
+ attn_output = flash_attn_func(
682
+ query_states,
683
+ key_states,
684
+ value_states,
685
+ dropout,
686
+ softmax_scale=softmax_scale,
687
+ causal=causal,
688
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
689
  else:
690
+ attn_output = flash_attn_func(
691
+ query_states,
692
+ key_states,
693
+ value_states,
694
+ dropout,
695
+ softmax_scale=softmax_scale,
696
+ causal=causal,
697
+ window_size=(self.config.sliding_window, self.config.sliding_window),
698
+ )
 
 
 
 
 
 
 
 
 
 
699
 
700
+ return attn_output
701
 
702
  def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length):
703
  batch_size, kv_seq_len, num_heads, head_dim = key_layer.shape
 
1864
  [shift_labels, padding],
1865
  dim=-1
1866
  )
1867
+
 
 
 
1868
 
1869
  # print((new_rm_tokens > self.vocab_size - 1).any().item())
1870
  new_rm_tokens = torch.clamp(new_rm_tokens, 0, self.vocab_size - 1)