Crystalcareai commited on
Commit
18dad8f
·
verified ·
1 Parent(s): 54575d5

Update modeling_quiet.py

Browse files
Files changed (1) hide show
  1. modeling_quiet.py +14 -35
modeling_quiet.py CHANGED
@@ -1070,40 +1070,28 @@ class QuietModel(QuietPreTrainedModel):
1070
  " this may lead to unexpected behaviour for Flash Attention version of Quiet. Make sure to "
1071
  " call `tokenizer.padding_side = 'left'` before tokenizing the input. "
1072
  )
1073
- if self._attn_implementation == "flash_attention_2":
1074
- num_thought_tokens = 0
1075
- if self.use_start_thought_token:
1076
- num_thought_tokens += 1
1077
- if self.use_end_thought_token:
1078
- num_thought_tokens += 1
1079
- original_sequence_length = input_ids.shape[1]
1080
- seq_length = original_sequence_length + num_thought_tokens
1081
- # Convert 2D mask to 4D and adjust size
1082
- attention_mask = _prepare_4d_causal_attention_mask(
1083
- attention_mask,
1084
- (batch_size, seq_len), # Adjust size
1085
- inputs_embeds,
1086
- past_key_values_length,
1087
- sliding_window=self.config.sliding_window,
1088
- )
1089
  # 2d mask is passed through the layers
1090
  attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
1091
- elif self._attn_implementation == "sdpa" and not output_attentions and attention_mask.dim() == 2 and False:
1092
  # output_attentions=True can not be supported when using SDPA, and we fall back on
1093
  # the manual implementation that requires a 4D causal mask in all cases.
1094
- seq_length = original_sequence_length + num_thought_tokens
1095
- attention_mask = _prepare_4d_causal_attention_mask(
1096
- attention_mask,
1097
- (batch_size, seq_len),
1098
  inputs_embeds,
1099
  past_key_values_length,
1100
  )
1101
- elif attention_mask is None or attention_mask.dim() == 2:
 
 
 
 
 
1102
  # 4d mask is passed through the layers
1103
- seq_length = original_sequence_length + num_thought_tokens
1104
  attention_mask = _prepare_4d_causal_attention_mask(
1105
- attention_mask,
1106
- (batch_size, seq_len),
1107
  inputs_embeds,
1108
  past_key_values_length,
1109
  sliding_window=self.config.sliding_window,
@@ -1668,18 +1656,9 @@ class QuietForCausalLM(QuietPreTrainedModel):
1668
  [torch.ones((attention_mask.shape[0], past_key_values_length), dtype=attention_mask.dtype, device=attention_mask.device), attention_mask],
1669
  dim=-1
1670
  )
1671
-
1672
- num_thought_tokens = 0
1673
- if self.use_start_thought_token:
1674
- num_thought_tokens += 1
1675
- if self.use_end_thought_token:
1676
- num_thought_tokens += 1
1677
-
1678
- original_sequence_length = input_ids.shape[1]
1679
- seq_length = original_sequence_length + num_thought_tokens
1680
  # # if the attention mask
1681
  attention_mask = _prepare_4d_causal_attention_mask(
1682
- attention_mask,
1683
  (batch_size, seq_len),
1684
  inputs_embeds,
1685
  past_key_values_length,
 
1070
  " this may lead to unexpected behaviour for Flash Attention version of Quiet. Make sure to "
1071
  " call `tokenizer.padding_side = 'left'` before tokenizing the input. "
1072
  )
1073
+ if attention_mask is not None and self._attn_implementation == 'flash_attention_2':
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1074
  # 2d mask is passed through the layers
1075
  attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
1076
+ elif self._attn_implementation == 'sdpa' and not output_attentions:
1077
  # output_attentions=True can not be supported when using SDPA, and we fall back on
1078
  # the manual implementation that requires a 4D causal mask in all cases.
1079
+ attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
1080
+ attention_mask,
1081
+ (batch_size, seq_length),
 
1082
  inputs_embeds,
1083
  past_key_values_length,
1084
  )
1085
+ else:
1086
+ # Check the shape of the attention mask
1087
+ if attention_mask is not None and attention_mask.dim() == 2:
1088
+ # Reshape the attention mask to 4D
1089
+ attention_mask = attention_mask.unsqueeze(1).unsqueeze(1)
1090
+
1091
  # 4d mask is passed through the layers
 
1092
  attention_mask = _prepare_4d_causal_attention_mask(
1093
+ attention_mask,
1094
+ (batch_size, seq_length),
1095
  inputs_embeds,
1096
  past_key_values_length,
1097
  sliding_window=self.config.sliding_window,
 
1656
  [torch.ones((attention_mask.shape[0], past_key_values_length), dtype=attention_mask.dtype, device=attention_mask.device), attention_mask],
1657
  dim=-1
1658
  )
 
 
 
 
 
 
 
 
 
1659
  # # if the attention mask
1660
  attention_mask = _prepare_4d_causal_attention_mask(
1661
+ attention_mask,
1662
  (batch_size, seq_len),
1663
  inputs_embeds,
1664
  past_key_values_length,