Crystalcareai commited on
Commit
45d3082
·
verified ·
1 Parent(s): c7e43b2

Update modeling_quiet.py

Browse files
Files changed (1) hide show
  1. modeling_quiet.py +10 -97
modeling_quiet.py CHANGED
@@ -23,7 +23,7 @@ import math
23
  import pdb
24
  import warnings
25
  from collections import defaultdict
26
- from typing import List, Optional, Tuple, Union, Iterable, Callable
27
 
28
  import torch
29
  import torch.nn.functional as F
@@ -32,7 +32,7 @@ from torch import nn
32
  from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
33
  from transformers.generation.utils import GenerationMixin
34
  from transformers.generation.stopping_criteria import StoppingCriteriaList, validate_stopping_criteria
35
- from transformers import TextStreamer
36
 
37
  from transformers.activations import ACT2FN
38
  from transformers.cache_utils import Cache, DynamicCache
@@ -774,7 +774,7 @@ class QuietSdpaAttention(QuietAttention):
774
  raise ValueError(
775
  f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
776
  )
777
-
778
  # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
779
  # Reference: https://github.com/pytorch/pytorch/issues/112577.
780
  if query_states.device.type == "cuda" and attention_mask is not None:
@@ -1182,7 +1182,7 @@ class QuietForCausalLM(QuietPreTrainedModel, GenerationMixin):
1182
  self.n_tokens_print = 1
1183
  self.gradient_accumulation_steps = 1
1184
  self.training_steps = 0
1185
- self.tokenizer = None
1186
  self.start_token_id = None
1187
  self.end_token_id = None
1188
  self.rm_initialized = False
@@ -1238,7 +1238,6 @@ class QuietForCausalLM(QuietPreTrainedModel, GenerationMixin):
1238
  self.embedding_scale = 1e2
1239
  self.temperature = nn.Parameter(torch.ones(1))
1240
  self.max_temperature = config.max_temperature
1241
- self.complexity_factor = config.complexity_factor
1242
  self.reinforce_temperature = 3
1243
  self.base_loss_beta = 1
1244
  self.thinking_usefulness_head = nn.Linear(self.model.config.hidden_size, 1)
@@ -1425,67 +1424,9 @@ class QuietForCausalLM(QuietPreTrainedModel, GenerationMixin):
1425
  logits = self.lm_head(mixed_hidden_states)
1426
  return logits
1427
 
1428
- @torch.no_grad()
1429
- def generate(
1430
- self,
1431
- input_ids: Optional[torch.LongTensor] = None,
1432
- max_length: Optional[int] = None,
1433
- min_length: Optional[int] = None,
1434
- do_sample: Optional[bool] = None,
1435
- early_stopping: Optional[bool] = None,
1436
- num_beams: Optional[int] = None,
1437
- temperature: Optional[float] = None,
1438
- top_k: Optional[int] = None,
1439
- top_p: Optional[float] = None,
1440
- repetition_penalty: Optional[float] = None,
1441
- bad_words_ids: Optional[Iterable[int]] = None,
1442
- bos_token_id: Optional[int] = None,
1443
- pad_token_id: Optional[int] = None,
1444
- eos_token_id: Optional[int] = None,
1445
- length_penalty: Optional[float] = None,
1446
- no_repeat_ngram_size: Optional[int] = None,
1447
- encoder_no_repeat_ngram_size: Optional[int] = None,
1448
- num_return_sequences: Optional[int] = None,
1449
- max_time: Optional[float] = None,
1450
- max_new_tokens: Optional[int] = None,
1451
- decoder_start_token_id: Optional[int] = None,
1452
- use_cache: Optional[bool] = None,
1453
- num_beam_groups: Optional[int] = None,
1454
- diversity_penalty: Optional[float] = None,
1455
- prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None,
1456
- output_attentions: Optional[bool] = None,
1457
- output_hidden_states: Optional[bool] = None,
1458
- output_scores: Optional[bool] = None,
1459
- return_dict_in_generate: Optional[bool] = None,
1460
- forced_bos_token_id: Optional[int] = None,
1461
- forced_eos_token_id: Optional[int] = None,
1462
- remove_invalid_values: Optional[bool] = None,
1463
- synced_gpus: Optional[bool] = False,
1464
- stopping_criteria: Optional[StoppingCriteriaList] = None,
1465
- **model_kwargs,
1466
- ):
1467
- # # Validate stopping criteria
1468
- # stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length)
1469
-
1470
- # Set default values for attention_mask and max_new_tokens
1471
- if "attention_mask" not in model_kwargs:
1472
- attention_mask = torch.where(input_ids != self.tokenizer.pad_token_id, 1, 0).to(input_ids.device)
1473
- model_kwargs["attention_mask"] = attention_mask
1474
- if max_new_tokens is None:
1475
- max_new_tokens = 512
1476
-
1477
- streamer = TextStreamer(self.tokenizer, skip_prompt=False, skip_special_tokens=True)
1478
-
1479
- # Call the custom generate function
1480
- output_ids, _ = custom_generate(
1481
- self,
1482
- input_ids=input_ids,
1483
- streamer=streamer,
1484
- max_new_tokens=max_new_tokens,
1485
- **model_kwargs,
1486
- )
1487
-
1488
- return output_ids
1489
 
1490
  @add_start_docstrings_to_model_forward(QUIET_INPUTS_DOCSTRING)
1491
  @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
@@ -1675,8 +1616,7 @@ class QuietForCausalLM(QuietPreTrainedModel, GenerationMixin):
1675
  sample_probs_history = []
1676
  action_loglikelihoods_list = []
1677
 
1678
- complexity_scores = self.compute_complexity_scores(input_ids, attention_mask)
1679
- temperature = self.temperature * complexity_scores.unsqueeze(-1)
1680
 
1681
  if self.use_end_thought_token or self.use_start_thought_token:
1682
  if not self.use_reparam_for_thought_embeddings:
@@ -1734,10 +1674,10 @@ class QuietForCausalLM(QuietPreTrainedModel, GenerationMixin):
1734
  base_attention_mask = base_attention_mask.view(1, 1, seq_len, seq_len)
1735
  base_attention_mask = base_attention_mask.repeat(input_ids.shape[0], 1, 1, 1)
1736
  attention_mask = base_attention_mask
1737
- breakpoint()
1738
  elif attention_mask.dim() == 2:
1739
  if seq_len + past_key_values_length != attention_mask.shape[-1]:
1740
- breakpoint()
1741
  attention_mask = torch.cat(
1742
  [torch.ones((attention_mask.shape[0], past_key_values_length), dtype=attention_mask.dtype, device=attention_mask.device), attention_mask],
1743
  dim=-1
@@ -2190,33 +2130,6 @@ class QuietForCausalLM(QuietPreTrainedModel, GenerationMixin):
2190
  attentions=outputs.attentions,
2191
  )
2192
 
2193
-
2194
-
2195
- def compute_complexity_scores(self, input_ids, attention_mask):
2196
- # Compute complexity scores based on input sequence characteristics
2197
- # Example: Normalize sequence lengths and consider the presence of rare tokens
2198
- seq_lengths = torch.sum(attention_mask, dim=-1)
2199
- max_length = torch.max(seq_lengths)
2200
- length_scores = seq_lengths / max_length
2201
-
2202
- # Compute the proportion of rare tokens in each sequence
2203
- rare_token_ids = self.get_rare_token_ids()
2204
- rare_token_mask = torch.isin(input_ids, rare_token_ids)
2205
- rare_token_counts = torch.sum(rare_token_mask, dim=-1)
2206
- rare_token_scores = rare_token_counts / seq_lengths
2207
-
2208
- # Combine length scores and rare token scores
2209
- complexity_scores = self.complexity_factor * length_scores + (1 - self.complexity_factor) * rare_token_scores
2210
- return complexity_scores
2211
-
2212
- def get_rare_token_ids(self):
2213
- # Get the IDs of rare tokens based on a predefined frequency threshold
2214
- frequency_threshold = 1e-4
2215
- token_counts = torch.bincount(self.model.embed_tokens.weight.argmax(dim=-1))
2216
- total_tokens = torch.sum(token_counts)
2217
- rare_token_mask = token_counts / total_tokens < frequency_threshold
2218
- rare_token_ids = torch.nonzero(rare_token_mask).squeeze(-1)
2219
- return rare_token_ids
2220
 
2221
 
2222
  def prepare_inputs_for_generation(
 
23
  import pdb
24
  import warnings
25
  from collections import defaultdict
26
+ from typing import List, Optional, Tuple, Union
27
 
28
  import torch
29
  import torch.nn.functional as F
 
32
  from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
33
  from transformers.generation.utils import GenerationMixin
34
  from transformers.generation.stopping_criteria import StoppingCriteriaList, validate_stopping_criteria
35
+ from transformers import TextStreamer, AutoTokenizer
36
 
37
  from transformers.activations import ACT2FN
38
  from transformers.cache_utils import Cache, DynamicCache
 
774
  raise ValueError(
775
  f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
776
  )
777
+ attention_mask = attention_mask.to(query_states.dtype)
778
  # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
779
  # Reference: https://github.com/pytorch/pytorch/issues/112577.
780
  if query_states.device.type == "cuda" and attention_mask is not None:
 
1182
  self.n_tokens_print = 1
1183
  self.gradient_accumulation_steps = 1
1184
  self.training_steps = 0
1185
+ self.tokenizer = AutoTokenizer.from_pretrained("Crystalcareai/Quiet-Star-Custom")
1186
  self.start_token_id = None
1187
  self.end_token_id = None
1188
  self.rm_initialized = False
 
1238
  self.embedding_scale = 1e2
1239
  self.temperature = nn.Parameter(torch.ones(1))
1240
  self.max_temperature = config.max_temperature
 
1241
  self.reinforce_temperature = 3
1242
  self.base_loss_beta = 1
1243
  self.thinking_usefulness_head = nn.Linear(self.model.config.hidden_size, 1)
 
1424
  logits = self.lm_head(mixed_hidden_states)
1425
  return logits
1426
 
1427
+ def generate(self, input_ids, attention_mask=None, max_length=None, temperature=1.0, **kwargs):
1428
+ from .generate import generate
1429
+ return generate(self, input_ids=input_ids, attention_mask=attention_mask, max_length=max_length, temperature=temperature, **kwargs)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1430
 
1431
  @add_start_docstrings_to_model_forward(QUIET_INPUTS_DOCSTRING)
1432
  @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
 
1616
  sample_probs_history = []
1617
  action_loglikelihoods_list = []
1618
 
1619
+ temperature = self.temperature
 
1620
 
1621
  if self.use_end_thought_token or self.use_start_thought_token:
1622
  if not self.use_reparam_for_thought_embeddings:
 
1674
  base_attention_mask = base_attention_mask.view(1, 1, seq_len, seq_len)
1675
  base_attention_mask = base_attention_mask.repeat(input_ids.shape[0], 1, 1, 1)
1676
  attention_mask = base_attention_mask
1677
+ # breakpoint()
1678
  elif attention_mask.dim() == 2:
1679
  if seq_len + past_key_values_length != attention_mask.shape[-1]:
1680
+ # breakpoint()
1681
  attention_mask = torch.cat(
1682
  [torch.ones((attention_mask.shape[0], past_key_values_length), dtype=attention_mask.dtype, device=attention_mask.device), attention_mask],
1683
  dim=-1
 
2130
  attentions=outputs.attentions,
2131
  )
2132
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2133
 
2134
 
2135
  def prepare_inputs_for_generation(