Update modeling_quiet.py
Browse files- modeling_quiet.py +14 -35
modeling_quiet.py
CHANGED
@@ -1070,40 +1070,28 @@ class QuietModel(QuietPreTrainedModel):
|
|
1070 |
" this may lead to unexpected behaviour for Flash Attention version of Quiet. Make sure to "
|
1071 |
" call `tokenizer.padding_side = 'left'` before tokenizing the input. "
|
1072 |
)
|
1073 |
-
if self._attn_implementation ==
|
1074 |
-
num_thought_tokens = 0
|
1075 |
-
if self.use_start_thought_token:
|
1076 |
-
num_thought_tokens += 1
|
1077 |
-
if self.use_end_thought_token:
|
1078 |
-
num_thought_tokens += 1
|
1079 |
-
original_sequence_length = input_ids.shape[1]
|
1080 |
-
seq_length = original_sequence_length + num_thought_tokens
|
1081 |
-
# Convert 2D mask to 4D and adjust size
|
1082 |
-
attention_mask = _prepare_4d_causal_attention_mask(
|
1083 |
-
attention_mask,
|
1084 |
-
(batch_size, seq_len), # Adjust size
|
1085 |
-
inputs_embeds,
|
1086 |
-
past_key_values_length,
|
1087 |
-
sliding_window=self.config.sliding_window,
|
1088 |
-
)
|
1089 |
# 2d mask is passed through the layers
|
1090 |
attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
|
1091 |
-
elif self._attn_implementation ==
|
1092 |
# output_attentions=True can not be supported when using SDPA, and we fall back on
|
1093 |
# the manual implementation that requires a 4D causal mask in all cases.
|
1094 |
-
|
1095 |
-
|
1096 |
-
|
1097 |
-
(batch_size, seq_len),
|
1098 |
inputs_embeds,
|
1099 |
past_key_values_length,
|
1100 |
)
|
1101 |
-
|
|
|
|
|
|
|
|
|
|
|
1102 |
# 4d mask is passed through the layers
|
1103 |
-
seq_length = original_sequence_length + num_thought_tokens
|
1104 |
attention_mask = _prepare_4d_causal_attention_mask(
|
1105 |
-
attention_mask,
|
1106 |
-
(batch_size,
|
1107 |
inputs_embeds,
|
1108 |
past_key_values_length,
|
1109 |
sliding_window=self.config.sliding_window,
|
@@ -1668,18 +1656,9 @@ class QuietForCausalLM(QuietPreTrainedModel):
|
|
1668 |
[torch.ones((attention_mask.shape[0], past_key_values_length), dtype=attention_mask.dtype, device=attention_mask.device), attention_mask],
|
1669 |
dim=-1
|
1670 |
)
|
1671 |
-
|
1672 |
-
num_thought_tokens = 0
|
1673 |
-
if self.use_start_thought_token:
|
1674 |
-
num_thought_tokens += 1
|
1675 |
-
if self.use_end_thought_token:
|
1676 |
-
num_thought_tokens += 1
|
1677 |
-
|
1678 |
-
original_sequence_length = input_ids.shape[1]
|
1679 |
-
seq_length = original_sequence_length + num_thought_tokens
|
1680 |
# # if the attention mask
|
1681 |
attention_mask = _prepare_4d_causal_attention_mask(
|
1682 |
-
attention_mask,
|
1683 |
(batch_size, seq_len),
|
1684 |
inputs_embeds,
|
1685 |
past_key_values_length,
|
|
|
1070 |
" this may lead to unexpected behaviour for Flash Attention version of Quiet. Make sure to "
|
1071 |
" call `tokenizer.padding_side = 'left'` before tokenizing the input. "
|
1072 |
)
|
1073 |
+
if attention_mask is not None and self._attn_implementation == 'flash_attention_2':
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1074 |
# 2d mask is passed through the layers
|
1075 |
attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
|
1076 |
+
elif self._attn_implementation == 'sdpa' and not output_attentions:
|
1077 |
# output_attentions=True can not be supported when using SDPA, and we fall back on
|
1078 |
# the manual implementation that requires a 4D causal mask in all cases.
|
1079 |
+
attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
|
1080 |
+
attention_mask,
|
1081 |
+
(batch_size, seq_length),
|
|
|
1082 |
inputs_embeds,
|
1083 |
past_key_values_length,
|
1084 |
)
|
1085 |
+
else:
|
1086 |
+
# Check the shape of the attention mask
|
1087 |
+
if attention_mask is not None and attention_mask.dim() == 2:
|
1088 |
+
# Reshape the attention mask to 4D
|
1089 |
+
attention_mask = attention_mask.unsqueeze(1).unsqueeze(1)
|
1090 |
+
|
1091 |
# 4d mask is passed through the layers
|
|
|
1092 |
attention_mask = _prepare_4d_causal_attention_mask(
|
1093 |
+
attention_mask,
|
1094 |
+
(batch_size, seq_length),
|
1095 |
inputs_embeds,
|
1096 |
past_key_values_length,
|
1097 |
sliding_window=self.config.sliding_window,
|
|
|
1656 |
[torch.ones((attention_mask.shape[0], past_key_values_length), dtype=attention_mask.dtype, device=attention_mask.device), attention_mask],
|
1657 |
dim=-1
|
1658 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1659 |
# # if the attention mask
|
1660 |
attention_mask = _prepare_4d_causal_attention_mask(
|
1661 |
+
attention_mask,
|
1662 |
(batch_size, seq_len),
|
1663 |
inputs_embeds,
|
1664 |
past_key_values_length,
|