Crystalcareai
commited on
Update modeling_quiet.py
Browse files- modeling_quiet.py +156 -64
modeling_quiet.py
CHANGED
@@ -32,11 +32,11 @@ from torch import nn
|
|
32 |
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
33 |
from transformers.generation.utils import GenerationMixin
|
34 |
from transformers.generation.stopping_criteria import StoppingCriteriaList, validate_stopping_criteria
|
35 |
-
from transformers import TextStreamer
|
36 |
|
37 |
from transformers.activations import ACT2FN
|
38 |
from transformers.cache_utils import Cache, DynamicCache
|
39 |
-
from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask
|
40 |
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast
|
41 |
from transformers.modeling_utils import PreTrainedModel
|
42 |
from transformers.utils import (
|
@@ -65,62 +65,62 @@ logger = logging.get_logger(__name__)
|
|
65 |
_CONFIG_FOR_DOC = "QuietConfig"
|
66 |
|
67 |
|
68 |
-
|
69 |
-
#
|
70 |
-
|
71 |
-
|
72 |
-
#
|
73 |
-
#
|
74 |
-
#
|
75 |
-
|
76 |
-
|
77 |
-
#
|
78 |
-
#
|
79 |
-
|
80 |
-
|
81 |
-
#
|
82 |
-
#
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
#
|
87 |
-
#
|
88 |
-
|
89 |
-
#
|
90 |
-
#
|
91 |
-
#
|
92 |
-
|
93 |
-
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
-
|
98 |
-
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
#
|
105 |
-
#
|
106 |
-
#
|
107 |
-
#
|
108 |
-
#
|
109 |
-
|
110 |
-
#
|
111 |
-
|
112 |
-
|
113 |
-
#
|
114 |
-
|
115 |
-
|
116 |
-
#
|
117 |
-
|
118 |
-
|
119 |
-
|
120 |
-
|
121 |
-
|
122 |
-
|
123 |
-
|
124 |
|
125 |
|
126 |
# Copied from transformers.models.llama.modeling_llama._get_unpad_data
|
@@ -774,7 +774,7 @@ class QuietSdpaAttention(QuietAttention):
|
|
774 |
raise ValueError(
|
775 |
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
|
776 |
)
|
777 |
-
|
778 |
# SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
|
779 |
# Reference: https://github.com/pytorch/pytorch/issues/112577.
|
780 |
if query_states.device.type == "cuda" and attention_mask is not None:
|
@@ -1182,7 +1182,7 @@ class QuietForCausalLM(QuietPreTrainedModel, GenerationMixin):
|
|
1182 |
self.n_tokens_print = 1
|
1183 |
self.gradient_accumulation_steps = 1
|
1184 |
self.training_steps = 0
|
1185 |
-
self.tokenizer =
|
1186 |
self.start_token_id = None
|
1187 |
self.end_token_id = None
|
1188 |
self.rm_initialized = False
|
@@ -1238,6 +1238,7 @@ class QuietForCausalLM(QuietPreTrainedModel, GenerationMixin):
|
|
1238 |
self.embedding_scale = 1e2
|
1239 |
self.temperature = nn.Parameter(torch.ones(1))
|
1240 |
self.max_temperature = config.max_temperature
|
|
|
1241 |
self.reinforce_temperature = 3
|
1242 |
self.base_loss_beta = 1
|
1243 |
self.thinking_usefulness_head = nn.Linear(self.model.config.hidden_size, 1)
|
@@ -1424,9 +1425,67 @@ class QuietForCausalLM(QuietPreTrainedModel, GenerationMixin):
|
|
1424 |
logits = self.lm_head(mixed_hidden_states)
|
1425 |
return logits
|
1426 |
|
1427 |
-
|
1428 |
-
|
1429 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1430 |
|
1431 |
@add_start_docstrings_to_model_forward(QUIET_INPUTS_DOCSTRING)
|
1432 |
@replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
|
@@ -1616,7 +1675,8 @@ class QuietForCausalLM(QuietPreTrainedModel, GenerationMixin):
|
|
1616 |
sample_probs_history = []
|
1617 |
action_loglikelihoods_list = []
|
1618 |
|
1619 |
-
|
|
|
1620 |
|
1621 |
if self.use_end_thought_token or self.use_start_thought_token:
|
1622 |
if not self.use_reparam_for_thought_embeddings:
|
@@ -1674,12 +1734,15 @@ class QuietForCausalLM(QuietPreTrainedModel, GenerationMixin):
|
|
1674 |
base_attention_mask = base_attention_mask.view(1, 1, seq_len, seq_len)
|
1675 |
base_attention_mask = base_attention_mask.repeat(input_ids.shape[0], 1, 1, 1)
|
1676 |
attention_mask = base_attention_mask
|
|
|
1677 |
elif attention_mask.dim() == 2:
|
1678 |
if seq_len + past_key_values_length != attention_mask.shape[-1]:
|
|
|
1679 |
attention_mask = torch.cat(
|
1680 |
[torch.ones((attention_mask.shape[0], past_key_values_length), dtype=attention_mask.dtype, device=attention_mask.device), attention_mask],
|
1681 |
dim=-1
|
1682 |
)
|
|
|
1683 |
attention_mask = _prepare_4d_causal_attention_mask(
|
1684 |
attention_mask,
|
1685 |
(batch_size, seq_len),
|
@@ -1697,8 +1760,10 @@ class QuietForCausalLM(QuietPreTrainedModel, GenerationMixin):
|
|
1697 |
use_cache=use_cache,
|
1698 |
output_attentions=output_attentions,
|
1699 |
output_hidden_states=output_hidden_states,
|
|
|
1700 |
return_dict=return_dict,
|
1701 |
)
|
|
|
1702 |
prev_hidden_states = hidden_states
|
1703 |
hidden_states = outputs[0]
|
1704 |
prev_rm_logits = rm_logits # for policy gradient
|
@@ -2125,6 +2190,33 @@ class QuietForCausalLM(QuietPreTrainedModel, GenerationMixin):
|
|
2125 |
attentions=outputs.attentions,
|
2126 |
)
|
2127 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2128 |
|
2129 |
|
2130 |
def prepare_inputs_for_generation(
|
|
|
32 |
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
33 |
from transformers.generation.utils import GenerationMixin
|
34 |
from transformers.generation.stopping_criteria import StoppingCriteriaList, validate_stopping_criteria
|
35 |
+
from transformers import TextStreamer
|
36 |
|
37 |
from transformers.activations import ACT2FN
|
38 |
from transformers.cache_utils import Cache, DynamicCache
|
39 |
+
from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask
|
40 |
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast
|
41 |
from transformers.modeling_utils import PreTrainedModel
|
42 |
from transformers.utils import (
|
|
|
65 |
_CONFIG_FOR_DOC = "QuietConfig"
|
66 |
|
67 |
|
68 |
+
def _prepare_4d_causal_attention_mask_for_sdpa(attention_mask, input_shape, inputs_embeds, past_key_values_length):
|
69 |
+
# Compute the attention mask correctly
|
70 |
+
bsz, tgt_len = input_shape
|
71 |
+
|
72 |
+
# Create a 4D attention mask from a 2D tensor mask.
|
73 |
+
# The shape of the output attention mask is (batch_size, 1, tgt_len, src_len)
|
74 |
+
# The values are either 0 or 1, where 0 means padding and 1 means non-padding.
|
75 |
+
combined_attention_mask = None
|
76 |
+
if attention_mask is not None:
|
77 |
+
# What if attention_mask is not None and has a shape of (batch_size, 1, tgt_len, src_len)
|
78 |
+
# In this case, we can just use it directly.
|
79 |
+
if attention_mask.dim() == 4:
|
80 |
+
combined_attention_mask = attention_mask
|
81 |
+
# What if attention_mask is not None and has a shape of (batch_size, 1, tgt_len)
|
82 |
+
# In this case, we need to expand it to (batch_size, 1, tgt_len, src_len)
|
83 |
+
elif attention_mask.dim() == 3:
|
84 |
+
expanded_attn_mask = attention_mask[:, None, :, :]
|
85 |
+
combined_attention_mask = expanded_attn_mask
|
86 |
+
# What if attention_mask is not None and has a shape of (batch_size, tgt_len)
|
87 |
+
# In this case, we need to expand it to (batch_size, 1, tgt_len, src_len)
|
88 |
+
elif attention_mask.dim() == 2:
|
89 |
+
# Provided a padding mask of dimensions [batch_size, seq_length]
|
90 |
+
# - if the model is a decoder, apply a causal mask in addition to the padding mask
|
91 |
+
# - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length]
|
92 |
+
if past_key_values_length > 0:
|
93 |
+
attention_mask = attention_mask.to(dtype=torch.long)
|
94 |
+
attention_mask = attention_mask[:, past_key_values_length:]
|
95 |
+
expanded_attn_mask = attention_mask[:, None, None, :]
|
96 |
+
combined_attention_mask = expanded_attn_mask
|
97 |
+
else:
|
98 |
+
raise ValueError(
|
99 |
+
"Wrong shape for input_ids (shape {}) or attention_mask (shape {})".format(
|
100 |
+
input_shape, attention_mask.shape
|
101 |
+
)
|
102 |
+
)
|
103 |
+
|
104 |
+
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
|
105 |
+
# masked positions, this operation will create a tensor which is 0.0 for
|
106 |
+
# positions we want to attend and -10000.0 for masked positions.
|
107 |
+
# Since we are adding it to the raw scores before the softmax, this is
|
108 |
+
# effectively the same as removing these entirely.
|
109 |
+
if combined_attention_mask is not None:
|
110 |
+
# Ensure the attention mask values are within a reasonable range
|
111 |
+
combined_attention_mask = combined_attention_mask.clamp(min=0, max=1)
|
112 |
+
|
113 |
+
# Convert the attention mask to bfloat16
|
114 |
+
combined_attention_mask = combined_attention_mask.to(torch.bfloat16)
|
115 |
+
|
116 |
+
# Normalize the attention mask values to be between 0 and 1
|
117 |
+
combined_attention_mask = (1.0 - combined_attention_mask) * -10000.0
|
118 |
+
else:
|
119 |
+
combined_attention_mask = torch.zeros(
|
120 |
+
(bsz, 1, tgt_len, tgt_len), dtype=torch.bfloat16, device=inputs_embeds.device
|
121 |
+
)
|
122 |
+
|
123 |
+
return combined_attention_mask
|
124 |
|
125 |
|
126 |
# Copied from transformers.models.llama.modeling_llama._get_unpad_data
|
|
|
774 |
raise ValueError(
|
775 |
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
|
776 |
)
|
777 |
+
|
778 |
# SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
|
779 |
# Reference: https://github.com/pytorch/pytorch/issues/112577.
|
780 |
if query_states.device.type == "cuda" and attention_mask is not None:
|
|
|
1182 |
self.n_tokens_print = 1
|
1183 |
self.gradient_accumulation_steps = 1
|
1184 |
self.training_steps = 0
|
1185 |
+
self.tokenizer = None
|
1186 |
self.start_token_id = None
|
1187 |
self.end_token_id = None
|
1188 |
self.rm_initialized = False
|
|
|
1238 |
self.embedding_scale = 1e2
|
1239 |
self.temperature = nn.Parameter(torch.ones(1))
|
1240 |
self.max_temperature = config.max_temperature
|
1241 |
+
self.complexity_factor = config.complexity_factor
|
1242 |
self.reinforce_temperature = 3
|
1243 |
self.base_loss_beta = 1
|
1244 |
self.thinking_usefulness_head = nn.Linear(self.model.config.hidden_size, 1)
|
|
|
1425 |
logits = self.lm_head(mixed_hidden_states)
|
1426 |
return logits
|
1427 |
|
1428 |
+
@torch.no_grad()
|
1429 |
+
def generate(
|
1430 |
+
self,
|
1431 |
+
input_ids: Optional[torch.LongTensor] = None,
|
1432 |
+
max_length: Optional[int] = None,
|
1433 |
+
min_length: Optional[int] = None,
|
1434 |
+
do_sample: Optional[bool] = None,
|
1435 |
+
early_stopping: Optional[bool] = None,
|
1436 |
+
num_beams: Optional[int] = None,
|
1437 |
+
temperature: Optional[float] = None,
|
1438 |
+
top_k: Optional[int] = None,
|
1439 |
+
top_p: Optional[float] = None,
|
1440 |
+
repetition_penalty: Optional[float] = None,
|
1441 |
+
bad_words_ids: Optional[Iterable[int]] = None,
|
1442 |
+
bos_token_id: Optional[int] = None,
|
1443 |
+
pad_token_id: Optional[int] = None,
|
1444 |
+
eos_token_id: Optional[int] = None,
|
1445 |
+
length_penalty: Optional[float] = None,
|
1446 |
+
no_repeat_ngram_size: Optional[int] = None,
|
1447 |
+
encoder_no_repeat_ngram_size: Optional[int] = None,
|
1448 |
+
num_return_sequences: Optional[int] = None,
|
1449 |
+
max_time: Optional[float] = None,
|
1450 |
+
max_new_tokens: Optional[int] = None,
|
1451 |
+
decoder_start_token_id: Optional[int] = None,
|
1452 |
+
use_cache: Optional[bool] = None,
|
1453 |
+
num_beam_groups: Optional[int] = None,
|
1454 |
+
diversity_penalty: Optional[float] = None,
|
1455 |
+
prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None,
|
1456 |
+
output_attentions: Optional[bool] = None,
|
1457 |
+
output_hidden_states: Optional[bool] = None,
|
1458 |
+
output_scores: Optional[bool] = None,
|
1459 |
+
return_dict_in_generate: Optional[bool] = None,
|
1460 |
+
forced_bos_token_id: Optional[int] = None,
|
1461 |
+
forced_eos_token_id: Optional[int] = None,
|
1462 |
+
remove_invalid_values: Optional[bool] = None,
|
1463 |
+
synced_gpus: Optional[bool] = False,
|
1464 |
+
stopping_criteria: Optional[StoppingCriteriaList] = None,
|
1465 |
+
**model_kwargs,
|
1466 |
+
):
|
1467 |
+
# Validate stopping criteria
|
1468 |
+
stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length)
|
1469 |
+
|
1470 |
+
# Set default values for attention_mask and max_new_tokens
|
1471 |
+
if "attention_mask" not in model_kwargs:
|
1472 |
+
attention_mask = torch.where(input_ids != self.tokenizer.pad_token_id, 1, 0).to(input_ids.device)
|
1473 |
+
model_kwargs["attention_mask"] = attention_mask
|
1474 |
+
if max_new_tokens is None:
|
1475 |
+
max_new_tokens = 512
|
1476 |
+
|
1477 |
+
streamer = TextStreamer(self.tokenizer, skip_prompt=False, skip_special_tokens=True)
|
1478 |
+
|
1479 |
+
# Call the custom generate function
|
1480 |
+
output_ids, _ = custom_generate(
|
1481 |
+
self,
|
1482 |
+
input_ids=input_ids,
|
1483 |
+
streamer=streamer,
|
1484 |
+
max_new_tokens=max_new_tokens,
|
1485 |
+
**model_kwargs,
|
1486 |
+
)
|
1487 |
+
|
1488 |
+
return output_ids
|
1489 |
|
1490 |
@add_start_docstrings_to_model_forward(QUIET_INPUTS_DOCSTRING)
|
1491 |
@replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
|
|
|
1675 |
sample_probs_history = []
|
1676 |
action_loglikelihoods_list = []
|
1677 |
|
1678 |
+
complexity_scores = self.compute_complexity_scores(input_ids, attention_mask)
|
1679 |
+
temperature = self.temperature * complexity_scores.unsqueeze(-1)
|
1680 |
|
1681 |
if self.use_end_thought_token or self.use_start_thought_token:
|
1682 |
if not self.use_reparam_for_thought_embeddings:
|
|
|
1734 |
base_attention_mask = base_attention_mask.view(1, 1, seq_len, seq_len)
|
1735 |
base_attention_mask = base_attention_mask.repeat(input_ids.shape[0], 1, 1, 1)
|
1736 |
attention_mask = base_attention_mask
|
1737 |
+
breakpoint()
|
1738 |
elif attention_mask.dim() == 2:
|
1739 |
if seq_len + past_key_values_length != attention_mask.shape[-1]:
|
1740 |
+
breakpoint()
|
1741 |
attention_mask = torch.cat(
|
1742 |
[torch.ones((attention_mask.shape[0], past_key_values_length), dtype=attention_mask.dtype, device=attention_mask.device), attention_mask],
|
1743 |
dim=-1
|
1744 |
)
|
1745 |
+
# # if the attention mask
|
1746 |
attention_mask = _prepare_4d_causal_attention_mask(
|
1747 |
attention_mask,
|
1748 |
(batch_size, seq_len),
|
|
|
1760 |
use_cache=use_cache,
|
1761 |
output_attentions=output_attentions,
|
1762 |
output_hidden_states=output_hidden_states,
|
1763 |
+
# output_router_logits=output_router_logits,
|
1764 |
return_dict=return_dict,
|
1765 |
)
|
1766 |
+
|
1767 |
prev_hidden_states = hidden_states
|
1768 |
hidden_states = outputs[0]
|
1769 |
prev_rm_logits = rm_logits # for policy gradient
|
|
|
2190 |
attentions=outputs.attentions,
|
2191 |
)
|
2192 |
|
2193 |
+
|
2194 |
+
|
2195 |
+
def compute_complexity_scores(self, input_ids, attention_mask):
|
2196 |
+
# Compute complexity scores based on input sequence characteristics
|
2197 |
+
# Example: Normalize sequence lengths and consider the presence of rare tokens
|
2198 |
+
seq_lengths = torch.sum(attention_mask, dim=-1)
|
2199 |
+
max_length = torch.max(seq_lengths)
|
2200 |
+
length_scores = seq_lengths / max_length
|
2201 |
+
|
2202 |
+
# Compute the proportion of rare tokens in each sequence
|
2203 |
+
rare_token_ids = self.get_rare_token_ids()
|
2204 |
+
rare_token_mask = torch.isin(input_ids, rare_token_ids)
|
2205 |
+
rare_token_counts = torch.sum(rare_token_mask, dim=-1)
|
2206 |
+
rare_token_scores = rare_token_counts / seq_lengths
|
2207 |
+
|
2208 |
+
# Combine length scores and rare token scores
|
2209 |
+
complexity_scores = self.complexity_factor * length_scores + (1 - self.complexity_factor) * rare_token_scores
|
2210 |
+
return complexity_scores
|
2211 |
+
|
2212 |
+
def get_rare_token_ids(self):
|
2213 |
+
# Get the IDs of rare tokens based on a predefined frequency threshold
|
2214 |
+
frequency_threshold = 1e-4
|
2215 |
+
token_counts = torch.bincount(self.model.embed_tokens.weight.argmax(dim=-1))
|
2216 |
+
total_tokens = torch.sum(token_counts)
|
2217 |
+
rare_token_mask = token_counts / total_tokens < frequency_threshold
|
2218 |
+
rare_token_ids = torch.nonzero(rare_token_mask).squeeze(-1)
|
2219 |
+
return rare_token_ids
|
2220 |
|
2221 |
|
2222 |
def prepare_inputs_for_generation(
|