Crystalcareai commited on
Commit
ced45b7
·
verified ·
1 Parent(s): 7d42e86

Update modeling_quiet.py

Browse files
Files changed (1) hide show
  1. modeling_quiet.py +81 -2
modeling_quiet.py CHANGED
@@ -294,6 +294,7 @@ class QuietAttention(nn.Module):
294
  self.rope_theta = config.rope_theta
295
  self.is_causal = True
296
  self.attention_dropout = config.attention_dropout
 
297
 
298
  if (self.head_dim * self.num_heads) != self.hidden_size:
299
  raise ValueError(
@@ -365,7 +366,30 @@ class QuietAttention(nn.Module):
365
  f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
366
  f" {attn_weights.size()}"
367
  )
368
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
369
  if attention_mask is not None:
370
  if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
371
  raise ValueError(
@@ -643,8 +667,63 @@ class QuietFlashAttention2(QuietAttention):
643
  causal=causal,
644
  window_size=(self.config.sliding_window, self.config.sliding_window),
645
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
646
 
647
- return attn_output
648
 
649
  def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length):
650
  batch_size, kv_seq_len, num_heads, head_dim = key_layer.shape
 
294
  self.rope_theta = config.rope_theta
295
  self.is_causal = True
296
  self.attention_dropout = config.attention_dropout
297
+ self._attn_implementation = config._attn_implementation
298
 
299
  if (self.head_dim * self.num_heads) != self.hidden_size:
300
  raise ValueError(
 
366
  f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
367
  f" {attn_weights.size()}"
368
  )
369
+ if self._attn_implementation == "flash_attention_2":
370
+ # Prepare attention mask for flash-attn
371
+ attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
372
+ elif self._attn_implementation == "sdpa":
373
+ # Prepare attention mask for SDPA
374
+ if attention_mask is None or attention_mask.dim() == 2:
375
+ attention_mask = _prepare_4d_causal_attention_mask(
376
+ attention_mask,
377
+ (batch_size, seq_length),
378
+ inputs_embeds,
379
+ past_key_values_length,
380
+ sliding_window=self.config.sliding_window,
381
+ )
382
+ else:
383
+ # Prepare attention mask for other implementations
384
+ if attention_mask is None or attention_mask.dim() == 2:
385
+ attention_mask = _prepare_4d_causal_attention_mask(
386
+ attention_mask,
387
+ (batch_size, seq_length),
388
+ inputs_embeds,
389
+ past_key_values_length,
390
+ sliding_window=self.config.sliding_window,
391
+ )
392
+
393
  if attention_mask is not None:
394
  if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
395
  raise ValueError(
 
667
  causal=causal,
668
  window_size=(self.config.sliding_window, self.config.sliding_window),
669
  )
670
+ try:
671
+ attn_output_unpad = flash_attn_varlen_func(
672
+ query_states,
673
+ key_states,
674
+ value_states,
675
+ cu_seqlens_q=cu_seqlens_q,
676
+ cu_seqlens_k=cu_seqlens_k,
677
+ max_seqlen_q=max_seqlen_in_batch_q,
678
+ max_seqlen_k=max_seqlen_in_batch_k,
679
+ dropout_p=dropout,
680
+ softmax_scale=softmax_scale,
681
+ causal=causal,
682
+ )
683
+ except RuntimeError as e:
684
+ if "cu_seqlens_q must have shape (batch_size + 1)" in str(e):
685
+ # Handle the case when cu_seqlens_q has an invalid shape
686
+ if attention_mask is not None:
687
+ # Ensure attention_mask has the correct shape
688
+ if attention_mask.dim() == 2:
689
+ # Convert 2D attention mask to 4D
690
+ attention_mask = _prepare_4d_causal_attention_mask(
691
+ attention_mask,
692
+ (query_states.size(0), query_states.size(1)),
693
+ query_states,
694
+ past_key_values_length=0,
695
+ sliding_window=0,
696
+ )
697
+ elif attention_mask.dim() != 4:
698
+ raise ValueError(
699
+ f"Invalid attention mask dimension: {attention_mask.dim()}. Expected 2D or 4D mask."
700
+ )
701
+
702
+ # Update cu_seqlens_q based on the attention mask
703
+ cu_seqlens_q = attention_mask.sum(dim=-1).flatten().cumsum(dim=0).to(torch.int32)
704
+ max_seqlen_in_batch_q = cu_seqlens_q[-1].item()
705
+
706
+ # Retry flash_attn_varlen_func with updated cu_seqlens_q
707
+ attn_output_unpad = flash_attn_varlen_func(
708
+ query_states,
709
+ key_states,
710
+ value_states,
711
+ cu_seqlens_q=cu_seqlens_q,
712
+ cu_seqlens_k=cu_seqlens_k,
713
+ max_seqlen_q=max_seqlen_in_batch_q,
714
+ max_seqlen_k=max_seqlen_in_batch_k,
715
+ dropout_p=dropout,
716
+ softmax_scale=softmax_scale,
717
+ causal=causal,
718
+ )
719
+ else:
720
+ raise ValueError(
721
+ "Attention mask is required for flash-attn when cu_seqlens_q has an invalid shape."
722
+ )
723
+ else:
724
+ raise e
725
 
726
+ return attn_output
727
 
728
  def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length):
729
  batch_size, kv_seq_len, num_heads, head_dim = key_layer.shape