Crystalcareai
commited on
Update modeling_quiet.py
Browse files- modeling_quiet.py +48 -106
modeling_quiet.py
CHANGED
@@ -23,16 +23,16 @@ import math
|
|
23 |
import pdb
|
24 |
import warnings
|
25 |
from collections import defaultdict
|
26 |
-
from typing import List, Optional, Tuple, Union
|
27 |
|
28 |
import torch
|
29 |
import torch.nn.functional as F
|
30 |
import torch.utils.checkpoint
|
31 |
from torch import nn
|
32 |
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
33 |
-
from transformers.generation.utils import GenerationMixin
|
34 |
from transformers.generation.stopping_criteria import StoppingCriteriaList, validate_stopping_criteria
|
35 |
-
from transformers import TextStreamer
|
36 |
|
37 |
from transformers.activations import ACT2FN
|
38 |
from transformers.cache_utils import Cache, DynamicCache
|
@@ -143,6 +143,7 @@ class QuietRMSNorm(nn.Module):
|
|
143 |
self.weight = nn.Parameter(torch.ones(hidden_size))
|
144 |
self.variance_epsilon = eps
|
145 |
|
|
|
146 |
def forward(self, hidden_states):
|
147 |
input_dtype = hidden_states.dtype
|
148 |
hidden_states = hidden_states.to(torch.float32)
|
@@ -150,6 +151,7 @@ class QuietRMSNorm(nn.Module):
|
|
150 |
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
|
151 |
return hidden_states.to(input_dtype) * self.weight.to(hidden_states.device)
|
152 |
|
|
|
153 |
# Copied from transformers.models.llama.modeling_llama.LlamaRotaryEmbedding with Llama->Quiet
|
154 |
class QuietRotaryEmbedding(nn.Module):
|
155 |
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
|
@@ -235,8 +237,7 @@ class QuietMLP(nn.Module):
|
|
235 |
self.act_fn = ACT2FN[config.hidden_act]
|
236 |
|
237 |
def forward(self, x):
|
238 |
-
|
239 |
-
return hidden_states
|
240 |
|
241 |
|
242 |
# Copied from transformers.models.llama.modeling_llama.repeat_kv
|
@@ -848,7 +849,7 @@ class QuietDecoderLayer(nn.Module):
|
|
848 |
residual = hidden_states
|
849 |
|
850 |
hidden_states = self.input_layernorm(hidden_states)
|
851 |
-
|
852 |
# Self Attention
|
853 |
hidden_states, self_attn_weights, present_key_value = self.self_attn(
|
854 |
hidden_states=hidden_states,
|
@@ -1022,8 +1023,6 @@ class QuietModel(QuietPreTrainedModel):
|
|
1022 |
output_hidden_states: Optional[bool] = None,
|
1023 |
return_dict: Optional[bool] = None,
|
1024 |
) -> Union[Tuple, BaseModelOutputWithPast]:
|
1025 |
-
|
1026 |
-
# print("Hidden states shape after embedding:", inputs_embeds.shape)
|
1027 |
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
1028 |
output_hidden_states = (
|
1029 |
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
@@ -1072,27 +1071,32 @@ class QuietModel(QuietPreTrainedModel):
|
|
1072 |
if self._attn_implementation == "flash_attention_2":
|
1073 |
# 2d mask is passed through the layers
|
1074 |
attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
|
1075 |
-
elif self._attn_implementation == "sdpa" and not output_attentions and attention_mask
|
1076 |
# output_attentions=True can not be supported when using SDPA, and we fall back on
|
1077 |
# the manual implementation that requires a 4D causal mask in all cases.
|
1078 |
attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
|
1079 |
-
attention_mask,
|
|
|
|
|
|
|
1080 |
)
|
1081 |
elif attention_mask is None or attention_mask.dim() == 2:
|
1082 |
# 4d mask is passed through the layers
|
1083 |
attention_mask = _prepare_4d_causal_attention_mask(
|
1084 |
-
attention_mask,
|
|
|
|
|
|
|
1085 |
sliding_window=self.config.sliding_window,
|
1086 |
)
|
1087 |
|
1088 |
-
|
1089 |
hidden_states = inputs_embeds
|
1090 |
|
1091 |
# decoder layers
|
1092 |
all_hidden_states = () if output_hidden_states else None
|
1093 |
all_self_attns = () if output_attentions else None
|
1094 |
next_decoder_cache = None
|
1095 |
-
|
1096 |
for decoder_layer in self.layers:
|
1097 |
if output_hidden_states:
|
1098 |
all_hidden_states += (hidden_states,)
|
@@ -1116,15 +1120,15 @@ class QuietModel(QuietPreTrainedModel):
|
|
1116 |
output_attentions=output_attentions,
|
1117 |
use_cache=use_cache,
|
1118 |
)
|
|
|
1119 |
hidden_states = layer_outputs[0]
|
1120 |
-
# print(f"Hidden states shape after decoder layer {decoder_layer}:", hidden_states.shape)
|
1121 |
-
# print("Hidden states shape after decoder layers:", hidden_states.shape)
|
1122 |
|
1123 |
if use_cache:
|
1124 |
next_decoder_cache = layer_outputs[2 if output_attentions else 1]
|
1125 |
|
1126 |
if output_attentions:
|
1127 |
all_self_attns += (layer_outputs[1],)
|
|
|
1128 |
hidden_states = self.norm(hidden_states)
|
1129 |
|
1130 |
# add hidden states from the last decoder layer
|
@@ -1155,7 +1159,7 @@ def loss_mean(x):
|
|
1155 |
class QuietForCausalLM(QuietPreTrainedModel, GenerationMixin):
|
1156 |
_tied_weights_keys = ["lm_head.weight"]
|
1157 |
|
1158 |
-
def __init__(self, config
|
1159 |
super().__init__(config)
|
1160 |
self.model = QuietModel(config)
|
1161 |
self.vocab_size = config.vocab_size
|
@@ -1178,7 +1182,7 @@ class QuietForCausalLM(QuietPreTrainedModel, GenerationMixin):
|
|
1178 |
self.n_tokens_print = 1
|
1179 |
self.gradient_accumulation_steps = 1
|
1180 |
self.training_steps = 0
|
1181 |
-
self.tokenizer =
|
1182 |
self.start_token_id = None
|
1183 |
self.end_token_id = None
|
1184 |
self.rm_initialized = False
|
@@ -1306,14 +1310,6 @@ class QuietForCausalLM(QuietPreTrainedModel, GenerationMixin):
|
|
1306 |
nn.init.constant_(module.bias, 0)
|
1307 |
elif isinstance(module, nn.Embedding):
|
1308 |
nn.init.xavier_uniform_(module.weight)
|
1309 |
-
|
1310 |
-
@classmethod
|
1311 |
-
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
|
1312 |
-
tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path)
|
1313 |
-
model = super().from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
|
1314 |
-
model.tokenizer = tokenizer
|
1315 |
-
return model
|
1316 |
-
|
1317 |
|
1318 |
@torch.no_grad()
|
1319 |
def infer(
|
@@ -1347,10 +1343,13 @@ class QuietForCausalLM(QuietPreTrainedModel, GenerationMixin):
|
|
1347 |
continuation_length = self.n_ahead - 2
|
1348 |
new_key_values = past_key_values
|
1349 |
|
|
|
|
|
|
|
1350 |
start_time = time.time()
|
1351 |
for continuation_idx in range(continuation_length):
|
1352 |
outputs = self.model(
|
1353 |
-
input_ids=
|
1354 |
attention_mask=attention_mask,
|
1355 |
position_ids=position_ids,
|
1356 |
past_key_values=new_key_values,
|
@@ -1371,86 +1370,33 @@ class QuietForCausalLM(QuietPreTrainedModel, GenerationMixin):
|
|
1371 |
next_token_logits = F.gumbel_softmax(logits, tau=self.gumbel_temperature, hard=True, dim=-1)
|
1372 |
next_token_id = torch.argmax(next_token_logits, dim=-1)
|
1373 |
|
1374 |
-
# Append the generated token to the
|
1375 |
-
|
1376 |
-
seq_len += 1
|
1377 |
|
1378 |
-
|
1379 |
-
if attention_mask is not None:
|
1380 |
-
attention_mask = torch.cat([attention_mask, torch.ones((batch_size, 1)).to(attention_mask.device)], dim=-1)
|
1381 |
-
|
1382 |
-
# Append the end thought token to the input sequence
|
1383 |
end_thought_token_id = self.tokenizer.convert_tokens_to_ids("<|endthought|>")
|
1384 |
-
|
1385 |
-
seq_len += 1
|
1386 |
|
1387 |
-
|
1388 |
-
|
1389 |
-
|
1390 |
-
|
1391 |
-
|
1392 |
-
|
1393 |
-
|
1394 |
-
|
1395 |
-
|
1396 |
-
|
1397 |
-
|
1398 |
-
|
1399 |
-
|
1400 |
-
|
1401 |
-
|
1402 |
-
|
1403 |
-
hidden_states_before = outputs_before[0][:, -1:, :]
|
1404 |
-
|
1405 |
-
outputs_after = self.model(
|
1406 |
-
input_ids=torch.cat([next_token_id.unsqueeze(-1).to(input_ids.device), torch.tensor([[end_thought_token_id]] * batch_size).to(input_ids.device)], dim=-1),
|
1407 |
-
attention_mask=attention_mask,
|
1408 |
-
position_ids=position_ids,
|
1409 |
-
past_key_values=new_key_values,
|
1410 |
-
inputs_embeds=inputs_embeds,
|
1411 |
-
use_cache=use_cache,
|
1412 |
-
output_attentions=output_attentions,
|
1413 |
-
output_hidden_states=output_hidden_states,
|
1414 |
-
return_dict=return_dict,
|
1415 |
)
|
1416 |
-
hidden_states_after = outputs_after[0][:, -1:, :]
|
1417 |
-
|
1418 |
-
# Apply the talk head to get the mixing weight
|
1419 |
-
mixing_weight = self.talk_head[0](torch.cat([hidden_states_before, hidden_states_after], dim=-1))
|
1420 |
|
1421 |
-
|
1422 |
-
mixed_hidden_states = (1 - mixing_weight) * hidden_states_before + mixing_weight * hidden_states_after
|
1423 |
|
1424 |
-
# Apply the language model head to get the final logits
|
1425 |
-
logits = self.lm_head(mixed_hidden_states)
|
1426 |
-
|
1427 |
-
return logits
|
1428 |
-
|
1429 |
-
@torch.no_grad()
|
1430 |
-
def generate(
|
1431 |
-
self,
|
1432 |
-
input_ids: torch.LongTensor,
|
1433 |
-
attention_mask: Optional[torch.Tensor] = None,
|
1434 |
-
position_ids: Optional[torch.LongTensor] = None,
|
1435 |
-
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
1436 |
-
inputs_embeds: Optional[torch.FloatTensor] = None,
|
1437 |
-
use_cache: Optional[bool] = None,
|
1438 |
-
output_attentions: Optional[bool] = None,
|
1439 |
-
output_hidden_states: Optional[bool] = None,
|
1440 |
-
return_dict_in_generate: Optional[bool] = None,
|
1441 |
-
**model_kwargs,
|
1442 |
-
) -> Union[BaseModelOutputWithPast, torch.LongTensor]:
|
1443 |
-
return self.infer(
|
1444 |
-
input_ids=input_ids,
|
1445 |
-
attention_mask=attention_mask,
|
1446 |
-
position_ids=position_ids,
|
1447 |
-
past_key_values=past_key_values,
|
1448 |
-
inputs_embeds=inputs_embeds,
|
1449 |
-
use_cache=use_cache,
|
1450 |
-
output_attentions=output_attentions,
|
1451 |
-
output_hidden_states=output_hidden_states,
|
1452 |
-
return_dict=return_dict_in_generate,
|
1453 |
-
)
|
1454 |
@add_start_docstrings_to_model_forward(QUIET_INPUTS_DOCSTRING)
|
1455 |
@replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
|
1456 |
def forward(
|
@@ -1641,7 +1587,7 @@ class QuietForCausalLM(QuietPreTrainedModel, GenerationMixin):
|
|
1641 |
|
1642 |
complexity_scores = self.compute_complexity_scores(input_ids, attention_mask)
|
1643 |
temperature = self.temperature * complexity_scores.unsqueeze(-1)
|
1644 |
-
|
1645 |
if self.use_end_thought_token or self.use_start_thought_token:
|
1646 |
if not self.use_reparam_for_thought_embeddings:
|
1647 |
start_embedding = self.start_embedding[0].unsqueeze(0) * self.embedding_scale * temperature
|
@@ -1671,10 +1617,7 @@ class QuietForCausalLM(QuietPreTrainedModel, GenerationMixin):
|
|
1671 |
position_ids = position_ids.unsqueeze(0).view(-1, seq_len)
|
1672 |
else:
|
1673 |
position_ids = position_ids.view(-1, seq_len).long()
|
1674 |
-
|
1675 |
-
|
1676 |
-
# print("Input IDs shape:", input_ids.shape)
|
1677 |
-
# print("Inputs embeds shape before embedding:", inputs_embeds.shape if inputs_embeds is not None else None)
|
1678 |
if inputs_embeds is None:
|
1679 |
contains_start = self.use_start_thought_token and (input_ids == self.start_token_id).any()
|
1680 |
contains_end = self.use_end_thought_token and (input_ids == self.end_token_id).any()
|
@@ -1694,7 +1637,6 @@ class QuietForCausalLM(QuietPreTrainedModel, GenerationMixin):
|
|
1694 |
else:
|
1695 |
with torch.set_grad_enabled(not self.train_only_thinking_embedding):
|
1696 |
inputs_embeds = self.model.embed_tokens(input_ids)
|
1697 |
-
# print("Inputs embeds shape after embedding:", inputs_embeds.shape)
|
1698 |
|
1699 |
if self.n_ahead != 1 or self.n_ahead_talk != 1 or self.comparison_mode:
|
1700 |
if attention_mask is None:
|
|
|
23 |
import pdb
|
24 |
import warnings
|
25 |
from collections import defaultdict
|
26 |
+
from typing import List, Optional, Tuple, Union
|
27 |
|
28 |
import torch
|
29 |
import torch.nn.functional as F
|
30 |
import torch.utils.checkpoint
|
31 |
from torch import nn
|
32 |
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
33 |
+
from transformers.generation.utils import GenerationMixin
|
34 |
from transformers.generation.stopping_criteria import StoppingCriteriaList, validate_stopping_criteria
|
35 |
+
from transformers import TextStreamer
|
36 |
|
37 |
from transformers.activations import ACT2FN
|
38 |
from transformers.cache_utils import Cache, DynamicCache
|
|
|
143 |
self.weight = nn.Parameter(torch.ones(hidden_size))
|
144 |
self.variance_epsilon = eps
|
145 |
|
146 |
+
|
147 |
def forward(self, hidden_states):
|
148 |
input_dtype = hidden_states.dtype
|
149 |
hidden_states = hidden_states.to(torch.float32)
|
|
|
151 |
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
|
152 |
return hidden_states.to(input_dtype) * self.weight.to(hidden_states.device)
|
153 |
|
154 |
+
|
155 |
# Copied from transformers.models.llama.modeling_llama.LlamaRotaryEmbedding with Llama->Quiet
|
156 |
class QuietRotaryEmbedding(nn.Module):
|
157 |
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
|
|
|
237 |
self.act_fn = ACT2FN[config.hidden_act]
|
238 |
|
239 |
def forward(self, x):
|
240 |
+
return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
|
|
|
241 |
|
242 |
|
243 |
# Copied from transformers.models.llama.modeling_llama.repeat_kv
|
|
|
849 |
residual = hidden_states
|
850 |
|
851 |
hidden_states = self.input_layernorm(hidden_states)
|
852 |
+
|
853 |
# Self Attention
|
854 |
hidden_states, self_attn_weights, present_key_value = self.self_attn(
|
855 |
hidden_states=hidden_states,
|
|
|
1023 |
output_hidden_states: Optional[bool] = None,
|
1024 |
return_dict: Optional[bool] = None,
|
1025 |
) -> Union[Tuple, BaseModelOutputWithPast]:
|
|
|
|
|
1026 |
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
1027 |
output_hidden_states = (
|
1028 |
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
|
|
1071 |
if self._attn_implementation == "flash_attention_2":
|
1072 |
# 2d mask is passed through the layers
|
1073 |
attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
|
1074 |
+
elif self._attn_implementation == "sdpa" and not output_attentions and attention_mask.dim() == 2 and False:
|
1075 |
# output_attentions=True can not be supported when using SDPA, and we fall back on
|
1076 |
# the manual implementation that requires a 4D causal mask in all cases.
|
1077 |
attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
|
1078 |
+
attention_mask,
|
1079 |
+
(batch_size, seq_length),
|
1080 |
+
inputs_embeds,
|
1081 |
+
past_key_values_length,
|
1082 |
)
|
1083 |
elif attention_mask is None or attention_mask.dim() == 2:
|
1084 |
# 4d mask is passed through the layers
|
1085 |
attention_mask = _prepare_4d_causal_attention_mask(
|
1086 |
+
attention_mask,
|
1087 |
+
(batch_size, seq_length),
|
1088 |
+
inputs_embeds,
|
1089 |
+
past_key_values_length,
|
1090 |
sliding_window=self.config.sliding_window,
|
1091 |
)
|
1092 |
|
|
|
1093 |
hidden_states = inputs_embeds
|
1094 |
|
1095 |
# decoder layers
|
1096 |
all_hidden_states = () if output_hidden_states else None
|
1097 |
all_self_attns = () if output_attentions else None
|
1098 |
next_decoder_cache = None
|
1099 |
+
|
1100 |
for decoder_layer in self.layers:
|
1101 |
if output_hidden_states:
|
1102 |
all_hidden_states += (hidden_states,)
|
|
|
1120 |
output_attentions=output_attentions,
|
1121 |
use_cache=use_cache,
|
1122 |
)
|
1123 |
+
|
1124 |
hidden_states = layer_outputs[0]
|
|
|
|
|
1125 |
|
1126 |
if use_cache:
|
1127 |
next_decoder_cache = layer_outputs[2 if output_attentions else 1]
|
1128 |
|
1129 |
if output_attentions:
|
1130 |
all_self_attns += (layer_outputs[1],)
|
1131 |
+
|
1132 |
hidden_states = self.norm(hidden_states)
|
1133 |
|
1134 |
# add hidden states from the last decoder layer
|
|
|
1159 |
class QuietForCausalLM(QuietPreTrainedModel, GenerationMixin):
|
1160 |
_tied_weights_keys = ["lm_head.weight"]
|
1161 |
|
1162 |
+
def __init__(self, config):
|
1163 |
super().__init__(config)
|
1164 |
self.model = QuietModel(config)
|
1165 |
self.vocab_size = config.vocab_size
|
|
|
1182 |
self.n_tokens_print = 1
|
1183 |
self.gradient_accumulation_steps = 1
|
1184 |
self.training_steps = 0
|
1185 |
+
self.tokenizer = None
|
1186 |
self.start_token_id = None
|
1187 |
self.end_token_id = None
|
1188 |
self.rm_initialized = False
|
|
|
1310 |
nn.init.constant_(module.bias, 0)
|
1311 |
elif isinstance(module, nn.Embedding):
|
1312 |
nn.init.xavier_uniform_(module.weight)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1313 |
|
1314 |
@torch.no_grad()
|
1315 |
def infer(
|
|
|
1343 |
continuation_length = self.n_ahead - 2
|
1344 |
new_key_values = past_key_values
|
1345 |
|
1346 |
+
# Initialize generated_ids with input_ids
|
1347 |
+
generated_ids = input_ids.clone()
|
1348 |
+
|
1349 |
start_time = time.time()
|
1350 |
for continuation_idx in range(continuation_length):
|
1351 |
outputs = self.model(
|
1352 |
+
input_ids=generated_ids if continuation_idx == 0 else next_token_id.unsqueeze(-1).to(generated_ids.device),
|
1353 |
attention_mask=attention_mask,
|
1354 |
position_ids=position_ids,
|
1355 |
past_key_values=new_key_values,
|
|
|
1370 |
next_token_logits = F.gumbel_softmax(logits, tau=self.gumbel_temperature, hard=True, dim=-1)
|
1371 |
next_token_id = torch.argmax(next_token_logits, dim=-1)
|
1372 |
|
1373 |
+
# Append the generated token to the generated_ids
|
1374 |
+
generated_ids = torch.cat([generated_ids, next_token_id.unsqueeze(-1).to(generated_ids.device)], dim=-1)
|
|
|
1375 |
|
1376 |
+
# Append the end thought token to the generated_ids
|
|
|
|
|
|
|
|
|
1377 |
end_thought_token_id = self.tokenizer.convert_tokens_to_ids("<|endthought|>")
|
1378 |
+
generated_ids = torch.cat([generated_ids, torch.tensor([[end_thought_token_id]] * batch_size).to(generated_ids.device)], dim=-1)
|
|
|
1379 |
|
1380 |
+
return generated_ids
|
1381 |
+
|
1382 |
+
|
1383 |
+
@torch.no_grad()
|
1384 |
+
def generate(self, *args, **kwargs):
|
1385 |
+
# Call the infer method to generate the token ids
|
1386 |
+
generated_ids = self.infer(
|
1387 |
+
input_ids=kwargs.pop("input_ids", None),
|
1388 |
+
attention_mask=kwargs.pop("attention_mask", None),
|
1389 |
+
position_ids=kwargs.pop("position_ids", None),
|
1390 |
+
past_key_values=kwargs.pop("past_key_values", None),
|
1391 |
+
inputs_embeds=kwargs.pop("inputs_embeds", None),
|
1392 |
+
use_cache=kwargs.pop("use_cache", None),
|
1393 |
+
output_attentions=kwargs.pop("output_attentions", None),
|
1394 |
+
output_hidden_states=kwargs.pop("output_hidden_states", None),
|
1395 |
+
return_dict=kwargs.pop("return_dict", None),
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1396 |
)
|
|
|
|
|
|
|
|
|
1397 |
|
1398 |
+
return generated_ids
|
|
|
1399 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1400 |
@add_start_docstrings_to_model_forward(QUIET_INPUTS_DOCSTRING)
|
1401 |
@replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
|
1402 |
def forward(
|
|
|
1587 |
|
1588 |
complexity_scores = self.compute_complexity_scores(input_ids, attention_mask)
|
1589 |
temperature = self.temperature * complexity_scores.unsqueeze(-1)
|
1590 |
+
|
1591 |
if self.use_end_thought_token or self.use_start_thought_token:
|
1592 |
if not self.use_reparam_for_thought_embeddings:
|
1593 |
start_embedding = self.start_embedding[0].unsqueeze(0) * self.embedding_scale * temperature
|
|
|
1617 |
position_ids = position_ids.unsqueeze(0).view(-1, seq_len)
|
1618 |
else:
|
1619 |
position_ids = position_ids.view(-1, seq_len).long()
|
1620 |
+
|
|
|
|
|
|
|
1621 |
if inputs_embeds is None:
|
1622 |
contains_start = self.use_start_thought_token and (input_ids == self.start_token_id).any()
|
1623 |
contains_end = self.use_end_thought_token and (input_ids == self.end_token_id).any()
|
|
|
1637 |
else:
|
1638 |
with torch.set_grad_enabled(not self.train_only_thinking_embedding):
|
1639 |
inputs_embeds = self.model.embed_tokens(input_ids)
|
|
|
1640 |
|
1641 |
if self.n_ahead != 1 or self.n_ahead_talk != 1 or self.comparison_mode:
|
1642 |
if attention_mask is None:
|