Crystalcareai
commited on
Update modeling_quiet.py
Browse files- modeling_quiet.py +40 -6
modeling_quiet.py
CHANGED
@@ -1246,6 +1246,9 @@ class QuietForCausalLM(QuietPreTrainedModel):
|
|
1246 |
|
1247 |
self.policy_loss_beta = 1e6
|
1248 |
self.embedding_scale = 1e2
|
|
|
|
|
|
|
1249 |
self.reinforce_temperature = 3
|
1250 |
self.base_loss_beta = 1
|
1251 |
self.thinking_usefulness_head = nn.Linear(self.model.config.hidden_size, 1)
|
@@ -1626,16 +1629,20 @@ class QuietForCausalLM(QuietPreTrainedModel):
|
|
1626 |
sample_probs_history = []
|
1627 |
action_loglikelihoods_list = []
|
1628 |
|
|
|
|
|
|
|
1629 |
if self.use_end_thought_token or self.use_start_thought_token:
|
1630 |
if not self.use_reparam_for_thought_embeddings:
|
1631 |
-
start_embedding = self.start_embedding[0].unsqueeze(0) * self.embedding_scale
|
1632 |
-
end_embedding = self.end_embedding[0].unsqueeze(0) * self.embedding_scale
|
1633 |
else:
|
1634 |
-
start_embedding = self.start_embedding * self.embedding_scale
|
1635 |
-
end_embedding = self.end_embedding * self.embedding_scale
|
1636 |
base_embeddings = self.model.embed_tokens.weight
|
1637 |
if self.train_only_thinking_embedding:
|
1638 |
base_embeddings = base_embeddings.detach()
|
|
|
1639 |
# # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
1640 |
fwd_iters = 1 if self.original_mode else self.n_ahead + self.n_ahead_talk - 1
|
1641 |
for ahead_idx in range(fwd_iters):
|
@@ -1900,9 +1907,10 @@ class QuietForCausalLM(QuietPreTrainedModel):
|
|
1900 |
contains_end = self.use_end_thought_token and (probabilities_2d[..., self.end_token_id].sum() > 0)
|
1901 |
contains_thought = contains_start or contains_end
|
1902 |
|
|
|
1903 |
if not contains_thought:
|
1904 |
with torch.set_grad_enabled(not self.train_only_thinking_embedding):
|
1905 |
-
inputs_embeds = probabilities_2d @ (self.model.embed_tokens.weight.to(probabilities.device).to(probabilities.dtype))
|
1906 |
else:
|
1907 |
thought_id = self.start_token_id if contains_start else self.end_token_id
|
1908 |
cur_thought_embedding = start_embedding if contains_start else end_embedding
|
@@ -1915,7 +1923,7 @@ class QuietForCausalLM(QuietPreTrainedModel):
|
|
1915 |
sampled_end = inputs_embeds.clone().detach()
|
1916 |
else:
|
1917 |
inputs_embeds = cur_thought_embedding.unsqueeze(0).repeat(batch_size, seq_len, 1)
|
1918 |
-
|
1919 |
inputs_embeds = inputs_embeds.view(probabilities.size(0), probabilities.size(1), -1).to(self.model.embed_tokens.weight.dtype)
|
1920 |
|
1921 |
# Predict the usefulness of thinking at each token position
|
@@ -2127,6 +2135,32 @@ class QuietForCausalLM(QuietPreTrainedModel):
|
|
2127 |
hidden_states=outputs.hidden_states,
|
2128 |
attentions=outputs.attentions,
|
2129 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2130 |
|
2131 |
|
2132 |
def prepare_inputs_for_generation(
|
|
|
1246 |
|
1247 |
self.policy_loss_beta = 1e6
|
1248 |
self.embedding_scale = 1e2
|
1249 |
+
self.temperature = nn.Parameter(torch.tensor(1.0))
|
1250 |
+
self.max_temperature = config.max_temperature
|
1251 |
+
self.complexity_factor = config.complexity_factor
|
1252 |
self.reinforce_temperature = 3
|
1253 |
self.base_loss_beta = 1
|
1254 |
self.thinking_usefulness_head = nn.Linear(self.model.config.hidden_size, 1)
|
|
|
1629 |
sample_probs_history = []
|
1630 |
action_loglikelihoods_list = []
|
1631 |
|
1632 |
+
complexity_scores = self.compute_complexity_scores(input_ids, attention_mask)
|
1633 |
+
temperature = self.temperature * complexity_scores.unsqueeze(-1)
|
1634 |
+
|
1635 |
if self.use_end_thought_token or self.use_start_thought_token:
|
1636 |
if not self.use_reparam_for_thought_embeddings:
|
1637 |
+
start_embedding = self.start_embedding[0].unsqueeze(0) * self.embedding_scale * temperature
|
1638 |
+
end_embedding = self.end_embedding[0].unsqueeze(0) * self.embedding_scale * temperature
|
1639 |
else:
|
1640 |
+
start_embedding = self.start_embedding * self.embedding_scale * temperature
|
1641 |
+
end_embedding = self.end_embedding * self.embedding_scale * temperature
|
1642 |
base_embeddings = self.model.embed_tokens.weight
|
1643 |
if self.train_only_thinking_embedding:
|
1644 |
base_embeddings = base_embeddings.detach()
|
1645 |
+
|
1646 |
# # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
1647 |
fwd_iters = 1 if self.original_mode else self.n_ahead + self.n_ahead_talk - 1
|
1648 |
for ahead_idx in range(fwd_iters):
|
|
|
1907 |
contains_end = self.use_end_thought_token and (probabilities_2d[..., self.end_token_id].sum() > 0)
|
1908 |
contains_thought = contains_start or contains_end
|
1909 |
|
1910 |
+
|
1911 |
if not contains_thought:
|
1912 |
with torch.set_grad_enabled(not self.train_only_thinking_embedding):
|
1913 |
+
inputs_embeds = probabilities_2d @ (self.model.embed_tokens.weight.to(probabilities.device).to(probabilities.dtype) * temperature)
|
1914 |
else:
|
1915 |
thought_id = self.start_token_id if contains_start else self.end_token_id
|
1916 |
cur_thought_embedding = start_embedding if contains_start else end_embedding
|
|
|
1923 |
sampled_end = inputs_embeds.clone().detach()
|
1924 |
else:
|
1925 |
inputs_embeds = cur_thought_embedding.unsqueeze(0).repeat(batch_size, seq_len, 1)
|
1926 |
+
inputs_embeds = inputs_embeds.view(probabilities.size(0), probabilities.size(1), -1).to(self.model.embed_tokens.weight.dtype)
|
1927 |
inputs_embeds = inputs_embeds.view(probabilities.size(0), probabilities.size(1), -1).to(self.model.embed_tokens.weight.dtype)
|
1928 |
|
1929 |
# Predict the usefulness of thinking at each token position
|
|
|
2135 |
hidden_states=outputs.hidden_states,
|
2136 |
attentions=outputs.attentions,
|
2137 |
)
|
2138 |
+
|
2139 |
+
def compute_complexity_scores(self, input_ids, attention_mask):
|
2140 |
+
# Compute complexity scores based on input sequence characteristics
|
2141 |
+
# Example: Normalize sequence lengths and consider the presence of rare tokens
|
2142 |
+
seq_lengths = torch.sum(attention_mask, dim=-1)
|
2143 |
+
max_length = torch.max(seq_lengths)
|
2144 |
+
length_scores = seq_lengths / max_length
|
2145 |
+
|
2146 |
+
# Compute the proportion of rare tokens in each sequence
|
2147 |
+
rare_token_ids = self.get_rare_token_ids()
|
2148 |
+
rare_token_mask = torch.isin(input_ids, rare_token_ids)
|
2149 |
+
rare_token_counts = torch.sum(rare_token_mask, dim=-1)
|
2150 |
+
rare_token_scores = rare_token_counts / seq_lengths
|
2151 |
+
|
2152 |
+
# Combine length scores and rare token scores
|
2153 |
+
complexity_scores = self.complexity_factor * length_scores + (1 - self.complexity_factor) * rare_token_scores
|
2154 |
+
return complexity_scores
|
2155 |
+
|
2156 |
+
def get_rare_token_ids(self):
|
2157 |
+
# Get the IDs of rare tokens based on a predefined frequency threshold
|
2158 |
+
frequency_threshold = 1e-4
|
2159 |
+
token_counts = torch.bincount(self.model.embed_tokens.weight.argmax(dim=-1))
|
2160 |
+
total_tokens = torch.sum(token_counts)
|
2161 |
+
rare_token_mask = token_counts / total_tokens < frequency_threshold
|
2162 |
+
rare_token_ids = torch.nonzero(rare_token_mask).squeeze(-1)
|
2163 |
+
return rare_token_ids
|
2164 |
|
2165 |
|
2166 |
def prepare_inputs_for_generation(
|