Crystalcareai commited on
Commit
8d2b1eb
·
verified ·
1 Parent(s): c0c99ee

Upload 3 files

Browse files
Files changed (3) hide show
  1. config.json +4 -14
  2. configuration_quiet.py +26 -44
  3. modeling_quiet.py +114 -1115
config.json CHANGED
@@ -1,9 +1,7 @@
1
  {
2
- "_name_or_path": "Crystalcareai/Quiet-Star-Custom",
3
  "architectures": [
4
  "QuietForCausalLM"
5
  ],
6
- "attention_dropout": 0.0,
7
  "auto_map": {
8
  "AutoConfig": "configuration_quiet.QuietConfig",
9
  "AutoModel": "modeling_quiet.QuietModel",
@@ -16,11 +14,9 @@
16
  "initializer_range": 0.02,
17
  "intermediate_size": 14336,
18
  "max_position_embeddings": 32768,
19
- "max_thoughts": 10,
20
- "merged_lm_and_talk_heads": false,
21
- "merged_lm_and_think_heads": true,
22
- "merged_talk_heads": true,
23
  "model_type": "quiet",
 
 
24
  "num_attention_heads": 32,
25
  "num_hidden_layers": 32,
26
  "num_key_value_heads": 8,
@@ -29,13 +25,7 @@
29
  "sliding_window": 4096,
30
  "tie_word_embeddings": false,
31
  "torch_dtype": "bfloat16",
32
- "transformers_version": "4.37.0.dev0",
33
  "use_cache": true,
34
- "use_complex_talk_head": true,
35
- "use_complex_think_head": false,
36
- "use_concat_talk_head": true,
37
- "use_shallow_talk": false,
38
- "use_shallow_think": true,
39
- "use_weighted_talk_head": true,
40
- "vocab_size": 32002
41
  }
 
1
  {
 
2
  "architectures": [
3
  "QuietForCausalLM"
4
  ],
 
5
  "auto_map": {
6
  "AutoConfig": "configuration_quiet.QuietConfig",
7
  "AutoModel": "modeling_quiet.QuietModel",
 
14
  "initializer_range": 0.02,
15
  "intermediate_size": 14336,
16
  "max_position_embeddings": 32768,
 
 
 
 
17
  "model_type": "quiet",
18
+ "max_thoughts": 3,
19
+ "thought_length": 10,
20
  "num_attention_heads": 32,
21
  "num_hidden_layers": 32,
22
  "num_key_value_heads": 8,
 
25
  "sliding_window": 4096,
26
  "tie_word_embeddings": false,
27
  "torch_dtype": "bfloat16",
28
+ "transformers_version": "4.34.0.dev0",
29
  "use_cache": true,
30
+ "vocab_size": 32000
 
 
 
 
 
 
31
  }
configuration_quiet.py CHANGED
@@ -1,5 +1,5 @@
1
  # coding=utf-8
2
- # Copyright 2023 Quiet AI and the HuggingFace Inc. team. All rights reserved.
3
  #
4
  # Licensed under the Apache License, Version 2.0 (the "License");
5
  # you may not use this file except in compliance with the License.
@@ -12,28 +12,26 @@
12
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
  # See the License for the specific language governing permissions and
14
  # limitations under the License.
15
- """ Quiet model configuration"""
16
 
17
- from transformers.configuration_utils import PretrainedConfig
18
- from transformers.utils import logging
19
 
20
 
21
  logger = logging.get_logger(__name__)
22
 
23
- QUIET_PRETRAINED_CONFIG_ARCHIVE_MAP = {
24
- "quietai/Quiet-7B-v0.1": "https://huggingface.co/quietai/Quiet-7B-v0.1/resolve/main/config.json",
25
- "quietai/Quiet-7B-Instruct-v0.1": "https://huggingface.co/quietai/Quiet-7B-Instruct-v0.1/resolve/main/config.json",
26
- }
27
 
 
28
 
29
- class QuietConfig(PretrainedConfig):
 
30
  r"""
31
- This is the configuration class to store the configuration of a [`QuietModel`]. It is used to instantiate an
32
- Quiet model according to the specified arguments, defining the model architecture. Instantiating a configuration
33
- with the defaults will yield a similar configuration to that of the Quiet-7B-v0.1 or Quiet-7B-Instruct-v0.1.
34
 
35
- [quietai/Quiet-7B-v0.1](https://huggingface.co/quietai/Quiet-7B-v0.1)
36
- [quietai/Quiet-7B-Instruct-v0.1](https://huggingface.co/quietai/Quiet-7B-Instruct-v0.1)
37
 
38
  Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
39
  documentation from [`PretrainedConfig`] for more information.
@@ -41,8 +39,8 @@ class QuietConfig(PretrainedConfig):
41
 
42
  Args:
43
  vocab_size (`int`, *optional*, defaults to 32000):
44
- Vocabulary size of the Quiet model. Defines the number of different tokens that can be represented by the
45
- `inputs_ids` passed when calling [`QuietModel`]
46
  hidden_size (`int`, *optional*, defaults to 4096):
47
  Dimension of the hidden representations.
48
  intermediate_size (`int`, *optional*, defaults to 14336):
@@ -61,7 +59,7 @@ class QuietConfig(PretrainedConfig):
61
  hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
62
  The non-linear activation function (function or string) in the decoder.
63
  max_position_embeddings (`int`, *optional*, defaults to `4096*32`):
64
- The maximum sequence length that this model might ever be used with. Quiet's sliding window attention
65
  allows sequence of up to 4096*32 tokens.
66
  initializer_range (`float`, *optional*, defaults to 0.02):
67
  The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
@@ -86,19 +84,19 @@ class QuietConfig(PretrainedConfig):
86
  The dropout ratio for the attention probabilities.
87
 
88
  ```python
89
- >>> from transformers import QuietModel, QuietConfig
90
 
91
- >>> # Initializing a Quiet 7B style configuration
92
- >>> configuration = QuietConfig()
93
 
94
- >>> # Initializing a model from the Quiet 7B style configuration
95
- >>> model = QuietModel(configuration)
96
 
97
  >>> # Accessing the model configuration
98
  >>> configuration = model.config
99
  ```"""
100
 
101
- model_type = "quiet"
102
  keys_to_ignore_at_inference = ["past_key_values"]
103
 
104
  def __init__(
@@ -116,21 +114,13 @@ class QuietConfig(PretrainedConfig):
116
  use_cache=True,
117
  pad_token_id=None,
118
  bos_token_id=1,
 
 
119
  eos_token_id=2,
120
  tie_word_embeddings=False,
121
  rope_theta=10000.0,
122
  sliding_window=4096,
123
  attention_dropout=0.0,
124
- max_thoughts=16,
125
- merged_talk_heads=True,
126
- merged_lm_and_talk_heads=False,
127
- merged_lm_and_think_heads=True,
128
- use_concat_talk_head=True,
129
- use_shallow_think=True,
130
- use_shallow_talk=False,
131
- use_complex_think_head=False,
132
- use_complex_talk_head=True,
133
- use_weighted_talk_head=True,
134
  **kwargs,
135
  ):
136
  self.vocab_size = vocab_size
@@ -150,18 +140,10 @@ class QuietConfig(PretrainedConfig):
150
  self.initializer_range = initializer_range
151
  self.rms_norm_eps = rms_norm_eps
152
  self.use_cache = use_cache
 
 
153
  self.rope_theta = rope_theta
154
  self.attention_dropout = attention_dropout
155
- self.max_thoughts = max_thoughts
156
- self.merged_talk_heads = merged_talk_heads
157
- self.merged_lm_and_talk_heads = merged_lm_and_talk_heads
158
- self.merged_lm_and_think_heads = merged_lm_and_think_heads
159
- self.use_concat_talk_head = use_concat_talk_head
160
- self.use_shallow_think = use_shallow_think
161
- self.use_shallow_talk = use_shallow_talk
162
- self.use_complex_think_head = use_complex_think_head
163
- self.use_complex_talk_head = use_complex_talk_head
164
- self.use_weighted_talk_head = use_weighted_talk_head
165
 
166
  super().__init__(
167
  pad_token_id=pad_token_id,
@@ -169,4 +151,4 @@ class QuietConfig(PretrainedConfig):
169
  eos_token_id=eos_token_id,
170
  tie_word_embeddings=tie_word_embeddings,
171
  **kwargs,
172
- )
 
1
  # coding=utf-8
2
+ # Copyright 2023 Mistral AI and the HuggingFace Inc. team. All rights reserved.
3
  #
4
  # Licensed under the Apache License, Version 2.0 (the "License");
5
  # you may not use this file except in compliance with the License.
 
12
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
  # See the License for the specific language governing permissions and
14
  # limitations under the License.
15
+ """ Mistral model configuration"""
16
 
17
+ from ...configuration_utils import PretrainedConfig
18
+ from ...utils import logging
19
 
20
 
21
  logger = logging.get_logger(__name__)
22
 
 
 
 
 
23
 
24
+ from ..deprecated._archive_maps import MISTRAL_PRETRAINED_CONFIG_ARCHIVE_MAP # noqa: F401, E402
25
 
26
+
27
+ class MistralConfig(PretrainedConfig):
28
  r"""
29
+ This is the configuration class to store the configuration of a [`MistralModel`]. It is used to instantiate an
30
+ Mistral model according to the specified arguments, defining the model architecture. Instantiating a configuration
31
+ with the defaults will yield a similar configuration to that of the Mistral-7B-v0.1 or Mistral-7B-Instruct-v0.1.
32
 
33
+ [mistralai/Mistral-7B-v0.1](https://huggingface.co/mistralai/Mistral-7B-v0.1)
34
+ [mistralai/Mistral-7B-Instruct-v0.1](https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.1)
35
 
36
  Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
37
  documentation from [`PretrainedConfig`] for more information.
 
39
 
40
  Args:
41
  vocab_size (`int`, *optional*, defaults to 32000):
42
+ Vocabulary size of the Mistral model. Defines the number of different tokens that can be represented by the
43
+ `inputs_ids` passed when calling [`MistralModel`]
44
  hidden_size (`int`, *optional*, defaults to 4096):
45
  Dimension of the hidden representations.
46
  intermediate_size (`int`, *optional*, defaults to 14336):
 
59
  hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
60
  The non-linear activation function (function or string) in the decoder.
61
  max_position_embeddings (`int`, *optional*, defaults to `4096*32`):
62
+ The maximum sequence length that this model might ever be used with. Mistral's sliding window attention
63
  allows sequence of up to 4096*32 tokens.
64
  initializer_range (`float`, *optional*, defaults to 0.02):
65
  The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
 
84
  The dropout ratio for the attention probabilities.
85
 
86
  ```python
87
+ >>> from transformers import MistralModel, MistralConfig
88
 
89
+ >>> # Initializing a Mistral 7B style configuration
90
+ >>> configuration = MistralConfig()
91
 
92
+ >>> # Initializing a model from the Mistral 7B style configuration
93
+ >>> model = MistralModel(configuration)
94
 
95
  >>> # Accessing the model configuration
96
  >>> configuration = model.config
97
  ```"""
98
 
99
+ model_type = "mistral"
100
  keys_to_ignore_at_inference = ["past_key_values"]
101
 
102
  def __init__(
 
114
  use_cache=True,
115
  pad_token_id=None,
116
  bos_token_id=1,
117
+ max_thoughts: int = 3,
118
+ thought_length: int = 10,
119
  eos_token_id=2,
120
  tie_word_embeddings=False,
121
  rope_theta=10000.0,
122
  sliding_window=4096,
123
  attention_dropout=0.0,
 
 
 
 
 
 
 
 
 
 
124
  **kwargs,
125
  ):
126
  self.vocab_size = vocab_size
 
140
  self.initializer_range = initializer_range
141
  self.rms_norm_eps = rms_norm_eps
142
  self.use_cache = use_cache
143
+ self.max_thoughts = max_thoughts
144
+ self.thought_length = thought_length
145
  self.rope_theta = rope_theta
146
  self.attention_dropout = attention_dropout
 
 
 
 
 
 
 
 
 
 
147
 
148
  super().__init__(
149
  pad_token_id=pad_token_id,
 
151
  eos_token_id=eos_token_id,
152
  tie_word_embeddings=tie_word_embeddings,
153
  **kwargs,
154
+ )
modeling_quiet.py CHANGED
@@ -1,5 +1,5 @@
1
  # coding=utf-8
2
- # Copyright 2023 Quiet AI and the HuggingFace Inc. team. All rights reserved.
3
  #
4
  # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
5
  # and OPT implementations in this library. It has been modified from its
@@ -17,23 +17,10 @@
17
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
18
  # See the License for the specific language governing permissions and
19
  # limitations under the License.
20
- """ PyTorch Quiet model."""
21
  import inspect
22
  import math
23
- import copy
24
- import os
25
- import time
26
- import pandas as pd
27
- import seaborn as sns
28
- import matplotlib.pyplot as plt
29
- import wandb
30
- from termcolor import colored
31
- from tqdm import tqdm
32
- import random
33
- import numpy as np
34
- from matplotlib.colors import LinearSegmentedColormap, LogNorm
35
  import warnings
36
- from collections import defaultdict
37
  from typing import List, Optional, Tuple, Union
38
 
39
  import torch
@@ -42,12 +29,12 @@ import torch.utils.checkpoint
42
  from torch import nn
43
  from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
44
 
45
- from transformers.activations import ACT2FN
46
- from transformers.cache_utils import Cache, DynamicCache
47
- from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask, _prepare_4d_causal_attention_mask_for_sdpa
48
- from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast
49
- from transformers.modeling_utils import PreTrainedModel
50
- from transformers.utils import (
51
  add_start_docstrings,
52
  add_start_docstrings_to_model_forward,
53
  is_flash_attn_2_available,
@@ -55,7 +42,7 @@ from transformers.utils import (
55
  logging,
56
  replace_return_docstrings,
57
  )
58
- from .configuration_quiet import QuietConfig
59
 
60
 
61
  if is_flash_attn_2_available():
@@ -67,73 +54,7 @@ if is_flash_attn_2_available():
67
 
68
  logger = logging.get_logger(__name__)
69
 
70
- _CONFIG_FOR_DOC = "QuietConfig"
71
-
72
- from reportlab.pdfgen import canvas
73
- from reportlab.lib.pagesizes import letter
74
- from reportlab.lib.colors import HexColor
75
-
76
- def save_tokens_with_rewards_to_pdf(input_ids, token_rewards, tokenizer, output_file="text.pdf", eps=0.2, eps2=0.5):
77
- c = canvas.Canvas(output_file, pagesize=letter)
78
- c.setFont("Courier", 8)
79
- x, y = 50, 750
80
- previous_text = ""
81
- current_text = ""
82
- for token_idx, reward in enumerate(token_rewards):
83
- current_text = tokenizer.decode(input_ids[: token_idx + 1])
84
- if current_text != previous_text:
85
- diff_text = current_text[len(previous_text) :]
86
- if "\n" in diff_text:
87
- lines = diff_text.split("\n")
88
- for line_idx, line in enumerate(lines):
89
- if line_idx > 0:
90
- x = 50
91
- y -= 12
92
- if abs(reward) < eps:
93
- opacity = 0
94
- elif abs(reward) > eps2:
95
- opacity = 0.8
96
- else:
97
- opacity = 0.8 * (abs(reward) - eps) / (eps2 - eps)
98
- text_width = c.stringWidth(line)
99
- if reward > 0:
100
- highlight_color = HexColor("#4CCD99")
101
- else:
102
- highlight_color = HexColor("#FFC700")
103
- highlight_color.alpha = opacity
104
- c.setFillColor(highlight_color)
105
- c.rect(x, y - 2, text_width, 10, fill=True, stroke=False)
106
- c.setFillColor(HexColor("#000000"))
107
- c.drawString(x, y, line)
108
- x += text_width
109
- else:
110
- if abs(reward) < eps:
111
- opacity = 0
112
- elif abs(reward) > eps2:
113
- opacity = 0.8
114
- else:
115
- opacity = 0.8 * (abs(reward) - eps) / (eps2 - eps)
116
- text_width = c.stringWidth(diff_text)
117
- if reward > 0:
118
- highlight_color = HexColor("#4CCD99")
119
- else:
120
- highlight_color = HexColor("#FFC700")
121
- highlight_color.alpha = opacity
122
- c.setFillColor(highlight_color)
123
- c.rect(x, y - 2, text_width, 10, fill=True, stroke=False)
124
- c.setFillColor(HexColor("#000000"))
125
- c.drawString(x, y, diff_text)
126
- x += text_width
127
- if x > 550:
128
- x = 50
129
- y -= 12
130
- if y < 50:
131
- c.showPage()
132
- y = 750
133
- x = 50
134
- previous_text = current_text
135
- c.showPage()
136
- c.save()
137
 
138
 
139
  # Copied from transformers.models.llama.modeling_llama._get_unpad_data
@@ -141,7 +62,7 @@ def _get_unpad_data(attention_mask):
141
  seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
142
  indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
143
  max_seqlen_in_batch = seqlens_in_batch.max().item()
144
- cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0))
145
  return (
146
  indices,
147
  cu_seqlens,
@@ -149,11 +70,11 @@ def _get_unpad_data(attention_mask):
149
  )
150
 
151
 
152
- # Copied from transformers.models.llama.modeling_llama.LlamaRMSNorm with Llama->Quiet
153
- class QuietRMSNorm(nn.Module):
154
  def __init__(self, hidden_size, eps=1e-6):
155
  """
156
- QuietRMSNorm is equivalent to T5LayerNorm
157
  """
158
  super().__init__()
159
  self.weight = nn.Parameter(torch.ones(hidden_size))
@@ -164,18 +85,19 @@ class QuietRMSNorm(nn.Module):
164
  hidden_states = hidden_states.to(torch.float32)
165
  variance = hidden_states.pow(2).mean(-1, keepdim=True)
166
  hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
167
- return hidden_states.to(input_dtype) * self.weight.to(hidden_states.device)
168
 
169
 
170
- # Copied from transformers.models.llama.modeling_llama.LlamaRotaryEmbedding with Llama->Quiet
171
- class QuietRotaryEmbedding(nn.Module):
 
172
  def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
173
  super().__init__()
174
 
175
  self.dim = dim
176
  self.max_position_embeddings = max_position_embeddings
177
  self.base = base
178
- inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))
179
  self.register_buffer("inv_freq", inv_freq, persistent=False)
180
 
181
  # Build here to make `torch.jit.trace` work.
@@ -185,7 +107,7 @@ class QuietRotaryEmbedding(nn.Module):
185
 
186
  def _set_cos_sin_cache(self, seq_len, device, dtype):
187
  self.max_seq_len_cached = seq_len
188
- t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
189
 
190
  freqs = torch.outer(t, self.inv_freq)
191
  # Different from paper, but it uses a different permutation in order to obtain the same calculation
@@ -212,7 +134,8 @@ def rotate_half(x):
212
  return torch.cat((-x2, x1), dim=-1)
213
 
214
 
215
- # Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb
 
216
  def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
217
  """Applies Rotary Position Embedding to the query and key tensors.
218
 
@@ -241,7 +164,7 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
241
  return q_embed, k_embed
242
 
243
 
244
- class QuietMLP(nn.Module):
245
  def __init__(self, config):
246
  super().__init__()
247
  self.config = config
@@ -269,20 +192,20 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
269
  return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
270
 
271
 
272
- class QuietAttention(nn.Module):
273
  """
274
  Multi-headed attention from 'Attention Is All You Need' paper. Modified to use sliding window attention: Longformer
275
  and "Generating Long Sequences with Sparse Transformers".
276
  """
277
 
278
- def __init__(self, config: QuietConfig, layer_idx: Optional[int] = None):
279
  super().__init__()
280
  self.config = config
281
  self.layer_idx = layer_idx
282
  if layer_idx is None:
283
  logger.warning_once(
284
- f"Instantiating {self.__class__.__name__} without passing `layer_idx` is not recommended and will "
285
- "to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` "
286
  "when creating this class."
287
  )
288
 
@@ -306,7 +229,7 @@ class QuietAttention(nn.Module):
306
  self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
307
  self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
308
 
309
- self.rotary_emb = QuietRotaryEmbedding(
310
  self.head_dim,
311
  max_position_embeddings=self.max_position_embeddings,
312
  base=self.rope_theta,
@@ -397,9 +320,9 @@ class QuietAttention(nn.Module):
397
  return attn_output, attn_weights, past_key_value
398
 
399
 
400
- class QuietFlashAttention2(QuietAttention):
401
  """
402
- Quiet flash attention module. This module inherits from `QuietAttention` as the weights of the module stays
403
  untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
404
  flash attention and deal with padding tokens in case the input contains any of them.
405
  """
@@ -573,7 +496,7 @@ class QuietFlashAttention2(QuietAttention):
573
  attention_mask (`torch.Tensor`):
574
  The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the
575
  position of padding tokens and 1 for the position of non-padding tokens.
576
- dropout (`int`, *optional*):
577
  Attention dropout
578
  softmax_scale (`float`, *optional*):
579
  The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
@@ -691,15 +614,16 @@ class QuietFlashAttention2(QuietAttention):
691
  )
692
 
693
 
694
- # Copied from transformers.models.llama.modeling_llama.LlamaSdpaAttention with Llama->Quiet
695
- class QuietSdpaAttention(QuietAttention):
 
696
  """
697
- Quiet attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
698
- `QuietAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to
699
  SDPA API.
700
  """
701
 
702
- # Adapted from QuietAttention.forward
703
  def forward(
704
  self,
705
  hidden_states: torch.Tensor,
@@ -712,7 +636,7 @@ class QuietSdpaAttention(QuietAttention):
712
  if output_attentions:
713
  # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
714
  logger.warning_once(
715
- "QuietModel is using QuietSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, "
716
  'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
717
  )
718
  return super().forward(
@@ -765,37 +689,37 @@ class QuietSdpaAttention(QuietAttention):
765
  query_states,
766
  key_states,
767
  value_states,
768
- attn_mask=attention_mask.to(query_states.device) if attention_mask is not None else None,
769
  dropout_p=self.attention_dropout if self.training else 0.0,
770
  # The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1.
771
  is_causal=self.is_causal and attention_mask is None and q_len > 1,
772
  )
773
 
774
  attn_output = attn_output.transpose(1, 2).contiguous()
775
- attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
776
 
777
  attn_output = self.o_proj(attn_output)
778
 
779
  return attn_output, None, past_key_value
780
 
781
 
782
- QUIET_ATTENTION_CLASSES = {
783
- "eager": QuietAttention,
784
- "flash_attention_2": QuietFlashAttention2,
785
- "sdpa": QuietSdpaAttention,
786
  }
787
 
788
 
789
- class QuietDecoderLayer(nn.Module):
790
- def __init__(self, config: QuietConfig, layer_idx: int):
791
  super().__init__()
792
  self.hidden_size = config.hidden_size
793
 
794
- self.self_attn = QUIET_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx)
795
 
796
- self.mlp = QuietMLP(config)
797
- self.input_layernorm = QuietRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
798
- self.post_attention_layernorm = QuietRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
799
 
800
  def forward(
801
  self,
@@ -838,7 +762,7 @@ class QuietDecoderLayer(nn.Module):
838
  output_attentions=output_attentions,
839
  use_cache=use_cache,
840
  )
841
- hidden_states = residual.to(hidden_states.device) + hidden_states
842
 
843
  # Fully Connected
844
  residual = hidden_states
@@ -857,7 +781,7 @@ class QuietDecoderLayer(nn.Module):
857
  return outputs
858
 
859
 
860
- QUIET_START_DOCSTRING = r"""
861
  This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
862
  library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
863
  etc.)
@@ -867,7 +791,7 @@ QUIET_START_DOCSTRING = r"""
867
  and behavior.
868
 
869
  Parameters:
870
- config ([`QuietConfig`]):
871
  Model configuration class with all the parameters of the model. Initializing with a config file does not
872
  load the weights associated with the model, only the configuration. Check out the
873
  [`~PreTrainedModel.from_pretrained`] method to load the model weights.
@@ -875,14 +799,14 @@ QUIET_START_DOCSTRING = r"""
875
 
876
 
877
  @add_start_docstrings(
878
- "The bare Quiet Model outputting raw hidden-states without any specific head on top.",
879
- QUIET_START_DOCSTRING,
880
  )
881
- class QuietPreTrainedModel(PreTrainedModel):
882
- config_class = QuietConfig
883
  base_model_prefix = "model"
884
  supports_gradient_checkpointing = True
885
- _no_split_modules = ["QuietDecoderLayer"]
886
  _skip_keys_device_placement = "past_key_values"
887
  _supports_flash_attn_2 = True
888
  _supports_sdpa = True
@@ -900,7 +824,7 @@ class QuietPreTrainedModel(PreTrainedModel):
900
  module.weight.data[module.padding_idx].zero_()
901
 
902
 
903
- QUIET_INPUTS_DOCSTRING = r"""
904
  Args:
905
  input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
906
  Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
@@ -971,28 +895,28 @@ QUIET_INPUTS_DOCSTRING = r"""
971
 
972
 
973
  @add_start_docstrings(
974
- "The bare Quiet Model outputting raw hidden-states without any specific head on top.",
975
- QUIET_START_DOCSTRING,
976
  )
977
- class QuietModel(QuietPreTrainedModel):
978
  """
979
- Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`QuietDecoderLayer`]
980
 
981
  Args:
982
- config: QuietConfig
983
  """
984
 
985
- def __init__(self, config: QuietConfig):
986
  super().__init__(config)
987
  self.padding_idx = config.pad_token_id
988
  self.vocab_size = config.vocab_size
989
 
990
  self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
991
  self.layers = nn.ModuleList(
992
- [QuietDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
993
  )
994
  self._attn_implementation = config._attn_implementation
995
- self.norm = QuietRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
996
 
997
  self.gradient_checkpointing = False
998
  # Initialize weights and apply final processing
@@ -1004,7 +928,7 @@ class QuietModel(QuietPreTrainedModel):
1004
  def set_input_embeddings(self, value):
1005
  self.embed_tokens = value
1006
 
1007
- @add_start_docstrings_to_model_forward(QUIET_INPUTS_DOCSTRING)
1008
  def forward(
1009
  self,
1010
  input_ids: torch.LongTensor = None,
@@ -1067,14 +991,14 @@ class QuietModel(QuietPreTrainedModel):
1067
  if is_padding_right:
1068
  raise ValueError(
1069
  "You are attempting to perform batched generation with padding_side='right'"
1070
- " this may lead to unexpected behaviour for Flash Attention version of Quiet. Make sure to "
1071
  " call `tokenizer.padding_side = 'left'` before tokenizing the input. "
1072
  )
1073
 
1074
  if self._attn_implementation == "flash_attention_2":
1075
  # 2d mask is passed through the layers
1076
  attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
1077
- elif self._attn_implementation == "sdpa" and not output_attentions and attention_mask.dim() == 2 and False:
1078
  # output_attentions=True can not be supported when using SDPA, and we fall back on
1079
  # the manual implementation that requires a 4D causal mask in all cases.
1080
  attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
@@ -1083,7 +1007,7 @@ class QuietModel(QuietPreTrainedModel):
1083
  inputs_embeds,
1084
  past_key_values_length,
1085
  )
1086
- elif attention_mask is None or attention_mask.dim() == 2:
1087
  # 4d mask is passed through the layers
1088
  attention_mask = _prepare_4d_causal_attention_mask(
1089
  attention_mask,
@@ -1151,129 +1075,15 @@ class QuietModel(QuietPreTrainedModel):
1151
  attentions=all_self_attns,
1152
  )
1153
 
1154
- def nonzero_mean(x, axis=None):
1155
- if axis is not None:
1156
- return x.sum(axis) / (x != 0).sum(axis)
1157
- return x.sum() / (x != 0).sum()
1158
-
1159
- def loss_mean(x):
1160
- return x.sum() / (x != 0).sum()
1161
 
1162
- class QuietForCausalLM(QuietPreTrainedModel):
1163
  _tied_weights_keys = ["lm_head.weight"]
1164
 
1165
  def __init__(self, config):
1166
  super().__init__(config)
1167
- self.model = QuietModel(config)
1168
  self.vocab_size = config.vocab_size
1169
  self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
1170
- self.max_thoughts = config.max_thoughts
1171
- self.merged_lm_and_talk_heads = config.merged_lm_and_talk_heads
1172
- self.use_concat_talk_head = config.use_concat_talk_head
1173
- self.use_shallow_talk = config.use_shallow_talk
1174
- self.use_complex_talk_head = config.use_complex_talk_head
1175
- self.use_weighted_talk_head = config.use_weighted_talk_head
1176
- # the weighted head will output a single value, so it can't be passed to the lm head
1177
- assert not (self.use_weighted_talk_head and self.use_shallow_talk)
1178
-
1179
- self.n_ahead = 1
1180
- self.n_ahead_talk = 1
1181
- self.n_passes = 1
1182
- self.n_tokens_print = 1
1183
- self.gradient_accumulation_steps = 1
1184
- self.training_steps = 0
1185
- self.tokenizer = None
1186
- self.start_token_id = None
1187
- self.end_token_id = None
1188
- self.rm_initialized = False
1189
- self.residual_talk_head = True
1190
- self.thought_init_std_scale = 1e-2
1191
-
1192
- self.final_only_mode = False
1193
- self.first_and_last_mode = True
1194
- self.first_only = False
1195
- self.original_loss_weight = 0.5
1196
-
1197
- self.cumulative_residual = False
1198
- self.clever_residual = False
1199
- self.skip_residual = False
1200
- self.no_residual = True
1201
-
1202
- self.optimize_lm_head_only_at_start = False
1203
- self.optimize_model_only_at_start = False
1204
-
1205
- if self.optimize_model_only_at_start:
1206
- raise NotImplementedError
1207
- self.train_only_thinking_embedding = False
1208
- self.weighted_embeddings = False
1209
- self.use_start_thought_token = True
1210
- self.use_end_thought_token = True
1211
- self.initialize_thought_embedding_to_normal = False
1212
- self.initial_start_token = "---"
1213
- self.initial_end_token = "---"
1214
- self.output_logits_at_the_end = True
1215
-
1216
- self.wandb_enabled = False
1217
- self.gumbel_temperature = 0.001
1218
-
1219
- self.use_policy_loss = True
1220
- self.include_policy_loss = True
1221
- self.trice_mode = True
1222
- self.remove_negative_rewards = True
1223
- self.use_policy_loss_for_end_thought = True
1224
-
1225
- self.base_original_mode = False
1226
- self.original_mode = False
1227
-
1228
- self.thought_prefix = "(Let's think step by step"
1229
- self.tokenized_thought_prefix = None
1230
- self.log_dict = defaultdict(int)
1231
- self.eval_log_dict = defaultdict(int)
1232
- self.print_final_only = True
1233
- self.loss_mean = loss_mean
1234
- self.all_rewards = []
1235
- self.all_unreduced_losses = []
1236
- self.kill_after = 100
1237
-
1238
- self.start_embedding = nn.Parameter(torch.zeros(2, self.model.config.hidden_size))
1239
- self.end_embedding = nn.Parameter(torch.zeros(2, self.model.config.hidden_size))
1240
-
1241
- self.policy_loss_beta = 1e6
1242
- self.embedding_scale = 1e2
1243
- self.reinforce_temperature = 3
1244
- self.base_loss_beta = 1
1245
-
1246
- # Not used in the paper:
1247
- self.use_thought_prefix = False
1248
- self.use_reparam_for_thought_embeddings = False
1249
- self.use_upper_triangular = False
1250
- self.subtract_mean_reward = False
1251
- self.comparison_mode = False
1252
- self.gumbel_detach = True
1253
-
1254
- # For visualization
1255
- self.eval_mode = False
1256
-
1257
- num_talk = 1
1258
- talk_input_dim = config.hidden_size if not self.use_concat_talk_head else config.hidden_size * 2
1259
- if self.use_weighted_talk_head:
1260
- talk_output_dim = 1
1261
- else:
1262
- talk_output_dim = config.hidden_size if self.use_shallow_talk else config.vocab_size
1263
-
1264
- if not self.merged_lm_and_talk_heads:
1265
- if self.use_complex_talk_head:
1266
- self.talk_head = nn.ModuleList([nn.Sequential(
1267
- nn.Linear(talk_input_dim, config.hidden_size),
1268
- nn.ReLU(),
1269
- nn.Linear(config.hidden_size, config.hidden_size),
1270
- nn.ReLU(),
1271
- nn.Linear(config.hidden_size, talk_output_dim, bias=False)
1272
- )])
1273
- else:
1274
- self.talk_head = nn.ModuleList([nn.Sequential(
1275
- nn.Linear(talk_input_dim, talk_output_dim, bias=False)
1276
- )])
1277
 
1278
  # Initialize weights and apply final processing
1279
  self.post_init()
@@ -1296,126 +1106,7 @@ class QuietForCausalLM(QuietPreTrainedModel):
1296
  def get_decoder(self):
1297
  return self.model
1298
 
1299
- @torch.no_grad()
1300
- def infer(
1301
- self,
1302
- input_ids: torch.LongTensor,
1303
- attention_mask: Optional[torch.Tensor] = None,
1304
- position_ids: Optional[torch.LongTensor] = None,
1305
- past_key_values: Optional[List[torch.FloatTensor]] = None,
1306
- inputs_embeds: Optional[torch.FloatTensor] = None,
1307
- use_cache: Optional[bool] = None,
1308
- output_attentions: Optional[bool] = None,
1309
- output_hidden_states: Optional[bool] = None,
1310
- return_dict: Optional[bool] = None,
1311
- ):
1312
- batch_size, seq_len = input_ids.shape
1313
-
1314
- # Save the original input_ids and attention_mask for later use
1315
- original_input_ids = input_ids.clone()
1316
- original_attention_mask = attention_mask.clone() if attention_mask is not None else None
1317
-
1318
- # Append the start thought token to the input sequence
1319
- start_thought_token_id = self.tokenizer.convert_tokens_to_ids("<|startthought|>")
1320
- input_ids = torch.cat([input_ids, torch.tensor([[start_thought_token_id]] * batch_size).to(input_ids.device)], dim=-1)
1321
- seq_len += 1
1322
-
1323
- # Update the attention mask
1324
- if attention_mask is not None:
1325
- attention_mask = torch.cat([attention_mask, torch.ones((batch_size, 1)).to(attention_mask.device)], dim=-1)
1326
-
1327
- # Generate the continuation
1328
- continuation_length = self.n_ahead - 2
1329
- new_key_values = past_key_values
1330
- generated_tokens = []
1331
-
1332
- for continuation_idx in range(continuation_length):
1333
- outputs = self.model(
1334
- input_ids=input_ids if continuation_idx == 0 else next_token_id.unsqueeze(-1).to(input_ids.device),
1335
- attention_mask=attention_mask,
1336
- position_ids=position_ids,
1337
- past_key_values=new_key_values,
1338
- inputs_embeds=inputs_embeds,
1339
- use_cache=True,
1340
- output_attentions=output_attentions,
1341
- output_hidden_states=output_hidden_states,
1342
- return_dict=return_dict,
1343
- )
1344
- new_key_values = outputs.past_key_values
1345
- hidden_states = outputs[0]
1346
- logits = self.lm_head(hidden_states)
1347
- logits = logits[:, -1, :] # Only consider the last token
1348
-
1349
- # Apply Gumbel-Softmax to the logits
1350
- next_token_logits = F.gumbel_softmax(logits, tau=self.gumbel_temperature, hard=True, dim=-1)
1351
- next_token_id = torch.argmax(next_token_logits, dim=-1)
1352
-
1353
- # Append the generated token to the input sequence
1354
- input_ids = torch.cat([input_ids, next_token_id.unsqueeze(-1).to(input_ids.device)], dim=-1)
1355
- generated_tokens.append(next_token_id)
1356
- seq_len += 1
1357
-
1358
- # Update the attention mask
1359
- if attention_mask is not None:
1360
- attention_mask = torch.cat([attention_mask, torch.ones((batch_size, 1)).to(attention_mask.device)], dim=-1)
1361
-
1362
- # Update the position ids
1363
- if position_ids is not None:
1364
- position_ids = torch.cat([position_ids, (position_ids[:, -1] + 1).unsqueeze(-1)], dim=-1)
1365
-
1366
- # Append the end thought token to the input sequence
1367
- end_thought_token_id = self.tokenizer.convert_tokens_to_ids("<|endthought|>")
1368
- input_ids = torch.cat([input_ids, torch.tensor([[end_thought_token_id]] * batch_size).to(input_ids.device)], dim=-1)
1369
- seq_len += 1
1370
-
1371
- # Update the attention mask
1372
- if attention_mask is not None:
1373
- attention_mask = torch.cat([attention_mask, torch.ones((batch_size, 1)).to(attention_mask.device)], dim=-1)
1374
-
1375
- # Get the hidden states before and after the thought
1376
- outputs_before = self.model(
1377
- input_ids=original_input_ids,
1378
- attention_mask=original_attention_mask,
1379
- position_ids=position_ids,
1380
- past_key_values=past_key_values,
1381
- inputs_embeds=inputs_embeds,
1382
- use_cache=use_cache,
1383
- output_attentions=output_attentions,
1384
- output_hidden_states=output_hidden_states,
1385
- return_dict=return_dict,
1386
- )
1387
- hidden_states_before = outputs_before[0][:, -1:, :]
1388
-
1389
- # two new tokens: last continuation token and end thought token
1390
- outputs_after = self.model(
1391
- input_ids=torch.cat([next_token_id.unsqueeze(-1).to(input_ids.device), torch.tensor(end_thought_token_id).unsqueeze(-1).unsqueeze(-1).to(input_ids.device)], dim=-1),
1392
- attention_mask=attention_mask,
1393
- position_ids=position_ids,
1394
- past_key_values=new_key_values,
1395
- inputs_embeds=inputs_embeds,
1396
- use_cache=use_cache,
1397
- output_attentions=output_attentions,
1398
- output_hidden_states=output_hidden_states,
1399
- return_dict=return_dict,
1400
- )
1401
- hidden_states_after = outputs_after[0][:, -1:, :]
1402
-
1403
- # Apply the talk head to get the mixing weight
1404
- mixing_weight = self.talk_head[0](torch.cat([hidden_states_before, hidden_states_after], dim=-1))
1405
-
1406
- # Apply the mixing weight to the hidden states
1407
- mixed_hidden_states = (1 - mixing_weight) * hidden_states_before + mixing_weight * hidden_states_after
1408
-
1409
- # Apply the language model head to get the final logits
1410
- logits = self.lm_head(mixed_hidden_states)
1411
-
1412
- # Decode the logits to get the generated text
1413
- generated_tokens = torch.cat(generated_tokens, dim=-1)
1414
- generated_text = self.tokenizer.decode(generated_tokens.squeeze(), skip_special_tokens=True)
1415
-
1416
- return generated_text
1417
-
1418
- @add_start_docstrings_to_model_forward(QUIET_INPUTS_DOCSTRING)
1419
  @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
1420
  def forward(
1421
  self,
@@ -1442,10 +1133,10 @@ class QuietForCausalLM(QuietPreTrainedModel):
1442
  Example:
1443
 
1444
  ```python
1445
- >>> from transformers import AutoTokenizer, QuietForCausalLM
1446
 
1447
- >>> model = QuietForCausalLM.from_pretrained("quietai/Quiet-7B-v0.1")
1448
- >>> tokenizer = AutoTokenizer.from_pretrained("quietai/Quiet-7B-v0.1")
1449
 
1450
  >>> prompt = "Hey, are you conscious? Can you talk to me?"
1451
  >>> inputs = tokenizer(prompt, return_tensors="pt")
@@ -1455,16 +1146,6 @@ class QuietForCausalLM(QuietPreTrainedModel):
1455
  >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
1456
  "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
1457
  ```"""
1458
- log_dict = self.log_dict if self.training else self.eval_log_dict
1459
-
1460
- if self.training and self.kill_after is not None and self.training_steps // self.gradient_accumulation_steps > self.kill_after:
1461
- raise ValueError("Killed after")
1462
-
1463
- if not self.training:
1464
- n_ahead_talk_to_restore = self.n_ahead_talk
1465
- n_passes_to_restore = self.n_passes
1466
- self.n_ahead_talk = 1
1467
- self.n_passes = 1
1468
 
1469
  output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1470
  output_hidden_states = (
@@ -1472,730 +1153,48 @@ class QuietForCausalLM(QuietPreTrainedModel):
1472
  )
1473
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1474
 
1475
- assert self.cumulative_residual or self.clever_residual or self.skip_residual or self.no_residual
1476
- assert not (self.skip_residual and self.use_policy_loss)
1477
-
1478
- if self.tokenized_thought_prefix is None and self.use_thought_prefix:
1479
- self.tokenized_thought_prefix = self.tokenizer(self.thought_prefix, return_tensors="pt", add_special_tokens=False)["input_ids"]
 
 
 
 
 
 
 
1480
 
1481
- def apply_head(head, states, detach=False):
1482
- if detach:
1483
- head_weight = head.weight.detach()
1484
- else:
1485
- head_weight = head.weight
1486
- head_weight = head_weight.to(states.device)
1487
- return (head_weight @ states.transpose(-1, -2)).transpose(-1, -2).contiguous()
1488
-
1489
- def idx_if_sequential(head, idx=0):
1490
- if isinstance(head, nn.Sequential) or isinstance(head, nn.ModuleList):
1491
- return idx_if_sequential(head[idx], idx=idx)
1492
- return head
1493
-
1494
- def none_repeat_interleave(x, n):
1495
- if x is None:
1496
- return x
1497
- return x.repeat_interleave(n, dim=0)
1498
-
1499
- if self.n_passes > 1:
1500
- input_ids = none_repeat_interleave(input_ids, self.n_passes)
1501
- attention_mask = none_repeat_interleave(attention_mask, self.n_passes)
1502
- position_ids = none_repeat_interleave(position_ids, self.n_passes)
1503
- inputs_embeds = none_repeat_interleave(inputs_embeds, self.n_passes)
1504
- labels = none_repeat_interleave(labels, self.n_passes)
1505
- if past_key_values is not None:
1506
- past_key_values = [none_repeat_interleave(p, self.n_passes) for p in past_key_values]
1507
- cur_token_indices = torch.arange(input_ids.shape[1], device=input_ids.device)
1508
-
1509
- self.tokenizer_has_start_thought_token = True
1510
- self.tokenizer_has_end_thought_token = True
1511
- if self.start_token_id is None:
1512
- self.start_token_id = self.tokenizer.convert_tokens_to_ids("<|startthought|>")
1513
- if self.start_token_id == 0:
1514
- self.start_token_id = self.tokenizer.bos_token_id
1515
- self.tokenizer_has_start_thought_token = False
1516
- elif self.use_start_thought_token:
1517
- # base_start_id = self.tokenizer.convert_tokens_to_ids(self.initial_start_token)
1518
- base_start_id = self.tokenizer.encode(self.initial_start_token, add_special_tokens=False)[0]
1519
- if self.initialize_thought_embedding_to_normal:
1520
- self.start_embedding.data = torch.zeros_like(self.start_embedding.data)
1521
- else:
1522
- self.start_embedding.data[0] = self.model.embed_tokens.weight.data[base_start_id].clone().detach() / self.embedding_scale
1523
- self.start_embedding.data[1] = torch.log(self.model.embed_tokens.weight.data.std(dim=0) * self.thought_init_std_scale / self.embedding_scale)
1524
- if self.end_token_id is None:
1525
- self.end_token_id = self.tokenizer.convert_tokens_to_ids("<|endthought|>")
1526
- if self.end_token_id == 0:
1527
- self.end_token_id = self.tokenizer.eos_token_id
1528
- self.tokenizer_has_end_thought_token = False
1529
- elif self.use_end_thought_token:
1530
- # base_end_id = self.tokenizer.convert_tokens_to_ids(self.initial_end_token)
1531
- base_end_id = self.tokenizer.encode(self.initial_end_token, add_special_tokens=False)[0]
1532
- if self.initialize_thought_embedding_to_normal:
1533
- self.end_embedding.data = torch.zeros_like(self.end_embedding.data)
1534
- else:
1535
- self.end_embedding.data[0] = self.model.embed_tokens.weight.data[base_end_id].clone().detach() / self.embedding_scale
1536
- self.end_embedding.data[1] = torch.log(self.model.embed_tokens.weight.data.std(dim=0) * self.thought_init_std_scale / self.embedding_scale)
1537
-
1538
- if not self.rm_initialized and (self.n_ahead > 1 or not self.base_original_mode):
1539
- self.rm_initialized = True
1540
- if not self.use_shallow_talk:
1541
- head = self.talk_head[0]
1542
- cur_head = head[-1] if isinstance(head, nn.Sequential) else head
1543
- talk_input_dim = cur_head.weight.data.shape[1]
1544
- talk_output_dim = 1 if self.use_weighted_talk_head else self.lm_head.weight.data.shape[0]
1545
- cur_head.weight.data = torch.zeros(talk_output_dim, talk_input_dim, device=cur_head.weight.device, dtype=cur_head.weight.dtype)
1546
- else:
1547
- # convert to identity transform
1548
- def lambda_transform(cur_head):
1549
- if cur_head.weight.data.shape[0] != cur_head.weight.data.shape[1]:
1550
- return torch.cat([
1551
- torch.eye(
1552
- cur_head.weight.data.shape[0],
1553
- device=cur_head.weight.device,
1554
- dtype=cur_head.weight.dtype
1555
- ),
1556
- torch.zeros(
1557
- cur_head.weight.data.shape[0],
1558
- cur_head.weight.data.shape[1] - cur_head.weight.data.shape[0],
1559
- device=cur_head.weight.device,
1560
- dtype=cur_head.weight.dtype
1561
- )], dim=1)
1562
- return torch.eye(
1563
- cur_head.weight.data.shape[0],
1564
- device=cur_head.weight.device,
1565
- dtype=cur_head.weight.dtype
1566
- )
1567
- if isinstance(self.talk_head[0], nn.Sequential):
1568
- for cur_head in self.talk_head[0]:
1569
- # if it has weights
1570
- if hasattr(cur_head, "weight"):
1571
- cur_head.weight.data = lambda_transform(cur_head)
1572
- else:
1573
- self.talk_head[-1].weight.data = lambda_transform(self.talk_head[0])
1574
 
1575
  loss = None
1576
- prev_rm_tokens = None
1577
- cur_rm_tokens = None
1578
- prev_rm_logits = None
1579
- prev_sample_probs = None
1580
- did_skip_sampling = None
1581
- skip_sampling = None
1582
- sample_probs = None
1583
- hidden_states = None
1584
- logits = None
1585
- talk_kl_penalty = None
1586
- rm_logits = None
1587
- residual_logits = None
1588
- probabilities_2d = None
1589
- prev_probabilities_2d = None
1590
- policy_reward = None
1591
- logits_to_output = None
1592
- batch_size, seq_len = input_ids.shape
1593
- base_input_ids = input_ids.clone()
1594
- loss_list = []
1595
- dqn_loss_list = []
1596
- sampled_token_history = []
1597
- sample_probs_history = []
1598
- action_loglikelihoods_list = []
1599
-
1600
- if self.use_end_thought_token or self.use_start_thought_token:
1601
- if not self.use_reparam_for_thought_embeddings:
1602
- start_embedding = self.start_embedding[0].unsqueeze(0) * self.embedding_scale
1603
- end_embedding = self.end_embedding[0].unsqueeze(0) * self.embedding_scale
1604
- else:
1605
- start_embedding = self.start_embedding * self.embedding_scale
1606
- end_embedding = self.end_embedding * self.embedding_scale
1607
- base_embeddings = self.model.embed_tokens.weight
1608
- if self.train_only_thinking_embedding:
1609
- base_embeddings = base_embeddings.detach()
1610
- # # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
1611
- fwd_iters = 1 if self.original_mode else self.n_ahead + self.n_ahead_talk - 1
1612
- for ahead_idx in range(fwd_iters):
1613
- past_key_values_length = 0
1614
- if past_key_values is not None:
1615
- use_legacy_cache = not isinstance(past_key_values, Cache)
1616
- if use_legacy_cache:
1617
- past_key_values = DynamicCache.from_legacy_cache(past_key_values)
1618
- past_key_values_length = past_key_values.get_usable_length(seq_len)
1619
-
1620
- if position_ids is None:
1621
- device = input_ids.device if input_ids is not None else inputs_embeds.device
1622
- position_ids = torch.arange(
1623
- past_key_values_length, seq_len + past_key_values_length, dtype=torch.long, device=device
1624
- )
1625
- position_ids = position_ids.unsqueeze(0).view(-1, seq_len)
1626
- else:
1627
- position_ids = position_ids.view(-1, seq_len).long()
1628
-
1629
- if inputs_embeds is None:
1630
- contains_start = self.use_start_thought_token and (input_ids == self.start_token_id).any()
1631
- contains_end = self.use_end_thought_token and (input_ids == self.end_token_id).any()
1632
- contains_thought = contains_start or contains_end
1633
- if contains_thought:
1634
- thought_id = self.start_token_id if contains_start else self.end_token_id
1635
- cur_thought_embedding = start_embedding if contains_start else end_embedding
1636
- if self.use_reparam_for_thought_embeddings:
1637
- inputs_embeds = torch.randn(batch_size, seq_len, self.model.config.hidden_size, device=input_ids.device, dtype=cur_thought_embedding.dtype)
1638
- inputs_embeds = inputs_embeds.detach() * torch.exp(cur_thought_embedding[1]) + cur_thought_embedding[0]
1639
- if contains_start:
1640
- sampled_start = inputs_embeds.clone().detach()
1641
- if contains_end:
1642
- sampled_end = inputs_embeds.clone().detach()
1643
- else:
1644
- inputs_embeds = cur_thought_embedding.unsqueeze(0).repeat(batch_size, seq_len, 1)
1645
- else:
1646
- with torch.set_grad_enabled(not self.train_only_thinking_embedding):
1647
- inputs_embeds = self.model.embed_tokens(input_ids)
1648
-
1649
- if self.n_ahead != 1 or self.n_ahead_talk != 1 or self.comparison_mode:
1650
- if attention_mask is None:
1651
- base_attention_mask = torch.triu(torch.ones(seq_len, seq_len), diagonal=0).to(input_ids.device)
1652
- base_attention_mask = base_attention_mask.view(1, 1, seq_len, seq_len)
1653
- base_attention_mask = base_attention_mask.repeat(input_ids.shape[0], 1, 1, 1)
1654
- attention_mask = base_attention_mask
1655
- breakpoint()
1656
- elif attention_mask.dim() == 2:
1657
- if seq_len + past_key_values_length != attention_mask.shape[-1]:
1658
- breakpoint()
1659
- attention_mask = torch.cat(
1660
- [torch.ones((attention_mask.shape[0], past_key_values_length), dtype=attention_mask.dtype, device=attention_mask.device), attention_mask],
1661
- dim=-1
1662
- )
1663
- # # if the attention mask
1664
- attention_mask = _prepare_4d_causal_attention_mask(
1665
- attention_mask,
1666
- (batch_size, seq_len),
1667
- inputs_embeds,
1668
- past_key_values_length,
1669
- sliding_window=self.config.sliding_window,
1670
- )
1671
-
1672
- outputs = self.model(
1673
- # input_ids=input_ids,
1674
- attention_mask=attention_mask,
1675
- position_ids=position_ids,
1676
- past_key_values=past_key_values,
1677
- inputs_embeds=inputs_embeds,
1678
- use_cache=use_cache,
1679
- output_attentions=output_attentions,
1680
- output_hidden_states=output_hidden_states,
1681
- return_dict=return_dict,
1682
- )
1683
-
1684
- prev_hidden_states = hidden_states
1685
- hidden_states = outputs[0]
1686
- prev_rm_logits = rm_logits # for policy gradient
1687
- prev_rm_tokens = cur_rm_tokens # for policy gradient
1688
-
1689
- if ahead_idx == 0:
1690
- hidden_states_lm = hidden_states
1691
- logits = self.lm_head(hidden_states_lm)
1692
- base_hidden_states = hidden_states.clone()
1693
- initial_loss_logits = logits.clone()
1694
- if self.optimize_lm_head_only_at_start or self.optimize_model_only_at_start:
1695
- logits = logits.detach()
1696
- base_hidden_states = base_hidden_states.detach()
1697
- if self.optimize_model_only_at_start:
1698
- hidden_states = hidden_states.detach()
1699
- base_logits = logits.clone()
1700
- else:
1701
- talk_hidden_states = hidden_states
1702
- if self.merged_lm_and_talk_heads:
1703
- assert self.no_residual
1704
- residual_logits = self.lm_head(hidden_states)
1705
- talk_hidden_states = hidden_states
1706
- else:
1707
- if ahead_idx > self.n_ahead - 1:
1708
- cur_base_hidden = torch.cat([
1709
- base_hidden_states[..., ahead_idx - self.n_ahead + 1:, :],
1710
- base_hidden_states[..., :ahead_idx - self.n_ahead + 1, :]
1711
- ], dim=-2)
1712
- else:
1713
- cur_base_hidden = base_hidden_states
1714
-
1715
- if self.use_concat_talk_head:
1716
- # concatenate the hidden states with the original hidden states
1717
- head_input_hidden_states = torch.cat([cur_base_hidden, talk_hidden_states], dim=-1)
1718
- else:
1719
- head_input_hidden_states = talk_hidden_states
1720
-
1721
- residual_logits = self.talk_head[0](head_input_hidden_states)
1722
- if self.use_shallow_talk:
1723
- residual_logits = apply_head(self.lm_head, residual_logits, detach=self.optimize_lm_head_only_at_start)
1724
- residual_logits = residual_logits.to(logits.device)
1725
- if self.use_weighted_talk_head:
1726
- # combine the cur_base_hidden with the talk_hidden_states according to the weighted head
1727
- residual_logits = cur_base_hidden * (1 - residual_logits) + talk_hidden_states * residual_logits
1728
- residual_logits = apply_head(self.lm_head, residual_logits, detach=self.optimize_lm_head_only_at_start)
1729
-
1730
- assert sum([self.cumulative_residual, self.clever_residual, self.skip_residual, self.no_residual]) == 1
1731
- if self.clever_residual:
1732
- if ahead_idx >= self.n_ahead - 1:
1733
- # get the logits shifted according to the current talk ahead
1734
- cur_base_logits = torch.cat([
1735
- base_logits[..., ahead_idx - self.n_ahead + 1:, :],
1736
- base_logits[..., :ahead_idx - self.n_ahead + 1, :]
1737
- ], dim=-2)
1738
- if self.optimize_lm_head_only_at_start:
1739
- cur_base_logits = cur_base_logits.detach()
1740
- logits = cur_base_logits + residual_logits
1741
- else:
1742
- logits += residual_logits / self.n_ahead
1743
- elif self.cumulative_residual:
1744
- if self.residual_talk_head:
1745
- if ahead_idx < self.n_ahead:
1746
- logits += residual_logits
1747
- else:
1748
- # get the logits shifted according to the current talk ahead
1749
- cur_base_logits = torch.cat([
1750
- base_logits[..., ahead_idx - self.n_ahead + 1:, :],
1751
- base_logits[..., :ahead_idx - self.n_ahead + 1, :]
1752
- ], dim=-2)
1753
- if self.optimize_lm_head_only_at_start:
1754
- cur_base_logits = cur_base_logits.detach()
1755
- logits = cur_base_logits + residual_logits
1756
- else:
1757
- if ahead_idx < self.n_ahead:
1758
- logits += residual_logits
1759
- else:
1760
- logits = residual_logits
1761
- elif self.skip_residual:
1762
- if ahead_idx >= self.n_ahead:
1763
- # get the logits shifted according to the current talk ahead
1764
- cur_base_logits = torch.cat([
1765
- base_logits[..., ahead_idx - self.n_ahead + 1:, :],
1766
- base_logits[..., :ahead_idx - self.n_ahead + 1, :]
1767
- ], dim=-2)
1768
- if self.optimize_lm_head_only_at_start:
1769
- cur_base_logits = cur_base_logits.detach()
1770
- logits = cur_base_logits
1771
- elif self.no_residual:
1772
- logits = residual_logits
1773
- else:
1774
- logits = base_logits + residual_logits
1775
-
1776
- attempted = False
1777
- talk_loss_list = []
1778
- if self.original_mode or (self.n_ahead == 1) or (self.comparison_mode and ahead_idx == 0):# or (self.optimize_lm_head_only_at_start and ahead_idx == 0):
1779
- loss = None
1780
- attempted = True
1781
-
1782
- if labels is not None:
1783
- for shift_amount in range(self.n_ahead_talk):
1784
- # Shift so that tokens < n predict n
1785
- # ab[cde]f
1786
- # abc[def]
1787
- if ahead_idx == 0 and self.optimize_lm_head_only_at_start:
1788
- loss_logits = initial_loss_logits
1789
- else:
1790
- loss_logits = logits
1791
- shift_logits = loss_logits[..., shift_amount:-1, :].contiguous()
1792
- shift_labels = labels[..., 1 + shift_amount:].contiguous()
1793
- # Flatten the tokens
1794
- loss_fct = CrossEntropyLoss(reduction="none")
1795
- shift_logits = shift_logits.view(-1, self.config.vocab_size)
1796
- shift_labels = shift_labels.view(-1).clone()
1797
- # Enable model parallelism
1798
- shift_labels[shift_labels == self.tokenizer.pad_token_id] = -100
1799
- shift_labels = shift_labels.to(shift_logits.device)
1800
- loss = loss_fct(shift_logits, shift_labels)
1801
- if not self.comparison_mode and not (self.optimize_lm_head_only_at_start and (self.n_ahead + self.n_ahead_talk > 2)) or self.original_mode:
1802
- loss_list.append(loss)
1803
- talk_loss_list.append(nonzero_mean(loss).detach())
1804
-
1805
- if not attempted or self.comparison_mode:
1806
- rm_hidden_states = hidden_states
1807
- # print("Magnitude of RM hidden states before RM head", rm_hidden_states.norm())
1808
- rm_logits = apply_head(self.lm_head, rm_hidden_states, detach=self.optimize_lm_head_only_at_start)
1809
-
1810
- # don't allow it to predict the thinking token
1811
- if self.tokenizer_has_start_thought_token:
1812
- rm_logits[..., self.start_token_id] = -1e10
1813
- if self.tokenizer_has_end_thought_token:
1814
- rm_logits[..., self.end_token_id] = -1e10
1815
- probabilities = rm_logits
1816
- if probabilities_2d is not None:
1817
- prev_probabilities_2d = probabilities_2d.clone()
1818
- probabilities_2d = probabilities.view(-1, probabilities.size(-1))
1819
-
1820
- did_skip_sampling = skip_sampling
1821
- skip_sampling = False
1822
- if ahead_idx == 0 and self.use_start_thought_token:
1823
- override_token = self.start_token_id
1824
- elif self.use_thought_prefix and ahead_idx < self.tokenized_thought_prefix.shape[-1]:
1825
- override_token = self.tokenized_thought_prefix[..., ahead_idx]
1826
- elif ahead_idx == self.n_ahead - 2 and self.use_end_thought_token:
1827
- override_token = self.end_token_id
1828
- else:
1829
- override_token = None
1830
- if override_token is not None and self.n_ahead > 1:
1831
- # always start with the start token
1832
- probabilities_2d = torch.zeros_like(probabilities_2d)
1833
- probabilities_2d[:, override_token] = 1.0
1834
- skip_sampling = True
1835
- elif ahead_idx >= self.n_ahead - 1:
1836
- if labels is not None: # we're in the talk phase
1837
- cur_talk_n = ahead_idx - (self.n_ahead - 1) + 1
1838
- # print("Setting rm to labels", cur_talk_n, "during", ahead_idx)
1839
- shift_labels = labels[..., cur_talk_n:].contiguous().to(probabilities_2d.device)
1840
- padding = torch.full_like(
1841
- labels[..., :cur_talk_n],
1842
- self.tokenizer.pad_token_id,
1843
- dtype=torch.long,
1844
- device=shift_labels.device
1845
- )
1846
- new_rm_tokens = torch.cat(
1847
- [shift_labels, padding],
1848
- dim=-1
1849
- )
1850
- # convert rm tokens to one-hot
1851
- probabilities_2d = F.one_hot(new_rm_tokens, num_classes=self.vocab_size).reshape(-1, self.vocab_size).to(probabilities_2d.dtype)
1852
- skip_sampling = True
1853
- else:
1854
- continue
1855
- temperature = self.gumbel_temperature if self.training else 0.001
1856
- prev_sample_probs = sample_probs
1857
- sample_probs = probabilities_2d
1858
- if ahead_idx < self.n_ahead - 1 and not skip_sampling:
1859
- probabilities_2d = F.gumbel_softmax(sample_probs, tau=temperature, hard=True, dim=-1)
1860
- if self.gumbel_detach:
1861
- probabilities_2d = probabilities_2d.detach()
1862
- sampled_token_history.append(probabilities_2d.argmax(dim=-1).detach().cpu())
1863
- # convert rm logits directly to embeddings
1864
- contains_start = self.use_start_thought_token and (probabilities_2d[..., self.start_token_id].sum() > 0)
1865
- contains_end = self.use_end_thought_token and (probabilities_2d[..., self.end_token_id].sum() > 0)
1866
- contains_thought = contains_start or contains_end
1867
-
1868
- if not contains_thought:
1869
- with torch.set_grad_enabled(not self.train_only_thinking_embedding):
1870
- inputs_embeds = probabilities_2d @ (self.model.embed_tokens.weight.to(probabilities.device).to(probabilities.dtype))
1871
- else:
1872
- thought_id = self.start_token_id if contains_start else self.end_token_id
1873
- cur_thought_embedding = start_embedding if contains_start else end_embedding
1874
- if self.use_reparam_for_thought_embeddings:
1875
- inputs_embeds = torch.randn(batch_size, seq_len, self.model.config.hidden_size, device=input_ids.device, dtype=cur_thought_embedding.dtype)
1876
- inputs_embeds = inputs_embeds * torch.exp(cur_thought_embedding[1]) + cur_thought_embedding[0]
1877
- if contains_start:
1878
- sampled_start = inputs_embeds.clone().detach()
1879
- else:
1880
- sampled_end = inputs_embeds.clone().detach()
1881
- else:
1882
- inputs_embeds = cur_thought_embedding.unsqueeze(0).repeat(batch_size, seq_len, 1)
1883
- inputs_embeds = inputs_embeds.view(probabilities.size(0), probabilities.size(1), -1).to(self.model.embed_tokens.weight.dtype)
1884
- inputs_embeds = inputs_embeds.view(probabilities.size(0), probabilities.size(1), -1).to(self.model.embed_tokens.weight.dtype)
1885
-
1886
- if len(attention_mask.shape) == 2:
1887
- breakpoint()
1888
- else:
1889
- original_attention = attention_mask[..., :attention_mask.shape[-2]]
1890
- if self.use_upper_triangular:
1891
- new_attention = original_attention
1892
- else:
1893
- original_attention = original_attention == attention_mask.max()
1894
- # because eye isn't implemented for BF16, we need to handle the case
1895
- if not attention_mask.dtype == torch.bfloat16:
1896
- new_attention = torch.eye(
1897
- seq_len, dtype=attention_mask.dtype, device=attention_mask.device
1898
- )
1899
- else:
1900
- new_attention = torch.eye(
1901
- seq_len, dtype=torch.float32, device=attention_mask.device
1902
- ).to(attention_mask.dtype)
1903
-
1904
- new_attention = new_attention.view(1, 1, seq_len, seq_len).repeat(input_ids.shape[0], 1, 1, 1)
1905
- new_attention = new_attention * original_attention
1906
- new_attention[new_attention == 0] = attention_mask.min()
1907
- new_attention[new_attention == 1] = attention_mask.max()
1908
- attention_mask = torch.cat([attention_mask, new_attention], dim=-1)
1909
- past_key_values = outputs.past_key_values
1910
- position_ids = position_ids + 1
1911
-
1912
- if labels is not None and (self.n_ahead > 1 or not self.base_original_mode):
1913
- # Shift so that tokens < n predict n
1914
- # logits: abcdef -> bcdef? -> cdef??
1915
- # labels: abcdef -> ?bcdef -> ??cdef
1916
- if ahead_idx == 0 and self.optimize_lm_head_only_at_start:
1917
- loss_logits = initial_loss_logits
1918
- else:
1919
- loss_logits = logits
1920
- shift_idx = 1 + max(0, ahead_idx - (self.n_ahead - 1))
1921
- shift_logits = loss_logits[..., :-shift_idx, :].contiguous()
1922
- shift_labels = labels[..., shift_idx:].contiguous()
1923
- # Flatten the tokens
1924
- loss_fct = CrossEntropyLoss(reduction="none")
1925
- shift_logits = shift_logits.view(-1, self.config.vocab_size)
1926
- shift_labels = shift_labels.view(-1)
1927
- # Enable model parallelism
1928
- shift_labels = shift_labels.to(shift_logits.device)
1929
- # if shift_labels.min() == self.tokenizer.pad_token_id:
1930
- shift_labels = torch.where(shift_labels == self.tokenizer.pad_token_id, -100, shift_labels)
1931
- unreduced_loss = loss_fct(shift_logits, shift_labels)
1932
- if torch.any(unreduced_loss != unreduced_loss):
1933
- raise ValueError("NaN loss")
1934
- unreduced_loss = unreduced_loss.reshape(logits.shape[0], -1)
1935
- loss_list.append(unreduced_loss)
1936
-
1937
-
1938
- if self.use_policy_loss and ahead_idx > 0 and (ahead_idx > 1 or not self.use_start_thought_token):
1939
- # we treat the change in loss as the reward
1940
- previous_loss = loss_list[-2]
1941
- # for example, suppose n_ahead = 3 and n_ahead_talk = 2
1942
- # note that we end at self.n_ahead + self.n_ahead_talk - 2
1943
- # in this case, 5 - 2 = 3, so we end at ahead_idx = 3
1944
- # we also predict the next token at ahead_idx = 2
1945
- # when we get to ahead_idx = 2, we predict ahead
1946
- # so we shift by 1
1947
- # note that this is ahead_idx = n_ahead - 1
1948
- # when we get to ahead_idx = 3, we predict ahead
1949
- # so we shift by 2
1950
- # note that this is ahead_idx = n_ahead
1951
- if ahead_idx < self.n_ahead - 1:
1952
- shift_amount = 0
1953
- original_dqn_reward = (previous_loss - unreduced_loss).detach()
1954
- if self.first_and_last_mode:
1955
- original_dqn_reward = original_dqn_reward * 0.0
1956
- else:
1957
- # logits vs cur_policy_shift_logits
1958
- # let's look at rm_logits and prev_rm_logits
1959
- shift_amount = max(0, ahead_idx - (self.n_ahead - 1))
1960
- # let's say shift_amount = 2
1961
- # abcdefg -> bcdefg? -> cdefg??
1962
- # logits = [a b]c d e f[g]
1963
- # labels = [a b c]d e f g
1964
- cur_policy_shift_logits = initial_loss_logits[..., shift_amount:-1, :].contiguous().detach()
1965
- cur_policy_shift_labels = labels[..., 1 + shift_amount:].contiguous()
1966
- # Flatten the tokens
1967
- cur_policy_loss_fct = CrossEntropyLoss(reduction="none")
1968
- cur_policy_shift_logits = cur_policy_shift_logits.view(-1, self.config.vocab_size)
1969
- cur_policy_shift_labels = cur_policy_shift_labels.view(-1).clone()
1970
- # Enable model parallelism
1971
- cur_policy_shift_labels[cur_policy_shift_labels == self.tokenizer.pad_token_id] = -100
1972
- cur_policy_shift_labels = cur_policy_shift_labels.to(cur_policy_shift_labels.device)
1973
- cur_policy_reward_base_loss = loss_fct(
1974
- cur_policy_shift_logits, cur_policy_shift_labels.to(cur_policy_shift_logits.device)
1975
- ).reshape(logits.shape[0], -1)
1976
- original_dqn_reward = cur_policy_reward_base_loss.detach() - unreduced_loss
1977
-
1978
- if not did_skip_sampling:
1979
- nonzero_indices = prev_probabilities_2d.nonzero()
1980
- action_loglikelihoods = F.log_softmax(prev_sample_probs / self.reinforce_temperature, dim=-1)[nonzero_indices[:, 0], nonzero_indices[:, 1]]
1981
- action_loglikelihoods_2d = action_loglikelihoods.reshape(batch_size, -1)[:, :-1 - shift_amount]
1982
- action_loglikelihoods_list.append(action_loglikelihoods_2d)
1983
- if policy_reward is None:
1984
- policy_reward = original_dqn_reward[:, :-(self.n_ahead_talk - shift_amount)]
1985
- else:
1986
- if self.n_ahead_talk > shift_amount:
1987
- added_reward = original_dqn_reward[:, :-(self.n_ahead_talk - shift_amount)]
1988
- else:
1989
- added_reward = original_dqn_reward
1990
- policy_reward += added_reward
1991
-
1992
- if self.use_policy_loss and ahead_idx == self.n_ahead + self.n_ahead_talk - 2:
1993
- # only compute during the thinking phase
1994
- if self.use_reparam_for_thought_embeddings and (self.use_start_thought_token or self.use_end_thought_token):
1995
- # sampled_start, sampled_end
1996
- # calculate the log likelihood of the start and end embeddings sampled from a multivariate normal distribution
1997
- # with mean start_embedding[0] and standard deviation start_embedding[1]
1998
- if self.use_start_thought_token:
1999
- exp_start_std = torch.exp(start_embedding[1])
2000
- start_loglikelihood = -0.5 * (sampled_start.detach() - start_embedding[0]) ** 2 / exp_start_std ** 2 - start_embedding[1] - 0.5 * math.log(2 * math.pi)
2001
- start_loglikelihood = start_loglikelihood.mean(dim=-1)
2002
- if self.use_end_thought_token:
2003
- exp_end_std = torch.exp(end_embedding[1])
2004
- end_loglikelihood = -0.5 * (sampled_end.detach() - end_embedding[0]) ** 2 / exp_end_std ** 2 - end_embedding[1] - 0.5 * math.log(2 * math.pi)
2005
- end_loglikelihood = end_loglikelihood.mean(dim=-1)
2006
- # we use the mean instead of the sum to prevent dependence on the dimensionality of the embeddings
2007
- if self.use_end_thought_token and self.use_policy_loss_for_end_thought:
2008
- action_loglikelihoods_list.append(end_loglikelihood)
2009
- if self.use_start_thought_token:
2010
- action_loglikelihoods_list.append(start_loglikelihood)
2011
-
2012
- if ahead_idx == self.n_ahead + self.n_ahead_talk - 2 and self.eval_mode:
2013
- with torch.no_grad():
2014
- # calculate the 0.75 quantile of the rewards
2015
- filtered_tokens = input_ids[:, :policy_reward.shape[-1]].cpu().detach().numpy().flatten()
2016
- filtered_tokens_mask = filtered_tokens != self.tokenizer.pad_token_id
2017
- filtered_tokens = filtered_tokens[filtered_tokens_mask]
2018
- filtered_rewards = policy_reward.float().cpu().detach().numpy()[:, :seq_len - self.n_ahead_talk].flatten()
2019
- filtered_rewards = filtered_rewards[filtered_tokens_mask]
2020
-
2021
- abs_reward_list = np.abs(policy_reward.float().cpu().detach().numpy()[:, :seq_len - self.n_ahead_talk].flatten())
2022
- abs_reward_list = abs_reward_list[filtered_tokens_mask]
2023
- medium_quantile = np.quantile(abs_reward_list, 0.5)
2024
- upper_quantile = np.quantile(abs_reward_list, 0.95)
2025
-
2026
- save_tokens_with_rewards_to_pdf(
2027
- filtered_tokens,
2028
- [0] + filtered_rewards.tolist(),
2029
- self.tokenizer,
2030
- output_file=f"texts/rewards_talk_{self.n_ahead_talk}_{self.training_steps}.pdf",
2031
- eps=medium_quantile,
2032
- eps2=upper_quantile,
2033
- )
2034
-
2035
- def plot_kde(data, losses):
2036
- sns.set(style="whitegrid")
2037
- # Create the KDE plot
2038
- sns.kdeplot(data, fill=True)
2039
- # Set the plot title and labels
2040
- plt.title("KDE Plot")
2041
- plt.xlabel("Value")
2042
- plt.ylabel("Density")
2043
- # Save the plot
2044
- plt.savefig(f"texts/kde_talk_{self.n_ahead_talk}_{self.training_steps}.pdf")
2045
- # Close the plot
2046
- plt.close()
2047
-
2048
- # Step 1: Create a base color palette
2049
- base_colors = sns.color_palette("light:#5A9", n_colors=256) # More colors for a smoother gradient
2050
- base_cmap = LinearSegmentedColormap.from_list("log_light", base_colors)
2051
- log_norm = LogNorm(vmin=1e-3, vmax=10)
2052
-
2053
- sns.kdeplot(x=data, y=losses, fill=True, levels=20, norm=log_norm, cut=0, linewidths=0)
2054
- # limit y to 0 to 25 and x to -1 to 1
2055
- plt.xlim(-1, 1)
2056
- plt.ylim(0, 25)
2057
- plt.savefig(f"texts/jointer_talk_{self.n_ahead_talk}_{self.training_steps}.pdf")
2058
- plt.close()
2059
-
2060
- self.all_rewards.extend(filtered_rewards)
2061
- self.all_unreduced_losses.extend(unreduced_loss[:, :-1].flatten()[filtered_tokens_mask].float().flatten().cpu().detach().numpy())
2062
- plot_kde(self.all_rewards, self.all_unreduced_losses)
2063
-
2064
- for action_loglikelihoods_2d in action_loglikelihoods_list:
2065
- train_policy_reward = policy_reward
2066
-
2067
- # discard rewards below the mean
2068
- if self.trice_mode and self.n_passes > 1:
2069
- batched_policy_reward = train_policy_reward.reshape(-1, self.n_passes, train_policy_reward.shape[-1])
2070
- # average over the passes
2071
- train_policy_reward = batched_policy_reward - batched_policy_reward.mean(dim=1, keepdim=True)
2072
- train_policy_reward = train_policy_reward.reshape(-1, train_policy_reward.shape[-1])
2073
-
2074
- if self.subtract_mean_reward:
2075
- train_policy_reward = train_policy_reward - train_policy_reward.mean()
2076
- if self.remove_negative_rewards:
2077
- fixed_policy_reward = train_policy_reward.detach().clamp(min=0)
2078
- else:
2079
- fixed_policy_reward = train_policy_reward.detach()
2080
- actor_loss = -fixed_policy_reward * action_loglikelihoods_2d[:, :policy_reward.shape[-1]].to(policy_reward.device)
2081
- if action_loglikelihoods_2d.mean() < -1e4 and not self.use_policy_loss_just_for_thoughts:
2082
- # This will only happen when we force the next token to be the end of thought token
2083
- break
2084
- dqn_loss_list.append(actor_loss.mean())
2085
-
2086
- if loss_list:
2087
- if self.first_and_last_mode:
2088
- loss = sum(
2089
- self.loss_mean(loss_list[-(i + 1)]) for i in range(self.n_ahead_talk)
2090
- ) * (1 - self.original_loss_weight) / self.n_ahead_talk
2091
- loss = loss + self.loss_mean(loss_list[0]) * self.original_loss_weight
2092
- # Let's NaN out the others
2093
- # e.g. if n_ahead_talk = 2 and the list is 5 long, we want to NaN out 1, 2 but keep 0, 3, 4
2094
- for i in range(1, len(loss_list) - self.n_ahead_talk):
2095
- loss_list[i] = loss_list[i] * math.nan
2096
- elif self.first_only:
2097
- loss = self.loss_mean(loss_list[0])
2098
- elif self.final_only_mode:
2099
- loss = sum(
2100
- self.loss_mean(loss_list[-i]) for i in range(1, self.n_ahead_talk + 1)
2101
- ) / self.n_ahead_talk
2102
- else:
2103
- loss = None
2104
- for i in range(len(loss_list)):
2105
- cur_loss = self.loss_mean(loss_list[i])
2106
- if loss is not None:
2107
- loss = loss + cur_loss.to(loss.device)
2108
- else:
2109
- loss = cur_loss
2110
- loss = loss / len(loss_list)
2111
-
2112
- loss = loss * self.base_loss_beta
2113
-
2114
- if dqn_loss_list:
2115
- dqn_loss = sum(dqn_loss_list) / len(dqn_loss_list)
2116
- if self.include_policy_loss:
2117
- if loss is not None:
2118
- loss += dqn_loss * self.policy_loss_beta
2119
- else:
2120
- loss = dqn_loss * self.policy_loss_beta
2121
 
2122
  if not return_dict:
2123
  output = (logits,) + outputs[1:]
2124
  return (loss,) + output if loss is not None else output
2125
-
2126
- base_log_dict = {
2127
- f"loss_{i}": nonzero_mean(loss_list[i]) for i in range(len(loss_list))
2128
- }
2129
-
2130
- if loss is not None:
2131
- base_log_dict["loss_train"] = loss.item()
2132
-
2133
- for loss_key, loss_val in base_log_dict.items():
2134
- log_dict[loss_key] += loss_val / self.n_tokens_print
2135
-
2136
- if self.use_policy_loss and policy_reward is not None:
2137
- log_dict["policy_loss"] += dqn_loss / self.n_tokens_print
2138
- log_dict["policy_reward"] += policy_reward.mean() / self.n_tokens_print
2139
-
2140
- if not loss_list:
2141
- if loss is not None:
2142
- log_dict["loss_0"] += loss / self.n_tokens_print
2143
- else:
2144
- log_dict["loss_final"] += nonzero_mean(loss_list[-1]) / self.n_tokens_print
2145
- log_dict["loss_talk"] += sum(nonzero_mean(cur_loss_item) for cur_loss_item in loss_list[-self.n_ahead_talk:]) / self.n_ahead_talk / self.n_tokens_print
2146
-
2147
- # also log relative losses to loss_0
2148
- if loss_list:
2149
- for i in range(len(loss_list)):
2150
- talk_idx = min(max(i - (self.n_ahead - 1), 0), len(talk_loss_list) - 1)
2151
- if not talk_loss_list:
2152
- cur_talk_loss = nonzero_mean(loss_list[0])
2153
- else:
2154
- cur_talk_loss = talk_loss_list[talk_idx]
2155
- log_dict[f"rel_loss_{i}"] += (nonzero_mean(loss_list[i]) - cur_talk_loss) / self.n_tokens_print
2156
- if self.training:
2157
- self.training_steps += 1
2158
- try:
2159
- # if self.training_steps % (self.gradient_accumulation_steps * 256) == 0:
2160
- if self.wandb_enabled:
2161
- if self.training_steps % (self.n_tokens_print) == 0 or not self.training:# and "0" in str(loss.device):
2162
- if not self.training:
2163
- new_log_dict = {}
2164
- for key in list(log_dict.keys()):
2165
- new_log_dict["eval_" + key] = log_dict[key]
2166
- log_dict = new_log_dict
2167
- log_dict["training_steps"] = self.training_steps
2168
- log_dict["batch_size"] = batch_size
2169
- log_dict["example_steps"] = self.training_steps * batch_size * self.gradient_accumulation_steps
2170
- if self.n_ahead > 1:
2171
- log_dict["compute_steps"] = self.training_steps * batch_size * (self.n_ahead + self.n_ahead_talk - 1) * self.gradient_accumulation_steps
2172
- else: # There's no overhead for talk tokens if there's no thinking
2173
- log_dict["compute_steps"] = self.training_steps * batch_size * self.gradient_accumulation_steps
2174
- # remove all nans
2175
- for key in list(log_dict.keys()):
2176
- if log_dict[key] != log_dict[key]:
2177
- del log_dict[key]
2178
- if self.training:
2179
- wandb.log(log_dict)
2180
- if self.training:
2181
- self.log_dict = defaultdict(int)
2182
- else:
2183
- self.eval_log_dict = defaultdict(int)
2184
- except Exception as e:
2185
- pass
2186
-
2187
- if not self.training:
2188
- self.n_ahead_talk = n_ahead_talk_to_restore
2189
- self.n_passes = n_passes_to_restore
2190
  return CausalLMOutputWithPast(
2191
- loss=loss if loss is not None else None,
2192
- logits=(rm_logits if self.n_ahead > 1 else logits) if not self.output_logits_at_the_end else logits,
2193
  past_key_values=outputs.past_key_values,
2194
  hidden_states=outputs.hidden_states,
2195
  attentions=outputs.attentions,
2196
  )
2197
 
2198
-
2199
  def prepare_inputs_for_generation(
2200
  self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
2201
  ):
@@ -2211,7 +1210,7 @@ class QuietForCausalLM(QuietPreTrainedModel):
2211
 
2212
  # Keep only the unprocessed tokens:
2213
  # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
2214
- # some of the inputs are exclusively passed as part of the cache (e.g. when passing inputs_embeds as
2215
  # input)
2216
  if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]:
2217
  input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
@@ -2265,9 +1264,9 @@ class QuietForCausalLM(QuietPreTrainedModel):
2265
 
2266
  @add_start_docstrings(
2267
  """
2268
- The Quiet Model transformer with a sequence classification head on top (linear layer).
2269
 
2270
- [`QuietForSequenceClassification`] uses the last token in order to do the classification, as other causal models
2271
  (e.g. GPT-2) do.
2272
 
2273
  Since it does classification on the last token, it requires to know the position of the last token. If a
@@ -2276,14 +1275,14 @@ class QuietForCausalLM(QuietPreTrainedModel):
2276
  padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in
2277
  each row of the batch).
2278
  """,
2279
- QUIET_START_DOCSTRING,
2280
  )
2281
- # Copied from transformers.models.llama.modeling_llama.LlamaForSequenceClassification with Llama->Quiet, LLAMA->QUIET
2282
- class QuietForSequenceClassification(QuietPreTrainedModel):
2283
  def __init__(self, config):
2284
  super().__init__(config)
2285
  self.num_labels = config.num_labels
2286
- self.model = QuietModel(config)
2287
  self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)
2288
 
2289
  # Initialize weights and apply final processing
@@ -2295,7 +1294,7 @@ class QuietForSequenceClassification(QuietPreTrainedModel):
2295
  def set_input_embeddings(self, value):
2296
  self.model.embed_tokens = value
2297
 
2298
- @add_start_docstrings_to_model_forward(QUIET_INPUTS_DOCSTRING)
2299
  def forward(
2300
  self,
2301
  input_ids: torch.LongTensor = None,
 
1
  # coding=utf-8
2
+ # Copyright 2023 Mistral AI and the HuggingFace Inc. team. All rights reserved.
3
  #
4
  # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
5
  # and OPT implementations in this library. It has been modified from its
 
17
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
18
  # See the License for the specific language governing permissions and
19
  # limitations under the License.
20
+ """ PyTorch Mistral model."""
21
  import inspect
22
  import math
 
 
 
 
 
 
 
 
 
 
 
 
23
  import warnings
 
24
  from typing import List, Optional, Tuple, Union
25
 
26
  import torch
 
29
  from torch import nn
30
  from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
31
 
32
+ from ...activations import ACT2FN
33
+ from ...cache_utils import Cache, DynamicCache
34
+ from ...modeling_attn_mask_utils import _prepare_4d_causal_attention_mask, _prepare_4d_causal_attention_mask_for_sdpa
35
+ from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast
36
+ from ...modeling_utils import PreTrainedModel
37
+ from ...utils import (
38
  add_start_docstrings,
39
  add_start_docstrings_to_model_forward,
40
  is_flash_attn_2_available,
 
42
  logging,
43
  replace_return_docstrings,
44
  )
45
+ from .configuration_mistral import MistralConfig
46
 
47
 
48
  if is_flash_attn_2_available():
 
54
 
55
  logger = logging.get_logger(__name__)
56
 
57
+ _CONFIG_FOR_DOC = "MistralConfig"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58
 
59
 
60
  # Copied from transformers.models.llama.modeling_llama._get_unpad_data
 
62
  seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
63
  indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
64
  max_seqlen_in_batch = seqlens_in_batch.max().item()
65
+ cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
66
  return (
67
  indices,
68
  cu_seqlens,
 
70
  )
71
 
72
 
73
+ # Copied from transformers.models.llama.modeling_llama.LlamaRMSNorm with Llama->Mistral
74
+ class MistralRMSNorm(nn.Module):
75
  def __init__(self, hidden_size, eps=1e-6):
76
  """
77
+ MistralRMSNorm is equivalent to T5LayerNorm
78
  """
79
  super().__init__()
80
  self.weight = nn.Parameter(torch.ones(hidden_size))
 
85
  hidden_states = hidden_states.to(torch.float32)
86
  variance = hidden_states.pow(2).mean(-1, keepdim=True)
87
  hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
88
+ return self.weight * hidden_states.to(input_dtype)
89
 
90
 
91
+ # copied from transformers.models.llama.modeling_llama.LlamaRotaryEmbedding with Llama->Mistral
92
+ # TODO @Arthur no longer copied from LLama after static cache
93
+ class MistralRotaryEmbedding(nn.Module):
94
  def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
95
  super().__init__()
96
 
97
  self.dim = dim
98
  self.max_position_embeddings = max_position_embeddings
99
  self.base = base
100
+ inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim))
101
  self.register_buffer("inv_freq", inv_freq, persistent=False)
102
 
103
  # Build here to make `torch.jit.trace` work.
 
107
 
108
  def _set_cos_sin_cache(self, seq_len, device, dtype):
109
  self.max_seq_len_cached = seq_len
110
+ t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.int64).type_as(self.inv_freq)
111
 
112
  freqs = torch.outer(t, self.inv_freq)
113
  # Different from paper, but it uses a different permutation in order to obtain the same calculation
 
134
  return torch.cat((-x2, x1), dim=-1)
135
 
136
 
137
+ # copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb
138
+ # TODO @Arthur no longer copied from LLama after static cache
139
  def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
140
  """Applies Rotary Position Embedding to the query and key tensors.
141
 
 
164
  return q_embed, k_embed
165
 
166
 
167
+ class MistralMLP(nn.Module):
168
  def __init__(self, config):
169
  super().__init__()
170
  self.config = config
 
192
  return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
193
 
194
 
195
+ class MistralAttention(nn.Module):
196
  """
197
  Multi-headed attention from 'Attention Is All You Need' paper. Modified to use sliding window attention: Longformer
198
  and "Generating Long Sequences with Sparse Transformers".
199
  """
200
 
201
+ def __init__(self, config: MistralConfig, layer_idx: Optional[int] = None):
202
  super().__init__()
203
  self.config = config
204
  self.layer_idx = layer_idx
205
  if layer_idx is None:
206
  logger.warning_once(
207
+ f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will "
208
+ "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` "
209
  "when creating this class."
210
  )
211
 
 
229
  self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
230
  self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
231
 
232
+ self.rotary_emb = MistralRotaryEmbedding(
233
  self.head_dim,
234
  max_position_embeddings=self.max_position_embeddings,
235
  base=self.rope_theta,
 
320
  return attn_output, attn_weights, past_key_value
321
 
322
 
323
+ class MistralFlashAttention2(MistralAttention):
324
  """
325
+ Mistral flash attention module. This module inherits from `MistralAttention` as the weights of the module stays
326
  untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
327
  flash attention and deal with padding tokens in case the input contains any of them.
328
  """
 
496
  attention_mask (`torch.Tensor`):
497
  The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the
498
  position of padding tokens and 1 for the position of non-padding tokens.
499
+ dropout (`float`):
500
  Attention dropout
501
  softmax_scale (`float`, *optional*):
502
  The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
 
614
  )
615
 
616
 
617
+ # copied from transformers.models.llama.modeling_llama.LlamaSdpaAttention with Llama->Mistral
618
+ # TODO @Arthur no longer copied from LLama after static cache
619
+ class MistralSdpaAttention(MistralAttention):
620
  """
621
+ Mistral attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
622
+ `MistralAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to
623
  SDPA API.
624
  """
625
 
626
+ # Adapted from MistralAttention.forward
627
  def forward(
628
  self,
629
  hidden_states: torch.Tensor,
 
636
  if output_attentions:
637
  # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
638
  logger.warning_once(
639
+ "MistralModel is using MistralSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, "
640
  'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
641
  )
642
  return super().forward(
 
689
  query_states,
690
  key_states,
691
  value_states,
692
+ attn_mask=attention_mask,
693
  dropout_p=self.attention_dropout if self.training else 0.0,
694
  # The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1.
695
  is_causal=self.is_causal and attention_mask is None and q_len > 1,
696
  )
697
 
698
  attn_output = attn_output.transpose(1, 2).contiguous()
699
+ attn_output = attn_output.view(bsz, q_len, self.hidden_size)
700
 
701
  attn_output = self.o_proj(attn_output)
702
 
703
  return attn_output, None, past_key_value
704
 
705
 
706
+ MISTRAL_ATTENTION_CLASSES = {
707
+ "eager": MistralAttention,
708
+ "flash_attention_2": MistralFlashAttention2,
709
+ "sdpa": MistralSdpaAttention,
710
  }
711
 
712
 
713
+ class MistralDecoderLayer(nn.Module):
714
+ def __init__(self, config: MistralConfig, layer_idx: int):
715
  super().__init__()
716
  self.hidden_size = config.hidden_size
717
 
718
+ self.self_attn = MISTRAL_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx)
719
 
720
+ self.mlp = MistralMLP(config)
721
+ self.input_layernorm = MistralRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
722
+ self.post_attention_layernorm = MistralRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
723
 
724
  def forward(
725
  self,
 
762
  output_attentions=output_attentions,
763
  use_cache=use_cache,
764
  )
765
+ hidden_states = residual + hidden_states
766
 
767
  # Fully Connected
768
  residual = hidden_states
 
781
  return outputs
782
 
783
 
784
+ MISTRAL_START_DOCSTRING = r"""
785
  This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
786
  library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
787
  etc.)
 
791
  and behavior.
792
 
793
  Parameters:
794
+ config ([`MistralConfig`]):
795
  Model configuration class with all the parameters of the model. Initializing with a config file does not
796
  load the weights associated with the model, only the configuration. Check out the
797
  [`~PreTrainedModel.from_pretrained`] method to load the model weights.
 
799
 
800
 
801
  @add_start_docstrings(
802
+ "The bare Mistral Model outputting raw hidden-states without any specific head on top.",
803
+ MISTRAL_START_DOCSTRING,
804
  )
805
+ class MistralPreTrainedModel(PreTrainedModel):
806
+ config_class = MistralConfig
807
  base_model_prefix = "model"
808
  supports_gradient_checkpointing = True
809
+ _no_split_modules = ["MistralDecoderLayer"]
810
  _skip_keys_device_placement = "past_key_values"
811
  _supports_flash_attn_2 = True
812
  _supports_sdpa = True
 
824
  module.weight.data[module.padding_idx].zero_()
825
 
826
 
827
+ MISTRAL_INPUTS_DOCSTRING = r"""
828
  Args:
829
  input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
830
  Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
 
895
 
896
 
897
  @add_start_docstrings(
898
+ "The bare Mistral Model outputting raw hidden-states without any specific head on top.",
899
+ MISTRAL_START_DOCSTRING,
900
  )
901
+ class MistralModel(MistralPreTrainedModel):
902
  """
903
+ Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`MistralDecoderLayer`]
904
 
905
  Args:
906
+ config: MistralConfig
907
  """
908
 
909
+ def __init__(self, config: MistralConfig):
910
  super().__init__(config)
911
  self.padding_idx = config.pad_token_id
912
  self.vocab_size = config.vocab_size
913
 
914
  self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
915
  self.layers = nn.ModuleList(
916
+ [MistralDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
917
  )
918
  self._attn_implementation = config._attn_implementation
919
+ self.norm = MistralRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
920
 
921
  self.gradient_checkpointing = False
922
  # Initialize weights and apply final processing
 
928
  def set_input_embeddings(self, value):
929
  self.embed_tokens = value
930
 
931
+ @add_start_docstrings_to_model_forward(MISTRAL_INPUTS_DOCSTRING)
932
  def forward(
933
  self,
934
  input_ids: torch.LongTensor = None,
 
991
  if is_padding_right:
992
  raise ValueError(
993
  "You are attempting to perform batched generation with padding_side='right'"
994
+ " this may lead to unexpected behaviour for Flash Attention version of Mistral. Make sure to "
995
  " call `tokenizer.padding_side = 'left'` before tokenizing the input. "
996
  )
997
 
998
  if self._attn_implementation == "flash_attention_2":
999
  # 2d mask is passed through the layers
1000
  attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
1001
+ elif self._attn_implementation == "sdpa" and not output_attentions:
1002
  # output_attentions=True can not be supported when using SDPA, and we fall back on
1003
  # the manual implementation that requires a 4D causal mask in all cases.
1004
  attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
 
1007
  inputs_embeds,
1008
  past_key_values_length,
1009
  )
1010
+ else:
1011
  # 4d mask is passed through the layers
1012
  attention_mask = _prepare_4d_causal_attention_mask(
1013
  attention_mask,
 
1075
  attentions=all_self_attns,
1076
  )
1077
 
 
 
 
 
 
 
 
1078
 
1079
+ class MistralForCausalLM(MistralPreTrainedModel):
1080
  _tied_weights_keys = ["lm_head.weight"]
1081
 
1082
  def __init__(self, config):
1083
  super().__init__(config)
1084
+ self.model = MistralModel(config)
1085
  self.vocab_size = config.vocab_size
1086
  self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1087
 
1088
  # Initialize weights and apply final processing
1089
  self.post_init()
 
1106
  def get_decoder(self):
1107
  return self.model
1108
 
1109
+ @add_start_docstrings_to_model_forward(MISTRAL_INPUTS_DOCSTRING)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1110
  @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
1111
  def forward(
1112
  self,
 
1133
  Example:
1134
 
1135
  ```python
1136
+ >>> from transformers import AutoTokenizer, MistralForCausalLM
1137
 
1138
+ >>> model = MistralForCausalLM.from_pretrained("mistralai/Mistral-7B-v0.1")
1139
+ >>> tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-v0.1")
1140
 
1141
  >>> prompt = "Hey, are you conscious? Can you talk to me?"
1142
  >>> inputs = tokenizer(prompt, return_tensors="pt")
 
1146
  >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
1147
  "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
1148
  ```"""
 
 
 
 
 
 
 
 
 
 
1149
 
1150
  output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1151
  output_hidden_states = (
 
1153
  )
1154
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1155
 
1156
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
1157
+ outputs = self.model(
1158
+ input_ids=input_ids,
1159
+ attention_mask=attention_mask,
1160
+ position_ids=position_ids,
1161
+ past_key_values=past_key_values,
1162
+ inputs_embeds=inputs_embeds,
1163
+ use_cache=use_cache,
1164
+ output_attentions=output_attentions,
1165
+ output_hidden_states=output_hidden_states,
1166
+ return_dict=return_dict,
1167
+ )
1168
 
1169
+ hidden_states = outputs[0]
1170
+ logits = self.lm_head(hidden_states)
1171
+ logits = logits.float()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1172
 
1173
  loss = None
1174
+ if labels is not None:
1175
+ # Shift so that tokens < n predict n
1176
+ shift_logits = logits[..., :-1, :].contiguous()
1177
+ shift_labels = labels[..., 1:].contiguous()
1178
+ # Flatten the tokens
1179
+ shift_logits = shift_logits.view(-1, self.config.vocab_size)
1180
+ shift_labels = shift_labels.view(-1)
1181
+ # Ensure tensors are on the same device
1182
+ shift_labels = shift_labels.to(shift_logits.device)
1183
+ loss_fct = CrossEntropyLoss()
1184
+ loss = loss_fct(shift_logits, shift_labels)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1185
 
1186
  if not return_dict:
1187
  output = (logits,) + outputs[1:]
1188
  return (loss,) + output if loss is not None else output
1189
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1190
  return CausalLMOutputWithPast(
1191
+ loss=loss,
1192
+ logits=logits,
1193
  past_key_values=outputs.past_key_values,
1194
  hidden_states=outputs.hidden_states,
1195
  attentions=outputs.attentions,
1196
  )
1197
 
 
1198
  def prepare_inputs_for_generation(
1199
  self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
1200
  ):
 
1210
 
1211
  # Keep only the unprocessed tokens:
1212
  # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
1213
+ # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as
1214
  # input)
1215
  if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]:
1216
  input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
 
1264
 
1265
  @add_start_docstrings(
1266
  """
1267
+ The Mistral Model transformer with a sequence classification head on top (linear layer).
1268
 
1269
+ [`MistralForSequenceClassification`] uses the last token in order to do the classification, as other causal models
1270
  (e.g. GPT-2) do.
1271
 
1272
  Since it does classification on the last token, it requires to know the position of the last token. If a
 
1275
  padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in
1276
  each row of the batch).
1277
  """,
1278
+ MISTRAL_START_DOCSTRING,
1279
  )
1280
+ # Copied from transformers.models.llama.modeling_llama.LlamaForSequenceClassification with Llama->Mistral, LLAMA->MISTRAL
1281
+ class MistralForSequenceClassification(MistralPreTrainedModel):
1282
  def __init__(self, config):
1283
  super().__init__(config)
1284
  self.num_labels = config.num_labels
1285
+ self.model = MistralModel(config)
1286
  self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)
1287
 
1288
  # Initialize weights and apply final processing
 
1294
  def set_input_embeddings(self, value):
1295
  self.model.embed_tokens = value
1296
 
1297
+ @add_start_docstrings_to_model_forward(MISTRAL_INPUTS_DOCSTRING)
1298
  def forward(
1299
  self,
1300
  input_ids: torch.LongTensor = None,