Crystalcareai commited on
Commit
6b08966
·
verified ·
1 Parent(s): d3e1600

Update modeling_quiet.py

Browse files
Files changed (1) hide show
  1. modeling_quiet.py +156 -64
modeling_quiet.py CHANGED
@@ -32,11 +32,11 @@ from torch import nn
32
  from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
33
  from transformers.generation.utils import GenerationMixin
34
  from transformers.generation.stopping_criteria import StoppingCriteriaList, validate_stopping_criteria
35
- from transformers import TextStreamer, AutoTokenizer
36
 
37
  from transformers.activations import ACT2FN
38
  from transformers.cache_utils import Cache, DynamicCache
39
- from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask, _prepare_4d_causal_attention_mask_for_sdpa
40
  from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast
41
  from transformers.modeling_utils import PreTrainedModel
42
  from transformers.utils import (
@@ -65,62 +65,62 @@ logger = logging.get_logger(__name__)
65
  _CONFIG_FOR_DOC = "QuietConfig"
66
 
67
 
68
- # def _prepare_4d_causal_attention_mask_for_sdpa(attention_mask, input_shape, inputs_embeds, past_key_values_length):
69
- # # Compute the attention mask correctly
70
- # bsz, tgt_len = input_shape
71
-
72
- # # Create a 4D attention mask from a 2D tensor mask.
73
- # # The shape of the output attention mask is (batch_size, 1, tgt_len, src_len)
74
- # # The values are either 0 or 1, where 0 means padding and 1 means non-padding.
75
- # combined_attention_mask = None
76
- # if attention_mask is not None:
77
- # # What if attention_mask is not None and has a shape of (batch_size, 1, tgt_len, src_len)
78
- # # In this case, we can just use it directly.
79
- # if attention_mask.dim() == 4:
80
- # combined_attention_mask = attention_mask
81
- # # What if attention_mask is not None and has a shape of (batch_size, 1, tgt_len)
82
- # # In this case, we need to expand it to (batch_size, 1, tgt_len, src_len)
83
- # elif attention_mask.dim() == 3:
84
- # expanded_attn_mask = attention_mask[:, None, :, :]
85
- # combined_attention_mask = expanded_attn_mask
86
- # # What if attention_mask is not None and has a shape of (batch_size, tgt_len)
87
- # # In this case, we need to expand it to (batch_size, 1, tgt_len, src_len)
88
- # elif attention_mask.dim() == 2:
89
- # # Provided a padding mask of dimensions [batch_size, seq_length]
90
- # # - if the model is a decoder, apply a causal mask in addition to the padding mask
91
- # # - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length]
92
- # if past_key_values_length > 0:
93
- # attention_mask = attention_mask.to(dtype=torch.long)
94
- # attention_mask = attention_mask[:, past_key_values_length:]
95
- # expanded_attn_mask = attention_mask[:, None, None, :]
96
- # combined_attention_mask = expanded_attn_mask
97
- # else:
98
- # raise ValueError(
99
- # "Wrong shape for input_ids (shape {}) or attention_mask (shape {})".format(
100
- # input_shape, attention_mask.shape
101
- # )
102
- # )
103
-
104
- # # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
105
- # # masked positions, this operation will create a tensor which is 0.0 for
106
- # # positions we want to attend and -10000.0 for masked positions.
107
- # # Since we are adding it to the raw scores before the softmax, this is
108
- # # effectively the same as removing these entirely.
109
- # if combined_attention_mask is not None:
110
- # # Ensure the attention mask values are within a reasonable range
111
- # combined_attention_mask = combined_attention_mask.clamp(min=0, max=1)
112
-
113
- # # Convert the attention mask to bfloat16
114
- # combined_attention_mask = combined_attention_mask.to(torch.bfloat16)
115
-
116
- # # Normalize the attention mask values to be between 0 and 1
117
- # combined_attention_mask = (1.0 - combined_attention_mask) * -10000.0
118
- # else:
119
- # combined_attention_mask = torch.zeros(
120
- # (bsz, 1, tgt_len, tgt_len), dtype=torch.bfloat16, device=inputs_embeds.device
121
- # )
122
-
123
- # return combined_attention_mask
124
 
125
 
126
  # Copied from transformers.models.llama.modeling_llama._get_unpad_data
@@ -774,7 +774,7 @@ class QuietSdpaAttention(QuietAttention):
774
  raise ValueError(
775
  f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
776
  )
777
-
778
  # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
779
  # Reference: https://github.com/pytorch/pytorch/issues/112577.
780
  if query_states.device.type == "cuda" and attention_mask is not None:
@@ -1182,7 +1182,7 @@ class QuietForCausalLM(QuietPreTrainedModel, GenerationMixin):
1182
  self.n_tokens_print = 1
1183
  self.gradient_accumulation_steps = 1
1184
  self.training_steps = 0
1185
- self.tokenizer = AutoTokenizer.from_pretrained("Crystalcareai/Quiet-Star-Custom")
1186
  self.start_token_id = None
1187
  self.end_token_id = None
1188
  self.rm_initialized = False
@@ -1238,6 +1238,7 @@ class QuietForCausalLM(QuietPreTrainedModel, GenerationMixin):
1238
  self.embedding_scale = 1e2
1239
  self.temperature = nn.Parameter(torch.ones(1))
1240
  self.max_temperature = config.max_temperature
 
1241
  self.reinforce_temperature = 3
1242
  self.base_loss_beta = 1
1243
  self.thinking_usefulness_head = nn.Linear(self.model.config.hidden_size, 1)
@@ -1424,9 +1425,67 @@ class QuietForCausalLM(QuietPreTrainedModel, GenerationMixin):
1424
  logits = self.lm_head(mixed_hidden_states)
1425
  return logits
1426
 
1427
- def generate(self, input_ids, attention_mask=None, max_length=None, temperature=1.0, **kwargs):
1428
- from .generate import generate
1429
- return generate(self, input_ids=input_ids, attention_mask=attention_mask, max_length=max_length, temperature=temperature, **kwargs)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1430
 
1431
  @add_start_docstrings_to_model_forward(QUIET_INPUTS_DOCSTRING)
1432
  @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
@@ -1616,7 +1675,8 @@ class QuietForCausalLM(QuietPreTrainedModel, GenerationMixin):
1616
  sample_probs_history = []
1617
  action_loglikelihoods_list = []
1618
 
1619
- temperature = self.temperature
 
1620
 
1621
  if self.use_end_thought_token or self.use_start_thought_token:
1622
  if not self.use_reparam_for_thought_embeddings:
@@ -1674,12 +1734,15 @@ class QuietForCausalLM(QuietPreTrainedModel, GenerationMixin):
1674
  base_attention_mask = base_attention_mask.view(1, 1, seq_len, seq_len)
1675
  base_attention_mask = base_attention_mask.repeat(input_ids.shape[0], 1, 1, 1)
1676
  attention_mask = base_attention_mask
 
1677
  elif attention_mask.dim() == 2:
1678
  if seq_len + past_key_values_length != attention_mask.shape[-1]:
 
1679
  attention_mask = torch.cat(
1680
  [torch.ones((attention_mask.shape[0], past_key_values_length), dtype=attention_mask.dtype, device=attention_mask.device), attention_mask],
1681
  dim=-1
1682
  )
 
1683
  attention_mask = _prepare_4d_causal_attention_mask(
1684
  attention_mask,
1685
  (batch_size, seq_len),
@@ -1697,8 +1760,10 @@ class QuietForCausalLM(QuietPreTrainedModel, GenerationMixin):
1697
  use_cache=use_cache,
1698
  output_attentions=output_attentions,
1699
  output_hidden_states=output_hidden_states,
 
1700
  return_dict=return_dict,
1701
  )
 
1702
  prev_hidden_states = hidden_states
1703
  hidden_states = outputs[0]
1704
  prev_rm_logits = rm_logits # for policy gradient
@@ -2125,6 +2190,33 @@ class QuietForCausalLM(QuietPreTrainedModel, GenerationMixin):
2125
  attentions=outputs.attentions,
2126
  )
2127
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2128
 
2129
 
2130
  def prepare_inputs_for_generation(
 
32
  from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
33
  from transformers.generation.utils import GenerationMixin
34
  from transformers.generation.stopping_criteria import StoppingCriteriaList, validate_stopping_criteria
35
+ from transformers import TextStreamer
36
 
37
  from transformers.activations import ACT2FN
38
  from transformers.cache_utils import Cache, DynamicCache
39
+ from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask
40
  from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast
41
  from transformers.modeling_utils import PreTrainedModel
42
  from transformers.utils import (
 
65
  _CONFIG_FOR_DOC = "QuietConfig"
66
 
67
 
68
+ def _prepare_4d_causal_attention_mask_for_sdpa(attention_mask, input_shape, inputs_embeds, past_key_values_length):
69
+ # Compute the attention mask correctly
70
+ bsz, tgt_len = input_shape
71
+
72
+ # Create a 4D attention mask from a 2D tensor mask.
73
+ # The shape of the output attention mask is (batch_size, 1, tgt_len, src_len)
74
+ # The values are either 0 or 1, where 0 means padding and 1 means non-padding.
75
+ combined_attention_mask = None
76
+ if attention_mask is not None:
77
+ # What if attention_mask is not None and has a shape of (batch_size, 1, tgt_len, src_len)
78
+ # In this case, we can just use it directly.
79
+ if attention_mask.dim() == 4:
80
+ combined_attention_mask = attention_mask
81
+ # What if attention_mask is not None and has a shape of (batch_size, 1, tgt_len)
82
+ # In this case, we need to expand it to (batch_size, 1, tgt_len, src_len)
83
+ elif attention_mask.dim() == 3:
84
+ expanded_attn_mask = attention_mask[:, None, :, :]
85
+ combined_attention_mask = expanded_attn_mask
86
+ # What if attention_mask is not None and has a shape of (batch_size, tgt_len)
87
+ # In this case, we need to expand it to (batch_size, 1, tgt_len, src_len)
88
+ elif attention_mask.dim() == 2:
89
+ # Provided a padding mask of dimensions [batch_size, seq_length]
90
+ # - if the model is a decoder, apply a causal mask in addition to the padding mask
91
+ # - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length]
92
+ if past_key_values_length > 0:
93
+ attention_mask = attention_mask.to(dtype=torch.long)
94
+ attention_mask = attention_mask[:, past_key_values_length:]
95
+ expanded_attn_mask = attention_mask[:, None, None, :]
96
+ combined_attention_mask = expanded_attn_mask
97
+ else:
98
+ raise ValueError(
99
+ "Wrong shape for input_ids (shape {}) or attention_mask (shape {})".format(
100
+ input_shape, attention_mask.shape
101
+ )
102
+ )
103
+
104
+ # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
105
+ # masked positions, this operation will create a tensor which is 0.0 for
106
+ # positions we want to attend and -10000.0 for masked positions.
107
+ # Since we are adding it to the raw scores before the softmax, this is
108
+ # effectively the same as removing these entirely.
109
+ if combined_attention_mask is not None:
110
+ # Ensure the attention mask values are within a reasonable range
111
+ combined_attention_mask = combined_attention_mask.clamp(min=0, max=1)
112
+
113
+ # Convert the attention mask to bfloat16
114
+ combined_attention_mask = combined_attention_mask.to(torch.bfloat16)
115
+
116
+ # Normalize the attention mask values to be between 0 and 1
117
+ combined_attention_mask = (1.0 - combined_attention_mask) * -10000.0
118
+ else:
119
+ combined_attention_mask = torch.zeros(
120
+ (bsz, 1, tgt_len, tgt_len), dtype=torch.bfloat16, device=inputs_embeds.device
121
+ )
122
+
123
+ return combined_attention_mask
124
 
125
 
126
  # Copied from transformers.models.llama.modeling_llama._get_unpad_data
 
774
  raise ValueError(
775
  f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
776
  )
777
+
778
  # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
779
  # Reference: https://github.com/pytorch/pytorch/issues/112577.
780
  if query_states.device.type == "cuda" and attention_mask is not None:
 
1182
  self.n_tokens_print = 1
1183
  self.gradient_accumulation_steps = 1
1184
  self.training_steps = 0
1185
+ self.tokenizer = None
1186
  self.start_token_id = None
1187
  self.end_token_id = None
1188
  self.rm_initialized = False
 
1238
  self.embedding_scale = 1e2
1239
  self.temperature = nn.Parameter(torch.ones(1))
1240
  self.max_temperature = config.max_temperature
1241
+ self.complexity_factor = config.complexity_factor
1242
  self.reinforce_temperature = 3
1243
  self.base_loss_beta = 1
1244
  self.thinking_usefulness_head = nn.Linear(self.model.config.hidden_size, 1)
 
1425
  logits = self.lm_head(mixed_hidden_states)
1426
  return logits
1427
 
1428
+ @torch.no_grad()
1429
+ def generate(
1430
+ self,
1431
+ input_ids: Optional[torch.LongTensor] = None,
1432
+ max_length: Optional[int] = None,
1433
+ min_length: Optional[int] = None,
1434
+ do_sample: Optional[bool] = None,
1435
+ early_stopping: Optional[bool] = None,
1436
+ num_beams: Optional[int] = None,
1437
+ temperature: Optional[float] = None,
1438
+ top_k: Optional[int] = None,
1439
+ top_p: Optional[float] = None,
1440
+ repetition_penalty: Optional[float] = None,
1441
+ bad_words_ids: Optional[Iterable[int]] = None,
1442
+ bos_token_id: Optional[int] = None,
1443
+ pad_token_id: Optional[int] = None,
1444
+ eos_token_id: Optional[int] = None,
1445
+ length_penalty: Optional[float] = None,
1446
+ no_repeat_ngram_size: Optional[int] = None,
1447
+ encoder_no_repeat_ngram_size: Optional[int] = None,
1448
+ num_return_sequences: Optional[int] = None,
1449
+ max_time: Optional[float] = None,
1450
+ max_new_tokens: Optional[int] = None,
1451
+ decoder_start_token_id: Optional[int] = None,
1452
+ use_cache: Optional[bool] = None,
1453
+ num_beam_groups: Optional[int] = None,
1454
+ diversity_penalty: Optional[float] = None,
1455
+ prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None,
1456
+ output_attentions: Optional[bool] = None,
1457
+ output_hidden_states: Optional[bool] = None,
1458
+ output_scores: Optional[bool] = None,
1459
+ return_dict_in_generate: Optional[bool] = None,
1460
+ forced_bos_token_id: Optional[int] = None,
1461
+ forced_eos_token_id: Optional[int] = None,
1462
+ remove_invalid_values: Optional[bool] = None,
1463
+ synced_gpus: Optional[bool] = False,
1464
+ stopping_criteria: Optional[StoppingCriteriaList] = None,
1465
+ **model_kwargs,
1466
+ ):
1467
+ # Validate stopping criteria
1468
+ stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length)
1469
+
1470
+ # Set default values for attention_mask and max_new_tokens
1471
+ if "attention_mask" not in model_kwargs:
1472
+ attention_mask = torch.where(input_ids != self.tokenizer.pad_token_id, 1, 0).to(input_ids.device)
1473
+ model_kwargs["attention_mask"] = attention_mask
1474
+ if max_new_tokens is None:
1475
+ max_new_tokens = 512
1476
+
1477
+ streamer = TextStreamer(self.tokenizer, skip_prompt=False, skip_special_tokens=True)
1478
+
1479
+ # Call the custom generate function
1480
+ output_ids, _ = custom_generate(
1481
+ self,
1482
+ input_ids=input_ids,
1483
+ streamer=streamer,
1484
+ max_new_tokens=max_new_tokens,
1485
+ **model_kwargs,
1486
+ )
1487
+
1488
+ return output_ids
1489
 
1490
  @add_start_docstrings_to_model_forward(QUIET_INPUTS_DOCSTRING)
1491
  @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
 
1675
  sample_probs_history = []
1676
  action_loglikelihoods_list = []
1677
 
1678
+ complexity_scores = self.compute_complexity_scores(input_ids, attention_mask)
1679
+ temperature = self.temperature * complexity_scores.unsqueeze(-1)
1680
 
1681
  if self.use_end_thought_token or self.use_start_thought_token:
1682
  if not self.use_reparam_for_thought_embeddings:
 
1734
  base_attention_mask = base_attention_mask.view(1, 1, seq_len, seq_len)
1735
  base_attention_mask = base_attention_mask.repeat(input_ids.shape[0], 1, 1, 1)
1736
  attention_mask = base_attention_mask
1737
+ breakpoint()
1738
  elif attention_mask.dim() == 2:
1739
  if seq_len + past_key_values_length != attention_mask.shape[-1]:
1740
+ breakpoint()
1741
  attention_mask = torch.cat(
1742
  [torch.ones((attention_mask.shape[0], past_key_values_length), dtype=attention_mask.dtype, device=attention_mask.device), attention_mask],
1743
  dim=-1
1744
  )
1745
+ # # if the attention mask
1746
  attention_mask = _prepare_4d_causal_attention_mask(
1747
  attention_mask,
1748
  (batch_size, seq_len),
 
1760
  use_cache=use_cache,
1761
  output_attentions=output_attentions,
1762
  output_hidden_states=output_hidden_states,
1763
+ # output_router_logits=output_router_logits,
1764
  return_dict=return_dict,
1765
  )
1766
+
1767
  prev_hidden_states = hidden_states
1768
  hidden_states = outputs[0]
1769
  prev_rm_logits = rm_logits # for policy gradient
 
2190
  attentions=outputs.attentions,
2191
  )
2192
 
2193
+
2194
+
2195
+ def compute_complexity_scores(self, input_ids, attention_mask):
2196
+ # Compute complexity scores based on input sequence characteristics
2197
+ # Example: Normalize sequence lengths and consider the presence of rare tokens
2198
+ seq_lengths = torch.sum(attention_mask, dim=-1)
2199
+ max_length = torch.max(seq_lengths)
2200
+ length_scores = seq_lengths / max_length
2201
+
2202
+ # Compute the proportion of rare tokens in each sequence
2203
+ rare_token_ids = self.get_rare_token_ids()
2204
+ rare_token_mask = torch.isin(input_ids, rare_token_ids)
2205
+ rare_token_counts = torch.sum(rare_token_mask, dim=-1)
2206
+ rare_token_scores = rare_token_counts / seq_lengths
2207
+
2208
+ # Combine length scores and rare token scores
2209
+ complexity_scores = self.complexity_factor * length_scores + (1 - self.complexity_factor) * rare_token_scores
2210
+ return complexity_scores
2211
+
2212
+ def get_rare_token_ids(self):
2213
+ # Get the IDs of rare tokens based on a predefined frequency threshold
2214
+ frequency_threshold = 1e-4
2215
+ token_counts = torch.bincount(self.model.embed_tokens.weight.argmax(dim=-1))
2216
+ total_tokens = torch.sum(token_counts)
2217
+ rare_token_mask = token_counts / total_tokens < frequency_threshold
2218
+ rare_token_ids = torch.nonzero(rare_token_mask).squeeze(-1)
2219
+ return rare_token_ids
2220
 
2221
 
2222
  def prepare_inputs_for_generation(