Crystalcareai
commited on
Update modeling_quiet.py
Browse files- modeling_quiet.py +33 -9
modeling_quiet.py
CHANGED
@@ -44,7 +44,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
|
44 |
|
45 |
from transformers.activations import ACT2FN
|
46 |
from transformers.cache_utils import Cache, DynamicCache
|
47 |
-
from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask,
|
48 |
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast
|
49 |
from transformers.modeling_utils import PreTrainedModel
|
50 |
from transformers.utils import (
|
@@ -134,6 +134,34 @@ def save_tokens_with_rewards_to_pdf(input_ids, token_rewards, tokenizer, output_
|
|
134 |
previous_text = current_text
|
135 |
c.showPage()
|
136 |
c.save()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
137 |
|
138 |
|
139 |
# Copied from transformers.models.llama.modeling_llama._get_unpad_data
|
@@ -1070,10 +1098,11 @@ class QuietModel(QuietPreTrainedModel):
|
|
1070 |
" this may lead to unexpected behaviour for Flash Attention version of Quiet. Make sure to "
|
1071 |
" call `tokenizer.padding_side = 'left'` before tokenizing the input. "
|
1072 |
)
|
1073 |
-
|
|
|
1074 |
# 2d mask is passed through the layers
|
1075 |
attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
|
1076 |
-
elif self._attn_implementation ==
|
1077 |
# output_attentions=True can not be supported when using SDPA, and we fall back on
|
1078 |
# the manual implementation that requires a 4D causal mask in all cases.
|
1079 |
attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
|
@@ -1082,12 +1111,7 @@ class QuietModel(QuietPreTrainedModel):
|
|
1082 |
inputs_embeds,
|
1083 |
past_key_values_length,
|
1084 |
)
|
1085 |
-
|
1086 |
-
# Check the shape of the attention mask
|
1087 |
-
if attention_mask is not None and attention_mask.dim() == 2:
|
1088 |
-
# Reshape the attention mask to 4D
|
1089 |
-
attention_mask = attention_mask.unsqueeze(1).unsqueeze(1)
|
1090 |
-
|
1091 |
# 4d mask is passed through the layers
|
1092 |
attention_mask = _prepare_4d_causal_attention_mask(
|
1093 |
attention_mask,
|
|
|
44 |
|
45 |
from transformers.activations import ACT2FN
|
46 |
from transformers.cache_utils import Cache, DynamicCache
|
47 |
+
from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask,
|
48 |
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast
|
49 |
from transformers.modeling_utils import PreTrainedModel
|
50 |
from transformers.utils import (
|
|
|
134 |
previous_text = current_text
|
135 |
c.showPage()
|
136 |
c.save()
|
137 |
+
|
138 |
+
def _prepare_4d_causal_attention_mask_for_sdpa(
|
139 |
+
attn_mask: Optional[torch.Tensor],
|
140 |
+
shape: Tuple[int, int],
|
141 |
+
inputs_embeds: Optional[torch.Tensor] = None,
|
142 |
+
past_key_values_length: int = 0,
|
143 |
+
) -> torch.Tensor:
|
144 |
+
batch_size, seq_len = shape
|
145 |
+
if attn_mask is None:
|
146 |
+
attn_mask = torch.ones((batch_size, seq_len), dtype=torch.bool, device=inputs_embeds.device)
|
147 |
+
else:
|
148 |
+
attn_mask = attn_mask.bool()
|
149 |
+
|
150 |
+
# Extend the attention mask to account for past key/value states
|
151 |
+
if past_key_values_length > 0:
|
152 |
+
extended_attn_mask = torch.cat(
|
153 |
+
[
|
154 |
+
attn_mask.new_zeros(batch_size, seq_len, past_key_values_length),
|
155 |
+
attn_mask.unsqueeze(2),
|
156 |
+
],
|
157 |
+
dim=2,
|
158 |
+
)
|
159 |
+
attn_mask = extended_attn_mask
|
160 |
+
|
161 |
+
attn_mask = attn_mask.unsqueeze(1).unsqueeze(1)
|
162 |
+
causal_mask = torch.tril(torch.ones(seq_len, seq_len + past_key_values_length, device=attn_mask.device)).bool()
|
163 |
+
attn_mask = attn_mask & causal_mask.unsqueeze(0).unsqueeze(0)
|
164 |
+
return attn_mask
|
165 |
|
166 |
|
167 |
# Copied from transformers.models.llama.modeling_llama._get_unpad_data
|
|
|
1098 |
" this may lead to unexpected behaviour for Flash Attention version of Quiet. Make sure to "
|
1099 |
" call `tokenizer.padding_side = 'left'` before tokenizing the input. "
|
1100 |
)
|
1101 |
+
|
1102 |
+
if self._attn_implementation == "flash_attention_2":
|
1103 |
# 2d mask is passed through the layers
|
1104 |
attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
|
1105 |
+
elif self._attn_implementation == "sdpa" and not output_attentions and attention_mask.dim() == 2 and False:
|
1106 |
# output_attentions=True can not be supported when using SDPA, and we fall back on
|
1107 |
# the manual implementation that requires a 4D causal mask in all cases.
|
1108 |
attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
|
|
|
1111 |
inputs_embeds,
|
1112 |
past_key_values_length,
|
1113 |
)
|
1114 |
+
elif attention_mask is None or attention_mask.dim() == 2:
|
|
|
|
|
|
|
|
|
|
|
1115 |
# 4d mask is passed through the layers
|
1116 |
attention_mask = _prepare_4d_causal_attention_mask(
|
1117 |
attention_mask,
|