Crystalcareai commited on
Commit
2fe12cb
·
verified ·
1 Parent(s): 18dad8f

Update modeling_quiet.py

Browse files
Files changed (1) hide show
  1. modeling_quiet.py +33 -9
modeling_quiet.py CHANGED
@@ -44,7 +44,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
44
 
45
  from transformers.activations import ACT2FN
46
  from transformers.cache_utils import Cache, DynamicCache
47
- from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask, _prepare_4d_causal_attention_mask_for_sdpa
48
  from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast
49
  from transformers.modeling_utils import PreTrainedModel
50
  from transformers.utils import (
@@ -134,6 +134,34 @@ def save_tokens_with_rewards_to_pdf(input_ids, token_rewards, tokenizer, output_
134
  previous_text = current_text
135
  c.showPage()
136
  c.save()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
137
 
138
 
139
  # Copied from transformers.models.llama.modeling_llama._get_unpad_data
@@ -1070,10 +1098,11 @@ class QuietModel(QuietPreTrainedModel):
1070
  " this may lead to unexpected behaviour for Flash Attention version of Quiet. Make sure to "
1071
  " call `tokenizer.padding_side = 'left'` before tokenizing the input. "
1072
  )
1073
- if attention_mask is not None and self._attn_implementation == 'flash_attention_2':
 
1074
  # 2d mask is passed through the layers
1075
  attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
1076
- elif self._attn_implementation == 'sdpa' and not output_attentions:
1077
  # output_attentions=True can not be supported when using SDPA, and we fall back on
1078
  # the manual implementation that requires a 4D causal mask in all cases.
1079
  attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
@@ -1082,12 +1111,7 @@ class QuietModel(QuietPreTrainedModel):
1082
  inputs_embeds,
1083
  past_key_values_length,
1084
  )
1085
- else:
1086
- # Check the shape of the attention mask
1087
- if attention_mask is not None and attention_mask.dim() == 2:
1088
- # Reshape the attention mask to 4D
1089
- attention_mask = attention_mask.unsqueeze(1).unsqueeze(1)
1090
-
1091
  # 4d mask is passed through the layers
1092
  attention_mask = _prepare_4d_causal_attention_mask(
1093
  attention_mask,
 
44
 
45
  from transformers.activations import ACT2FN
46
  from transformers.cache_utils import Cache, DynamicCache
47
+ from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask,
48
  from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast
49
  from transformers.modeling_utils import PreTrainedModel
50
  from transformers.utils import (
 
134
  previous_text = current_text
135
  c.showPage()
136
  c.save()
137
+
138
+ def _prepare_4d_causal_attention_mask_for_sdpa(
139
+ attn_mask: Optional[torch.Tensor],
140
+ shape: Tuple[int, int],
141
+ inputs_embeds: Optional[torch.Tensor] = None,
142
+ past_key_values_length: int = 0,
143
+ ) -> torch.Tensor:
144
+ batch_size, seq_len = shape
145
+ if attn_mask is None:
146
+ attn_mask = torch.ones((batch_size, seq_len), dtype=torch.bool, device=inputs_embeds.device)
147
+ else:
148
+ attn_mask = attn_mask.bool()
149
+
150
+ # Extend the attention mask to account for past key/value states
151
+ if past_key_values_length > 0:
152
+ extended_attn_mask = torch.cat(
153
+ [
154
+ attn_mask.new_zeros(batch_size, seq_len, past_key_values_length),
155
+ attn_mask.unsqueeze(2),
156
+ ],
157
+ dim=2,
158
+ )
159
+ attn_mask = extended_attn_mask
160
+
161
+ attn_mask = attn_mask.unsqueeze(1).unsqueeze(1)
162
+ causal_mask = torch.tril(torch.ones(seq_len, seq_len + past_key_values_length, device=attn_mask.device)).bool()
163
+ attn_mask = attn_mask & causal_mask.unsqueeze(0).unsqueeze(0)
164
+ return attn_mask
165
 
166
 
167
  # Copied from transformers.models.llama.modeling_llama._get_unpad_data
 
1098
  " this may lead to unexpected behaviour for Flash Attention version of Quiet. Make sure to "
1099
  " call `tokenizer.padding_side = 'left'` before tokenizing the input. "
1100
  )
1101
+
1102
+ if self._attn_implementation == "flash_attention_2":
1103
  # 2d mask is passed through the layers
1104
  attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
1105
+ elif self._attn_implementation == "sdpa" and not output_attentions and attention_mask.dim() == 2 and False:
1106
  # output_attentions=True can not be supported when using SDPA, and we fall back on
1107
  # the manual implementation that requires a 4D causal mask in all cases.
1108
  attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
 
1111
  inputs_embeds,
1112
  past_key_values_length,
1113
  )
1114
+ elif attention_mask is None or attention_mask.dim() == 2:
 
 
 
 
 
1115
  # 4d mask is passed through the layers
1116
  attention_mask = _prepare_4d_causal_attention_mask(
1117
  attention_mask,