Crystalcareai commited on
Commit
ef8c51a
·
verified ·
1 Parent(s): 9e5308f

Upload 3 files

Browse files
Files changed (3) hide show
  1. config.json +14 -4
  2. configuration_quiet.py +26 -5
  3. modeling_quiet.py +1050 -107
config.json CHANGED
@@ -1,7 +1,9 @@
1
  {
 
2
  "architectures": [
3
  "QuietForCausalLM"
4
  ],
 
5
  "auto_map": {
6
  "AutoConfig": "configuration_quiet.QuietConfig",
7
  "AutoModel": "modeling_quiet.QuietModel",
@@ -14,9 +16,11 @@
14
  "initializer_range": 0.02,
15
  "intermediate_size": 14336,
16
  "max_position_embeddings": 32768,
 
 
 
 
17
  "model_type": "quiet",
18
- "max_thoughts": 3,
19
- "thought_length": 10,
20
  "num_attention_heads": 32,
21
  "num_hidden_layers": 32,
22
  "num_key_value_heads": 8,
@@ -25,7 +29,13 @@
25
  "sliding_window": 4096,
26
  "tie_word_embeddings": false,
27
  "torch_dtype": "bfloat16",
28
- "transformers_version": "4.34.0.dev0",
29
  "use_cache": true,
30
- "vocab_size": 32000
 
 
 
 
 
 
31
  }
 
1
  {
2
+ "_name_or_path": "Crystalcareai/Quiet-Star-Custom",
3
  "architectures": [
4
  "QuietForCausalLM"
5
  ],
6
+ "attention_dropout": 0.0,
7
  "auto_map": {
8
  "AutoConfig": "configuration_quiet.QuietConfig",
9
  "AutoModel": "modeling_quiet.QuietModel",
 
16
  "initializer_range": 0.02,
17
  "intermediate_size": 14336,
18
  "max_position_embeddings": 32768,
19
+ "max_thoughts": 10,
20
+ "merged_lm_and_talk_heads": false,
21
+ "merged_lm_and_think_heads": true,
22
+ "merged_talk_heads": true,
23
  "model_type": "quiet",
 
 
24
  "num_attention_heads": 32,
25
  "num_hidden_layers": 32,
26
  "num_key_value_heads": 8,
 
29
  "sliding_window": 4096,
30
  "tie_word_embeddings": false,
31
  "torch_dtype": "bfloat16",
32
+ "transformers_version": "4.37.0.dev0",
33
  "use_cache": true,
34
+ "use_complex_talk_head": true,
35
+ "use_complex_think_head": false,
36
+ "use_concat_talk_head": true,
37
+ "use_shallow_talk": false,
38
+ "use_shallow_think": true,
39
+ "use_weighted_talk_head": true,
40
+ "vocab_size": 32002
41
  }
configuration_quiet.py CHANGED
@@ -20,6 +20,11 @@ from transformers.utils import logging
20
 
21
  logger = logging.get_logger(__name__)
22
 
 
 
 
 
 
23
 
24
  class QuietConfig(PretrainedConfig):
25
  r"""
@@ -111,13 +116,21 @@ class QuietConfig(PretrainedConfig):
111
  use_cache=True,
112
  pad_token_id=None,
113
  bos_token_id=1,
114
- max_thoughts: int = 3,
115
- thought_length: int = 10,
116
  eos_token_id=2,
117
  tie_word_embeddings=False,
118
  rope_theta=10000.0,
119
  sliding_window=4096,
120
  attention_dropout=0.0,
 
 
 
 
 
 
 
 
 
 
121
  **kwargs,
122
  ):
123
  self.vocab_size = vocab_size
@@ -137,10 +150,18 @@ class QuietConfig(PretrainedConfig):
137
  self.initializer_range = initializer_range
138
  self.rms_norm_eps = rms_norm_eps
139
  self.use_cache = use_cache
140
- self.max_thoughts = max_thoughts
141
- self.thought_length = thought_length
142
  self.rope_theta = rope_theta
143
  self.attention_dropout = attention_dropout
 
 
 
 
 
 
 
 
 
 
144
 
145
  super().__init__(
146
  pad_token_id=pad_token_id,
@@ -148,4 +169,4 @@ class QuietConfig(PretrainedConfig):
148
  eos_token_id=eos_token_id,
149
  tie_word_embeddings=tie_word_embeddings,
150
  **kwargs,
151
- )
 
20
 
21
  logger = logging.get_logger(__name__)
22
 
23
+ QUIET_PRETRAINED_CONFIG_ARCHIVE_MAP = {
24
+ "quietai/Quiet-7B-v0.1": "https://huggingface.co/quietai/Quiet-7B-v0.1/resolve/main/config.json",
25
+ "quietai/Quiet-7B-Instruct-v0.1": "https://huggingface.co/quietai/Quiet-7B-Instruct-v0.1/resolve/main/config.json",
26
+ }
27
+
28
 
29
  class QuietConfig(PretrainedConfig):
30
  r"""
 
116
  use_cache=True,
117
  pad_token_id=None,
118
  bos_token_id=1,
 
 
119
  eos_token_id=2,
120
  tie_word_embeddings=False,
121
  rope_theta=10000.0,
122
  sliding_window=4096,
123
  attention_dropout=0.0,
124
+ max_thoughts=16,
125
+ merged_talk_heads=True,
126
+ merged_lm_and_talk_heads=False,
127
+ merged_lm_and_think_heads=True,
128
+ use_concat_talk_head=True,
129
+ use_shallow_think=True,
130
+ use_shallow_talk=False,
131
+ use_complex_think_head=False,
132
+ use_complex_talk_head=True,
133
+ use_weighted_talk_head=True,
134
  **kwargs,
135
  ):
136
  self.vocab_size = vocab_size
 
150
  self.initializer_range = initializer_range
151
  self.rms_norm_eps = rms_norm_eps
152
  self.use_cache = use_cache
 
 
153
  self.rope_theta = rope_theta
154
  self.attention_dropout = attention_dropout
155
+ self.max_thoughts = max_thoughts
156
+ self.merged_talk_heads = merged_talk_heads
157
+ self.merged_lm_and_talk_heads = merged_lm_and_talk_heads
158
+ self.merged_lm_and_think_heads = merged_lm_and_think_heads
159
+ self.use_concat_talk_head = use_concat_talk_head
160
+ self.use_shallow_think = use_shallow_think
161
+ self.use_shallow_talk = use_shallow_talk
162
+ self.use_complex_think_head = use_complex_think_head
163
+ self.use_complex_talk_head = use_complex_talk_head
164
+ self.use_weighted_talk_head = use_weighted_talk_head
165
 
166
  super().__init__(
167
  pad_token_id=pad_token_id,
 
169
  eos_token_id=eos_token_id,
170
  tie_word_embeddings=tie_word_embeddings,
171
  **kwargs,
172
+ )
modeling_quiet.py CHANGED
@@ -20,7 +20,20 @@
20
  """ PyTorch Quiet model."""
21
  import inspect
22
  import math
 
 
 
 
 
 
 
 
 
 
 
 
23
  import warnings
 
24
  from typing import List, Optional, Tuple, Union
25
 
26
  import torch
@@ -56,13 +69,79 @@ logger = logging.get_logger(__name__)
56
 
57
  _CONFIG_FOR_DOC = "QuietConfig"
58
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
59
 
60
  # Copied from transformers.models.llama.modeling_llama._get_unpad_data
61
  def _get_unpad_data(attention_mask):
62
  seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
63
  indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
64
  max_seqlen_in_batch = seqlens_in_batch.max().item()
65
- cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
66
  return (
67
  indices,
68
  cu_seqlens,
@@ -85,11 +164,10 @@ class QuietRMSNorm(nn.Module):
85
  hidden_states = hidden_states.to(torch.float32)
86
  variance = hidden_states.pow(2).mean(-1, keepdim=True)
87
  hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
88
- return self.weight * hidden_states.to(input_dtype)
89
 
90
 
91
- # copied from transformers.models.llama.modeling_llama.LlamaRotaryEmbedding with Llama->Quiet
92
- # TODO @Arthur no longer copied from LLama after static cache
93
  class QuietRotaryEmbedding(nn.Module):
94
  def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
95
  super().__init__()
@@ -97,7 +175,7 @@ class QuietRotaryEmbedding(nn.Module):
97
  self.dim = dim
98
  self.max_position_embeddings = max_position_embeddings
99
  self.base = base
100
- inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim))
101
  self.register_buffer("inv_freq", inv_freq, persistent=False)
102
 
103
  # Build here to make `torch.jit.trace` work.
@@ -107,7 +185,7 @@ class QuietRotaryEmbedding(nn.Module):
107
 
108
  def _set_cos_sin_cache(self, seq_len, device, dtype):
109
  self.max_seq_len_cached = seq_len
110
- t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.int64).type_as(self.inv_freq)
111
 
112
  freqs = torch.outer(t, self.inv_freq)
113
  # Different from paper, but it uses a different permutation in order to obtain the same calculation
@@ -134,8 +212,7 @@ def rotate_half(x):
134
  return torch.cat((-x2, x1), dim=-1)
135
 
136
 
137
- # copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb
138
- # TODO @Arthur no longer copied from LLama after static cache
139
  def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
140
  """Applies Rotary Position Embedding to the query and key tensors.
141
 
@@ -204,8 +281,8 @@ class QuietAttention(nn.Module):
204
  self.layer_idx = layer_idx
205
  if layer_idx is None:
206
  logger.warning_once(
207
- f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will "
208
- "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` "
209
  "when creating this class."
210
  )
211
 
@@ -496,7 +573,7 @@ class QuietFlashAttention2(QuietAttention):
496
  attention_mask (`torch.Tensor`):
497
  The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the
498
  position of padding tokens and 1 for the position of non-padding tokens.
499
- dropout (`float`):
500
  Attention dropout
501
  softmax_scale (`float`, *optional*):
502
  The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
@@ -614,8 +691,7 @@ class QuietFlashAttention2(QuietAttention):
614
  )
615
 
616
 
617
- # copied from transformers.models.llama.modeling_llama.LlamaSdpaAttention with Llama->Quiet
618
- # TODO @Arthur no longer copied from LLama after static cache
619
  class QuietSdpaAttention(QuietAttention):
620
  """
621
  Quiet attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
@@ -689,14 +765,14 @@ class QuietSdpaAttention(QuietAttention):
689
  query_states,
690
  key_states,
691
  value_states,
692
- attn_mask=attention_mask,
693
  dropout_p=self.attention_dropout if self.training else 0.0,
694
  # The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1.
695
  is_causal=self.is_causal and attention_mask is None and q_len > 1,
696
  )
697
 
698
  attn_output = attn_output.transpose(1, 2).contiguous()
699
- attn_output = attn_output.view(bsz, q_len, self.hidden_size)
700
 
701
  attn_output = self.o_proj(attn_output)
702
 
@@ -762,7 +838,7 @@ class QuietDecoderLayer(nn.Module):
762
  output_attentions=output_attentions,
763
  use_cache=use_cache,
764
  )
765
- hidden_states = residual + hidden_states
766
 
767
  # Fully Connected
768
  residual = hidden_states
@@ -928,35 +1004,6 @@ class QuietModel(QuietPreTrainedModel):
928
  def set_input_embeddings(self, value):
929
  self.embed_tokens = value
930
 
931
- def _generate_thoughts(self, hidden_states, max_length):
932
- thought_ids = []
933
- thought_embeddings = []
934
-
935
- for _ in range(self.config.max_thoughts):
936
- thought_id = torch.LongTensor([[self.config.start_token_id]]).to(hidden_states.device)
937
- thought_embedding = self.embed_tokens(thought_id)
938
-
939
- for _ in range(max_length):
940
- outputs = self.forward(
941
- inputs_embeds=thought_embedding,
942
- attention_mask=None,
943
- use_cache=True,
944
- )
945
- logits = outputs.logits[:, -1, :]
946
- next_token_id = torch.argmax(logits, dim=-1)
947
-
948
- if next_token_id == self.config.end_token_id:
949
- break
950
-
951
- thought_id = torch.cat([thought_id, next_token_id.unsqueeze(0)], dim=-1)
952
- thought_embedding = torch.cat([thought_embedding, self.embed_tokens(next_token_id.unsqueeze(0))], dim=1)
953
-
954
- thought_ids.append(thought_id.squeeze(0))
955
- thought_embeddings.append(thought_embedding.squeeze(0))
956
-
957
- return thought_ids, thought_embeddings
958
-
959
-
960
  @add_start_docstrings_to_model_forward(QUIET_INPUTS_DOCSTRING)
961
  def forward(
962
  self,
@@ -1027,7 +1074,7 @@ class QuietModel(QuietPreTrainedModel):
1027
  if self._attn_implementation == "flash_attention_2":
1028
  # 2d mask is passed through the layers
1029
  attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
1030
- elif self._attn_implementation == "sdpa" and not output_attentions:
1031
  # output_attentions=True can not be supported when using SDPA, and we fall back on
1032
  # the manual implementation that requires a 4D causal mask in all cases.
1033
  attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
@@ -1036,7 +1083,7 @@ class QuietModel(QuietPreTrainedModel):
1036
  inputs_embeds,
1037
  past_key_values_length,
1038
  )
1039
- else:
1040
  # 4d mask is passed through the layers
1041
  attention_mask = _prepare_4d_causal_attention_mask(
1042
  attention_mask,
@@ -1104,37 +1151,132 @@ class QuietModel(QuietPreTrainedModel):
1104
  attentions=all_self_attns,
1105
  )
1106
 
 
 
 
 
 
 
 
1107
 
1108
  class QuietForCausalLM(QuietPreTrainedModel):
 
 
1109
  def __init__(self, config):
1110
  super().__init__(config)
1111
  self.model = QuietModel(config)
 
1112
  self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
1113
- self.mixing_head = nn.Sequential(
1114
- nn.Linear(config.hidden_size * 2, config.hidden_size),
1115
- nn.ReLU(),
1116
- nn.Linear(config.hidden_size, 1),
1117
- )
1118
-
1119
  self.max_thoughts = config.max_thoughts
1120
- self.thought_length = config.thought_length
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1121
  self.use_policy_loss = True
 
 
1122
  self.remove_negative_rewards = True
 
1123
 
1124
- self.post_init()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1125
 
1126
- def calculate_policy_loss(self, thoughts, rewards):
1127
- thought_log_probs = []
1128
- for thought in thoughts:
1129
- thought_log_prob = self.lm_head(thought).log_softmax(dim=-1)
1130
- thought_log_probs.append(thought_log_prob)
1131
-
1132
- thought_log_probs = torch.stack(thought_log_probs, dim=1) # (batch_size, num_thoughts, seq_length, vocab_size)
1133
- thought_probs = torch.exp(thought_log_probs)
1134
-
1135
- policy_loss = -torch.mean(thought_log_probs * rewards.unsqueeze(-1).unsqueeze(-1))
1136
-
1137
- return policy_loss
1138
 
1139
  def get_input_embeddings(self):
1140
  return self.model.embed_tokens
@@ -1154,6 +1296,125 @@ class QuietForCausalLM(QuietPreTrainedModel):
1154
  def get_decoder(self):
1155
  return self.model
1156
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1157
  @add_start_docstrings_to_model_forward(QUIET_INPUTS_DOCSTRING)
1158
  @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
1159
  def forward(
@@ -1194,6 +1455,16 @@ class QuietForCausalLM(QuietPreTrainedModel):
1194
  >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
1195
  "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
1196
  ```"""
 
 
 
 
 
 
 
 
 
 
1197
 
1198
  output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1199
  output_hidden_states = (
@@ -1201,58 +1472,730 @@ class QuietForCausalLM(QuietPreTrainedModel):
1201
  )
1202
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1203
 
1204
- # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
1205
- outputs = self.model(
1206
- input_ids,
1207
- attention_mask=attention_mask,
1208
- position_ids=position_ids,
1209
- past_key_values=past_key_values,
1210
- inputs_embeds=inputs_embeds,
1211
- use_cache=use_cache,
1212
- output_attentions=output_attentions,
1213
- output_hidden_states=output_hidden_states,
1214
- return_dict=return_dict,
1215
- )
1216
-
1217
- hidden_states = outputs.last_hidden_state
1218
- base_logits = self.lm_head(hidden_states)
1219
 
1220
- thought_ids, thought_embeddings = self.model._generate_thoughts(hidden_states, max_length=self.thought_length)
1221
- thought_hidden_states = self.model(inputs_embeds=thought_embeddings).last_hidden_state
1222
- thought_logits = self.lm_head(thought_hidden_states)
1223
 
1224
- mixing_input = torch.cat([hidden_states, thought_hidden_states], dim=-1)
1225
- mixing_weights = self.mixing_head(mixing_input).squeeze(-1) # (batch_size, seq_length)
1226
- mixed_logits = base_logits * (1 - mixing_weights.unsqueeze(-1)) + thought_logits * mixing_weights.unsqueeze(-1)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1227
 
1228
  loss = None
1229
- if labels is not None:
1230
- # Shift so that tokens < n predict n
1231
- shift_logits = mixed_logits[..., :-1, :].contiguous()
1232
- shift_labels = labels[..., 1:].contiguous()
1233
- # Flatten the tokens
1234
- loss_fct = CrossEntropyLoss()
1235
- loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
1236
-
1237
- if self.use_policy_loss:
1238
- rewards = loss.detach().unsqueeze(1).repeat(1, self.max_thoughts)
1239
- if self.remove_negative_rewards:
1240
- rewards = torch.clamp(rewards, min=0)
1241
- policy_loss = self.calculate_policy_loss(thought_ids, rewards)
1242
- loss = loss + policy_loss
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1243
 
1244
  if not return_dict:
1245
- output = (mixed_logits,) + outputs[1:]
1246
- return ((loss,) + output) if loss is not None else output
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1247
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1248
  return CausalLMOutputWithPast(
1249
- loss=loss,
1250
- logits=mixed_logits,
1251
  past_key_values=outputs.past_key_values,
1252
  hidden_states=outputs.hidden_states,
1253
  attentions=outputs.attentions,
1254
  )
1255
 
 
1256
  def prepare_inputs_for_generation(
1257
  self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
1258
  ):
@@ -1268,7 +2211,7 @@ class QuietForCausalLM(QuietPreTrainedModel):
1268
 
1269
  # Keep only the unprocessed tokens:
1270
  # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
1271
- # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as
1272
  # input)
1273
  if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]:
1274
  input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
 
20
  """ PyTorch Quiet model."""
21
  import inspect
22
  import math
23
+ import copy
24
+ import os
25
+ import time
26
+ import pandas as pd
27
+ import seaborn as sns
28
+ import matplotlib.pyplot as plt
29
+ import wandb
30
+ from termcolor import colored
31
+ from tqdm import tqdm
32
+ import random
33
+ import numpy as np
34
+ from matplotlib.colors import LinearSegmentedColormap, LogNorm
35
  import warnings
36
+ from collections import defaultdict
37
  from typing import List, Optional, Tuple, Union
38
 
39
  import torch
 
69
 
70
  _CONFIG_FOR_DOC = "QuietConfig"
71
 
72
+ from reportlab.pdfgen import canvas
73
+ from reportlab.lib.pagesizes import letter
74
+ from reportlab.lib.colors import HexColor
75
+
76
+ def save_tokens_with_rewards_to_pdf(input_ids, token_rewards, tokenizer, output_file="text.pdf", eps=0.2, eps2=0.5):
77
+ c = canvas.Canvas(output_file, pagesize=letter)
78
+ c.setFont("Courier", 8)
79
+ x, y = 50, 750
80
+ previous_text = ""
81
+ current_text = ""
82
+ for token_idx, reward in enumerate(token_rewards):
83
+ current_text = tokenizer.decode(input_ids[: token_idx + 1])
84
+ if current_text != previous_text:
85
+ diff_text = current_text[len(previous_text) :]
86
+ if "\n" in diff_text:
87
+ lines = diff_text.split("\n")
88
+ for line_idx, line in enumerate(lines):
89
+ if line_idx > 0:
90
+ x = 50
91
+ y -= 12
92
+ if abs(reward) < eps:
93
+ opacity = 0
94
+ elif abs(reward) > eps2:
95
+ opacity = 0.8
96
+ else:
97
+ opacity = 0.8 * (abs(reward) - eps) / (eps2 - eps)
98
+ text_width = c.stringWidth(line)
99
+ if reward > 0:
100
+ highlight_color = HexColor("#4CCD99")
101
+ else:
102
+ highlight_color = HexColor("#FFC700")
103
+ highlight_color.alpha = opacity
104
+ c.setFillColor(highlight_color)
105
+ c.rect(x, y - 2, text_width, 10, fill=True, stroke=False)
106
+ c.setFillColor(HexColor("#000000"))
107
+ c.drawString(x, y, line)
108
+ x += text_width
109
+ else:
110
+ if abs(reward) < eps:
111
+ opacity = 0
112
+ elif abs(reward) > eps2:
113
+ opacity = 0.8
114
+ else:
115
+ opacity = 0.8 * (abs(reward) - eps) / (eps2 - eps)
116
+ text_width = c.stringWidth(diff_text)
117
+ if reward > 0:
118
+ highlight_color = HexColor("#4CCD99")
119
+ else:
120
+ highlight_color = HexColor("#FFC700")
121
+ highlight_color.alpha = opacity
122
+ c.setFillColor(highlight_color)
123
+ c.rect(x, y - 2, text_width, 10, fill=True, stroke=False)
124
+ c.setFillColor(HexColor("#000000"))
125
+ c.drawString(x, y, diff_text)
126
+ x += text_width
127
+ if x > 550:
128
+ x = 50
129
+ y -= 12
130
+ if y < 50:
131
+ c.showPage()
132
+ y = 750
133
+ x = 50
134
+ previous_text = current_text
135
+ c.showPage()
136
+ c.save()
137
+
138
 
139
  # Copied from transformers.models.llama.modeling_llama._get_unpad_data
140
  def _get_unpad_data(attention_mask):
141
  seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
142
  indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
143
  max_seqlen_in_batch = seqlens_in_batch.max().item()
144
+ cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0))
145
  return (
146
  indices,
147
  cu_seqlens,
 
164
  hidden_states = hidden_states.to(torch.float32)
165
  variance = hidden_states.pow(2).mean(-1, keepdim=True)
166
  hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
167
+ return hidden_states.to(input_dtype) * self.weight.to(hidden_states.device)
168
 
169
 
170
+ # Copied from transformers.models.llama.modeling_llama.LlamaRotaryEmbedding with Llama->Quiet
 
171
  class QuietRotaryEmbedding(nn.Module):
172
  def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
173
  super().__init__()
 
175
  self.dim = dim
176
  self.max_position_embeddings = max_position_embeddings
177
  self.base = base
178
+ inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))
179
  self.register_buffer("inv_freq", inv_freq, persistent=False)
180
 
181
  # Build here to make `torch.jit.trace` work.
 
185
 
186
  def _set_cos_sin_cache(self, seq_len, device, dtype):
187
  self.max_seq_len_cached = seq_len
188
+ t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
189
 
190
  freqs = torch.outer(t, self.inv_freq)
191
  # Different from paper, but it uses a different permutation in order to obtain the same calculation
 
212
  return torch.cat((-x2, x1), dim=-1)
213
 
214
 
215
+ # Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb
 
216
  def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
217
  """Applies Rotary Position Embedding to the query and key tensors.
218
 
 
281
  self.layer_idx = layer_idx
282
  if layer_idx is None:
283
  logger.warning_once(
284
+ f"Instantiating {self.__class__.__name__} without passing `layer_idx` is not recommended and will "
285
+ "to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` "
286
  "when creating this class."
287
  )
288
 
 
573
  attention_mask (`torch.Tensor`):
574
  The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the
575
  position of padding tokens and 1 for the position of non-padding tokens.
576
+ dropout (`int`, *optional*):
577
  Attention dropout
578
  softmax_scale (`float`, *optional*):
579
  The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
 
691
  )
692
 
693
 
694
+ # Copied from transformers.models.llama.modeling_llama.LlamaSdpaAttention with Llama->Quiet
 
695
  class QuietSdpaAttention(QuietAttention):
696
  """
697
  Quiet attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
 
765
  query_states,
766
  key_states,
767
  value_states,
768
+ attn_mask=attention_mask.to(query_states.device) if attention_mask is not None else None,
769
  dropout_p=self.attention_dropout if self.training else 0.0,
770
  # The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1.
771
  is_causal=self.is_causal and attention_mask is None and q_len > 1,
772
  )
773
 
774
  attn_output = attn_output.transpose(1, 2).contiguous()
775
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
776
 
777
  attn_output = self.o_proj(attn_output)
778
 
 
838
  output_attentions=output_attentions,
839
  use_cache=use_cache,
840
  )
841
+ hidden_states = residual.to(hidden_states.device) + hidden_states
842
 
843
  # Fully Connected
844
  residual = hidden_states
 
1004
  def set_input_embeddings(self, value):
1005
  self.embed_tokens = value
1006
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1007
  @add_start_docstrings_to_model_forward(QUIET_INPUTS_DOCSTRING)
1008
  def forward(
1009
  self,
 
1074
  if self._attn_implementation == "flash_attention_2":
1075
  # 2d mask is passed through the layers
1076
  attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
1077
+ elif self._attn_implementation == "sdpa" and not output_attentions and attention_mask.dim() == 2 and False:
1078
  # output_attentions=True can not be supported when using SDPA, and we fall back on
1079
  # the manual implementation that requires a 4D causal mask in all cases.
1080
  attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
 
1083
  inputs_embeds,
1084
  past_key_values_length,
1085
  )
1086
+ elif attention_mask is None or attention_mask.dim() == 2:
1087
  # 4d mask is passed through the layers
1088
  attention_mask = _prepare_4d_causal_attention_mask(
1089
  attention_mask,
 
1151
  attentions=all_self_attns,
1152
  )
1153
 
1154
+ def nonzero_mean(x, axis=None):
1155
+ if axis is not None:
1156
+ return x.sum(axis) / (x != 0).sum(axis)
1157
+ return x.sum() / (x != 0).sum()
1158
+
1159
+ def loss_mean(x):
1160
+ return x.sum() / (x != 0).sum()
1161
 
1162
  class QuietForCausalLM(QuietPreTrainedModel):
1163
+ _tied_weights_keys = ["lm_head.weight"]
1164
+
1165
  def __init__(self, config):
1166
  super().__init__(config)
1167
  self.model = QuietModel(config)
1168
+ self.vocab_size = config.vocab_size
1169
  self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
 
 
 
 
 
 
1170
  self.max_thoughts = config.max_thoughts
1171
+ self.merged_lm_and_talk_heads = config.merged_lm_and_talk_heads
1172
+ self.use_concat_talk_head = config.use_concat_talk_head
1173
+ self.use_shallow_talk = config.use_shallow_talk
1174
+ self.use_complex_talk_head = config.use_complex_talk_head
1175
+ self.use_weighted_talk_head = config.use_weighted_talk_head
1176
+ # the weighted head will output a single value, so it can't be passed to the lm head
1177
+ assert not (self.use_weighted_talk_head and self.use_shallow_talk)
1178
+
1179
+ self.n_ahead = 1
1180
+ self.n_ahead_talk = 1
1181
+ self.n_passes = 1
1182
+ self.n_tokens_print = 1
1183
+ self.gradient_accumulation_steps = 1
1184
+ self.training_steps = 0
1185
+ self.tokenizer = None
1186
+ self.start_token_id = None
1187
+ self.end_token_id = None
1188
+ self.rm_initialized = False
1189
+ self.residual_talk_head = True
1190
+ self.thought_init_std_scale = 1e-2
1191
+
1192
+ self.final_only_mode = False
1193
+ self.first_and_last_mode = True
1194
+ self.first_only = False
1195
+ self.original_loss_weight = 0.5
1196
+
1197
+ self.cumulative_residual = False
1198
+ self.clever_residual = False
1199
+ self.skip_residual = False
1200
+ self.no_residual = True
1201
+
1202
+ self.optimize_lm_head_only_at_start = False
1203
+ self.optimize_model_only_at_start = False
1204
+
1205
+ if self.optimize_model_only_at_start:
1206
+ raise NotImplementedError
1207
+ self.train_only_thinking_embedding = False
1208
+ self.weighted_embeddings = False
1209
+ self.use_start_thought_token = True
1210
+ self.use_end_thought_token = True
1211
+ self.initialize_thought_embedding_to_normal = False
1212
+ self.initial_start_token = "---"
1213
+ self.initial_end_token = "---"
1214
+ self.output_logits_at_the_end = True
1215
+
1216
+ self.wandb_enabled = False
1217
+ self.gumbel_temperature = 0.001
1218
+
1219
  self.use_policy_loss = True
1220
+ self.include_policy_loss = True
1221
+ self.trice_mode = True
1222
  self.remove_negative_rewards = True
1223
+ self.use_policy_loss_for_end_thought = True
1224
 
1225
+ self.base_original_mode = False
1226
+ self.original_mode = False
1227
+
1228
+ self.thought_prefix = "(Let's think step by step"
1229
+ self.tokenized_thought_prefix = None
1230
+ self.log_dict = defaultdict(int)
1231
+ self.eval_log_dict = defaultdict(int)
1232
+ self.print_final_only = True
1233
+ self.loss_mean = loss_mean
1234
+ self.all_rewards = []
1235
+ self.all_unreduced_losses = []
1236
+ self.kill_after = 100
1237
+
1238
+ self.start_embedding = nn.Parameter(torch.zeros(2, self.model.config.hidden_size))
1239
+ self.end_embedding = nn.Parameter(torch.zeros(2, self.model.config.hidden_size))
1240
+
1241
+ self.policy_loss_beta = 1e6
1242
+ self.embedding_scale = 1e2
1243
+ self.reinforce_temperature = 3
1244
+ self.base_loss_beta = 1
1245
+
1246
+ # Not used in the paper:
1247
+ self.use_thought_prefix = False
1248
+ self.use_reparam_for_thought_embeddings = False
1249
+ self.use_upper_triangular = False
1250
+ self.subtract_mean_reward = False
1251
+ self.comparison_mode = False
1252
+ self.gumbel_detach = True
1253
+
1254
+ # For visualization
1255
+ self.eval_mode = False
1256
+
1257
+ num_talk = 1
1258
+ talk_input_dim = config.hidden_size if not self.use_concat_talk_head else config.hidden_size * 2
1259
+ if self.use_weighted_talk_head:
1260
+ talk_output_dim = 1
1261
+ else:
1262
+ talk_output_dim = config.hidden_size if self.use_shallow_talk else config.vocab_size
1263
+
1264
+ if not self.merged_lm_and_talk_heads:
1265
+ if self.use_complex_talk_head:
1266
+ self.talk_head = nn.ModuleList([nn.Sequential(
1267
+ nn.Linear(talk_input_dim, config.hidden_size),
1268
+ nn.ReLU(),
1269
+ nn.Linear(config.hidden_size, config.hidden_size),
1270
+ nn.ReLU(),
1271
+ nn.Linear(config.hidden_size, talk_output_dim, bias=False)
1272
+ )])
1273
+ else:
1274
+ self.talk_head = nn.ModuleList([nn.Sequential(
1275
+ nn.Linear(talk_input_dim, talk_output_dim, bias=False)
1276
+ )])
1277
 
1278
+ # Initialize weights and apply final processing
1279
+ self.post_init()
 
 
 
 
 
 
 
 
 
 
1280
 
1281
  def get_input_embeddings(self):
1282
  return self.model.embed_tokens
 
1296
  def get_decoder(self):
1297
  return self.model
1298
 
1299
+ @torch.no_grad()
1300
+ def infer(
1301
+ self,
1302
+ input_ids: torch.LongTensor,
1303
+ attention_mask: Optional[torch.Tensor] = None,
1304
+ position_ids: Optional[torch.LongTensor] = None,
1305
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
1306
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1307
+ use_cache: Optional[bool] = None,
1308
+ output_attentions: Optional[bool] = None,
1309
+ output_hidden_states: Optional[bool] = None,
1310
+ return_dict: Optional[bool] = None,
1311
+ ):
1312
+ batch_size, seq_len = input_ids.shape
1313
+
1314
+ # Save the original input_ids and attention_mask for later use
1315
+ original_input_ids = input_ids.clone()
1316
+ original_attention_mask = attention_mask.clone() if attention_mask is not None else None
1317
+
1318
+ # Append the start thought token to the input sequence
1319
+ start_thought_token_id = self.tokenizer.convert_tokens_to_ids("<|startthought|>")
1320
+ input_ids = torch.cat([input_ids, torch.tensor([[start_thought_token_id]] * batch_size).to(input_ids.device)], dim=-1)
1321
+ seq_len += 1
1322
+
1323
+ # Update the attention mask
1324
+ if attention_mask is not None:
1325
+ attention_mask = torch.cat([attention_mask, torch.ones((batch_size, 1)).to(attention_mask.device)], dim=-1)
1326
+
1327
+ # Generate the continuation
1328
+ continuation_length = self.n_ahead - 2
1329
+ new_key_values = past_key_values
1330
+ generated_tokens = []
1331
+
1332
+ for continuation_idx in range(continuation_length):
1333
+ outputs = self.model(
1334
+ input_ids=input_ids if continuation_idx == 0 else next_token_id.unsqueeze(-1).to(input_ids.device),
1335
+ attention_mask=attention_mask,
1336
+ position_ids=position_ids,
1337
+ past_key_values=new_key_values,
1338
+ inputs_embeds=inputs_embeds,
1339
+ use_cache=True,
1340
+ output_attentions=output_attentions,
1341
+ output_hidden_states=output_hidden_states,
1342
+ return_dict=return_dict,
1343
+ )
1344
+ new_key_values = outputs.past_key_values
1345
+ hidden_states = outputs[0]
1346
+ logits = self.lm_head(hidden_states)
1347
+ logits = logits[:, -1, :] # Only consider the last token
1348
+
1349
+ # Apply Gumbel-Softmax to the logits
1350
+ next_token_logits = F.gumbel_softmax(logits, tau=self.gumbel_temperature, hard=True, dim=-1)
1351
+ next_token_id = torch.argmax(next_token_logits, dim=-1)
1352
+
1353
+ # Append the generated token to the input sequence
1354
+ input_ids = torch.cat([input_ids, next_token_id.unsqueeze(-1).to(input_ids.device)], dim=-1)
1355
+ generated_tokens.append(next_token_id)
1356
+ seq_len += 1
1357
+
1358
+ # Update the attention mask
1359
+ if attention_mask is not None:
1360
+ attention_mask = torch.cat([attention_mask, torch.ones((batch_size, 1)).to(attention_mask.device)], dim=-1)
1361
+
1362
+ # Update the position ids
1363
+ if position_ids is not None:
1364
+ position_ids = torch.cat([position_ids, (position_ids[:, -1] + 1).unsqueeze(-1)], dim=-1)
1365
+
1366
+ # Append the end thought token to the input sequence
1367
+ end_thought_token_id = self.tokenizer.convert_tokens_to_ids("<|endthought|>")
1368
+ input_ids = torch.cat([input_ids, torch.tensor([[end_thought_token_id]] * batch_size).to(input_ids.device)], dim=-1)
1369
+ seq_len += 1
1370
+
1371
+ # Update the attention mask
1372
+ if attention_mask is not None:
1373
+ attention_mask = torch.cat([attention_mask, torch.ones((batch_size, 1)).to(attention_mask.device)], dim=-1)
1374
+
1375
+ # Get the hidden states before and after the thought
1376
+ outputs_before = self.model(
1377
+ input_ids=original_input_ids,
1378
+ attention_mask=original_attention_mask,
1379
+ position_ids=position_ids,
1380
+ past_key_values=past_key_values,
1381
+ inputs_embeds=inputs_embeds,
1382
+ use_cache=use_cache,
1383
+ output_attentions=output_attentions,
1384
+ output_hidden_states=output_hidden_states,
1385
+ return_dict=return_dict,
1386
+ )
1387
+ hidden_states_before = outputs_before[0][:, -1:, :]
1388
+
1389
+ # two new tokens: last continuation token and end thought token
1390
+ outputs_after = self.model(
1391
+ input_ids=torch.cat([next_token_id.unsqueeze(-1).to(input_ids.device), torch.tensor(end_thought_token_id).unsqueeze(-1).unsqueeze(-1).to(input_ids.device)], dim=-1),
1392
+ attention_mask=attention_mask,
1393
+ position_ids=position_ids,
1394
+ past_key_values=new_key_values,
1395
+ inputs_embeds=inputs_embeds,
1396
+ use_cache=use_cache,
1397
+ output_attentions=output_attentions,
1398
+ output_hidden_states=output_hidden_states,
1399
+ return_dict=return_dict,
1400
+ )
1401
+ hidden_states_after = outputs_after[0][:, -1:, :]
1402
+
1403
+ # Apply the talk head to get the mixing weight
1404
+ mixing_weight = self.talk_head[0](torch.cat([hidden_states_before, hidden_states_after], dim=-1))
1405
+
1406
+ # Apply the mixing weight to the hidden states
1407
+ mixed_hidden_states = (1 - mixing_weight) * hidden_states_before + mixing_weight * hidden_states_after
1408
+
1409
+ # Apply the language model head to get the final logits
1410
+ logits = self.lm_head(mixed_hidden_states)
1411
+
1412
+ # Decode the logits to get the generated text
1413
+ generated_tokens = torch.cat(generated_tokens, dim=-1)
1414
+ generated_text = self.tokenizer.decode(generated_tokens.squeeze(), skip_special_tokens=True)
1415
+
1416
+ return generated_text
1417
+
1418
  @add_start_docstrings_to_model_forward(QUIET_INPUTS_DOCSTRING)
1419
  @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
1420
  def forward(
 
1455
  >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
1456
  "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
1457
  ```"""
1458
+ log_dict = self.log_dict if self.training else self.eval_log_dict
1459
+
1460
+ if self.training and self.kill_after is not None and self.training_steps // self.gradient_accumulation_steps > self.kill_after:
1461
+ raise ValueError("Killed after")
1462
+
1463
+ if not self.training:
1464
+ n_ahead_talk_to_restore = self.n_ahead_talk
1465
+ n_passes_to_restore = self.n_passes
1466
+ self.n_ahead_talk = 1
1467
+ self.n_passes = 1
1468
 
1469
  output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1470
  output_hidden_states = (
 
1472
  )
1473
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1474
 
1475
+ assert self.cumulative_residual or self.clever_residual or self.skip_residual or self.no_residual
1476
+ assert not (self.skip_residual and self.use_policy_loss)
 
 
 
 
 
 
 
 
 
 
 
 
 
1477
 
1478
+ if self.tokenized_thought_prefix is None and self.use_thought_prefix:
1479
+ self.tokenized_thought_prefix = self.tokenizer(self.thought_prefix, return_tensors="pt", add_special_tokens=False)["input_ids"]
 
1480
 
1481
+ def apply_head(head, states, detach=False):
1482
+ if detach:
1483
+ head_weight = head.weight.detach()
1484
+ else:
1485
+ head_weight = head.weight
1486
+ head_weight = head_weight.to(states.device)
1487
+ return (head_weight @ states.transpose(-1, -2)).transpose(-1, -2).contiguous()
1488
+
1489
+ def idx_if_sequential(head, idx=0):
1490
+ if isinstance(head, nn.Sequential) or isinstance(head, nn.ModuleList):
1491
+ return idx_if_sequential(head[idx], idx=idx)
1492
+ return head
1493
+
1494
+ def none_repeat_interleave(x, n):
1495
+ if x is None:
1496
+ return x
1497
+ return x.repeat_interleave(n, dim=0)
1498
+
1499
+ if self.n_passes > 1:
1500
+ input_ids = none_repeat_interleave(input_ids, self.n_passes)
1501
+ attention_mask = none_repeat_interleave(attention_mask, self.n_passes)
1502
+ position_ids = none_repeat_interleave(position_ids, self.n_passes)
1503
+ inputs_embeds = none_repeat_interleave(inputs_embeds, self.n_passes)
1504
+ labels = none_repeat_interleave(labels, self.n_passes)
1505
+ if past_key_values is not None:
1506
+ past_key_values = [none_repeat_interleave(p, self.n_passes) for p in past_key_values]
1507
+ cur_token_indices = torch.arange(input_ids.shape[1], device=input_ids.device)
1508
+
1509
+ self.tokenizer_has_start_thought_token = True
1510
+ self.tokenizer_has_end_thought_token = True
1511
+ if self.start_token_id is None:
1512
+ self.start_token_id = self.tokenizer.convert_tokens_to_ids("<|startthought|>")
1513
+ if self.start_token_id == 0:
1514
+ self.start_token_id = self.tokenizer.bos_token_id
1515
+ self.tokenizer_has_start_thought_token = False
1516
+ elif self.use_start_thought_token:
1517
+ # base_start_id = self.tokenizer.convert_tokens_to_ids(self.initial_start_token)
1518
+ base_start_id = self.tokenizer.encode(self.initial_start_token, add_special_tokens=False)[0]
1519
+ if self.initialize_thought_embedding_to_normal:
1520
+ self.start_embedding.data = torch.zeros_like(self.start_embedding.data)
1521
+ else:
1522
+ self.start_embedding.data[0] = self.model.embed_tokens.weight.data[base_start_id].clone().detach() / self.embedding_scale
1523
+ self.start_embedding.data[1] = torch.log(self.model.embed_tokens.weight.data.std(dim=0) * self.thought_init_std_scale / self.embedding_scale)
1524
+ if self.end_token_id is None:
1525
+ self.end_token_id = self.tokenizer.convert_tokens_to_ids("<|endthought|>")
1526
+ if self.end_token_id == 0:
1527
+ self.end_token_id = self.tokenizer.eos_token_id
1528
+ self.tokenizer_has_end_thought_token = False
1529
+ elif self.use_end_thought_token:
1530
+ # base_end_id = self.tokenizer.convert_tokens_to_ids(self.initial_end_token)
1531
+ base_end_id = self.tokenizer.encode(self.initial_end_token, add_special_tokens=False)[0]
1532
+ if self.initialize_thought_embedding_to_normal:
1533
+ self.end_embedding.data = torch.zeros_like(self.end_embedding.data)
1534
+ else:
1535
+ self.end_embedding.data[0] = self.model.embed_tokens.weight.data[base_end_id].clone().detach() / self.embedding_scale
1536
+ self.end_embedding.data[1] = torch.log(self.model.embed_tokens.weight.data.std(dim=0) * self.thought_init_std_scale / self.embedding_scale)
1537
+
1538
+ if not self.rm_initialized and (self.n_ahead > 1 or not self.base_original_mode):
1539
+ self.rm_initialized = True
1540
+ if not self.use_shallow_talk:
1541
+ head = self.talk_head[0]
1542
+ cur_head = head[-1] if isinstance(head, nn.Sequential) else head
1543
+ talk_input_dim = cur_head.weight.data.shape[1]
1544
+ talk_output_dim = 1 if self.use_weighted_talk_head else self.lm_head.weight.data.shape[0]
1545
+ cur_head.weight.data = torch.zeros(talk_output_dim, talk_input_dim, device=cur_head.weight.device, dtype=cur_head.weight.dtype)
1546
+ else:
1547
+ # convert to identity transform
1548
+ def lambda_transform(cur_head):
1549
+ if cur_head.weight.data.shape[0] != cur_head.weight.data.shape[1]:
1550
+ return torch.cat([
1551
+ torch.eye(
1552
+ cur_head.weight.data.shape[0],
1553
+ device=cur_head.weight.device,
1554
+ dtype=cur_head.weight.dtype
1555
+ ),
1556
+ torch.zeros(
1557
+ cur_head.weight.data.shape[0],
1558
+ cur_head.weight.data.shape[1] - cur_head.weight.data.shape[0],
1559
+ device=cur_head.weight.device,
1560
+ dtype=cur_head.weight.dtype
1561
+ )], dim=1)
1562
+ return torch.eye(
1563
+ cur_head.weight.data.shape[0],
1564
+ device=cur_head.weight.device,
1565
+ dtype=cur_head.weight.dtype
1566
+ )
1567
+ if isinstance(self.talk_head[0], nn.Sequential):
1568
+ for cur_head in self.talk_head[0]:
1569
+ # if it has weights
1570
+ if hasattr(cur_head, "weight"):
1571
+ cur_head.weight.data = lambda_transform(cur_head)
1572
+ else:
1573
+ self.talk_head[-1].weight.data = lambda_transform(self.talk_head[0])
1574
 
1575
  loss = None
1576
+ prev_rm_tokens = None
1577
+ cur_rm_tokens = None
1578
+ prev_rm_logits = None
1579
+ prev_sample_probs = None
1580
+ did_skip_sampling = None
1581
+ skip_sampling = None
1582
+ sample_probs = None
1583
+ hidden_states = None
1584
+ logits = None
1585
+ talk_kl_penalty = None
1586
+ rm_logits = None
1587
+ residual_logits = None
1588
+ probabilities_2d = None
1589
+ prev_probabilities_2d = None
1590
+ policy_reward = None
1591
+ logits_to_output = None
1592
+ batch_size, seq_len = input_ids.shape
1593
+ base_input_ids = input_ids.clone()
1594
+ loss_list = []
1595
+ dqn_loss_list = []
1596
+ sampled_token_history = []
1597
+ sample_probs_history = []
1598
+ action_loglikelihoods_list = []
1599
+
1600
+ if self.use_end_thought_token or self.use_start_thought_token:
1601
+ if not self.use_reparam_for_thought_embeddings:
1602
+ start_embedding = self.start_embedding[0].unsqueeze(0) * self.embedding_scale
1603
+ end_embedding = self.end_embedding[0].unsqueeze(0) * self.embedding_scale
1604
+ else:
1605
+ start_embedding = self.start_embedding * self.embedding_scale
1606
+ end_embedding = self.end_embedding * self.embedding_scale
1607
+ base_embeddings = self.model.embed_tokens.weight
1608
+ if self.train_only_thinking_embedding:
1609
+ base_embeddings = base_embeddings.detach()
1610
+ # # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
1611
+ fwd_iters = 1 if self.original_mode else self.n_ahead + self.n_ahead_talk - 1
1612
+ for ahead_idx in range(fwd_iters):
1613
+ past_key_values_length = 0
1614
+ if past_key_values is not None:
1615
+ use_legacy_cache = not isinstance(past_key_values, Cache)
1616
+ if use_legacy_cache:
1617
+ past_key_values = DynamicCache.from_legacy_cache(past_key_values)
1618
+ past_key_values_length = past_key_values.get_usable_length(seq_len)
1619
+
1620
+ if position_ids is None:
1621
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
1622
+ position_ids = torch.arange(
1623
+ past_key_values_length, seq_len + past_key_values_length, dtype=torch.long, device=device
1624
+ )
1625
+ position_ids = position_ids.unsqueeze(0).view(-1, seq_len)
1626
+ else:
1627
+ position_ids = position_ids.view(-1, seq_len).long()
1628
+
1629
+ if inputs_embeds is None:
1630
+ contains_start = self.use_start_thought_token and (input_ids == self.start_token_id).any()
1631
+ contains_end = self.use_end_thought_token and (input_ids == self.end_token_id).any()
1632
+ contains_thought = contains_start or contains_end
1633
+ if contains_thought:
1634
+ thought_id = self.start_token_id if contains_start else self.end_token_id
1635
+ cur_thought_embedding = start_embedding if contains_start else end_embedding
1636
+ if self.use_reparam_for_thought_embeddings:
1637
+ inputs_embeds = torch.randn(batch_size, seq_len, self.model.config.hidden_size, device=input_ids.device, dtype=cur_thought_embedding.dtype)
1638
+ inputs_embeds = inputs_embeds.detach() * torch.exp(cur_thought_embedding[1]) + cur_thought_embedding[0]
1639
+ if contains_start:
1640
+ sampled_start = inputs_embeds.clone().detach()
1641
+ if contains_end:
1642
+ sampled_end = inputs_embeds.clone().detach()
1643
+ else:
1644
+ inputs_embeds = cur_thought_embedding.unsqueeze(0).repeat(batch_size, seq_len, 1)
1645
+ else:
1646
+ with torch.set_grad_enabled(not self.train_only_thinking_embedding):
1647
+ inputs_embeds = self.model.embed_tokens(input_ids)
1648
+
1649
+ if self.n_ahead != 1 or self.n_ahead_talk != 1 or self.comparison_mode:
1650
+ if attention_mask is None:
1651
+ base_attention_mask = torch.triu(torch.ones(seq_len, seq_len), diagonal=0).to(input_ids.device)
1652
+ base_attention_mask = base_attention_mask.view(1, 1, seq_len, seq_len)
1653
+ base_attention_mask = base_attention_mask.repeat(input_ids.shape[0], 1, 1, 1)
1654
+ attention_mask = base_attention_mask
1655
+ breakpoint()
1656
+ elif attention_mask.dim() == 2:
1657
+ if seq_len + past_key_values_length != attention_mask.shape[-1]:
1658
+ breakpoint()
1659
+ attention_mask = torch.cat(
1660
+ [torch.ones((attention_mask.shape[0], past_key_values_length), dtype=attention_mask.dtype, device=attention_mask.device), attention_mask],
1661
+ dim=-1
1662
+ )
1663
+ # # if the attention mask
1664
+ attention_mask = _prepare_4d_causal_attention_mask(
1665
+ attention_mask,
1666
+ (batch_size, seq_len),
1667
+ inputs_embeds,
1668
+ past_key_values_length,
1669
+ sliding_window=self.config.sliding_window,
1670
+ )
1671
+
1672
+ outputs = self.model(
1673
+ # input_ids=input_ids,
1674
+ attention_mask=attention_mask,
1675
+ position_ids=position_ids,
1676
+ past_key_values=past_key_values,
1677
+ inputs_embeds=inputs_embeds,
1678
+ use_cache=use_cache,
1679
+ output_attentions=output_attentions,
1680
+ output_hidden_states=output_hidden_states,
1681
+ return_dict=return_dict,
1682
+ )
1683
+
1684
+ prev_hidden_states = hidden_states
1685
+ hidden_states = outputs[0]
1686
+ prev_rm_logits = rm_logits # for policy gradient
1687
+ prev_rm_tokens = cur_rm_tokens # for policy gradient
1688
+
1689
+ if ahead_idx == 0:
1690
+ hidden_states_lm = hidden_states
1691
+ logits = self.lm_head(hidden_states_lm)
1692
+ base_hidden_states = hidden_states.clone()
1693
+ initial_loss_logits = logits.clone()
1694
+ if self.optimize_lm_head_only_at_start or self.optimize_model_only_at_start:
1695
+ logits = logits.detach()
1696
+ base_hidden_states = base_hidden_states.detach()
1697
+ if self.optimize_model_only_at_start:
1698
+ hidden_states = hidden_states.detach()
1699
+ base_logits = logits.clone()
1700
+ else:
1701
+ talk_hidden_states = hidden_states
1702
+ if self.merged_lm_and_talk_heads:
1703
+ assert self.no_residual
1704
+ residual_logits = self.lm_head(hidden_states)
1705
+ talk_hidden_states = hidden_states
1706
+ else:
1707
+ if ahead_idx > self.n_ahead - 1:
1708
+ cur_base_hidden = torch.cat([
1709
+ base_hidden_states[..., ahead_idx - self.n_ahead + 1:, :],
1710
+ base_hidden_states[..., :ahead_idx - self.n_ahead + 1, :]
1711
+ ], dim=-2)
1712
+ else:
1713
+ cur_base_hidden = base_hidden_states
1714
+
1715
+ if self.use_concat_talk_head:
1716
+ # concatenate the hidden states with the original hidden states
1717
+ head_input_hidden_states = torch.cat([cur_base_hidden, talk_hidden_states], dim=-1)
1718
+ else:
1719
+ head_input_hidden_states = talk_hidden_states
1720
+
1721
+ residual_logits = self.talk_head[0](head_input_hidden_states)
1722
+ if self.use_shallow_talk:
1723
+ residual_logits = apply_head(self.lm_head, residual_logits, detach=self.optimize_lm_head_only_at_start)
1724
+ residual_logits = residual_logits.to(logits.device)
1725
+ if self.use_weighted_talk_head:
1726
+ # combine the cur_base_hidden with the talk_hidden_states according to the weighted head
1727
+ residual_logits = cur_base_hidden * (1 - residual_logits) + talk_hidden_states * residual_logits
1728
+ residual_logits = apply_head(self.lm_head, residual_logits, detach=self.optimize_lm_head_only_at_start)
1729
+
1730
+ assert sum([self.cumulative_residual, self.clever_residual, self.skip_residual, self.no_residual]) == 1
1731
+ if self.clever_residual:
1732
+ if ahead_idx >= self.n_ahead - 1:
1733
+ # get the logits shifted according to the current talk ahead
1734
+ cur_base_logits = torch.cat([
1735
+ base_logits[..., ahead_idx - self.n_ahead + 1:, :],
1736
+ base_logits[..., :ahead_idx - self.n_ahead + 1, :]
1737
+ ], dim=-2)
1738
+ if self.optimize_lm_head_only_at_start:
1739
+ cur_base_logits = cur_base_logits.detach()
1740
+ logits = cur_base_logits + residual_logits
1741
+ else:
1742
+ logits += residual_logits / self.n_ahead
1743
+ elif self.cumulative_residual:
1744
+ if self.residual_talk_head:
1745
+ if ahead_idx < self.n_ahead:
1746
+ logits += residual_logits
1747
+ else:
1748
+ # get the logits shifted according to the current talk ahead
1749
+ cur_base_logits = torch.cat([
1750
+ base_logits[..., ahead_idx - self.n_ahead + 1:, :],
1751
+ base_logits[..., :ahead_idx - self.n_ahead + 1, :]
1752
+ ], dim=-2)
1753
+ if self.optimize_lm_head_only_at_start:
1754
+ cur_base_logits = cur_base_logits.detach()
1755
+ logits = cur_base_logits + residual_logits
1756
+ else:
1757
+ if ahead_idx < self.n_ahead:
1758
+ logits += residual_logits
1759
+ else:
1760
+ logits = residual_logits
1761
+ elif self.skip_residual:
1762
+ if ahead_idx >= self.n_ahead:
1763
+ # get the logits shifted according to the current talk ahead
1764
+ cur_base_logits = torch.cat([
1765
+ base_logits[..., ahead_idx - self.n_ahead + 1:, :],
1766
+ base_logits[..., :ahead_idx - self.n_ahead + 1, :]
1767
+ ], dim=-2)
1768
+ if self.optimize_lm_head_only_at_start:
1769
+ cur_base_logits = cur_base_logits.detach()
1770
+ logits = cur_base_logits
1771
+ elif self.no_residual:
1772
+ logits = residual_logits
1773
+ else:
1774
+ logits = base_logits + residual_logits
1775
+
1776
+ attempted = False
1777
+ talk_loss_list = []
1778
+ if self.original_mode or (self.n_ahead == 1) or (self.comparison_mode and ahead_idx == 0):# or (self.optimize_lm_head_only_at_start and ahead_idx == 0):
1779
+ loss = None
1780
+ attempted = True
1781
+
1782
+ if labels is not None:
1783
+ for shift_amount in range(self.n_ahead_talk):
1784
+ # Shift so that tokens < n predict n
1785
+ # ab[cde]f
1786
+ # abc[def]
1787
+ if ahead_idx == 0 and self.optimize_lm_head_only_at_start:
1788
+ loss_logits = initial_loss_logits
1789
+ else:
1790
+ loss_logits = logits
1791
+ shift_logits = loss_logits[..., shift_amount:-1, :].contiguous()
1792
+ shift_labels = labels[..., 1 + shift_amount:].contiguous()
1793
+ # Flatten the tokens
1794
+ loss_fct = CrossEntropyLoss(reduction="none")
1795
+ shift_logits = shift_logits.view(-1, self.config.vocab_size)
1796
+ shift_labels = shift_labels.view(-1).clone()
1797
+ # Enable model parallelism
1798
+ shift_labels[shift_labels == self.tokenizer.pad_token_id] = -100
1799
+ shift_labels = shift_labels.to(shift_logits.device)
1800
+ loss = loss_fct(shift_logits, shift_labels)
1801
+ if not self.comparison_mode and not (self.optimize_lm_head_only_at_start and (self.n_ahead + self.n_ahead_talk > 2)) or self.original_mode:
1802
+ loss_list.append(loss)
1803
+ talk_loss_list.append(nonzero_mean(loss).detach())
1804
+
1805
+ if not attempted or self.comparison_mode:
1806
+ rm_hidden_states = hidden_states
1807
+ # print("Magnitude of RM hidden states before RM head", rm_hidden_states.norm())
1808
+ rm_logits = apply_head(self.lm_head, rm_hidden_states, detach=self.optimize_lm_head_only_at_start)
1809
+
1810
+ # don't allow it to predict the thinking token
1811
+ if self.tokenizer_has_start_thought_token:
1812
+ rm_logits[..., self.start_token_id] = -1e10
1813
+ if self.tokenizer_has_end_thought_token:
1814
+ rm_logits[..., self.end_token_id] = -1e10
1815
+ probabilities = rm_logits
1816
+ if probabilities_2d is not None:
1817
+ prev_probabilities_2d = probabilities_2d.clone()
1818
+ probabilities_2d = probabilities.view(-1, probabilities.size(-1))
1819
+
1820
+ did_skip_sampling = skip_sampling
1821
+ skip_sampling = False
1822
+ if ahead_idx == 0 and self.use_start_thought_token:
1823
+ override_token = self.start_token_id
1824
+ elif self.use_thought_prefix and ahead_idx < self.tokenized_thought_prefix.shape[-1]:
1825
+ override_token = self.tokenized_thought_prefix[..., ahead_idx]
1826
+ elif ahead_idx == self.n_ahead - 2 and self.use_end_thought_token:
1827
+ override_token = self.end_token_id
1828
+ else:
1829
+ override_token = None
1830
+ if override_token is not None and self.n_ahead > 1:
1831
+ # always start with the start token
1832
+ probabilities_2d = torch.zeros_like(probabilities_2d)
1833
+ probabilities_2d[:, override_token] = 1.0
1834
+ skip_sampling = True
1835
+ elif ahead_idx >= self.n_ahead - 1:
1836
+ if labels is not None: # we're in the talk phase
1837
+ cur_talk_n = ahead_idx - (self.n_ahead - 1) + 1
1838
+ # print("Setting rm to labels", cur_talk_n, "during", ahead_idx)
1839
+ shift_labels = labels[..., cur_talk_n:].contiguous().to(probabilities_2d.device)
1840
+ padding = torch.full_like(
1841
+ labels[..., :cur_talk_n],
1842
+ self.tokenizer.pad_token_id,
1843
+ dtype=torch.long,
1844
+ device=shift_labels.device
1845
+ )
1846
+ new_rm_tokens = torch.cat(
1847
+ [shift_labels, padding],
1848
+ dim=-1
1849
+ )
1850
+ # convert rm tokens to one-hot
1851
+ probabilities_2d = F.one_hot(new_rm_tokens, num_classes=self.vocab_size).reshape(-1, self.vocab_size).to(probabilities_2d.dtype)
1852
+ skip_sampling = True
1853
+ else:
1854
+ continue
1855
+ temperature = self.gumbel_temperature if self.training else 0.001
1856
+ prev_sample_probs = sample_probs
1857
+ sample_probs = probabilities_2d
1858
+ if ahead_idx < self.n_ahead - 1 and not skip_sampling:
1859
+ probabilities_2d = F.gumbel_softmax(sample_probs, tau=temperature, hard=True, dim=-1)
1860
+ if self.gumbel_detach:
1861
+ probabilities_2d = probabilities_2d.detach()
1862
+ sampled_token_history.append(probabilities_2d.argmax(dim=-1).detach().cpu())
1863
+ # convert rm logits directly to embeddings
1864
+ contains_start = self.use_start_thought_token and (probabilities_2d[..., self.start_token_id].sum() > 0)
1865
+ contains_end = self.use_end_thought_token and (probabilities_2d[..., self.end_token_id].sum() > 0)
1866
+ contains_thought = contains_start or contains_end
1867
+
1868
+ if not contains_thought:
1869
+ with torch.set_grad_enabled(not self.train_only_thinking_embedding):
1870
+ inputs_embeds = probabilities_2d @ (self.model.embed_tokens.weight.to(probabilities.device).to(probabilities.dtype))
1871
+ else:
1872
+ thought_id = self.start_token_id if contains_start else self.end_token_id
1873
+ cur_thought_embedding = start_embedding if contains_start else end_embedding
1874
+ if self.use_reparam_for_thought_embeddings:
1875
+ inputs_embeds = torch.randn(batch_size, seq_len, self.model.config.hidden_size, device=input_ids.device, dtype=cur_thought_embedding.dtype)
1876
+ inputs_embeds = inputs_embeds * torch.exp(cur_thought_embedding[1]) + cur_thought_embedding[0]
1877
+ if contains_start:
1878
+ sampled_start = inputs_embeds.clone().detach()
1879
+ else:
1880
+ sampled_end = inputs_embeds.clone().detach()
1881
+ else:
1882
+ inputs_embeds = cur_thought_embedding.unsqueeze(0).repeat(batch_size, seq_len, 1)
1883
+ inputs_embeds = inputs_embeds.view(probabilities.size(0), probabilities.size(1), -1).to(self.model.embed_tokens.weight.dtype)
1884
+ inputs_embeds = inputs_embeds.view(probabilities.size(0), probabilities.size(1), -1).to(self.model.embed_tokens.weight.dtype)
1885
+
1886
+ if len(attention_mask.shape) == 2:
1887
+ breakpoint()
1888
+ else:
1889
+ original_attention = attention_mask[..., :attention_mask.shape[-2]]
1890
+ if self.use_upper_triangular:
1891
+ new_attention = original_attention
1892
+ else:
1893
+ original_attention = original_attention == attention_mask.max()
1894
+ # because eye isn't implemented for BF16, we need to handle the case
1895
+ if not attention_mask.dtype == torch.bfloat16:
1896
+ new_attention = torch.eye(
1897
+ seq_len, dtype=attention_mask.dtype, device=attention_mask.device
1898
+ )
1899
+ else:
1900
+ new_attention = torch.eye(
1901
+ seq_len, dtype=torch.float32, device=attention_mask.device
1902
+ ).to(attention_mask.dtype)
1903
+
1904
+ new_attention = new_attention.view(1, 1, seq_len, seq_len).repeat(input_ids.shape[0], 1, 1, 1)
1905
+ new_attention = new_attention * original_attention
1906
+ new_attention[new_attention == 0] = attention_mask.min()
1907
+ new_attention[new_attention == 1] = attention_mask.max()
1908
+ attention_mask = torch.cat([attention_mask, new_attention], dim=-1)
1909
+ past_key_values = outputs.past_key_values
1910
+ position_ids = position_ids + 1
1911
+
1912
+ if labels is not None and (self.n_ahead > 1 or not self.base_original_mode):
1913
+ # Shift so that tokens < n predict n
1914
+ # logits: abcdef -> bcdef? -> cdef??
1915
+ # labels: abcdef -> ?bcdef -> ??cdef
1916
+ if ahead_idx == 0 and self.optimize_lm_head_only_at_start:
1917
+ loss_logits = initial_loss_logits
1918
+ else:
1919
+ loss_logits = logits
1920
+ shift_idx = 1 + max(0, ahead_idx - (self.n_ahead - 1))
1921
+ shift_logits = loss_logits[..., :-shift_idx, :].contiguous()
1922
+ shift_labels = labels[..., shift_idx:].contiguous()
1923
+ # Flatten the tokens
1924
+ loss_fct = CrossEntropyLoss(reduction="none")
1925
+ shift_logits = shift_logits.view(-1, self.config.vocab_size)
1926
+ shift_labels = shift_labels.view(-1)
1927
+ # Enable model parallelism
1928
+ shift_labels = shift_labels.to(shift_logits.device)
1929
+ # if shift_labels.min() == self.tokenizer.pad_token_id:
1930
+ shift_labels = torch.where(shift_labels == self.tokenizer.pad_token_id, -100, shift_labels)
1931
+ unreduced_loss = loss_fct(shift_logits, shift_labels)
1932
+ if torch.any(unreduced_loss != unreduced_loss):
1933
+ raise ValueError("NaN loss")
1934
+ unreduced_loss = unreduced_loss.reshape(logits.shape[0], -1)
1935
+ loss_list.append(unreduced_loss)
1936
+
1937
+
1938
+ if self.use_policy_loss and ahead_idx > 0 and (ahead_idx > 1 or not self.use_start_thought_token):
1939
+ # we treat the change in loss as the reward
1940
+ previous_loss = loss_list[-2]
1941
+ # for example, suppose n_ahead = 3 and n_ahead_talk = 2
1942
+ # note that we end at self.n_ahead + self.n_ahead_talk - 2
1943
+ # in this case, 5 - 2 = 3, so we end at ahead_idx = 3
1944
+ # we also predict the next token at ahead_idx = 2
1945
+ # when we get to ahead_idx = 2, we predict ahead
1946
+ # so we shift by 1
1947
+ # note that this is ahead_idx = n_ahead - 1
1948
+ # when we get to ahead_idx = 3, we predict ahead
1949
+ # so we shift by 2
1950
+ # note that this is ahead_idx = n_ahead
1951
+ if ahead_idx < self.n_ahead - 1:
1952
+ shift_amount = 0
1953
+ original_dqn_reward = (previous_loss - unreduced_loss).detach()
1954
+ if self.first_and_last_mode:
1955
+ original_dqn_reward = original_dqn_reward * 0.0
1956
+ else:
1957
+ # logits vs cur_policy_shift_logits
1958
+ # let's look at rm_logits and prev_rm_logits
1959
+ shift_amount = max(0, ahead_idx - (self.n_ahead - 1))
1960
+ # let's say shift_amount = 2
1961
+ # abcdefg -> bcdefg? -> cdefg??
1962
+ # logits = [a b]c d e f[g]
1963
+ # labels = [a b c]d e f g
1964
+ cur_policy_shift_logits = initial_loss_logits[..., shift_amount:-1, :].contiguous().detach()
1965
+ cur_policy_shift_labels = labels[..., 1 + shift_amount:].contiguous()
1966
+ # Flatten the tokens
1967
+ cur_policy_loss_fct = CrossEntropyLoss(reduction="none")
1968
+ cur_policy_shift_logits = cur_policy_shift_logits.view(-1, self.config.vocab_size)
1969
+ cur_policy_shift_labels = cur_policy_shift_labels.view(-1).clone()
1970
+ # Enable model parallelism
1971
+ cur_policy_shift_labels[cur_policy_shift_labels == self.tokenizer.pad_token_id] = -100
1972
+ cur_policy_shift_labels = cur_policy_shift_labels.to(cur_policy_shift_labels.device)
1973
+ cur_policy_reward_base_loss = loss_fct(
1974
+ cur_policy_shift_logits, cur_policy_shift_labels.to(cur_policy_shift_logits.device)
1975
+ ).reshape(logits.shape[0], -1)
1976
+ original_dqn_reward = cur_policy_reward_base_loss.detach() - unreduced_loss
1977
+
1978
+ if not did_skip_sampling:
1979
+ nonzero_indices = prev_probabilities_2d.nonzero()
1980
+ action_loglikelihoods = F.log_softmax(prev_sample_probs / self.reinforce_temperature, dim=-1)[nonzero_indices[:, 0], nonzero_indices[:, 1]]
1981
+ action_loglikelihoods_2d = action_loglikelihoods.reshape(batch_size, -1)[:, :-1 - shift_amount]
1982
+ action_loglikelihoods_list.append(action_loglikelihoods_2d)
1983
+ if policy_reward is None:
1984
+ policy_reward = original_dqn_reward[:, :-(self.n_ahead_talk - shift_amount)]
1985
+ else:
1986
+ if self.n_ahead_talk > shift_amount:
1987
+ added_reward = original_dqn_reward[:, :-(self.n_ahead_talk - shift_amount)]
1988
+ else:
1989
+ added_reward = original_dqn_reward
1990
+ policy_reward += added_reward
1991
+
1992
+ if self.use_policy_loss and ahead_idx == self.n_ahead + self.n_ahead_talk - 2:
1993
+ # only compute during the thinking phase
1994
+ if self.use_reparam_for_thought_embeddings and (self.use_start_thought_token or self.use_end_thought_token):
1995
+ # sampled_start, sampled_end
1996
+ # calculate the log likelihood of the start and end embeddings sampled from a multivariate normal distribution
1997
+ # with mean start_embedding[0] and standard deviation start_embedding[1]
1998
+ if self.use_start_thought_token:
1999
+ exp_start_std = torch.exp(start_embedding[1])
2000
+ start_loglikelihood = -0.5 * (sampled_start.detach() - start_embedding[0]) ** 2 / exp_start_std ** 2 - start_embedding[1] - 0.5 * math.log(2 * math.pi)
2001
+ start_loglikelihood = start_loglikelihood.mean(dim=-1)
2002
+ if self.use_end_thought_token:
2003
+ exp_end_std = torch.exp(end_embedding[1])
2004
+ end_loglikelihood = -0.5 * (sampled_end.detach() - end_embedding[0]) ** 2 / exp_end_std ** 2 - end_embedding[1] - 0.5 * math.log(2 * math.pi)
2005
+ end_loglikelihood = end_loglikelihood.mean(dim=-1)
2006
+ # we use the mean instead of the sum to prevent dependence on the dimensionality of the embeddings
2007
+ if self.use_end_thought_token and self.use_policy_loss_for_end_thought:
2008
+ action_loglikelihoods_list.append(end_loglikelihood)
2009
+ if self.use_start_thought_token:
2010
+ action_loglikelihoods_list.append(start_loglikelihood)
2011
+
2012
+ if ahead_idx == self.n_ahead + self.n_ahead_talk - 2 and self.eval_mode:
2013
+ with torch.no_grad():
2014
+ # calculate the 0.75 quantile of the rewards
2015
+ filtered_tokens = input_ids[:, :policy_reward.shape[-1]].cpu().detach().numpy().flatten()
2016
+ filtered_tokens_mask = filtered_tokens != self.tokenizer.pad_token_id
2017
+ filtered_tokens = filtered_tokens[filtered_tokens_mask]
2018
+ filtered_rewards = policy_reward.float().cpu().detach().numpy()[:, :seq_len - self.n_ahead_talk].flatten()
2019
+ filtered_rewards = filtered_rewards[filtered_tokens_mask]
2020
+
2021
+ abs_reward_list = np.abs(policy_reward.float().cpu().detach().numpy()[:, :seq_len - self.n_ahead_talk].flatten())
2022
+ abs_reward_list = abs_reward_list[filtered_tokens_mask]
2023
+ medium_quantile = np.quantile(abs_reward_list, 0.5)
2024
+ upper_quantile = np.quantile(abs_reward_list, 0.95)
2025
+
2026
+ save_tokens_with_rewards_to_pdf(
2027
+ filtered_tokens,
2028
+ [0] + filtered_rewards.tolist(),
2029
+ self.tokenizer,
2030
+ output_file=f"texts/rewards_talk_{self.n_ahead_talk}_{self.training_steps}.pdf",
2031
+ eps=medium_quantile,
2032
+ eps2=upper_quantile,
2033
+ )
2034
+
2035
+ def plot_kde(data, losses):
2036
+ sns.set(style="whitegrid")
2037
+ # Create the KDE plot
2038
+ sns.kdeplot(data, fill=True)
2039
+ # Set the plot title and labels
2040
+ plt.title("KDE Plot")
2041
+ plt.xlabel("Value")
2042
+ plt.ylabel("Density")
2043
+ # Save the plot
2044
+ plt.savefig(f"texts/kde_talk_{self.n_ahead_talk}_{self.training_steps}.pdf")
2045
+ # Close the plot
2046
+ plt.close()
2047
+
2048
+ # Step 1: Create a base color palette
2049
+ base_colors = sns.color_palette("light:#5A9", n_colors=256) # More colors for a smoother gradient
2050
+ base_cmap = LinearSegmentedColormap.from_list("log_light", base_colors)
2051
+ log_norm = LogNorm(vmin=1e-3, vmax=10)
2052
+
2053
+ sns.kdeplot(x=data, y=losses, fill=True, levels=20, norm=log_norm, cut=0, linewidths=0)
2054
+ # limit y to 0 to 25 and x to -1 to 1
2055
+ plt.xlim(-1, 1)
2056
+ plt.ylim(0, 25)
2057
+ plt.savefig(f"texts/jointer_talk_{self.n_ahead_talk}_{self.training_steps}.pdf")
2058
+ plt.close()
2059
+
2060
+ self.all_rewards.extend(filtered_rewards)
2061
+ self.all_unreduced_losses.extend(unreduced_loss[:, :-1].flatten()[filtered_tokens_mask].float().flatten().cpu().detach().numpy())
2062
+ plot_kde(self.all_rewards, self.all_unreduced_losses)
2063
+
2064
+ for action_loglikelihoods_2d in action_loglikelihoods_list:
2065
+ train_policy_reward = policy_reward
2066
+
2067
+ # discard rewards below the mean
2068
+ if self.trice_mode and self.n_passes > 1:
2069
+ batched_policy_reward = train_policy_reward.reshape(-1, self.n_passes, train_policy_reward.shape[-1])
2070
+ # average over the passes
2071
+ train_policy_reward = batched_policy_reward - batched_policy_reward.mean(dim=1, keepdim=True)
2072
+ train_policy_reward = train_policy_reward.reshape(-1, train_policy_reward.shape[-1])
2073
+
2074
+ if self.subtract_mean_reward:
2075
+ train_policy_reward = train_policy_reward - train_policy_reward.mean()
2076
+ if self.remove_negative_rewards:
2077
+ fixed_policy_reward = train_policy_reward.detach().clamp(min=0)
2078
+ else:
2079
+ fixed_policy_reward = train_policy_reward.detach()
2080
+ actor_loss = -fixed_policy_reward * action_loglikelihoods_2d[:, :policy_reward.shape[-1]].to(policy_reward.device)
2081
+ if action_loglikelihoods_2d.mean() < -1e4 and not self.use_policy_loss_just_for_thoughts:
2082
+ # This will only happen when we force the next token to be the end of thought token
2083
+ break
2084
+ dqn_loss_list.append(actor_loss.mean())
2085
+
2086
+ if loss_list:
2087
+ if self.first_and_last_mode:
2088
+ loss = sum(
2089
+ self.loss_mean(loss_list[-(i + 1)]) for i in range(self.n_ahead_talk)
2090
+ ) * (1 - self.original_loss_weight) / self.n_ahead_talk
2091
+ loss = loss + self.loss_mean(loss_list[0]) * self.original_loss_weight
2092
+ # Let's NaN out the others
2093
+ # e.g. if n_ahead_talk = 2 and the list is 5 long, we want to NaN out 1, 2 but keep 0, 3, 4
2094
+ for i in range(1, len(loss_list) - self.n_ahead_talk):
2095
+ loss_list[i] = loss_list[i] * math.nan
2096
+ elif self.first_only:
2097
+ loss = self.loss_mean(loss_list[0])
2098
+ elif self.final_only_mode:
2099
+ loss = sum(
2100
+ self.loss_mean(loss_list[-i]) for i in range(1, self.n_ahead_talk + 1)
2101
+ ) / self.n_ahead_talk
2102
+ else:
2103
+ loss = None
2104
+ for i in range(len(loss_list)):
2105
+ cur_loss = self.loss_mean(loss_list[i])
2106
+ if loss is not None:
2107
+ loss = loss + cur_loss.to(loss.device)
2108
+ else:
2109
+ loss = cur_loss
2110
+ loss = loss / len(loss_list)
2111
+
2112
+ loss = loss * self.base_loss_beta
2113
+
2114
+ if dqn_loss_list:
2115
+ dqn_loss = sum(dqn_loss_list) / len(dqn_loss_list)
2116
+ if self.include_policy_loss:
2117
+ if loss is not None:
2118
+ loss += dqn_loss * self.policy_loss_beta
2119
+ else:
2120
+ loss = dqn_loss * self.policy_loss_beta
2121
 
2122
  if not return_dict:
2123
+ output = (logits,) + outputs[1:]
2124
+ return (loss,) + output if loss is not None else output
2125
+
2126
+ base_log_dict = {
2127
+ f"loss_{i}": nonzero_mean(loss_list[i]) for i in range(len(loss_list))
2128
+ }
2129
+
2130
+ if loss is not None:
2131
+ base_log_dict["loss_train"] = loss.item()
2132
+
2133
+ for loss_key, loss_val in base_log_dict.items():
2134
+ log_dict[loss_key] += loss_val / self.n_tokens_print
2135
+
2136
+ if self.use_policy_loss and policy_reward is not None:
2137
+ log_dict["policy_loss"] += dqn_loss / self.n_tokens_print
2138
+ log_dict["policy_reward"] += policy_reward.mean() / self.n_tokens_print
2139
 
2140
+ if not loss_list:
2141
+ if loss is not None:
2142
+ log_dict["loss_0"] += loss / self.n_tokens_print
2143
+ else:
2144
+ log_dict["loss_final"] += nonzero_mean(loss_list[-1]) / self.n_tokens_print
2145
+ log_dict["loss_talk"] += sum(nonzero_mean(cur_loss_item) for cur_loss_item in loss_list[-self.n_ahead_talk:]) / self.n_ahead_talk / self.n_tokens_print
2146
+
2147
+ # also log relative losses to loss_0
2148
+ if loss_list:
2149
+ for i in range(len(loss_list)):
2150
+ talk_idx = min(max(i - (self.n_ahead - 1), 0), len(talk_loss_list) - 1)
2151
+ if not talk_loss_list:
2152
+ cur_talk_loss = nonzero_mean(loss_list[0])
2153
+ else:
2154
+ cur_talk_loss = talk_loss_list[talk_idx]
2155
+ log_dict[f"rel_loss_{i}"] += (nonzero_mean(loss_list[i]) - cur_talk_loss) / self.n_tokens_print
2156
+ if self.training:
2157
+ self.training_steps += 1
2158
+ try:
2159
+ # if self.training_steps % (self.gradient_accumulation_steps * 256) == 0:
2160
+ if self.wandb_enabled:
2161
+ if self.training_steps % (self.n_tokens_print) == 0 or not self.training:# and "0" in str(loss.device):
2162
+ if not self.training:
2163
+ new_log_dict = {}
2164
+ for key in list(log_dict.keys()):
2165
+ new_log_dict["eval_" + key] = log_dict[key]
2166
+ log_dict = new_log_dict
2167
+ log_dict["training_steps"] = self.training_steps
2168
+ log_dict["batch_size"] = batch_size
2169
+ log_dict["example_steps"] = self.training_steps * batch_size * self.gradient_accumulation_steps
2170
+ if self.n_ahead > 1:
2171
+ log_dict["compute_steps"] = self.training_steps * batch_size * (self.n_ahead + self.n_ahead_talk - 1) * self.gradient_accumulation_steps
2172
+ else: # There's no overhead for talk tokens if there's no thinking
2173
+ log_dict["compute_steps"] = self.training_steps * batch_size * self.gradient_accumulation_steps
2174
+ # remove all nans
2175
+ for key in list(log_dict.keys()):
2176
+ if log_dict[key] != log_dict[key]:
2177
+ del log_dict[key]
2178
+ if self.training:
2179
+ wandb.log(log_dict)
2180
+ if self.training:
2181
+ self.log_dict = defaultdict(int)
2182
+ else:
2183
+ self.eval_log_dict = defaultdict(int)
2184
+ except Exception as e:
2185
+ pass
2186
+
2187
+ if not self.training:
2188
+ self.n_ahead_talk = n_ahead_talk_to_restore
2189
+ self.n_passes = n_passes_to_restore
2190
  return CausalLMOutputWithPast(
2191
+ loss=loss if loss is not None else None,
2192
+ logits=(rm_logits if self.n_ahead > 1 else logits) if not self.output_logits_at_the_end else logits,
2193
  past_key_values=outputs.past_key_values,
2194
  hidden_states=outputs.hidden_states,
2195
  attentions=outputs.attentions,
2196
  )
2197
 
2198
+
2199
  def prepare_inputs_for_generation(
2200
  self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
2201
  ):
 
2211
 
2212
  # Keep only the unprocessed tokens:
2213
  # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
2214
+ # some of the inputs are exclusively passed as part of the cache (e.g. when passing inputs_embeds as
2215
  # input)
2216
  if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]:
2217
  input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]