Crystalcareai commited on
Commit
edf1879
·
verified ·
1 Parent(s): 6e928c7

Update modeling_quiet.py

Browse files
Files changed (1) hide show
  1. modeling_quiet.py +54 -259
modeling_quiet.py CHANGED
@@ -20,7 +20,7 @@
20
  """ PyTorch Quiet model."""
21
  import inspect
22
  import math
23
- # import pdb
24
  import warnings
25
  from collections import defaultdict
26
  from typing import List, Optional, Tuple, Union
@@ -32,8 +32,8 @@ from torch import nn
32
  from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
33
  from transformers.generation.utils import GenerationMixin
34
  from transformers.generation.stopping_criteria import StoppingCriteriaList, validate_stopping_criteria
35
- from transformers import TextStreamer
36
- from transformers import AutoTokenizer
37
  from transformers.activations import ACT2FN
38
  from transformers.cache_utils import Cache, DynamicCache
39
  from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask
@@ -48,7 +48,7 @@ from transformers.utils import (
48
  replace_return_docstrings,
49
  )
50
  from .configuration_quiet import QuietConfig
51
- # from .generate import generate
52
  import time
53
  from typing import Optional, List
54
 
@@ -354,26 +354,28 @@ class QuietAttention(nn.Module):
354
  f" {attn_weights.size()}"
355
  )
356
  if self._attn_implementation == "flash_attention_2":
357
- # 2d mask is passed through the layers
358
  attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
359
- elif self._attn_implementation == "sdpa" and not output_attentions and (attention_mask is None or attention_mask.dim() == 2) and False:
360
- # output_attentions=True can not be supported when using SDPA, and we fall back on
361
- # the manual implementation that requires a 4D causal mask in all cases.
362
- attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
363
- attention_mask,
364
- (batch_size, seq_length),
365
- inputs_embeds,
366
- past_key_values_length,
367
- )
368
- elif attention_mask is None or attention_mask.dim() == 2:
369
- # 4d mask is passed through the layers
370
- attention_mask = _prepare_4d_causal_attention_mask(
371
- attention_mask,
372
- (batch_size, seq_length),
373
- inputs_embeds,
374
- past_key_values_length,
375
- sliding_window=self.config.sliding_window,
376
- )
 
 
377
 
378
  if attention_mask is not None:
379
  if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
@@ -772,7 +774,7 @@ class QuietSdpaAttention(QuietAttention):
772
  raise ValueError(
773
  f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
774
  )
775
-
776
  # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
777
  # Reference: https://github.com/pytorch/pytorch/issues/112577.
778
  if query_states.device.type == "cuda" and attention_mask is not None:
@@ -784,7 +786,7 @@ class QuietSdpaAttention(QuietAttention):
784
  query_states,
785
  key_states,
786
  value_states,
787
- attn_mask=attention_mask.to(torch.bool).to(query_states.device) if attention_mask is not None else None,
788
  dropout_p=self.attention_dropout if self.training else 0.0,
789
  # The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1.
790
  is_causal=self.is_causal and attention_mask is None and q_len > 1,
@@ -1069,7 +1071,7 @@ class QuietModel(QuietPreTrainedModel):
1069
  if self._attn_implementation == "flash_attention_2":
1070
  # 2d mask is passed through the layers
1071
  attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
1072
- elif self._attn_implementation == "sdpa" and not output_attentions and (attention_mask is None or (attention_mask.dim() == 2 and False)):
1073
  # output_attentions=True can not be supported when using SDPA, and we fall back on
1074
  # the manual implementation that requires a 4D causal mask in all cases.
1075
  attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
@@ -1078,16 +1080,15 @@ class QuietModel(QuietPreTrainedModel):
1078
  inputs_embeds,
1079
  past_key_values_length,
1080
  )
1081
- else:
1082
  # 4d mask is passed through the layers
1083
- if attention_mask is None or attention_mask.dim() == 2:
1084
- attention_mask = _prepare_4d_causal_attention_mask(
1085
- attention_mask,
1086
- (batch_size, seq_length),
1087
- inputs_embeds,
1088
- past_key_values_length,
1089
- sliding_window=self.config.sliding_window,
1090
- )
1091
 
1092
  hidden_states = inputs_embeds
1093
 
@@ -1309,7 +1310,6 @@ class QuietForCausalLM(QuietPreTrainedModel, GenerationMixin):
1309
  elif isinstance(module, nn.Embedding):
1310
  nn.init.xavier_uniform_(module.weight)
1311
 
1312
-
1313
  @torch.no_grad()
1314
  def infer(
1315
  self,
@@ -1342,6 +1342,9 @@ class QuietForCausalLM(QuietPreTrainedModel, GenerationMixin):
1342
  continuation_length = self.n_ahead - 2
1343
  new_key_values = past_key_values
1344
 
 
 
 
1345
  start_time = time.time()
1346
  for continuation_idx in range(continuation_length):
1347
  outputs = self.model(
@@ -1367,7 +1370,7 @@ class QuietForCausalLM(QuietPreTrainedModel, GenerationMixin):
1367
  next_token_id = torch.argmax(next_token_logits, dim=-1)
1368
 
1369
  # Append the generated token to the input sequence
1370
- input_ids = torch.cat([input_ids, next_token_id.unsqueeze(-1).to(input_ids.device)], dim=-1)
1371
  seq_len += 1
1372
 
1373
  # Update the attention mask
@@ -1399,8 +1402,8 @@ class QuietForCausalLM(QuietPreTrainedModel, GenerationMixin):
1399
 
1400
  # two new tokens: last continuation token and end thought token
1401
  outputs_after = self.model(
1402
- input_ids=torch.cat([next_token_id.unsqueeze(-1).to(input_ids.device), torch.tensor(end_thought_token_id).unsqueeze(-1).unsqueeze(-1).to(input_ids.device)], dim=-1),
1403
- attention_mask=attention_mask,
1404
  position_ids=position_ids,
1405
  past_key_values=new_key_values,
1406
  inputs_embeds=inputs_embeds,
@@ -1421,218 +1424,10 @@ class QuietForCausalLM(QuietPreTrainedModel, GenerationMixin):
1421
  logits = self.lm_head(mixed_hidden_states)
1422
  return logits
1423
 
1424
- # from transformers.generation.utils import (
1425
- # GenerationMixin,
1426
- # validate_stopping_criteria,
1427
- # StoppingCriteriaList,
1428
- # )
1429
-
1430
- # logger = logging.get_logger(__name__)
1431
-
1432
- # def custom_generate(
1433
- # self,
1434
- # input_ids,
1435
- # attention_mask=None,
1436
- # max_length=None,
1437
- # min_length=None,
1438
- # do_sample=None,
1439
- # early_stopping=None,
1440
- # num_beams=None,
1441
- # temperature=None,
1442
- # top_k=None,
1443
- # top_p=None,
1444
- # repetition_penalty=None,
1445
- # bad_words_ids=None,
1446
- # bos_token_id=None,
1447
- # pad_token_id=None,
1448
- # eos_token_id=None,
1449
- # streamer=None,
1450
- # length_penalty=None,
1451
- # no_repeat_ngram_size=None,
1452
- # num_return_sequences=None,
1453
- # decoder_start_token_id=None,
1454
- # use_cache=None,
1455
- # num_beam_groups=None,
1456
- # diversity_penalty=None,
1457
- # prefix_allowed_tokens_fn=None,
1458
- # output_attentions=None,
1459
- # output_hidden_states=None,
1460
- # output_scores=None,
1461
- # return_dict_in_generate=None,
1462
- # forced_bos_token_id=None,
1463
- # forced_eos_token_id=None,
1464
- # remove_invalid_values=None,
1465
- # synced_gpus=None,
1466
- # **kwargs,
1467
- # ):
1468
- # with torch.no_grad():
1469
- # finished_generating = torch.zeros(len(input_ids), dtype=torch.bool, device=input_ids.device)
1470
-
1471
- # while not finished_generating.all() and input_ids.shape[1] < max_length:
1472
- # # Sample the next token
1473
- # new_ids = self(
1474
- # input_ids[~finished_generating],
1475
- # attention_mask=attention_mask[~finished_generating] if attention_mask is not None else None,
1476
- # **kwargs
1477
- # )['logits']
1478
-
1479
- # # Mask out the start and end thought tokens so we don't accidentally sample them
1480
- # new_ids[:, :, self.tokenizer.vocab_size:] = -float("inf")
1481
-
1482
- # for list_idx, answer_idx in enumerate((~finished_generating).nonzero(as_tuple=True)[0]):
1483
- # # Find the index of the last token that is not padding
1484
- # base_answer_ids = input_ids[answer_idx]
1485
- # new_answer_ids = new_ids[list_idx]
1486
- # last_token_idx = (base_answer_ids != self.tokenizer.pad_token_id).nonzero(as_tuple=True)[0].max()
1487
-
1488
- # new_ids_sampled = torch.multinomial(
1489
- # torch.nn.functional.softmax(new_answer_ids[last_token_idx] / temperature, dim=-1), 1)
1490
-
1491
- # # Assign the new id to the last token
1492
- # if last_token_idx + 1 >= len(base_answer_ids):
1493
- # # Add padding everywhere
1494
- # new_padding = torch.full((len(input_ids), 1), self.tokenizer.pad_token_id, dtype=torch.long,
1495
- # device=input_ids.device)
1496
- # input_ids = torch.cat([input_ids, new_padding], dim=-1)
1497
- # if attention_mask is not None:
1498
- # attention_mask = torch.cat([attention_mask, torch.zeros_like(new_padding)], dim=-1)
1499
-
1500
- # if attention_mask is not None:
1501
- # attention_mask[answer_idx, last_token_idx + 1] = 1
1502
- # input_ids[answer_idx, last_token_idx + 1] = new_ids_sampled
1503
-
1504
- # if new_ids_sampled == self.tokenizer.eos_token_id or new_ids_sampled == self.tokenizer.bos_token_id or new_ids_sampled == self.tokenizer.pad_token_id:
1505
- # finished_generating[answer_idx] = 1
1506
-
1507
- # # Check if the end token is generated
1508
- # if new_ids_sampled == self.tokenizer.convert_tokens_to_ids("</s>"):
1509
- # finished_generating[answer_idx] = 1
1510
-
1511
- # if streamer is not None:
1512
- # streamer.put(new_ids_sampled)
1513
-
1514
- # generated_token_ids = input_ids.tolist()
1515
-
1516
- # return generated_token_ids
1517
-
1518
-
1519
- # def use_generate(
1520
- # self,
1521
- # input_ids,
1522
- # attention_mask=None,
1523
- # max_length=None,
1524
- # min_length=None,
1525
- # do_sample=None,
1526
- # early_stopping=None,
1527
- # num_beams=None,
1528
- # temperature=None,
1529
- # streamer=None,
1530
- # top_k=None,
1531
- # top_p=None,
1532
- # repetition_penalty=None,
1533
- # bad_words_ids=None,
1534
- # bos_token_id=None,
1535
- # pad_token_id=None,
1536
- # eos_token_id=None,
1537
- # length_penalty=None,
1538
- # no_repeat_ngram_size=None,
1539
- # num_return_sequences=None,
1540
- # decoder_start_token_id=None,
1541
- # use_cache=None,
1542
- # num_beam_groups=None,
1543
- # diversity_penalty=None,
1544
- # prefix_allowed_tokens_fn=None,
1545
- # output_attentions=None,
1546
- # output_hidden_states=None,
1547
- # output_scores=None,
1548
- # return_dict_in_generate=None,
1549
- # forced_bos_token_id=None,
1550
- # forced_eos_token_id=None,
1551
- # remove_invalid_values=None,
1552
- # synced_gpus=None,
1553
- # n_ahead=8,
1554
- # n_ahead_talk=4,
1555
- # merged_talk_heads=True,
1556
- # merged_lm_and_talk_heads=False,
1557
- # merged_lm_and_think_heads=True,
1558
- # use_concat_talk_head=True,
1559
- # use_shallow_think=True,
1560
- # use_shallow_talk=False,
1561
- # use_complex_think_head=False,
1562
- # use_complex_talk_head=True,
1563
- # use_weighted_talk_head=True,
1564
- # trust_remote_code=True,
1565
- # torch_dtype=torch.bfloat16,
1566
- # **model_kwargs,
1567
- # ):
1568
- # # Set model attributes
1569
- # self.max_thoughts = n_ahead + n_ahead_talk + 1
1570
- # self.merged_talk_heads = merged_talk_heads
1571
- # self.merged_lm_and_talk_heads = merged_lm_and_talk_heads
1572
- # self.merged_lm_and_think_heads = merged_lm_and_think_heads
1573
- # self.use_concat_talk_head = use_concat_talk_head
1574
- # self.use_shallow_think = use_shallow_think
1575
- # self.use_shallow_talk = use_shallow_talk
1576
- # self.use_complex_think_head = use_complex_think_head
1577
- # self.use_complex_talk_head = use_complex_talk_head
1578
- # self.use_weighted_talk_head = use_weighted_talk_head
1579
-
1580
- # # Set model properties
1581
- # self.use_end_thought_token = True
1582
- # self.use_start_thought_token = True
1583
- # self.wandb_enabled = True
1584
- # self.n_ahead = n_ahead
1585
- # self.n_passes = 1
1586
- # self.eval_mode = True
1587
- # self.first_run = False
1588
- # self.kill_after = 100
1589
- # self.rm_initialized = True
1590
- # self.original_mode = False
1591
-
1592
- # # Generate using the custom generate function
1593
- # generated_token_ids = custom_generate(
1594
- # self,
1595
- # input_ids=input_ids,
1596
- # attention_mask=attention_mask,
1597
- # max_length=max_length,
1598
- # min_length=min_length,
1599
- # do_sample=do_sample,
1600
- # early_stopping=early_stopping,
1601
- # num_beams=num_beams,
1602
- # temperature=temperature,
1603
- # top_k=top_k,
1604
- # top_p=top_p,
1605
- # repetition_penalty=repetition_penalty,
1606
- # bad_words_ids=bad_words_ids,
1607
- # bos_token_id=bos_token_id,
1608
- # pad_token_id=pad_token_id,
1609
- # eos_token_id=eos_token_id,
1610
- # length_penalty=length_penalty,
1611
- # no_repeat_ngram_size=no_repeat_ngram_size,
1612
- # num_return_sequences=num_return_sequences,
1613
- # decoder_start_token_id=decoder_start_token_id,
1614
- # use_cache=use_cache,
1615
- # num_beam_groups=num_beam_groups,
1616
- # diversity_penalty=diversity_penalty,
1617
- # prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
1618
- # output_attentions=output_attentions,
1619
- # output_hidden_states=output_hidden_states,
1620
- # output_scores=output_scores,
1621
- # return_dict_in_generate=return_dict_in_generate,
1622
- # forced_bos_token_id=forced_bos_token_id,
1623
- # forced_eos_token_id=forced_eos_token_id,
1624
- # remove_invalid_values=remove_invalid_values,
1625
- # synced_gpus=synced_gpus,
1626
- # streamer=streamer,
1627
- # **model_kwargs,
1628
- # )
1629
-
1630
- # return generated_token_ids
1631
-
1632
-
1633
- # def generate(self, input_ids, attention_mask=None, max_length=None, temperature=1.0, **kwargs):
1634
- # from .generate import generate
1635
- # return generate(self, input_ids=input_ids, attention_mask=attention_mask, max_length=max_length, temperature=temperature, **kwargs)
1636
  @add_start_docstrings_to_model_forward(QUIET_INPUTS_DOCSTRING)
1637
  @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
1638
  def forward(
@@ -1648,7 +1443,6 @@ class QuietForCausalLM(QuietPreTrainedModel, GenerationMixin):
1648
  output_attentions: Optional[bool] = None,
1649
  output_hidden_states: Optional[bool] = None,
1650
  return_dict: Optional[bool] = None,
1651
- streamer: Optional[TextStreamer] = None,
1652
  ) -> Union[Tuple, CausalLMOutputWithPast]:
1653
  r"""
1654
  Args:
@@ -1822,17 +1616,15 @@ class QuietForCausalLM(QuietPreTrainedModel, GenerationMixin):
1822
  sample_probs_history = []
1823
  action_loglikelihoods_list = []
1824
 
1825
-
1826
- # complexity_scores = self.compute_complexity_scores(input_ids, attention_mask)
1827
- temperature = self.temperature #* complexity_scores.unsqueeze(-1)
1828
 
1829
  if self.use_end_thought_token or self.use_start_thought_token:
1830
  if not self.use_reparam_for_thought_embeddings:
1831
- start_embedding = self.start_embedding[0].unsqueeze(0) * self.embedding_scale
1832
- end_embedding = self.end_embedding[0].unsqueeze(0) * self.embedding_scale
1833
  else:
1834
- start_embedding = self.start_embedding * self.embedding_scale
1835
- end_embedding = self.end_embedding * self.embedding_scale
1836
  base_embeddings = self.model.embed_tokens.weight
1837
  if self.train_only_thinking_embedding:
1838
  base_embeddings = base_embeddings.detach()
@@ -2328,6 +2120,7 @@ class QuietForCausalLM(QuietPreTrainedModel, GenerationMixin):
2328
  del start_embedding
2329
  del end_embedding
2330
  torch.cuda.empty_cache()
 
2331
 
2332
  return CausalLMOutputWithPast(
2333
  loss=loss if loss is not None else None,
@@ -2336,6 +2129,8 @@ class QuietForCausalLM(QuietPreTrainedModel, GenerationMixin):
2336
  hidden_states=outputs.hidden_states,
2337
  attentions=outputs.attentions,
2338
  )
 
 
2339
 
2340
  def prepare_inputs_for_generation(
2341
  self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
 
20
  """ PyTorch Quiet model."""
21
  import inspect
22
  import math
23
+ import pdb
24
  import warnings
25
  from collections import defaultdict
26
  from typing import List, Optional, Tuple, Union
 
32
  from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
33
  from transformers.generation.utils import GenerationMixin
34
  from transformers.generation.stopping_criteria import StoppingCriteriaList, validate_stopping_criteria
35
+ from transformers import TextStreamer, AutoTokenizer
36
+
37
  from transformers.activations import ACT2FN
38
  from transformers.cache_utils import Cache, DynamicCache
39
  from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask
 
48
  replace_return_docstrings,
49
  )
50
  from .configuration_quiet import QuietConfig
51
+
52
  import time
53
  from typing import Optional, List
54
 
 
354
  f" {attn_weights.size()}"
355
  )
356
  if self._attn_implementation == "flash_attention_2":
357
+ # Prepare attention mask for flash-attn
358
  attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
359
+ elif self._attn_implementation == "sdpa":
360
+ # Prepare attention mask for SDPA
361
+ if attention_mask is None or attention_mask.dim() == 2:
362
+ attention_mask = _prepare_4d_causal_attention_mask(
363
+ attention_mask,
364
+ (batch_size, seq_length),
365
+ inputs_embeds,
366
+ past_key_values_length,
367
+ sliding_window=self.config.sliding_window,
368
+ )
369
+ else:
370
+ # Prepare attention mask for other implementations
371
+ if attention_mask is None or attention_mask.dim() == 2:
372
+ attention_mask = _prepare_4d_causal_attention_mask(
373
+ attention_mask,
374
+ (batch_size, seq_length),
375
+ inputs_embeds,
376
+ past_key_values_length,
377
+ sliding_window=self.config.sliding_window,
378
+ )
379
 
380
  if attention_mask is not None:
381
  if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
 
774
  raise ValueError(
775
  f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
776
  )
777
+ attention_mask = attention_mask.to(query_states.dtype)
778
  # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
779
  # Reference: https://github.com/pytorch/pytorch/issues/112577.
780
  if query_states.device.type == "cuda" and attention_mask is not None:
 
786
  query_states,
787
  key_states,
788
  value_states,
789
+ attn_mask=attention_mask.to(query_states.device) if attention_mask is not None else None,
790
  dropout_p=self.attention_dropout if self.training else 0.0,
791
  # The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1.
792
  is_causal=self.is_causal and attention_mask is None and q_len > 1,
 
1071
  if self._attn_implementation == "flash_attention_2":
1072
  # 2d mask is passed through the layers
1073
  attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
1074
+ elif self._attn_implementation == "sdpa" and not output_attentions and attention_mask.dim() == 2 and False:
1075
  # output_attentions=True can not be supported when using SDPA, and we fall back on
1076
  # the manual implementation that requires a 4D causal mask in all cases.
1077
  attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
 
1080
  inputs_embeds,
1081
  past_key_values_length,
1082
  )
1083
+ elif attention_mask is None or attention_mask.dim() == 2:
1084
  # 4d mask is passed through the layers
1085
+ attention_mask = _prepare_4d_causal_attention_mask(
1086
+ attention_mask,
1087
+ (batch_size, seq_length),
1088
+ inputs_embeds,
1089
+ past_key_values_length,
1090
+ sliding_window=self.config.sliding_window,
1091
+ )
 
1092
 
1093
  hidden_states = inputs_embeds
1094
 
 
1310
  elif isinstance(module, nn.Embedding):
1311
  nn.init.xavier_uniform_(module.weight)
1312
 
 
1313
  @torch.no_grad()
1314
  def infer(
1315
  self,
 
1342
  continuation_length = self.n_ahead - 2
1343
  new_key_values = past_key_values
1344
 
1345
+ # Initialize next_token_id with a default value
1346
+ next_token_id = torch.zeros(batch_size, dtype=torch.long).to(input_ids.device)
1347
+
1348
  start_time = time.time()
1349
  for continuation_idx in range(continuation_length):
1350
  outputs = self.model(
 
1370
  next_token_id = torch.argmax(next_token_logits, dim=-1)
1371
 
1372
  # Append the generated token to the input sequence
1373
+ # input_ids = torch.cat([input_ids, next_token_id.unsqueeze(-1).to(input_ids.device)], dim=-1)
1374
  seq_len += 1
1375
 
1376
  # Update the attention mask
 
1402
 
1403
  # two new tokens: last continuation token and end thought token
1404
  outputs_after = self.model(
1405
+ input_ids=torch.cat([next_token_id.unsqueeze(-1).to(input_ids.device), torch.tensor([[end_thought_token_id]] * batch_size).to(input_ids.device)], dim=-1),
1406
+ attention_mask=torch.cat([attention_mask[:, -1:], torch.ones((batch_size, 1)).to(attention_mask.device)], dim=-1),
1407
  position_ids=position_ids,
1408
  past_key_values=new_key_values,
1409
  inputs_embeds=inputs_embeds,
 
1424
  logits = self.lm_head(mixed_hidden_states)
1425
  return logits
1426
 
1427
+ def generate(self, input_ids, attention_mask=None, max_length=None, temperature=1.0, **kwargs):
1428
+ from .generate import generate
1429
+ return generate(self, input_ids=input_ids, attention_mask=attention_mask, max_length=max_length, temperature=temperature, **kwargs)
1430
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1431
  @add_start_docstrings_to_model_forward(QUIET_INPUTS_DOCSTRING)
1432
  @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
1433
  def forward(
 
1443
  output_attentions: Optional[bool] = None,
1444
  output_hidden_states: Optional[bool] = None,
1445
  return_dict: Optional[bool] = None,
 
1446
  ) -> Union[Tuple, CausalLMOutputWithPast]:
1447
  r"""
1448
  Args:
 
1616
  sample_probs_history = []
1617
  action_loglikelihoods_list = []
1618
 
1619
+ temperature = self.temperature
 
 
1620
 
1621
  if self.use_end_thought_token or self.use_start_thought_token:
1622
  if not self.use_reparam_for_thought_embeddings:
1623
+ start_embedding = self.start_embedding[0].unsqueeze(0) * self.embedding_scale * temperature
1624
+ end_embedding = self.end_embedding[0].unsqueeze(0) * self.embedding_scale * temperature
1625
  else:
1626
+ start_embedding = self.start_embedding * self.embedding_scale * temperature
1627
+ end_embedding = self.end_embedding * self.embedding_scale * temperature
1628
  base_embeddings = self.model.embed_tokens.weight
1629
  if self.train_only_thinking_embedding:
1630
  base_embeddings = base_embeddings.detach()
 
2120
  del start_embedding
2121
  del end_embedding
2122
  torch.cuda.empty_cache()
2123
+
2124
 
2125
  return CausalLMOutputWithPast(
2126
  loss=loss if loss is not None else None,
 
2129
  hidden_states=outputs.hidden_states,
2130
  attentions=outputs.attentions,
2131
  )
2132
+
2133
+
2134
 
2135
  def prepare_inputs_for_generation(
2136
  self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs