Crystalcareai
commited on
Update modeling_quiet.py
Browse files- modeling_quiet.py +10 -97
modeling_quiet.py
CHANGED
@@ -23,7 +23,7 @@ import math
|
|
23 |
import pdb
|
24 |
import warnings
|
25 |
from collections import defaultdict
|
26 |
-
from typing import List, Optional, Tuple, Union
|
27 |
|
28 |
import torch
|
29 |
import torch.nn.functional as F
|
@@ -32,7 +32,7 @@ from torch import nn
|
|
32 |
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
33 |
from transformers.generation.utils import GenerationMixin
|
34 |
from transformers.generation.stopping_criteria import StoppingCriteriaList, validate_stopping_criteria
|
35 |
-
from transformers import TextStreamer
|
36 |
|
37 |
from transformers.activations import ACT2FN
|
38 |
from transformers.cache_utils import Cache, DynamicCache
|
@@ -774,7 +774,7 @@ class QuietSdpaAttention(QuietAttention):
|
|
774 |
raise ValueError(
|
775 |
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
|
776 |
)
|
777 |
-
|
778 |
# SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
|
779 |
# Reference: https://github.com/pytorch/pytorch/issues/112577.
|
780 |
if query_states.device.type == "cuda" and attention_mask is not None:
|
@@ -1182,7 +1182,7 @@ class QuietForCausalLM(QuietPreTrainedModel, GenerationMixin):
|
|
1182 |
self.n_tokens_print = 1
|
1183 |
self.gradient_accumulation_steps = 1
|
1184 |
self.training_steps = 0
|
1185 |
-
self.tokenizer =
|
1186 |
self.start_token_id = None
|
1187 |
self.end_token_id = None
|
1188 |
self.rm_initialized = False
|
@@ -1238,7 +1238,6 @@ class QuietForCausalLM(QuietPreTrainedModel, GenerationMixin):
|
|
1238 |
self.embedding_scale = 1e2
|
1239 |
self.temperature = nn.Parameter(torch.ones(1))
|
1240 |
self.max_temperature = config.max_temperature
|
1241 |
-
self.complexity_factor = config.complexity_factor
|
1242 |
self.reinforce_temperature = 3
|
1243 |
self.base_loss_beta = 1
|
1244 |
self.thinking_usefulness_head = nn.Linear(self.model.config.hidden_size, 1)
|
@@ -1425,67 +1424,9 @@ class QuietForCausalLM(QuietPreTrainedModel, GenerationMixin):
|
|
1425 |
logits = self.lm_head(mixed_hidden_states)
|
1426 |
return logits
|
1427 |
|
1428 |
-
|
1429 |
-
|
1430 |
-
self,
|
1431 |
-
input_ids: Optional[torch.LongTensor] = None,
|
1432 |
-
max_length: Optional[int] = None,
|
1433 |
-
min_length: Optional[int] = None,
|
1434 |
-
do_sample: Optional[bool] = None,
|
1435 |
-
early_stopping: Optional[bool] = None,
|
1436 |
-
num_beams: Optional[int] = None,
|
1437 |
-
temperature: Optional[float] = None,
|
1438 |
-
top_k: Optional[int] = None,
|
1439 |
-
top_p: Optional[float] = None,
|
1440 |
-
repetition_penalty: Optional[float] = None,
|
1441 |
-
bad_words_ids: Optional[Iterable[int]] = None,
|
1442 |
-
bos_token_id: Optional[int] = None,
|
1443 |
-
pad_token_id: Optional[int] = None,
|
1444 |
-
eos_token_id: Optional[int] = None,
|
1445 |
-
length_penalty: Optional[float] = None,
|
1446 |
-
no_repeat_ngram_size: Optional[int] = None,
|
1447 |
-
encoder_no_repeat_ngram_size: Optional[int] = None,
|
1448 |
-
num_return_sequences: Optional[int] = None,
|
1449 |
-
max_time: Optional[float] = None,
|
1450 |
-
max_new_tokens: Optional[int] = None,
|
1451 |
-
decoder_start_token_id: Optional[int] = None,
|
1452 |
-
use_cache: Optional[bool] = None,
|
1453 |
-
num_beam_groups: Optional[int] = None,
|
1454 |
-
diversity_penalty: Optional[float] = None,
|
1455 |
-
prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None,
|
1456 |
-
output_attentions: Optional[bool] = None,
|
1457 |
-
output_hidden_states: Optional[bool] = None,
|
1458 |
-
output_scores: Optional[bool] = None,
|
1459 |
-
return_dict_in_generate: Optional[bool] = None,
|
1460 |
-
forced_bos_token_id: Optional[int] = None,
|
1461 |
-
forced_eos_token_id: Optional[int] = None,
|
1462 |
-
remove_invalid_values: Optional[bool] = None,
|
1463 |
-
synced_gpus: Optional[bool] = False,
|
1464 |
-
stopping_criteria: Optional[StoppingCriteriaList] = None,
|
1465 |
-
**model_kwargs,
|
1466 |
-
):
|
1467 |
-
# # Validate stopping criteria
|
1468 |
-
# stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length)
|
1469 |
-
|
1470 |
-
# Set default values for attention_mask and max_new_tokens
|
1471 |
-
if "attention_mask" not in model_kwargs:
|
1472 |
-
attention_mask = torch.where(input_ids != self.tokenizer.pad_token_id, 1, 0).to(input_ids.device)
|
1473 |
-
model_kwargs["attention_mask"] = attention_mask
|
1474 |
-
if max_new_tokens is None:
|
1475 |
-
max_new_tokens = 512
|
1476 |
-
|
1477 |
-
streamer = TextStreamer(self.tokenizer, skip_prompt=False, skip_special_tokens=True)
|
1478 |
-
|
1479 |
-
# Call the custom generate function
|
1480 |
-
output_ids, _ = custom_generate(
|
1481 |
-
self,
|
1482 |
-
input_ids=input_ids,
|
1483 |
-
streamer=streamer,
|
1484 |
-
max_new_tokens=max_new_tokens,
|
1485 |
-
**model_kwargs,
|
1486 |
-
)
|
1487 |
-
|
1488 |
-
return output_ids
|
1489 |
|
1490 |
@add_start_docstrings_to_model_forward(QUIET_INPUTS_DOCSTRING)
|
1491 |
@replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
|
@@ -1675,8 +1616,7 @@ class QuietForCausalLM(QuietPreTrainedModel, GenerationMixin):
|
|
1675 |
sample_probs_history = []
|
1676 |
action_loglikelihoods_list = []
|
1677 |
|
1678 |
-
|
1679 |
-
temperature = self.temperature * complexity_scores.unsqueeze(-1)
|
1680 |
|
1681 |
if self.use_end_thought_token or self.use_start_thought_token:
|
1682 |
if not self.use_reparam_for_thought_embeddings:
|
@@ -1734,10 +1674,10 @@ class QuietForCausalLM(QuietPreTrainedModel, GenerationMixin):
|
|
1734 |
base_attention_mask = base_attention_mask.view(1, 1, seq_len, seq_len)
|
1735 |
base_attention_mask = base_attention_mask.repeat(input_ids.shape[0], 1, 1, 1)
|
1736 |
attention_mask = base_attention_mask
|
1737 |
-
breakpoint()
|
1738 |
elif attention_mask.dim() == 2:
|
1739 |
if seq_len + past_key_values_length != attention_mask.shape[-1]:
|
1740 |
-
breakpoint()
|
1741 |
attention_mask = torch.cat(
|
1742 |
[torch.ones((attention_mask.shape[0], past_key_values_length), dtype=attention_mask.dtype, device=attention_mask.device), attention_mask],
|
1743 |
dim=-1
|
@@ -2190,33 +2130,6 @@ class QuietForCausalLM(QuietPreTrainedModel, GenerationMixin):
|
|
2190 |
attentions=outputs.attentions,
|
2191 |
)
|
2192 |
|
2193 |
-
|
2194 |
-
|
2195 |
-
def compute_complexity_scores(self, input_ids, attention_mask):
|
2196 |
-
# Compute complexity scores based on input sequence characteristics
|
2197 |
-
# Example: Normalize sequence lengths and consider the presence of rare tokens
|
2198 |
-
seq_lengths = torch.sum(attention_mask, dim=-1)
|
2199 |
-
max_length = torch.max(seq_lengths)
|
2200 |
-
length_scores = seq_lengths / max_length
|
2201 |
-
|
2202 |
-
# Compute the proportion of rare tokens in each sequence
|
2203 |
-
rare_token_ids = self.get_rare_token_ids()
|
2204 |
-
rare_token_mask = torch.isin(input_ids, rare_token_ids)
|
2205 |
-
rare_token_counts = torch.sum(rare_token_mask, dim=-1)
|
2206 |
-
rare_token_scores = rare_token_counts / seq_lengths
|
2207 |
-
|
2208 |
-
# Combine length scores and rare token scores
|
2209 |
-
complexity_scores = self.complexity_factor * length_scores + (1 - self.complexity_factor) * rare_token_scores
|
2210 |
-
return complexity_scores
|
2211 |
-
|
2212 |
-
def get_rare_token_ids(self):
|
2213 |
-
# Get the IDs of rare tokens based on a predefined frequency threshold
|
2214 |
-
frequency_threshold = 1e-4
|
2215 |
-
token_counts = torch.bincount(self.model.embed_tokens.weight.argmax(dim=-1))
|
2216 |
-
total_tokens = torch.sum(token_counts)
|
2217 |
-
rare_token_mask = token_counts / total_tokens < frequency_threshold
|
2218 |
-
rare_token_ids = torch.nonzero(rare_token_mask).squeeze(-1)
|
2219 |
-
return rare_token_ids
|
2220 |
|
2221 |
|
2222 |
def prepare_inputs_for_generation(
|
|
|
23 |
import pdb
|
24 |
import warnings
|
25 |
from collections import defaultdict
|
26 |
+
from typing import List, Optional, Tuple, Union
|
27 |
|
28 |
import torch
|
29 |
import torch.nn.functional as F
|
|
|
32 |
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
33 |
from transformers.generation.utils import GenerationMixin
|
34 |
from transformers.generation.stopping_criteria import StoppingCriteriaList, validate_stopping_criteria
|
35 |
+
from transformers import TextStreamer, AutoTokenizer
|
36 |
|
37 |
from transformers.activations import ACT2FN
|
38 |
from transformers.cache_utils import Cache, DynamicCache
|
|
|
774 |
raise ValueError(
|
775 |
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
|
776 |
)
|
777 |
+
attention_mask = attention_mask.to(query_states.dtype)
|
778 |
# SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
|
779 |
# Reference: https://github.com/pytorch/pytorch/issues/112577.
|
780 |
if query_states.device.type == "cuda" and attention_mask is not None:
|
|
|
1182 |
self.n_tokens_print = 1
|
1183 |
self.gradient_accumulation_steps = 1
|
1184 |
self.training_steps = 0
|
1185 |
+
self.tokenizer = AutoTokenizer.from_pretrained("Crystalcareai/Quiet-Star-Custom")
|
1186 |
self.start_token_id = None
|
1187 |
self.end_token_id = None
|
1188 |
self.rm_initialized = False
|
|
|
1238 |
self.embedding_scale = 1e2
|
1239 |
self.temperature = nn.Parameter(torch.ones(1))
|
1240 |
self.max_temperature = config.max_temperature
|
|
|
1241 |
self.reinforce_temperature = 3
|
1242 |
self.base_loss_beta = 1
|
1243 |
self.thinking_usefulness_head = nn.Linear(self.model.config.hidden_size, 1)
|
|
|
1424 |
logits = self.lm_head(mixed_hidden_states)
|
1425 |
return logits
|
1426 |
|
1427 |
+
def generate(self, input_ids, attention_mask=None, max_length=None, temperature=1.0, **kwargs):
|
1428 |
+
from .generate import generate
|
1429 |
+
return generate(self, input_ids=input_ids, attention_mask=attention_mask, max_length=max_length, temperature=temperature, **kwargs)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1430 |
|
1431 |
@add_start_docstrings_to_model_forward(QUIET_INPUTS_DOCSTRING)
|
1432 |
@replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
|
|
|
1616 |
sample_probs_history = []
|
1617 |
action_loglikelihoods_list = []
|
1618 |
|
1619 |
+
temperature = self.temperature
|
|
|
1620 |
|
1621 |
if self.use_end_thought_token or self.use_start_thought_token:
|
1622 |
if not self.use_reparam_for_thought_embeddings:
|
|
|
1674 |
base_attention_mask = base_attention_mask.view(1, 1, seq_len, seq_len)
|
1675 |
base_attention_mask = base_attention_mask.repeat(input_ids.shape[0], 1, 1, 1)
|
1676 |
attention_mask = base_attention_mask
|
1677 |
+
# breakpoint()
|
1678 |
elif attention_mask.dim() == 2:
|
1679 |
if seq_len + past_key_values_length != attention_mask.shape[-1]:
|
1680 |
+
# breakpoint()
|
1681 |
attention_mask = torch.cat(
|
1682 |
[torch.ones((attention_mask.shape[0], past_key_values_length), dtype=attention_mask.dtype, device=attention_mask.device), attention_mask],
|
1683 |
dim=-1
|
|
|
2130 |
attentions=outputs.attentions,
|
2131 |
)
|
2132 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2133 |
|
2134 |
|
2135 |
def prepare_inputs_for_generation(
|