Crystalcareai
commited on
Update modeling_quiet.py
Browse files- modeling_quiet.py +14 -71
modeling_quiet.py
CHANGED
@@ -42,12 +42,12 @@ import torch.utils.checkpoint
|
|
42 |
from torch import nn
|
43 |
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
44 |
|
45 |
-
from
|
46 |
-
from
|
47 |
-
from
|
48 |
-
from
|
49 |
-
from
|
50 |
-
from
|
51 |
add_start_docstrings,
|
52 |
add_start_docstrings_to_model_forward,
|
53 |
is_flash_attn_2_available,
|
@@ -134,34 +134,6 @@ def save_tokens_with_rewards_to_pdf(input_ids, token_rewards, tokenizer, output_
|
|
134 |
previous_text = current_text
|
135 |
c.showPage()
|
136 |
c.save()
|
137 |
-
|
138 |
-
def _prepare_4d_causal_attention_mask_for_sdpa(
|
139 |
-
attn_mask: Optional[torch.Tensor],
|
140 |
-
shape: Tuple[int, int],
|
141 |
-
inputs_embeds: Optional[torch.Tensor] = None,
|
142 |
-
past_key_values_length: int = 0,
|
143 |
-
) -> torch.Tensor:
|
144 |
-
batch_size, seq_len = shape
|
145 |
-
if attn_mask is None:
|
146 |
-
attn_mask = torch.ones((batch_size, seq_len), dtype=torch.bool, device=inputs_embeds.device)
|
147 |
-
else:
|
148 |
-
attn_mask = attn_mask.bool()
|
149 |
-
|
150 |
-
# Extend the attention mask to account for past key/value states
|
151 |
-
if past_key_values_length > 0:
|
152 |
-
extended_attn_mask = torch.cat(
|
153 |
-
[
|
154 |
-
attn_mask.new_zeros(batch_size, seq_len, past_key_values_length),
|
155 |
-
attn_mask.unsqueeze(2),
|
156 |
-
],
|
157 |
-
dim=2,
|
158 |
-
)
|
159 |
-
attn_mask = extended_attn_mask
|
160 |
-
|
161 |
-
attn_mask = attn_mask.unsqueeze(1).unsqueeze(1)
|
162 |
-
causal_mask = torch.tril(torch.ones(seq_len, seq_len + past_key_values_length, device=attn_mask.device)).bool()
|
163 |
-
attn_mask = attn_mask & causal_mask.unsqueeze(0).unsqueeze(0)
|
164 |
-
return attn_mask
|
165 |
|
166 |
|
167 |
# Copied from transformers.models.llama.modeling_llama._get_unpad_data
|
@@ -309,8 +281,8 @@ class QuietAttention(nn.Module):
|
|
309 |
self.layer_idx = layer_idx
|
310 |
if layer_idx is None:
|
311 |
logger.warning_once(
|
312 |
-
f"Instantiating {self.__class__.__name__} without passing `layer_idx` is not recommended and will "
|
313 |
-
"to errors during the forward call
|
314 |
"when creating this class."
|
315 |
)
|
316 |
|
@@ -601,7 +573,7 @@ class QuietFlashAttention2(QuietAttention):
|
|
601 |
attention_mask (`torch.Tensor`):
|
602 |
The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the
|
603 |
position of padding tokens and 1 for the position of non-padding tokens.
|
604 |
-
dropout (`
|
605 |
Attention dropout
|
606 |
softmax_scale (`float`, *optional*):
|
607 |
The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
|
@@ -719,7 +691,8 @@ class QuietFlashAttention2(QuietAttention):
|
|
719 |
)
|
720 |
|
721 |
|
722 |
-
#
|
|
|
723 |
class QuietSdpaAttention(QuietAttention):
|
724 |
"""
|
725 |
Quiet attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
|
@@ -793,14 +766,14 @@ class QuietSdpaAttention(QuietAttention):
|
|
793 |
query_states,
|
794 |
key_states,
|
795 |
value_states,
|
796 |
-
attn_mask=attention_mask
|
797 |
dropout_p=self.attention_dropout if self.training else 0.0,
|
798 |
# The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1.
|
799 |
is_causal=self.is_causal and attention_mask is None and q_len > 1,
|
800 |
)
|
801 |
|
802 |
attn_output = attn_output.transpose(1, 2).contiguous()
|
803 |
-
attn_output = attn_output.
|
804 |
|
805 |
attn_output = self.o_proj(attn_output)
|
806 |
|
@@ -1665,37 +1638,7 @@ class QuietForCausalLM(QuietPreTrainedModel):
|
|
1665 |
else:
|
1666 |
with torch.set_grad_enabled(not self.train_only_thinking_embedding):
|
1667 |
inputs_embeds = self.model.embed_tokens(input_ids)
|
1668 |
-
|
1669 |
-
def _update_inputs_for_thought_tokens(
|
1670 |
-
self, input_ids, attention_mask, contains_start, contains_end
|
1671 |
-
):
|
1672 |
-
batch_size = input_ids.size(0)
|
1673 |
-
seq_len = input_ids.size(1)
|
1674 |
-
|
1675 |
-
if contains_start:
|
1676 |
-
start_token_ids = torch.tensor(
|
1677 |
-
[[self.start_token_id]] * batch_size, device=input_ids.device
|
1678 |
-
)
|
1679 |
-
input_ids = torch.cat([input_ids, start_token_ids], dim=1)
|
1680 |
-
if attention_mask is not None:
|
1681 |
-
start_attention_mask = torch.ones(
|
1682 |
-
(batch_size, 1), device=attention_mask.device
|
1683 |
-
)
|
1684 |
-
attention_mask = torch.cat([attention_mask, start_attention_mask], dim=1)
|
1685 |
-
|
1686 |
-
if contains_end:
|
1687 |
-
end_token_ids = torch.tensor(
|
1688 |
-
[[self.end_token_id]] * batch_size, device=input_ids.device
|
1689 |
-
)
|
1690 |
-
input_ids = torch.cat([input_ids, end_token_ids], dim=1)
|
1691 |
-
if attention_mask is not None:
|
1692 |
-
end_attention_mask = torch.ones(
|
1693 |
-
(batch_size, 1), device=attention_mask.device
|
1694 |
-
)
|
1695 |
-
attention_mask = torch.cat([attention_mask, end_attention_mask], dim=1)
|
1696 |
-
|
1697 |
-
return input_ids, attention_mask
|
1698 |
-
|
1699 |
if self.n_ahead != 1 or self.n_ahead_talk != 1 or self.comparison_mode:
|
1700 |
if attention_mask is None:
|
1701 |
base_attention_mask = torch.triu(torch.ones(seq_len, seq_len), diagonal=0).to(input_ids.device)
|
|
|
42 |
from torch import nn
|
43 |
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
44 |
|
45 |
+
from ...activations import ACT2FN
|
46 |
+
from ...cache_utils import Cache, DynamicCache
|
47 |
+
from ...modeling_attn_mask_utils import _prepare_4d_causal_attention_mask, _prepare_4d_causal_attention_mask_for_sdpa
|
48 |
+
from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast
|
49 |
+
from ...modeling_utils import PreTrainedModel
|
50 |
+
from ...utils import (
|
51 |
add_start_docstrings,
|
52 |
add_start_docstrings_to_model_forward,
|
53 |
is_flash_attn_2_available,
|
|
|
134 |
previous_text = current_text
|
135 |
c.showPage()
|
136 |
c.save()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
137 |
|
138 |
|
139 |
# Copied from transformers.models.llama.modeling_llama._get_unpad_data
|
|
|
281 |
self.layer_idx = layer_idx
|
282 |
if layer_idx is None:
|
283 |
logger.warning_once(
|
284 |
+
f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will "
|
285 |
+
"lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` "
|
286 |
"when creating this class."
|
287 |
)
|
288 |
|
|
|
573 |
attention_mask (`torch.Tensor`):
|
574 |
The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the
|
575 |
position of padding tokens and 1 for the position of non-padding tokens.
|
576 |
+
dropout (`float`):
|
577 |
Attention dropout
|
578 |
softmax_scale (`float`, *optional*):
|
579 |
The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
|
|
|
691 |
)
|
692 |
|
693 |
|
694 |
+
# copied from transformers.models.llama.modeling_llama.LlamaSdpaAttention with Llama->Quiet
|
695 |
+
# TODO @Arthur no longer copied from LLama after static cache
|
696 |
class QuietSdpaAttention(QuietAttention):
|
697 |
"""
|
698 |
Quiet attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
|
|
|
766 |
query_states,
|
767 |
key_states,
|
768 |
value_states,
|
769 |
+
attn_mask=attention_mask,
|
770 |
dropout_p=self.attention_dropout if self.training else 0.0,
|
771 |
# The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1.
|
772 |
is_causal=self.is_causal and attention_mask is None and q_len > 1,
|
773 |
)
|
774 |
|
775 |
attn_output = attn_output.transpose(1, 2).contiguous()
|
776 |
+
attn_output = attn_output.view(bsz, q_len, self.hidden_size)
|
777 |
|
778 |
attn_output = self.o_proj(attn_output)
|
779 |
|
|
|
1638 |
else:
|
1639 |
with torch.set_grad_enabled(not self.train_only_thinking_embedding):
|
1640 |
inputs_embeds = self.model.embed_tokens(input_ids)
|
1641 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1642 |
if self.n_ahead != 1 or self.n_ahead_talk != 1 or self.comparison_mode:
|
1643 |
if attention_mask is None:
|
1644 |
base_attention_mask = torch.triu(torch.ones(seq_len, seq_len), diagonal=0).to(input_ids.device)
|