Crystalcareai
commited on
Update modeling_quiet.py
Browse files- modeling_quiet.py +81 -2
modeling_quiet.py
CHANGED
@@ -294,6 +294,7 @@ class QuietAttention(nn.Module):
|
|
294 |
self.rope_theta = config.rope_theta
|
295 |
self.is_causal = True
|
296 |
self.attention_dropout = config.attention_dropout
|
|
|
297 |
|
298 |
if (self.head_dim * self.num_heads) != self.hidden_size:
|
299 |
raise ValueError(
|
@@ -365,7 +366,30 @@ class QuietAttention(nn.Module):
|
|
365 |
f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
|
366 |
f" {attn_weights.size()}"
|
367 |
)
|
368 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
369 |
if attention_mask is not None:
|
370 |
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
|
371 |
raise ValueError(
|
@@ -643,8 +667,63 @@ class QuietFlashAttention2(QuietAttention):
|
|
643 |
causal=causal,
|
644 |
window_size=(self.config.sliding_window, self.config.sliding_window),
|
645 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
646 |
|
647 |
-
|
648 |
|
649 |
def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length):
|
650 |
batch_size, kv_seq_len, num_heads, head_dim = key_layer.shape
|
|
|
294 |
self.rope_theta = config.rope_theta
|
295 |
self.is_causal = True
|
296 |
self.attention_dropout = config.attention_dropout
|
297 |
+
self._attn_implementation = config._attn_implementation
|
298 |
|
299 |
if (self.head_dim * self.num_heads) != self.hidden_size:
|
300 |
raise ValueError(
|
|
|
366 |
f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
|
367 |
f" {attn_weights.size()}"
|
368 |
)
|
369 |
+
if self._attn_implementation == "flash_attention_2":
|
370 |
+
# Prepare attention mask for flash-attn
|
371 |
+
attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
|
372 |
+
elif self._attn_implementation == "sdpa":
|
373 |
+
# Prepare attention mask for SDPA
|
374 |
+
if attention_mask is None or attention_mask.dim() == 2:
|
375 |
+
attention_mask = _prepare_4d_causal_attention_mask(
|
376 |
+
attention_mask,
|
377 |
+
(batch_size, seq_length),
|
378 |
+
inputs_embeds,
|
379 |
+
past_key_values_length,
|
380 |
+
sliding_window=self.config.sliding_window,
|
381 |
+
)
|
382 |
+
else:
|
383 |
+
# Prepare attention mask for other implementations
|
384 |
+
if attention_mask is None or attention_mask.dim() == 2:
|
385 |
+
attention_mask = _prepare_4d_causal_attention_mask(
|
386 |
+
attention_mask,
|
387 |
+
(batch_size, seq_length),
|
388 |
+
inputs_embeds,
|
389 |
+
past_key_values_length,
|
390 |
+
sliding_window=self.config.sliding_window,
|
391 |
+
)
|
392 |
+
|
393 |
if attention_mask is not None:
|
394 |
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
|
395 |
raise ValueError(
|
|
|
667 |
causal=causal,
|
668 |
window_size=(self.config.sliding_window, self.config.sliding_window),
|
669 |
)
|
670 |
+
try:
|
671 |
+
attn_output_unpad = flash_attn_varlen_func(
|
672 |
+
query_states,
|
673 |
+
key_states,
|
674 |
+
value_states,
|
675 |
+
cu_seqlens_q=cu_seqlens_q,
|
676 |
+
cu_seqlens_k=cu_seqlens_k,
|
677 |
+
max_seqlen_q=max_seqlen_in_batch_q,
|
678 |
+
max_seqlen_k=max_seqlen_in_batch_k,
|
679 |
+
dropout_p=dropout,
|
680 |
+
softmax_scale=softmax_scale,
|
681 |
+
causal=causal,
|
682 |
+
)
|
683 |
+
except RuntimeError as e:
|
684 |
+
if "cu_seqlens_q must have shape (batch_size + 1)" in str(e):
|
685 |
+
# Handle the case when cu_seqlens_q has an invalid shape
|
686 |
+
if attention_mask is not None:
|
687 |
+
# Ensure attention_mask has the correct shape
|
688 |
+
if attention_mask.dim() == 2:
|
689 |
+
# Convert 2D attention mask to 4D
|
690 |
+
attention_mask = _prepare_4d_causal_attention_mask(
|
691 |
+
attention_mask,
|
692 |
+
(query_states.size(0), query_states.size(1)),
|
693 |
+
query_states,
|
694 |
+
past_key_values_length=0,
|
695 |
+
sliding_window=0,
|
696 |
+
)
|
697 |
+
elif attention_mask.dim() != 4:
|
698 |
+
raise ValueError(
|
699 |
+
f"Invalid attention mask dimension: {attention_mask.dim()}. Expected 2D or 4D mask."
|
700 |
+
)
|
701 |
+
|
702 |
+
# Update cu_seqlens_q based on the attention mask
|
703 |
+
cu_seqlens_q = attention_mask.sum(dim=-1).flatten().cumsum(dim=0).to(torch.int32)
|
704 |
+
max_seqlen_in_batch_q = cu_seqlens_q[-1].item()
|
705 |
+
|
706 |
+
# Retry flash_attn_varlen_func with updated cu_seqlens_q
|
707 |
+
attn_output_unpad = flash_attn_varlen_func(
|
708 |
+
query_states,
|
709 |
+
key_states,
|
710 |
+
value_states,
|
711 |
+
cu_seqlens_q=cu_seqlens_q,
|
712 |
+
cu_seqlens_k=cu_seqlens_k,
|
713 |
+
max_seqlen_q=max_seqlen_in_batch_q,
|
714 |
+
max_seqlen_k=max_seqlen_in_batch_k,
|
715 |
+
dropout_p=dropout,
|
716 |
+
softmax_scale=softmax_scale,
|
717 |
+
causal=causal,
|
718 |
+
)
|
719 |
+
else:
|
720 |
+
raise ValueError(
|
721 |
+
"Attention mask is required for flash-attn when cu_seqlens_q has an invalid shape."
|
722 |
+
)
|
723 |
+
else:
|
724 |
+
raise e
|
725 |
|
726 |
+
return attn_output
|
727 |
|
728 |
def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length):
|
729 |
batch_size, kv_seq_len, num_heads, head_dim = key_layer.shape
|