Crystalcareai
commited on
Update modeling_quiet.py
Browse files- modeling_quiet.py +9 -58
modeling_quiet.py
CHANGED
@@ -18,9 +18,7 @@
|
|
18 |
# See the License for the specific language governing permissions and
|
19 |
# limitations under the License.
|
20 |
""" PyTorch Quiet model."""
|
21 |
-
import inspect
|
22 |
import math
|
23 |
-
import pdb
|
24 |
import warnings
|
25 |
from collections import defaultdict
|
26 |
from typing import List, Optional, Tuple, Union
|
@@ -31,8 +29,7 @@ import torch.utils.checkpoint
|
|
31 |
from torch import nn
|
32 |
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
33 |
from transformers.generation.utils import GenerationMixin
|
34 |
-
from transformers
|
35 |
-
from transformers import TextStreamer, AutoTokenizer
|
36 |
import transformers
|
37 |
|
38 |
from transformers.activations import ACT2FN
|
@@ -43,8 +40,6 @@ from transformers.modeling_utils import PreTrainedModel
|
|
43 |
from transformers.utils import (
|
44 |
add_start_docstrings,
|
45 |
add_start_docstrings_to_model_forward,
|
46 |
-
is_flash_attn_2_available,
|
47 |
-
is_flash_attn_greater_or_equal_2_10,
|
48 |
logging,
|
49 |
replace_return_docstrings,
|
50 |
)
|
@@ -240,7 +235,6 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
|
|
240 |
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
|
241 |
"""
|
242 |
|
243 |
-
# pdb.set_trace()
|
244 |
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
|
245 |
if n_rep == 1:
|
246 |
return hidden_states
|
@@ -332,7 +326,7 @@ class QuietAttention(nn.Module):
|
|
332 |
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
|
333 |
|
334 |
if past_key_value is not None:
|
335 |
-
cache_kwargs = {"sin": sin, "cos": cos}
|
336 |
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
337 |
|
338 |
# repeat k/v heads if n_kv_heads < n_heads
|
@@ -377,8 +371,7 @@ class QuietAttention(nn.Module):
|
|
377 |
)
|
378 |
|
379 |
attn_weights = attn_weights + attention_mask
|
380 |
-
|
381 |
-
# upcast attention to fp32
|
382 |
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
|
383 |
attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
|
384 |
attn_output = torch.matmul(attn_weights, value_states)
|
@@ -851,16 +844,12 @@ class QuietForCausalLM(QuietPreTrainedModel, GenerationMixin):
|
|
851 |
self.model = QuietModel(config)
|
852 |
self.vocab_size = config.vocab_size
|
853 |
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
854 |
-
# self.router_aux_loss_coef = config.router_aux_loss_coef
|
855 |
-
# self.num_experts = config.num_experts
|
856 |
-
# self.num_experts_per_tok = config.num_experts_per_tok
|
857 |
self.max_thoughts = config.max_thoughts
|
858 |
self.merged_lm_and_talk_heads = config.merged_lm_and_talk_heads
|
859 |
self.use_concat_talk_head = config.use_concat_talk_head
|
860 |
self.use_shallow_talk = config.use_shallow_talk
|
861 |
self.use_complex_talk_head = config.use_complex_talk_head
|
862 |
self.use_weighted_talk_head = config.use_weighted_talk_head
|
863 |
-
# the weighted head will output a single value, so it can't be passed to the lm head
|
864 |
assert not (self.use_weighted_talk_head and self.use_shallow_talk)
|
865 |
|
866 |
self.n_ahead = 1
|
@@ -931,7 +920,6 @@ class QuietForCausalLM(QuietPreTrainedModel, GenerationMixin):
|
|
931 |
self.thinking_threshold = 0.5
|
932 |
self.thinking_usefulness_loss_weight = 1e-2
|
933 |
|
934 |
-
# Not used in the paper:
|
935 |
self.use_thought_prefix = False
|
936 |
self.use_reparam_for_thought_embeddings = False
|
937 |
self.use_upper_triangular = False
|
@@ -939,7 +927,6 @@ class QuietForCausalLM(QuietPreTrainedModel, GenerationMixin):
|
|
939 |
self.comparison_mode = False
|
940 |
self.gumbel_detach = False
|
941 |
|
942 |
-
# For visualization
|
943 |
self.eval_mode = False
|
944 |
|
945 |
num_talk = 1
|
@@ -968,7 +955,6 @@ class QuietForCausalLM(QuietPreTrainedModel, GenerationMixin):
|
|
968 |
# Add dropout regularization
|
969 |
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
970 |
|
971 |
-
# Initialize weights and apply final processing
|
972 |
self.post_init()
|
973 |
|
974 |
def get_input_embeddings(self):
|
@@ -1219,20 +1205,6 @@ class QuietForCausalLM(QuietPreTrainedModel, GenerationMixin):
|
|
1219 |
n_passes_to_restore = self.n_passes
|
1220 |
self.n_ahead_talk = 1
|
1221 |
self.n_passes = 1
|
1222 |
-
|
1223 |
-
# aux_loss = None
|
1224 |
-
# output_router_logits = output_router_logits if output_router_logits is not None else self.config.output_router_logits
|
1225 |
-
# if output_router_logits:
|
1226 |
-
# router_logits = outputs.router_logits if return_dict else outputs[-1]
|
1227 |
-
# if router_logits is not None:
|
1228 |
-
# aux_loss = load_balancing_loss_func(
|
1229 |
-
# router_logits,
|
1230 |
-
# self.num_experts,
|
1231 |
-
# self.num_experts_per_tok,
|
1232 |
-
# attention_mask,
|
1233 |
-
# )
|
1234 |
-
# if labels is not None:
|
1235 |
-
# loss += self.router_aux_loss_coef * aux_loss.to(loss.device)
|
1236 |
if input_ids.dim() == 1:
|
1237 |
input_ids = input_ids.unsqueeze(0)
|
1238 |
attention_mask = attention_mask.unsqueeze(0) if attention_mask is not None else None
|
@@ -1300,7 +1272,6 @@ class QuietForCausalLM(QuietPreTrainedModel, GenerationMixin):
|
|
1300 |
self.start_token_id = self.tokenizer.bos_token_id
|
1301 |
self.tokenizer_has_start_thought_token = False
|
1302 |
elif self.use_start_thought_token:
|
1303 |
-
# base_start_id = self.tokenizer.convert_tokens_to_ids(self.initial_start_token)
|
1304 |
base_start_id = self.tokenizer.encode(self.initial_start_token, add_special_tokens=False)[0]
|
1305 |
if self.initialize_thought_embedding_to_normal:
|
1306 |
self.start_embedding.data = torch.zeros_like(self.start_embedding.data)
|
@@ -1313,7 +1284,6 @@ class QuietForCausalLM(QuietPreTrainedModel, GenerationMixin):
|
|
1313 |
self.end_token_id = self.tokenizer.eos_token_id
|
1314 |
self.tokenizer_has_end_thought_token = False
|
1315 |
elif self.use_end_thought_token:
|
1316 |
-
# base_end_id = self.tokenizer.convert_tokens_to_ids(self.initial_end_token)
|
1317 |
base_end_id = self.tokenizer.encode(self.initial_end_token, add_special_tokens=False)[0]
|
1318 |
if self.initialize_thought_embedding_to_normal:
|
1319 |
self.end_embedding.data = torch.zeros_like(self.end_embedding.data)
|
@@ -1332,7 +1302,6 @@ class QuietForCausalLM(QuietPreTrainedModel, GenerationMixin):
|
|
1332 |
else:
|
1333 |
# convert to identity transform
|
1334 |
def lambda_transform(cur_head):
|
1335 |
-
# pdb.set_trace()
|
1336 |
if cur_head.weight.data.shape[0] != cur_head.weight.data.shape[1]:
|
1337 |
return torch.cat([
|
1338 |
torch.eye(
|
@@ -1360,28 +1329,23 @@ class QuietForCausalLM(QuietPreTrainedModel, GenerationMixin):
|
|
1360 |
self.talk_head[-1].weight.data = lambda_transform(self.talk_head[0])
|
1361 |
|
1362 |
loss = None
|
1363 |
-
prev_rm_tokens = None
|
1364 |
cur_rm_tokens = None
|
1365 |
-
prev_rm_logits = None
|
1366 |
prev_sample_probs = None
|
1367 |
did_skip_sampling = None
|
1368 |
skip_sampling = None
|
1369 |
sample_probs = None
|
1370 |
hidden_states = None
|
1371 |
logits = None
|
1372 |
-
talk_kl_penalty = None
|
1373 |
rm_logits = None
|
1374 |
residual_logits = None
|
1375 |
probabilities_2d = None
|
1376 |
prev_probabilities_2d = None
|
1377 |
policy_reward = None
|
1378 |
-
logits_to_output = None
|
1379 |
batch_size, seq_len = input_ids.shape
|
1380 |
base_input_ids = input_ids.clone()
|
1381 |
loss_list = []
|
1382 |
dqn_loss_list = []
|
1383 |
sampled_token_history = []
|
1384 |
-
sample_probs_history = []
|
1385 |
action_loglikelihoods_list = []
|
1386 |
|
1387 |
temperature = self.temperature
|
@@ -1397,7 +1361,7 @@ class QuietForCausalLM(QuietPreTrainedModel, GenerationMixin):
|
|
1397 |
if self.train_only_thinking_embedding:
|
1398 |
base_embeddings = base_embeddings.detach()
|
1399 |
|
1400 |
-
#
|
1401 |
fwd_iters = 1 if self.original_mode else self.n_ahead + self.n_ahead_talk - 1
|
1402 |
for ahead_idx in range(fwd_iters):
|
1403 |
past_key_values_length = 0
|
@@ -1442,15 +1406,12 @@ class QuietForCausalLM(QuietPreTrainedModel, GenerationMixin):
|
|
1442 |
base_attention_mask = base_attention_mask.view(1, 1, seq_len, seq_len)
|
1443 |
base_attention_mask = base_attention_mask.repeat(input_ids.shape[0], 1, 1, 1)
|
1444 |
attention_mask = base_attention_mask
|
1445 |
-
# breakpoint()
|
1446 |
elif attention_mask.dim() == 2:
|
1447 |
if seq_len + past_key_values_length != attention_mask.shape[-1]:
|
1448 |
-
# breakpoint()
|
1449 |
attention_mask = torch.cat(
|
1450 |
[torch.ones((attention_mask.shape[0], past_key_values_length), dtype=attention_mask.dtype, device=attention_mask.device), attention_mask],
|
1451 |
dim=-1
|
1452 |
)
|
1453 |
-
# # if the attention mask
|
1454 |
attention_mask = _prepare_4d_causal_attention_mask(
|
1455 |
attention_mask,
|
1456 |
(batch_size, seq_len),
|
@@ -1460,7 +1421,6 @@ class QuietForCausalLM(QuietPreTrainedModel, GenerationMixin):
|
|
1460 |
)
|
1461 |
|
1462 |
outputs = self.model(
|
1463 |
-
# input_ids=input_ids,
|
1464 |
attention_mask=attention_mask,
|
1465 |
position_ids=position_ids,
|
1466 |
past_key_values=past_key_values,
|
@@ -1468,14 +1428,13 @@ class QuietForCausalLM(QuietPreTrainedModel, GenerationMixin):
|
|
1468 |
use_cache=use_cache,
|
1469 |
output_attentions=output_attentions,
|
1470 |
output_hidden_states=output_hidden_states,
|
1471 |
-
# output_router_logits=output_router_logits,
|
1472 |
return_dict=return_dict,
|
1473 |
)
|
1474 |
|
1475 |
prev_hidden_states = hidden_states
|
1476 |
hidden_states = outputs[0]
|
1477 |
-
prev_rm_logits = rm_logits
|
1478 |
-
prev_rm_tokens = cur_rm_tokens
|
1479 |
|
1480 |
if ahead_idx == 0:
|
1481 |
hidden_states_lm = hidden_states
|
@@ -1521,7 +1480,6 @@ class QuietForCausalLM(QuietPreTrainedModel, GenerationMixin):
|
|
1521 |
assert sum([self.cumulative_residual, self.clever_residual, self.skip_residual, self.no_residual]) == 1
|
1522 |
if self.clever_residual:
|
1523 |
if ahead_idx >= self.n_ahead - 1:
|
1524 |
-
# get the logits shifted according to the current talk ahead
|
1525 |
cur_base_logits = torch.cat([
|
1526 |
base_logits[..., ahead_idx - self.n_ahead + 1:, :],
|
1527 |
base_logits[..., :ahead_idx - self.n_ahead + 1, :]
|
@@ -1566,7 +1524,7 @@ class QuietForCausalLM(QuietPreTrainedModel, GenerationMixin):
|
|
1566 |
|
1567 |
attempted = False
|
1568 |
talk_loss_list = []
|
1569 |
-
if self.original_mode or (self.n_ahead == 1) or (self.comparison_mode and ahead_idx == 0)
|
1570 |
loss = None
|
1571 |
attempted = True
|
1572 |
|
@@ -1597,7 +1555,6 @@ class QuietForCausalLM(QuietPreTrainedModel, GenerationMixin):
|
|
1597 |
|
1598 |
if not attempted or self.comparison_mode:
|
1599 |
rm_hidden_states = hidden_states
|
1600 |
-
# print("Magnitude of RM hidden states before RM head", rm_hidden_states.norm())
|
1601 |
rm_logits = apply_head(self.lm_head, rm_hidden_states, detach=self.optimize_lm_head_only_at_start)
|
1602 |
|
1603 |
# don't allow it to predict the thinking token
|
@@ -1626,9 +1583,8 @@ class QuietForCausalLM(QuietPreTrainedModel, GenerationMixin):
|
|
1626 |
probabilities_2d[:, override_token] = 1.0
|
1627 |
skip_sampling = True
|
1628 |
elif ahead_idx >= self.n_ahead - 1:
|
1629 |
-
if labels is not None:
|
1630 |
cur_talk_n = ahead_idx - (self.n_ahead - 1) + 1
|
1631 |
-
# print("Setting rm to labels", cur_talk_n, "during", ahead_idx)
|
1632 |
shift_labels = labels[..., cur_talk_n:].contiguous().to(probabilities_2d.device)
|
1633 |
padding = torch.full_like(
|
1634 |
labels[..., :cur_talk_n],
|
@@ -1640,11 +1596,9 @@ class QuietForCausalLM(QuietPreTrainedModel, GenerationMixin):
|
|
1640 |
[shift_labels, padding],
|
1641 |
dim=-1
|
1642 |
)
|
1643 |
-
|
1644 |
-
# print((new_rm_tokens > self.vocab_size - 1).any().item())
|
1645 |
new_rm_tokens = torch.clamp(new_rm_tokens, 0, self.vocab_size - 1)
|
1646 |
|
1647 |
-
# Now safely convert rm tokens to one-hot
|
1648 |
probabilities_2d = F.one_hot(new_rm_tokens, num_classes=self.vocab_size).reshape(-1, self.vocab_size).to(probabilities_2d.dtype)
|
1649 |
else:
|
1650 |
continue
|
@@ -1704,7 +1658,6 @@ class QuietForCausalLM(QuietPreTrainedModel, GenerationMixin):
|
|
1704 |
new_attention = original_attention
|
1705 |
else:
|
1706 |
original_attention = original_attention == attention_mask.max()
|
1707 |
-
# because eye isn't implemented for BF16, we need to handle the case
|
1708 |
if not attention_mask.dtype == torch.bfloat16:
|
1709 |
new_attention = torch.eye(
|
1710 |
seq_len, dtype=attention_mask.dtype, device=attention_mask.device
|
@@ -1742,9 +1695,7 @@ class QuietForCausalLM(QuietPreTrainedModel, GenerationMixin):
|
|
1742 |
# if shift_labels.min() == self.tokenizer.pad_token_id:
|
1743 |
shift_labels = torch.where(shift_labels == self.tokenizer.pad_token_id, -100, shift_labels)
|
1744 |
unreduced_loss = loss_fct(shift_logits, shift_labels)
|
1745 |
-
# print("Loss:", unreduced_loss.item()) # Print the loss before checking for NaN values
|
1746 |
if torch.any(unreduced_loss != unreduced_loss):
|
1747 |
-
# pdb.set_trace()
|
1748 |
raise ValueError("NaN loss")
|
1749 |
unreduced_loss = unreduced_loss.reshape(logits.shape[0], -1)
|
1750 |
loss_list.append(unreduced_loss)
|
|
|
18 |
# See the License for the specific language governing permissions and
|
19 |
# limitations under the License.
|
20 |
""" PyTorch Quiet model."""
|
|
|
21 |
import math
|
|
|
22 |
import warnings
|
23 |
from collections import defaultdict
|
24 |
from typing import List, Optional, Tuple, Union
|
|
|
29 |
from torch import nn
|
30 |
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
31 |
from transformers.generation.utils import GenerationMixin
|
32 |
+
from transformers import AutoTokenizer
|
|
|
33 |
import transformers
|
34 |
|
35 |
from transformers.activations import ACT2FN
|
|
|
40 |
from transformers.utils import (
|
41 |
add_start_docstrings,
|
42 |
add_start_docstrings_to_model_forward,
|
|
|
|
|
43 |
logging,
|
44 |
replace_return_docstrings,
|
45 |
)
|
|
|
235 |
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
|
236 |
"""
|
237 |
|
|
|
238 |
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
|
239 |
if n_rep == 1:
|
240 |
return hidden_states
|
|
|
326 |
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
|
327 |
|
328 |
if past_key_value is not None:
|
329 |
+
cache_kwargs = {"sin": sin, "cos": cos}
|
330 |
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
331 |
|
332 |
# repeat k/v heads if n_kv_heads < n_heads
|
|
|
371 |
)
|
372 |
|
373 |
attn_weights = attn_weights + attention_mask
|
374 |
+
|
|
|
375 |
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
|
376 |
attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
|
377 |
attn_output = torch.matmul(attn_weights, value_states)
|
|
|
844 |
self.model = QuietModel(config)
|
845 |
self.vocab_size = config.vocab_size
|
846 |
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
|
|
|
|
|
|
847 |
self.max_thoughts = config.max_thoughts
|
848 |
self.merged_lm_and_talk_heads = config.merged_lm_and_talk_heads
|
849 |
self.use_concat_talk_head = config.use_concat_talk_head
|
850 |
self.use_shallow_talk = config.use_shallow_talk
|
851 |
self.use_complex_talk_head = config.use_complex_talk_head
|
852 |
self.use_weighted_talk_head = config.use_weighted_talk_head
|
|
|
853 |
assert not (self.use_weighted_talk_head and self.use_shallow_talk)
|
854 |
|
855 |
self.n_ahead = 1
|
|
|
920 |
self.thinking_threshold = 0.5
|
921 |
self.thinking_usefulness_loss_weight = 1e-2
|
922 |
|
|
|
923 |
self.use_thought_prefix = False
|
924 |
self.use_reparam_for_thought_embeddings = False
|
925 |
self.use_upper_triangular = False
|
|
|
927 |
self.comparison_mode = False
|
928 |
self.gumbel_detach = False
|
929 |
|
|
|
930 |
self.eval_mode = False
|
931 |
|
932 |
num_talk = 1
|
|
|
955 |
# Add dropout regularization
|
956 |
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
957 |
|
|
|
958 |
self.post_init()
|
959 |
|
960 |
def get_input_embeddings(self):
|
|
|
1205 |
n_passes_to_restore = self.n_passes
|
1206 |
self.n_ahead_talk = 1
|
1207 |
self.n_passes = 1
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1208 |
if input_ids.dim() == 1:
|
1209 |
input_ids = input_ids.unsqueeze(0)
|
1210 |
attention_mask = attention_mask.unsqueeze(0) if attention_mask is not None else None
|
|
|
1272 |
self.start_token_id = self.tokenizer.bos_token_id
|
1273 |
self.tokenizer_has_start_thought_token = False
|
1274 |
elif self.use_start_thought_token:
|
|
|
1275 |
base_start_id = self.tokenizer.encode(self.initial_start_token, add_special_tokens=False)[0]
|
1276 |
if self.initialize_thought_embedding_to_normal:
|
1277 |
self.start_embedding.data = torch.zeros_like(self.start_embedding.data)
|
|
|
1284 |
self.end_token_id = self.tokenizer.eos_token_id
|
1285 |
self.tokenizer_has_end_thought_token = False
|
1286 |
elif self.use_end_thought_token:
|
|
|
1287 |
base_end_id = self.tokenizer.encode(self.initial_end_token, add_special_tokens=False)[0]
|
1288 |
if self.initialize_thought_embedding_to_normal:
|
1289 |
self.end_embedding.data = torch.zeros_like(self.end_embedding.data)
|
|
|
1302 |
else:
|
1303 |
# convert to identity transform
|
1304 |
def lambda_transform(cur_head):
|
|
|
1305 |
if cur_head.weight.data.shape[0] != cur_head.weight.data.shape[1]:
|
1306 |
return torch.cat([
|
1307 |
torch.eye(
|
|
|
1329 |
self.talk_head[-1].weight.data = lambda_transform(self.talk_head[0])
|
1330 |
|
1331 |
loss = None
|
|
|
1332 |
cur_rm_tokens = None
|
|
|
1333 |
prev_sample_probs = None
|
1334 |
did_skip_sampling = None
|
1335 |
skip_sampling = None
|
1336 |
sample_probs = None
|
1337 |
hidden_states = None
|
1338 |
logits = None
|
|
|
1339 |
rm_logits = None
|
1340 |
residual_logits = None
|
1341 |
probabilities_2d = None
|
1342 |
prev_probabilities_2d = None
|
1343 |
policy_reward = None
|
|
|
1344 |
batch_size, seq_len = input_ids.shape
|
1345 |
base_input_ids = input_ids.clone()
|
1346 |
loss_list = []
|
1347 |
dqn_loss_list = []
|
1348 |
sampled_token_history = []
|
|
|
1349 |
action_loglikelihoods_list = []
|
1350 |
|
1351 |
temperature = self.temperature
|
|
|
1361 |
if self.train_only_thinking_embedding:
|
1362 |
base_embeddings = base_embeddings.detach()
|
1363 |
|
1364 |
+
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
1365 |
fwd_iters = 1 if self.original_mode else self.n_ahead + self.n_ahead_talk - 1
|
1366 |
for ahead_idx in range(fwd_iters):
|
1367 |
past_key_values_length = 0
|
|
|
1406 |
base_attention_mask = base_attention_mask.view(1, 1, seq_len, seq_len)
|
1407 |
base_attention_mask = base_attention_mask.repeat(input_ids.shape[0], 1, 1, 1)
|
1408 |
attention_mask = base_attention_mask
|
|
|
1409 |
elif attention_mask.dim() == 2:
|
1410 |
if seq_len + past_key_values_length != attention_mask.shape[-1]:
|
|
|
1411 |
attention_mask = torch.cat(
|
1412 |
[torch.ones((attention_mask.shape[0], past_key_values_length), dtype=attention_mask.dtype, device=attention_mask.device), attention_mask],
|
1413 |
dim=-1
|
1414 |
)
|
|
|
1415 |
attention_mask = _prepare_4d_causal_attention_mask(
|
1416 |
attention_mask,
|
1417 |
(batch_size, seq_len),
|
|
|
1421 |
)
|
1422 |
|
1423 |
outputs = self.model(
|
|
|
1424 |
attention_mask=attention_mask,
|
1425 |
position_ids=position_ids,
|
1426 |
past_key_values=past_key_values,
|
|
|
1428 |
use_cache=use_cache,
|
1429 |
output_attentions=output_attentions,
|
1430 |
output_hidden_states=output_hidden_states,
|
|
|
1431 |
return_dict=return_dict,
|
1432 |
)
|
1433 |
|
1434 |
prev_hidden_states = hidden_states
|
1435 |
hidden_states = outputs[0]
|
1436 |
+
prev_rm_logits = rm_logits
|
1437 |
+
prev_rm_tokens = cur_rm_tokens
|
1438 |
|
1439 |
if ahead_idx == 0:
|
1440 |
hidden_states_lm = hidden_states
|
|
|
1480 |
assert sum([self.cumulative_residual, self.clever_residual, self.skip_residual, self.no_residual]) == 1
|
1481 |
if self.clever_residual:
|
1482 |
if ahead_idx >= self.n_ahead - 1:
|
|
|
1483 |
cur_base_logits = torch.cat([
|
1484 |
base_logits[..., ahead_idx - self.n_ahead + 1:, :],
|
1485 |
base_logits[..., :ahead_idx - self.n_ahead + 1, :]
|
|
|
1524 |
|
1525 |
attempted = False
|
1526 |
talk_loss_list = []
|
1527 |
+
if self.original_mode or (self.n_ahead == 1) or (self.comparison_mode and ahead_idx == 0):
|
1528 |
loss = None
|
1529 |
attempted = True
|
1530 |
|
|
|
1555 |
|
1556 |
if not attempted or self.comparison_mode:
|
1557 |
rm_hidden_states = hidden_states
|
|
|
1558 |
rm_logits = apply_head(self.lm_head, rm_hidden_states, detach=self.optimize_lm_head_only_at_start)
|
1559 |
|
1560 |
# don't allow it to predict the thinking token
|
|
|
1583 |
probabilities_2d[:, override_token] = 1.0
|
1584 |
skip_sampling = True
|
1585 |
elif ahead_idx >= self.n_ahead - 1:
|
1586 |
+
if labels is not None:
|
1587 |
cur_talk_n = ahead_idx - (self.n_ahead - 1) + 1
|
|
|
1588 |
shift_labels = labels[..., cur_talk_n:].contiguous().to(probabilities_2d.device)
|
1589 |
padding = torch.full_like(
|
1590 |
labels[..., :cur_talk_n],
|
|
|
1596 |
[shift_labels, padding],
|
1597 |
dim=-1
|
1598 |
)
|
1599 |
+
|
|
|
1600 |
new_rm_tokens = torch.clamp(new_rm_tokens, 0, self.vocab_size - 1)
|
1601 |
|
|
|
1602 |
probabilities_2d = F.one_hot(new_rm_tokens, num_classes=self.vocab_size).reshape(-1, self.vocab_size).to(probabilities_2d.dtype)
|
1603 |
else:
|
1604 |
continue
|
|
|
1658 |
new_attention = original_attention
|
1659 |
else:
|
1660 |
original_attention = original_attention == attention_mask.max()
|
|
|
1661 |
if not attention_mask.dtype == torch.bfloat16:
|
1662 |
new_attention = torch.eye(
|
1663 |
seq_len, dtype=attention_mask.dtype, device=attention_mask.device
|
|
|
1695 |
# if shift_labels.min() == self.tokenizer.pad_token_id:
|
1696 |
shift_labels = torch.where(shift_labels == self.tokenizer.pad_token_id, -100, shift_labels)
|
1697 |
unreduced_loss = loss_fct(shift_logits, shift_labels)
|
|
|
1698 |
if torch.any(unreduced_loss != unreduced_loss):
|
|
|
1699 |
raise ValueError("NaN loss")
|
1700 |
unreduced_loss = unreduced_loss.reshape(logits.shape[0], -1)
|
1701 |
loss_list.append(unreduced_loss)
|