Crystalcareai commited on
Commit
995fb64
·
verified ·
1 Parent(s): 6f9c805

Update modeling_quiet.py

Browse files
Files changed (1) hide show
  1. modeling_quiet.py +103 -154
modeling_quiet.py CHANGED
@@ -432,7 +432,7 @@ class QuietFlashAttention2(QuietAttention):
432
  super().__init__(*args, **kwargs)
433
 
434
  # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
435
- # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
436
  # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
437
  self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
438
 
@@ -533,171 +533,123 @@ class QuietFlashAttention2(QuietAttention):
533
  if torch.is_autocast_enabled():
534
  target_dtype = torch.get_autocast_gpu_dtype()
535
  # Handle the case where the model is quantized
536
- elif hasattr(self.config, "_pre_quantization_dtype"):
537
- target_dtype = self.config._pre_quantization_dtype
538
  else:
539
- target_dtype = self.q_proj.weight.dtype
540
-
541
- logger.warning_once(
542
- f"The input hidden states seems to be silently casted in float32, this might be related to"
543
- f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
544
- f" {target_dtype}."
545
- )
546
-
547
  query_states = query_states.to(target_dtype)
548
  key_states = key_states.to(target_dtype)
549
  value_states = value_states.to(target_dtype)
550
 
551
- # Reashape to the expected shape for Flash Attention
552
- query_states = query_states.transpose(1, 2)
553
- key_states = key_states.transpose(1, 2)
554
- value_states = value_states.transpose(1, 2)
555
-
556
- attn_output = self._flash_attention_forward(
557
- query_states,
558
- key_states,
559
- value_states,
560
- attention_mask,
561
- q_len,
562
- dropout=dropout_rate,
563
- use_sliding_windows=use_sliding_windows,
564
- )
565
-
566
- attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
567
- attn_output = self.o_proj(attn_output)
568
-
569
- if not output_attentions:
570
- attn_weights = None
571
-
572
- return attn_output, attn_weights, past_key_value
573
-
574
- def _flash_attention_forward(
575
- self,
576
- query_states,
577
- key_states,
578
- value_states,
579
- attention_mask,
580
- query_length,
581
- dropout=0.0,
582
- softmax_scale=None,
583
- use_sliding_windows=False,
584
- ):
585
- """
586
- Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
587
- first unpad the input, then computes the attention scores and pad the final attention scores.
588
- Args:
589
- query_states (`torch.Tensor`):
590
- Input query states to be passed to Flash Attention API
591
- key_states (`torch.Tensor`):
592
- Input key states to be passed to Flash Attention API
593
- value_states (`torch.Tensor`):
594
- Input value states to be passed to Flash Attention API
595
- attention_mask (`torch.Tensor`):
596
- The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the
597
- position of padding tokens and 1 for the position of non-padding tokens.
598
- dropout (`int`, *optional*):
599
- Attention dropout
600
- softmax_scale (`float`, *optional*):
601
- The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
602
- use_sliding_windows (`bool`, *optional*):
603
- Whether to activate sliding window attention.
604
- """
605
- if not self._flash_attn_uses_top_left_mask:
606
- causal = self.is_causal
607
- else:
608
- # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in LlamaFlashAttention2 __init__.
609
- causal = self.is_causal and query_length != 1
610
-
611
- # Ensure attention_mask has the correct shape and values
612
- if attention_mask is not None:
613
- if attention_mask.dim() == 4:
614
- # Convert 4D attention mask to 2D
615
- attention_mask = attention_mask.squeeze(1).squeeze(1)
616
- elif attention_mask.dim() != 2:
617
- raise ValueError(
618
- f"Invalid attention mask dimension: {attention_mask.dim()}. Expected 2D or 4D mask."
619
- )
620
-
621
- # Ensure attention_mask has values of 0 and 1
622
- attention_mask = attention_mask.to(torch.bool).to(torch.int32)
623
-
624
- # Contains at least one padding token in the sequence
625
- if attention_mask is not None:
626
- batch_size = query_states.shape[0]
627
- query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input(
628
- query_states, key_states, value_states, attention_mask, query_length
629
- )
630
 
631
- cu_seqlens_q, cu_seqlens_k = cu_seq_lens
632
- max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
 
 
 
633
 
634
- # Create the cu_seqlens_q and cu_seqlens_k tensors
635
- q_max_s, k_max_s = query_states.shape[1], key_states.shape[1]
636
- qkv_max_s = max(q_max_s, k_max_s)
 
637
 
638
- q_seqlens = torch.full((batch_size,), q_max_s, dtype=torch.int32, device=query_states.device)
639
- k_seqlens = torch.full((batch_size,), k_max_s, dtype=torch.int32, device=key_states.device)
640
 
641
- # Adjust the attention mask to match the sequence lengths
642
  if attention_mask is not None:
643
- q_seqlens = attention_mask.sum(dim=1).int()
644
- k_seqlens = attention_mask.sum(dim=1).int()
645
-
646
- # Convert seqlens to cumulative sequence lengths
647
- cu_seqlens_q = torch.cat([torch.zeros(1, dtype=torch.int32, device=q_seqlens.device), q_seqlens.cumsum(dim=0)])
648
- cu_seqlens_k = torch.cat([torch.zeros(1, dtype=torch.int32, device=k_seqlens.device), k_seqlens.cumsum(dim=0)])
649
-
650
- if not use_sliding_windows:
651
- attn_output_unpad = flash_attn_varlen_func(
652
- query_states,
653
- key_states,
654
- value_states,
655
- cu_seqlens_q=cu_seqlens_q,
656
- cu_seqlens_k=cu_seqlens_k,
657
- max_seqlen_q=qkv_max_s,
658
- max_seqlen_k=qkv_max_s,
659
- dropout_p=dropout,
660
- softmax_scale=softmax_scale,
661
- causal=causal,
662
- )
663
- else:
664
- attn_output_unpad = flash_attn_varlen_func(
665
  query_states,
666
  key_states,
667
  value_states,
668
- cu_seqlens_q=cu_seqlens_q,
669
- cu_seqlens_k=cu_seqlens_k,
670
- max_seqlen_q=qkv_max_s,
671
- max_seqlen_k=qkv_max_s,
672
- dropout_p=dropout,
673
- softmax_scale=softmax_scale,
674
- causal=causal,
675
- window_size=(self.config.sliding_window, self.config.sliding_window),
676
- )
677
 
678
- attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)
679
- else:
680
- if not use_sliding_windows:
681
- attn_output = flash_attn_func(
682
- query_states,
683
- key_states,
684
- value_states,
685
- dropout,
686
- softmax_scale=softmax_scale,
687
- causal=causal,
688
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
689
  else:
690
- attn_output = flash_attn_func(
691
- query_states,
692
- key_states,
693
- value_states,
694
- dropout,
695
- softmax_scale=softmax_scale,
696
- causal=causal,
697
- window_size=(self.config.sliding_window, self.config.sliding_window),
698
- )
 
 
 
 
 
 
 
 
 
 
699
 
700
- return attn_output
701
 
702
  def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length):
703
  batch_size, kv_seq_len, num_heads, head_dim = key_layer.shape
@@ -711,8 +663,7 @@ def _flash_attention_forward(
711
  indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask)
712
 
713
  key_layer = index_first_axis(key_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k)
714
- value_layer = index_first_axis(value_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k)
715
-
716
  if query_length == kv_seq_len:
717
  query_layer = index_first_axis(
718
  query_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k
@@ -740,8 +691,6 @@ def _flash_attention_forward(
740
  (cu_seqlens_q, cu_seqlens_k),
741
  (max_seqlen_in_batch_q, max_seqlen_in_batch_k),
742
  )
743
-
744
-
745
  # Copied from transformers.models.llama.modeling_llama.LlamaSdpaAttention with Llama->Quiet
746
  class QuietSdpaAttention(QuietAttention):
747
  """
 
432
  super().__init__(*args, **kwargs)
433
 
434
  # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
435
+ # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignment, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
436
  # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
437
  self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
438
 
 
533
  if torch.is_autocast_enabled():
534
  target_dtype = torch.get_autocast_gpu_dtype()
535
  # Handle the case where the model is quantized
 
 
536
  else:
537
+ target_dtype = torch.float16
 
 
 
 
 
 
 
538
  query_states = query_states.to(target_dtype)
539
  key_states = key_states.to(target_dtype)
540
  value_states = value_states.to(target_dtype)
541
 
542
+ # Compute the causal mask
543
+ causal = self.config.causal
544
+ if causal:
545
+ if self._flash_attn_uses_top_left_mask:
546
+ # Compute the causal mask
547
+ causal_mask = torch.tril(torch.ones((q_len, kv_seq_len), dtype=torch.bool, device=query_states.device))
548
+ # Invert the mask
549
+ causal_mask = ~causal_mask
550
+ else:
551
+ causal_mask = torch.triu(
552
+ torch.ones((q_len, kv_seq_len), dtype=torch.bool, device=query_states.device), diagonal=1
553
+ )
554
+ else:
555
+ causal_mask = None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
556
 
557
+ # Compute the attention mask
558
+ if attention_mask is not None:
559
+ if attention_mask.dim() == 2:
560
+ attention_mask = attention_mask[:, None, :]
561
+ attention_mask = attention_mask.to(torch.bool)
562
 
563
+ if causal:
564
+ attention_mask = attention_mask & causal_mask
565
+ else:
566
+ attention_mask = attention_mask
567
 
568
+ # Compute the softmax scale
569
+ softmax_scale = self.head_dim**-0.5
570
 
571
+ # Compute the attention scores
572
  if attention_mask is not None:
573
+ # Unpad the input
574
+ (
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
575
  query_states,
576
  key_states,
577
  value_states,
578
+ indices_q,
579
+ cu_seq_lens,
580
+ max_seq_lens,
581
+ ) = self._upad_input(query_states, key_states, value_states, attention_mask, q_len)
 
 
 
 
 
582
 
583
+ cu_seqlens_q, cu_seqlens_k = cu_seq_lens
584
+ max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
585
+
586
+ # Create the cu_seqlens_q and cu_seqlens_k tensors
587
+ q_max_s, k_max_s = query_states.shape[1], key_states.shape[1]
588
+ qkv_max_s = max(q_max_s, k_max_s)
589
+
590
+ q_seqlens = torch.full((batch_size,), q_max_s, dtype=torch.int32, device=query_states.device)
591
+ k_seqlens = torch.full((batch_size,), k_max_s, dtype=torch.int32, device=key_states.device)
592
+
593
+ # Adjust the attention mask to match the sequence lengths
594
+ if attention_mask is not None:
595
+ q_seqlens = attention_mask.sum(dim=1).int()
596
+ k_seqlens = attention_mask.sum(dim=1).int()
597
+
598
+ # Convert seqlens to cumulative sequence lengths
599
+ cu_seqlens_q = torch.cat([torch.zeros(1, dtype=torch.int32, device=q_seqlens.device), q_seqlens.cumsum(dim=0)])
600
+ cu_seqlens_k = torch.cat([torch.zeros(1, dtype=torch.int32, device=k_seqlens.device), k_seqlens.cumsum(dim=0)])
601
+
602
+ if not use_sliding_windows:
603
+ attn_output_unpad = flash_attn_varlen_func(
604
+ query_states,
605
+ key_states,
606
+ value_states,
607
+ cu_seqlens_q=cu_seqlens_q,
608
+ cu_seqlens_k=cu_seqlens_k,
609
+ max_seqlen_q=qkv_max_s,
610
+ max_seqlen_k=qkv_max_s,
611
+ dropout_p=dropout,
612
+ softmax_scale=softmax_scale,
613
+ causal=causal,
614
+ )
615
+ else:
616
+ attn_output_unpad = flash_attn_varlen_func(
617
+ query_states,
618
+ key_states,
619
+ value_states,
620
+ cu_seqlens_q=cu_seqlens_q,
621
+ cu_seqlens_k=cu_seqlens_k,
622
+ max_seqlen_q=qkv_max_s,
623
+ max_seqlen_k=qkv_max_s,
624
+ dropout_p=dropout,
625
+ softmax_scale=softmax_scale,
626
+ causal=causal,
627
+ window_size=(self.config.sliding_window, self.config.sliding_window),
628
+ )
629
+
630
+ attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)
631
  else:
632
+ if not use_sliding_windows:
633
+ attn_output = flash_attn_func(
634
+ query_states,
635
+ key_states,
636
+ value_states,
637
+ dropout,
638
+ softmax_scale=softmax_scale,
639
+ causal=causal,
640
+ )
641
+ else:
642
+ attn_output = flash_attn_func(
643
+ query_states,
644
+ key_states,
645
+ value_states,
646
+ dropout,
647
+ softmax_scale=softmax_scale,
648
+ causal=causal,
649
+ window_size=(self.config.sliding_window, self.config.sliding_window),
650
+ )
651
 
652
+ return attn_output
653
 
654
  def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length):
655
  batch_size, kv_seq_len, num_heads, head_dim = key_layer.shape
 
663
  indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask)
664
 
665
  key_layer = index_first_axis(key_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k)
666
+ value_layer= index_first_axis(value_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k)
 
667
  if query_length == kv_seq_len:
668
  query_layer = index_first_axis(
669
  query_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k
 
691
  (cu_seqlens_q, cu_seqlens_k),
692
  (max_seqlen_in_batch_q, max_seqlen_in_batch_k),
693
  )
 
 
694
  # Copied from transformers.models.llama.modeling_llama.LlamaSdpaAttention with Llama->Quiet
695
  class QuietSdpaAttention(QuietAttention):
696
  """