Crystalcareai commited on
Commit
120f09f
·
verified ·
1 Parent(s): ced45b7

Update modeling_quiet.py

Browse files
Files changed (1) hide show
  1. modeling_quiet.py +14 -56
modeling_quiet.py CHANGED
@@ -607,6 +607,19 @@ class QuietFlashAttention2(QuietAttention):
607
  else:
608
  # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in LlamaFlashAttention2 __init__.
609
  causal = self.is_causal and query_length != 1
 
 
 
 
 
 
 
 
 
 
 
 
 
610
 
611
  # Contains at least one padding token in the sequence
612
  if attention_mask is not None:
@@ -667,63 +680,8 @@ class QuietFlashAttention2(QuietAttention):
667
  causal=causal,
668
  window_size=(self.config.sliding_window, self.config.sliding_window),
669
  )
670
- try:
671
- attn_output_unpad = flash_attn_varlen_func(
672
- query_states,
673
- key_states,
674
- value_states,
675
- cu_seqlens_q=cu_seqlens_q,
676
- cu_seqlens_k=cu_seqlens_k,
677
- max_seqlen_q=max_seqlen_in_batch_q,
678
- max_seqlen_k=max_seqlen_in_batch_k,
679
- dropout_p=dropout,
680
- softmax_scale=softmax_scale,
681
- causal=causal,
682
- )
683
- except RuntimeError as e:
684
- if "cu_seqlens_q must have shape (batch_size + 1)" in str(e):
685
- # Handle the case when cu_seqlens_q has an invalid shape
686
- if attention_mask is not None:
687
- # Ensure attention_mask has the correct shape
688
- if attention_mask.dim() == 2:
689
- # Convert 2D attention mask to 4D
690
- attention_mask = _prepare_4d_causal_attention_mask(
691
- attention_mask,
692
- (query_states.size(0), query_states.size(1)),
693
- query_states,
694
- past_key_values_length=0,
695
- sliding_window=0,
696
- )
697
- elif attention_mask.dim() != 4:
698
- raise ValueError(
699
- f"Invalid attention mask dimension: {attention_mask.dim()}. Expected 2D or 4D mask."
700
- )
701
-
702
- # Update cu_seqlens_q based on the attention mask
703
- cu_seqlens_q = attention_mask.sum(dim=-1).flatten().cumsum(dim=0).to(torch.int32)
704
- max_seqlen_in_batch_q = cu_seqlens_q[-1].item()
705
-
706
- # Retry flash_attn_varlen_func with updated cu_seqlens_q
707
- attn_output_unpad = flash_attn_varlen_func(
708
- query_states,
709
- key_states,
710
- value_states,
711
- cu_seqlens_q=cu_seqlens_q,
712
- cu_seqlens_k=cu_seqlens_k,
713
- max_seqlen_q=max_seqlen_in_batch_q,
714
- max_seqlen_k=max_seqlen_in_batch_k,
715
- dropout_p=dropout,
716
- softmax_scale=softmax_scale,
717
- causal=causal,
718
- )
719
- else:
720
- raise ValueError(
721
- "Attention mask is required for flash-attn when cu_seqlens_q has an invalid shape."
722
- )
723
- else:
724
- raise e
725
 
726
- return attn_output
727
 
728
  def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length):
729
  batch_size, kv_seq_len, num_heads, head_dim = key_layer.shape
 
607
  else:
608
  # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in LlamaFlashAttention2 __init__.
609
  causal = self.is_causal and query_length != 1
610
+
611
+ # Ensure attention_mask has the correct shape and values
612
+ if attention_mask is not None:
613
+ if attention_mask.dim() == 4:
614
+ # Convert 4D attention mask to 2D
615
+ attention_mask = attention_mask.squeeze(1).squeeze(1)
616
+ elif attention_mask.dim() != 2:
617
+ raise ValueError(
618
+ f"Invalid attention mask dimension: {attention_mask.dim()}. Expected 2D or 4D mask."
619
+ )
620
+
621
+ # Ensure attention_mask has values of 0 and 1
622
+ attention_mask = attention_mask.to(torch.bool).to(torch.int32)
623
 
624
  # Contains at least one padding token in the sequence
625
  if attention_mask is not None:
 
680
  causal=causal,
681
  window_size=(self.config.sliding_window, self.config.sliding_window),
682
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
683
 
684
+ return attn_output
685
 
686
  def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length):
687
  batch_size, kv_seq_len, num_heads, head_dim = key_layer.shape