Crystalcareai
commited on
Update modeling_quiet.py
Browse files- modeling_quiet.py +103 -154
modeling_quiet.py
CHANGED
@@ -432,7 +432,7 @@ class QuietFlashAttention2(QuietAttention):
|
|
432 |
super().__init__(*args, **kwargs)
|
433 |
|
434 |
# TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
|
435 |
-
# flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right
|
436 |
# Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
|
437 |
self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
|
438 |
|
@@ -533,171 +533,123 @@ class QuietFlashAttention2(QuietAttention):
|
|
533 |
if torch.is_autocast_enabled():
|
534 |
target_dtype = torch.get_autocast_gpu_dtype()
|
535 |
# Handle the case where the model is quantized
|
536 |
-
elif hasattr(self.config, "_pre_quantization_dtype"):
|
537 |
-
target_dtype = self.config._pre_quantization_dtype
|
538 |
else:
|
539 |
-
target_dtype =
|
540 |
-
|
541 |
-
logger.warning_once(
|
542 |
-
f"The input hidden states seems to be silently casted in float32, this might be related to"
|
543 |
-
f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
|
544 |
-
f" {target_dtype}."
|
545 |
-
)
|
546 |
-
|
547 |
query_states = query_states.to(target_dtype)
|
548 |
key_states = key_states.to(target_dtype)
|
549 |
value_states = value_states.to(target_dtype)
|
550 |
|
551 |
-
#
|
552 |
-
|
553 |
-
|
554 |
-
|
555 |
-
|
556 |
-
|
557 |
-
|
558 |
-
|
559 |
-
|
560 |
-
|
561 |
-
|
562 |
-
|
563 |
-
|
564 |
-
|
565 |
-
|
566 |
-
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
|
567 |
-
attn_output = self.o_proj(attn_output)
|
568 |
-
|
569 |
-
if not output_attentions:
|
570 |
-
attn_weights = None
|
571 |
-
|
572 |
-
return attn_output, attn_weights, past_key_value
|
573 |
-
|
574 |
-
def _flash_attention_forward(
|
575 |
-
self,
|
576 |
-
query_states,
|
577 |
-
key_states,
|
578 |
-
value_states,
|
579 |
-
attention_mask,
|
580 |
-
query_length,
|
581 |
-
dropout=0.0,
|
582 |
-
softmax_scale=None,
|
583 |
-
use_sliding_windows=False,
|
584 |
-
):
|
585 |
-
"""
|
586 |
-
Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
|
587 |
-
first unpad the input, then computes the attention scores and pad the final attention scores.
|
588 |
-
Args:
|
589 |
-
query_states (`torch.Tensor`):
|
590 |
-
Input query states to be passed to Flash Attention API
|
591 |
-
key_states (`torch.Tensor`):
|
592 |
-
Input key states to be passed to Flash Attention API
|
593 |
-
value_states (`torch.Tensor`):
|
594 |
-
Input value states to be passed to Flash Attention API
|
595 |
-
attention_mask (`torch.Tensor`):
|
596 |
-
The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the
|
597 |
-
position of padding tokens and 1 for the position of non-padding tokens.
|
598 |
-
dropout (`int`, *optional*):
|
599 |
-
Attention dropout
|
600 |
-
softmax_scale (`float`, *optional*):
|
601 |
-
The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
|
602 |
-
use_sliding_windows (`bool`, *optional*):
|
603 |
-
Whether to activate sliding window attention.
|
604 |
-
"""
|
605 |
-
if not self._flash_attn_uses_top_left_mask:
|
606 |
-
causal = self.is_causal
|
607 |
-
else:
|
608 |
-
# TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in LlamaFlashAttention2 __init__.
|
609 |
-
causal = self.is_causal and query_length != 1
|
610 |
-
|
611 |
-
# Ensure attention_mask has the correct shape and values
|
612 |
-
if attention_mask is not None:
|
613 |
-
if attention_mask.dim() == 4:
|
614 |
-
# Convert 4D attention mask to 2D
|
615 |
-
attention_mask = attention_mask.squeeze(1).squeeze(1)
|
616 |
-
elif attention_mask.dim() != 2:
|
617 |
-
raise ValueError(
|
618 |
-
f"Invalid attention mask dimension: {attention_mask.dim()}. Expected 2D or 4D mask."
|
619 |
-
)
|
620 |
-
|
621 |
-
# Ensure attention_mask has values of 0 and 1
|
622 |
-
attention_mask = attention_mask.to(torch.bool).to(torch.int32)
|
623 |
-
|
624 |
-
# Contains at least one padding token in the sequence
|
625 |
-
if attention_mask is not None:
|
626 |
-
batch_size = query_states.shape[0]
|
627 |
-
query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input(
|
628 |
-
query_states, key_states, value_states, attention_mask, query_length
|
629 |
-
)
|
630 |
|
631 |
-
|
632 |
-
|
|
|
|
|
|
|
633 |
|
634 |
-
|
635 |
-
|
636 |
-
|
|
|
637 |
|
638 |
-
|
639 |
-
|
640 |
|
641 |
-
#
|
642 |
if attention_mask is not None:
|
643 |
-
|
644 |
-
|
645 |
-
|
646 |
-
# Convert seqlens to cumulative sequence lengths
|
647 |
-
cu_seqlens_q = torch.cat([torch.zeros(1, dtype=torch.int32, device=q_seqlens.device), q_seqlens.cumsum(dim=0)])
|
648 |
-
cu_seqlens_k = torch.cat([torch.zeros(1, dtype=torch.int32, device=k_seqlens.device), k_seqlens.cumsum(dim=0)])
|
649 |
-
|
650 |
-
if not use_sliding_windows:
|
651 |
-
attn_output_unpad = flash_attn_varlen_func(
|
652 |
-
query_states,
|
653 |
-
key_states,
|
654 |
-
value_states,
|
655 |
-
cu_seqlens_q=cu_seqlens_q,
|
656 |
-
cu_seqlens_k=cu_seqlens_k,
|
657 |
-
max_seqlen_q=qkv_max_s,
|
658 |
-
max_seqlen_k=qkv_max_s,
|
659 |
-
dropout_p=dropout,
|
660 |
-
softmax_scale=softmax_scale,
|
661 |
-
causal=causal,
|
662 |
-
)
|
663 |
-
else:
|
664 |
-
attn_output_unpad = flash_attn_varlen_func(
|
665 |
query_states,
|
666 |
key_states,
|
667 |
value_states,
|
668 |
-
|
669 |
-
|
670 |
-
|
671 |
-
|
672 |
-
dropout_p=dropout,
|
673 |
-
softmax_scale=softmax_scale,
|
674 |
-
causal=causal,
|
675 |
-
window_size=(self.config.sliding_window, self.config.sliding_window),
|
676 |
-
)
|
677 |
|
678 |
-
|
679 |
-
|
680 |
-
|
681 |
-
|
682 |
-
|
683 |
-
|
684 |
-
|
685 |
-
|
686 |
-
|
687 |
-
|
688 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
689 |
else:
|
690 |
-
|
691 |
-
|
692 |
-
|
693 |
-
|
694 |
-
|
695 |
-
|
696 |
-
|
697 |
-
|
698 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
699 |
|
700 |
-
|
701 |
|
702 |
def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length):
|
703 |
batch_size, kv_seq_len, num_heads, head_dim = key_layer.shape
|
@@ -711,8 +663,7 @@ def _flash_attention_forward(
|
|
711 |
indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask)
|
712 |
|
713 |
key_layer = index_first_axis(key_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k)
|
714 |
-
value_layer
|
715 |
-
|
716 |
if query_length == kv_seq_len:
|
717 |
query_layer = index_first_axis(
|
718 |
query_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k
|
@@ -740,8 +691,6 @@ def _flash_attention_forward(
|
|
740 |
(cu_seqlens_q, cu_seqlens_k),
|
741 |
(max_seqlen_in_batch_q, max_seqlen_in_batch_k),
|
742 |
)
|
743 |
-
|
744 |
-
|
745 |
# Copied from transformers.models.llama.modeling_llama.LlamaSdpaAttention with Llama->Quiet
|
746 |
class QuietSdpaAttention(QuietAttention):
|
747 |
"""
|
|
|
432 |
super().__init__(*args, **kwargs)
|
433 |
|
434 |
# TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
|
435 |
+
# flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignment, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
|
436 |
# Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
|
437 |
self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
|
438 |
|
|
|
533 |
if torch.is_autocast_enabled():
|
534 |
target_dtype = torch.get_autocast_gpu_dtype()
|
535 |
# Handle the case where the model is quantized
|
|
|
|
|
536 |
else:
|
537 |
+
target_dtype = torch.float16
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
538 |
query_states = query_states.to(target_dtype)
|
539 |
key_states = key_states.to(target_dtype)
|
540 |
value_states = value_states.to(target_dtype)
|
541 |
|
542 |
+
# Compute the causal mask
|
543 |
+
causal = self.config.causal
|
544 |
+
if causal:
|
545 |
+
if self._flash_attn_uses_top_left_mask:
|
546 |
+
# Compute the causal mask
|
547 |
+
causal_mask = torch.tril(torch.ones((q_len, kv_seq_len), dtype=torch.bool, device=query_states.device))
|
548 |
+
# Invert the mask
|
549 |
+
causal_mask = ~causal_mask
|
550 |
+
else:
|
551 |
+
causal_mask = torch.triu(
|
552 |
+
torch.ones((q_len, kv_seq_len), dtype=torch.bool, device=query_states.device), diagonal=1
|
553 |
+
)
|
554 |
+
else:
|
555 |
+
causal_mask = None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
556 |
|
557 |
+
# Compute the attention mask
|
558 |
+
if attention_mask is not None:
|
559 |
+
if attention_mask.dim() == 2:
|
560 |
+
attention_mask = attention_mask[:, None, :]
|
561 |
+
attention_mask = attention_mask.to(torch.bool)
|
562 |
|
563 |
+
if causal:
|
564 |
+
attention_mask = attention_mask & causal_mask
|
565 |
+
else:
|
566 |
+
attention_mask = attention_mask
|
567 |
|
568 |
+
# Compute the softmax scale
|
569 |
+
softmax_scale = self.head_dim**-0.5
|
570 |
|
571 |
+
# Compute the attention scores
|
572 |
if attention_mask is not None:
|
573 |
+
# Unpad the input
|
574 |
+
(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
575 |
query_states,
|
576 |
key_states,
|
577 |
value_states,
|
578 |
+
indices_q,
|
579 |
+
cu_seq_lens,
|
580 |
+
max_seq_lens,
|
581 |
+
) = self._upad_input(query_states, key_states, value_states, attention_mask, q_len)
|
|
|
|
|
|
|
|
|
|
|
582 |
|
583 |
+
cu_seqlens_q, cu_seqlens_k = cu_seq_lens
|
584 |
+
max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
|
585 |
+
|
586 |
+
# Create the cu_seqlens_q and cu_seqlens_k tensors
|
587 |
+
q_max_s, k_max_s = query_states.shape[1], key_states.shape[1]
|
588 |
+
qkv_max_s = max(q_max_s, k_max_s)
|
589 |
+
|
590 |
+
q_seqlens = torch.full((batch_size,), q_max_s, dtype=torch.int32, device=query_states.device)
|
591 |
+
k_seqlens = torch.full((batch_size,), k_max_s, dtype=torch.int32, device=key_states.device)
|
592 |
+
|
593 |
+
# Adjust the attention mask to match the sequence lengths
|
594 |
+
if attention_mask is not None:
|
595 |
+
q_seqlens = attention_mask.sum(dim=1).int()
|
596 |
+
k_seqlens = attention_mask.sum(dim=1).int()
|
597 |
+
|
598 |
+
# Convert seqlens to cumulative sequence lengths
|
599 |
+
cu_seqlens_q = torch.cat([torch.zeros(1, dtype=torch.int32, device=q_seqlens.device), q_seqlens.cumsum(dim=0)])
|
600 |
+
cu_seqlens_k = torch.cat([torch.zeros(1, dtype=torch.int32, device=k_seqlens.device), k_seqlens.cumsum(dim=0)])
|
601 |
+
|
602 |
+
if not use_sliding_windows:
|
603 |
+
attn_output_unpad = flash_attn_varlen_func(
|
604 |
+
query_states,
|
605 |
+
key_states,
|
606 |
+
value_states,
|
607 |
+
cu_seqlens_q=cu_seqlens_q,
|
608 |
+
cu_seqlens_k=cu_seqlens_k,
|
609 |
+
max_seqlen_q=qkv_max_s,
|
610 |
+
max_seqlen_k=qkv_max_s,
|
611 |
+
dropout_p=dropout,
|
612 |
+
softmax_scale=softmax_scale,
|
613 |
+
causal=causal,
|
614 |
+
)
|
615 |
+
else:
|
616 |
+
attn_output_unpad = flash_attn_varlen_func(
|
617 |
+
query_states,
|
618 |
+
key_states,
|
619 |
+
value_states,
|
620 |
+
cu_seqlens_q=cu_seqlens_q,
|
621 |
+
cu_seqlens_k=cu_seqlens_k,
|
622 |
+
max_seqlen_q=qkv_max_s,
|
623 |
+
max_seqlen_k=qkv_max_s,
|
624 |
+
dropout_p=dropout,
|
625 |
+
softmax_scale=softmax_scale,
|
626 |
+
causal=causal,
|
627 |
+
window_size=(self.config.sliding_window, self.config.sliding_window),
|
628 |
+
)
|
629 |
+
|
630 |
+
attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)
|
631 |
else:
|
632 |
+
if not use_sliding_windows:
|
633 |
+
attn_output = flash_attn_func(
|
634 |
+
query_states,
|
635 |
+
key_states,
|
636 |
+
value_states,
|
637 |
+
dropout,
|
638 |
+
softmax_scale=softmax_scale,
|
639 |
+
causal=causal,
|
640 |
+
)
|
641 |
+
else:
|
642 |
+
attn_output = flash_attn_func(
|
643 |
+
query_states,
|
644 |
+
key_states,
|
645 |
+
value_states,
|
646 |
+
dropout,
|
647 |
+
softmax_scale=softmax_scale,
|
648 |
+
causal=causal,
|
649 |
+
window_size=(self.config.sliding_window, self.config.sliding_window),
|
650 |
+
)
|
651 |
|
652 |
+
return attn_output
|
653 |
|
654 |
def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length):
|
655 |
batch_size, kv_seq_len, num_heads, head_dim = key_layer.shape
|
|
|
663 |
indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask)
|
664 |
|
665 |
key_layer = index_first_axis(key_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k)
|
666 |
+
value_layer= index_first_axis(value_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k)
|
|
|
667 |
if query_length == kv_seq_len:
|
668 |
query_layer = index_first_axis(
|
669 |
query_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k
|
|
|
691 |
(cu_seqlens_q, cu_seqlens_k),
|
692 |
(max_seqlen_in_batch_q, max_seqlen_in_batch_k),
|
693 |
)
|
|
|
|
|
694 |
# Copied from transformers.models.llama.modeling_llama.LlamaSdpaAttention with Llama->Quiet
|
695 |
class QuietSdpaAttention(QuietAttention):
|
696 |
"""
|