Crystalcareai commited on
Commit
72e45de
·
verified ·
1 Parent(s): a00ce27

Update modeling_quiet.py

Browse files
Files changed (1) hide show
  1. modeling_quiet.py +14 -71
modeling_quiet.py CHANGED
@@ -42,12 +42,12 @@ import torch.utils.checkpoint
42
  from torch import nn
43
  from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
44
 
45
- from transformers.activations import ACT2FN
46
- from transformers.cache_utils import Cache, DynamicCache
47
- from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask
48
- from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast
49
- from transformers.modeling_utils import PreTrainedModel
50
- from transformers.utils import (
51
  add_start_docstrings,
52
  add_start_docstrings_to_model_forward,
53
  is_flash_attn_2_available,
@@ -134,34 +134,6 @@ def save_tokens_with_rewards_to_pdf(input_ids, token_rewards, tokenizer, output_
134
  previous_text = current_text
135
  c.showPage()
136
  c.save()
137
-
138
- def _prepare_4d_causal_attention_mask_for_sdpa(
139
- attn_mask: Optional[torch.Tensor],
140
- shape: Tuple[int, int],
141
- inputs_embeds: Optional[torch.Tensor] = None,
142
- past_key_values_length: int = 0,
143
- ) -> torch.Tensor:
144
- batch_size, seq_len = shape
145
- if attn_mask is None:
146
- attn_mask = torch.ones((batch_size, seq_len), dtype=torch.bool, device=inputs_embeds.device)
147
- else:
148
- attn_mask = attn_mask.bool()
149
-
150
- # Extend the attention mask to account for past key/value states
151
- if past_key_values_length > 0:
152
- extended_attn_mask = torch.cat(
153
- [
154
- attn_mask.new_zeros(batch_size, seq_len, past_key_values_length),
155
- attn_mask.unsqueeze(2),
156
- ],
157
- dim=2,
158
- )
159
- attn_mask = extended_attn_mask
160
-
161
- attn_mask = attn_mask.unsqueeze(1).unsqueeze(1)
162
- causal_mask = torch.tril(torch.ones(seq_len, seq_len + past_key_values_length, device=attn_mask.device)).bool()
163
- attn_mask = attn_mask & causal_mask.unsqueeze(0).unsqueeze(0)
164
- return attn_mask
165
 
166
 
167
  # Copied from transformers.models.llama.modeling_llama._get_unpad_data
@@ -309,8 +281,8 @@ class QuietAttention(nn.Module):
309
  self.layer_idx = layer_idx
310
  if layer_idx is None:
311
  logger.warning_once(
312
- f"Instantiating {self.__class__.__name__} without passing `layer_idx` is not recommended and will "
313
- "to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` "
314
  "when creating this class."
315
  )
316
 
@@ -601,7 +573,7 @@ class QuietFlashAttention2(QuietAttention):
601
  attention_mask (`torch.Tensor`):
602
  The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the
603
  position of padding tokens and 1 for the position of non-padding tokens.
604
- dropout (`int`, *optional*):
605
  Attention dropout
606
  softmax_scale (`float`, *optional*):
607
  The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
@@ -719,7 +691,8 @@ class QuietFlashAttention2(QuietAttention):
719
  )
720
 
721
 
722
- # Copied from transformers.models.llama.modeling_llama.LlamaSdpaAttention with Llama->Quiet
 
723
  class QuietSdpaAttention(QuietAttention):
724
  """
725
  Quiet attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
@@ -793,14 +766,14 @@ class QuietSdpaAttention(QuietAttention):
793
  query_states,
794
  key_states,
795
  value_states,
796
- attn_mask=attention_mask.to(query_states.device) if attention_mask is not None else None,
797
  dropout_p=self.attention_dropout if self.training else 0.0,
798
  # The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1.
799
  is_causal=self.is_causal and attention_mask is None and q_len > 1,
800
  )
801
 
802
  attn_output = attn_output.transpose(1, 2).contiguous()
803
- attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
804
 
805
  attn_output = self.o_proj(attn_output)
806
 
@@ -1665,37 +1638,7 @@ class QuietForCausalLM(QuietPreTrainedModel):
1665
  else:
1666
  with torch.set_grad_enabled(not self.train_only_thinking_embedding):
1667
  inputs_embeds = self.model.embed_tokens(input_ids)
1668
-
1669
- def _update_inputs_for_thought_tokens(
1670
- self, input_ids, attention_mask, contains_start, contains_end
1671
- ):
1672
- batch_size = input_ids.size(0)
1673
- seq_len = input_ids.size(1)
1674
-
1675
- if contains_start:
1676
- start_token_ids = torch.tensor(
1677
- [[self.start_token_id]] * batch_size, device=input_ids.device
1678
- )
1679
- input_ids = torch.cat([input_ids, start_token_ids], dim=1)
1680
- if attention_mask is not None:
1681
- start_attention_mask = torch.ones(
1682
- (batch_size, 1), device=attention_mask.device
1683
- )
1684
- attention_mask = torch.cat([attention_mask, start_attention_mask], dim=1)
1685
-
1686
- if contains_end:
1687
- end_token_ids = torch.tensor(
1688
- [[self.end_token_id]] * batch_size, device=input_ids.device
1689
- )
1690
- input_ids = torch.cat([input_ids, end_token_ids], dim=1)
1691
- if attention_mask is not None:
1692
- end_attention_mask = torch.ones(
1693
- (batch_size, 1), device=attention_mask.device
1694
- )
1695
- attention_mask = torch.cat([attention_mask, end_attention_mask], dim=1)
1696
-
1697
- return input_ids, attention_mask
1698
-
1699
  if self.n_ahead != 1 or self.n_ahead_talk != 1 or self.comparison_mode:
1700
  if attention_mask is None:
1701
  base_attention_mask = torch.triu(torch.ones(seq_len, seq_len), diagonal=0).to(input_ids.device)
 
42
  from torch import nn
43
  from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
44
 
45
+ from ...activations import ACT2FN
46
+ from ...cache_utils import Cache, DynamicCache
47
+ from ...modeling_attn_mask_utils import _prepare_4d_causal_attention_mask, _prepare_4d_causal_attention_mask_for_sdpa
48
+ from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast
49
+ from ...modeling_utils import PreTrainedModel
50
+ from ...utils import (
51
  add_start_docstrings,
52
  add_start_docstrings_to_model_forward,
53
  is_flash_attn_2_available,
 
134
  previous_text = current_text
135
  c.showPage()
136
  c.save()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
137
 
138
 
139
  # Copied from transformers.models.llama.modeling_llama._get_unpad_data
 
281
  self.layer_idx = layer_idx
282
  if layer_idx is None:
283
  logger.warning_once(
284
+ f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will "
285
+ "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` "
286
  "when creating this class."
287
  )
288
 
 
573
  attention_mask (`torch.Tensor`):
574
  The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the
575
  position of padding tokens and 1 for the position of non-padding tokens.
576
+ dropout (`float`):
577
  Attention dropout
578
  softmax_scale (`float`, *optional*):
579
  The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
 
691
  )
692
 
693
 
694
+ # copied from transformers.models.llama.modeling_llama.LlamaSdpaAttention with Llama->Quiet
695
+ # TODO @Arthur no longer copied from LLama after static cache
696
  class QuietSdpaAttention(QuietAttention):
697
  """
698
  Quiet attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
 
766
  query_states,
767
  key_states,
768
  value_states,
769
+ attn_mask=attention_mask,
770
  dropout_p=self.attention_dropout if self.training else 0.0,
771
  # The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1.
772
  is_causal=self.is_causal and attention_mask is None and q_len > 1,
773
  )
774
 
775
  attn_output = attn_output.transpose(1, 2).contiguous()
776
+ attn_output = attn_output.view(bsz, q_len, self.hidden_size)
777
 
778
  attn_output = self.o_proj(attn_output)
779
 
 
1638
  else:
1639
  with torch.set_grad_enabled(not self.train_only_thinking_embedding):
1640
  inputs_embeds = self.model.embed_tokens(input_ids)
1641
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1642
  if self.n_ahead != 1 or self.n_ahead_talk != 1 or self.comparison_mode:
1643
  if attention_mask is None:
1644
  base_attention_mask = torch.triu(torch.ones(seq_len, seq_len), diagonal=0).to(input_ids.device)