Crystalcareai
commited on
Update modeling_quiet.py
Browse files- modeling_quiet.py +122 -109
modeling_quiet.py
CHANGED
@@ -571,117 +571,133 @@ class QuietFlashAttention2(QuietAttention):
|
|
571 |
|
572 |
return attn_output, attn_weights, past_key_value
|
573 |
|
574 |
-
|
575 |
-
|
576 |
-
|
577 |
-
|
578 |
-
|
579 |
-
|
580 |
-
|
581 |
-
|
582 |
-
|
583 |
-
|
584 |
-
|
585 |
-
|
586 |
-
|
587 |
-
|
588 |
-
|
589 |
-
|
590 |
-
|
591 |
-
|
592 |
-
|
593 |
-
|
594 |
-
|
595 |
-
|
596 |
-
|
597 |
-
|
598 |
-
|
599 |
-
|
600 |
-
|
601 |
-
|
602 |
-
|
603 |
-
|
604 |
-
|
605 |
-
|
606 |
-
|
607 |
-
|
608 |
-
|
609 |
-
|
610 |
-
|
611 |
-
|
612 |
-
|
613 |
-
|
614 |
-
|
615 |
-
|
616 |
-
|
617 |
-
|
618 |
-
|
619 |
-
|
620 |
-
|
621 |
-
|
622 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
623 |
|
624 |
-
|
|
|
|
|
|
|
625 |
if attention_mask is not None:
|
626 |
-
|
627 |
-
|
628 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
629 |
)
|
630 |
|
631 |
-
|
632 |
-
|
633 |
-
|
634 |
-
|
635 |
-
|
636 |
-
|
637 |
-
|
638 |
-
|
639 |
-
|
640 |
-
|
641 |
-
|
642 |
-
max_seqlen_k=max_seqlen_in_batch_k,
|
643 |
-
dropout_p=dropout,
|
644 |
-
softmax_scale=softmax_scale,
|
645 |
-
causal=causal,
|
646 |
-
)
|
647 |
-
else:
|
648 |
-
attn_output_unpad = flash_attn_varlen_func(
|
649 |
-
query_states,
|
650 |
-
key_states,
|
651 |
-
value_states,
|
652 |
-
cu_seqlens_q=cu_seqlens_q,
|
653 |
-
cu_seqlens_k=cu_seqlens_k,
|
654 |
-
max_seqlen_q=max_seqlen_in_batch_q,
|
655 |
-
max_seqlen_k=max_seqlen_in_batch_k,
|
656 |
-
dropout_p=dropout,
|
657 |
-
softmax_scale=softmax_scale,
|
658 |
-
causal=causal,
|
659 |
-
window_size=(self.config.sliding_window, self.config.sliding_window),
|
660 |
-
)
|
661 |
-
|
662 |
-
attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)
|
663 |
else:
|
664 |
-
|
665 |
-
|
666 |
-
|
667 |
-
|
668 |
-
|
669 |
-
|
670 |
-
|
671 |
-
|
672 |
-
|
673 |
-
else:
|
674 |
-
attn_output = flash_attn_func(
|
675 |
-
query_states,
|
676 |
-
key_states,
|
677 |
-
value_states,
|
678 |
-
dropout,
|
679 |
-
softmax_scale=softmax_scale,
|
680 |
-
causal=causal,
|
681 |
-
window_size=(self.config.sliding_window, self.config.sliding_window),
|
682 |
-
)
|
683 |
|
684 |
-
|
685 |
|
686 |
def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length):
|
687 |
batch_size, kv_seq_len, num_heads, head_dim = key_layer.shape
|
@@ -1848,10 +1864,7 @@ class QuietForCausalLM(QuietPreTrainedModel):
|
|
1848 |
[shift_labels, padding],
|
1849 |
dim=-1
|
1850 |
)
|
1851 |
-
|
1852 |
-
# Adjust the labels to account for the additional thinking tokens
|
1853 |
-
new_rm_tokens = torch.cat([torch.full_like(new_rm_tokens[..., :self.n_ahead], self.tokenizer.pad_token_id, dtype=torch.long, device=new_rm_tokens.device), new_rm_tokens], dim=-1)
|
1854 |
-
|
1855 |
|
1856 |
# print((new_rm_tokens > self.vocab_size - 1).any().item())
|
1857 |
new_rm_tokens = torch.clamp(new_rm_tokens, 0, self.vocab_size - 1)
|
|
|
571 |
|
572 |
return attn_output, attn_weights, past_key_value
|
573 |
|
574 |
+
def _flash_attention_forward(
|
575 |
+
self,
|
576 |
+
query_states,
|
577 |
+
key_states,
|
578 |
+
value_states,
|
579 |
+
attention_mask,
|
580 |
+
query_length,
|
581 |
+
dropout=0.0,
|
582 |
+
softmax_scale=None,
|
583 |
+
use_sliding_windows=False,
|
584 |
+
):
|
585 |
+
"""
|
586 |
+
Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
|
587 |
+
first unpad the input, then computes the attention scores and pad the final attention scores.
|
588 |
+
Args:
|
589 |
+
query_states (`torch.Tensor`):
|
590 |
+
Input query states to be passed to Flash Attention API
|
591 |
+
key_states (`torch.Tensor`):
|
592 |
+
Input key states to be passed to Flash Attention API
|
593 |
+
value_states (`torch.Tensor`):
|
594 |
+
Input value states to be passed to Flash Attention API
|
595 |
+
attention_mask (`torch.Tensor`):
|
596 |
+
The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the
|
597 |
+
position of padding tokens and 1 for the position of non-padding tokens.
|
598 |
+
dropout (`int`, *optional*):
|
599 |
+
Attention dropout
|
600 |
+
softmax_scale (`float`, *optional*):
|
601 |
+
The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
|
602 |
+
use_sliding_windows (`bool`, *optional*):
|
603 |
+
Whether to activate sliding window attention.
|
604 |
+
"""
|
605 |
+
if not self._flash_attn_uses_top_left_mask:
|
606 |
+
causal = self.is_causal
|
607 |
+
else:
|
608 |
+
# TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in LlamaFlashAttention2 __init__.
|
609 |
+
causal = self.is_causal and query_length != 1
|
610 |
+
|
611 |
+
# Ensure attention_mask has the correct shape and values
|
612 |
+
if attention_mask is not None:
|
613 |
+
if attention_mask.dim() == 4:
|
614 |
+
# Convert 4D attention mask to 2D
|
615 |
+
attention_mask = attention_mask.squeeze(1).squeeze(1)
|
616 |
+
elif attention_mask.dim() != 2:
|
617 |
+
raise ValueError(
|
618 |
+
f"Invalid attention mask dimension: {attention_mask.dim()}. Expected 2D or 4D mask."
|
619 |
+
)
|
620 |
+
|
621 |
+
# Ensure attention_mask has values of 0 and 1
|
622 |
+
attention_mask = attention_mask.to(torch.bool).to(torch.int32)
|
623 |
+
|
624 |
+
# Contains at least one padding token in the sequence
|
625 |
+
if attention_mask is not None:
|
626 |
+
batch_size = query_states.shape[0]
|
627 |
+
query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input(
|
628 |
+
query_states, key_states, value_states, attention_mask, query_length
|
629 |
+
)
|
630 |
+
|
631 |
+
cu_seqlens_q, cu_seqlens_k = cu_seq_lens
|
632 |
+
max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
|
633 |
+
|
634 |
+
# Create the cu_seqlens_q and cu_seqlens_k tensors
|
635 |
+
q_max_s, k_max_s = query_states.shape[1], key_states.shape[1]
|
636 |
+
qkv_max_s = max(q_max_s, k_max_s)
|
637 |
|
638 |
+
q_seqlens = torch.full((batch_size,), q_max_s, dtype=torch.int32, device=query_states.device)
|
639 |
+
k_seqlens = torch.full((batch_size,), k_max_s, dtype=torch.int32, device=key_states.device)
|
640 |
+
|
641 |
+
# Adjust the attention mask to match the sequence lengths
|
642 |
if attention_mask is not None:
|
643 |
+
q_seqlens = attention_mask.sum(dim=1).int()
|
644 |
+
k_seqlens = attention_mask.sum(dim=1).int()
|
645 |
+
|
646 |
+
# Convert seqlens to cumulative sequence lengths
|
647 |
+
cu_seqlens_q = torch.cat([torch.zeros(1, dtype=torch.int32, device=q_seqlens.device), q_seqlens.cumsum(dim=0)])
|
648 |
+
cu_seqlens_k = torch.cat([torch.zeros(1, dtype=torch.int32, device=k_seqlens.device), k_seqlens.cumsum(dim=0)])
|
649 |
+
|
650 |
+
if not use_sliding_windows:
|
651 |
+
attn_output_unpad = flash_attn_varlen_func(
|
652 |
+
query_states,
|
653 |
+
key_states,
|
654 |
+
value_states,
|
655 |
+
cu_seqlens_q=cu_seqlens_q,
|
656 |
+
cu_seqlens_k=cu_seqlens_k,
|
657 |
+
max_seqlen_q=qkv_max_s,
|
658 |
+
max_seqlen_k=qkv_max_s,
|
659 |
+
dropout_p=dropout,
|
660 |
+
softmax_scale=softmax_scale,
|
661 |
+
causal=causal,
|
662 |
+
)
|
663 |
+
else:
|
664 |
+
attn_output_unpad = flash_attn_varlen_func(
|
665 |
+
query_states,
|
666 |
+
key_states,
|
667 |
+
value_states,
|
668 |
+
cu_seqlens_q=cu_seqlens_q,
|
669 |
+
cu_seqlens_k=cu_seqlens_k,
|
670 |
+
max_seqlen_q=qkv_max_s,
|
671 |
+
max_seqlen_k=qkv_max_s,
|
672 |
+
dropout_p=dropout,
|
673 |
+
softmax_scale=softmax_scale,
|
674 |
+
causal=causal,
|
675 |
+
window_size=(self.config.sliding_window, self.config.sliding_window),
|
676 |
)
|
677 |
|
678 |
+
attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)
|
679 |
+
else:
|
680 |
+
if not use_sliding_windows:
|
681 |
+
attn_output = flash_attn_func(
|
682 |
+
query_states,
|
683 |
+
key_states,
|
684 |
+
value_states,
|
685 |
+
dropout,
|
686 |
+
softmax_scale=softmax_scale,
|
687 |
+
causal=causal,
|
688 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
689 |
else:
|
690 |
+
attn_output = flash_attn_func(
|
691 |
+
query_states,
|
692 |
+
key_states,
|
693 |
+
value_states,
|
694 |
+
dropout,
|
695 |
+
softmax_scale=softmax_scale,
|
696 |
+
causal=causal,
|
697 |
+
window_size=(self.config.sliding_window, self.config.sliding_window),
|
698 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
699 |
|
700 |
+
return attn_output
|
701 |
|
702 |
def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length):
|
703 |
batch_size, kv_seq_len, num_heads, head_dim = key_layer.shape
|
|
|
1864 |
[shift_labels, padding],
|
1865 |
dim=-1
|
1866 |
)
|
1867 |
+
|
|
|
|
|
|
|
1868 |
|
1869 |
# print((new_rm_tokens > self.vocab_size - 1).any().item())
|
1870 |
new_rm_tokens = torch.clamp(new_rm_tokens, 0, self.vocab_size - 1)
|