Question Answering
Transformers
Safetensors
English
doge
text-generation
custom_code
JingzeShi commited on
Commit
a84e78f
verified
1 Parent(s): 25b547e

Upload DogeForCausalLM

Browse files
Files changed (5) hide show
  1. config.json +44 -37
  2. configuration_doge.py +58 -35
  3. generation_config.json +7 -7
  4. model.safetensors +2 -2
  5. modeling_doge.py +327 -199
config.json CHANGED
@@ -1,37 +1,44 @@
1
- {
2
- "_name_or_path": "./results/Doge-20M/checkpoint-1792",
3
- "architectures": [
4
- "DogeForCausalLM"
5
- ],
6
- "attention_dropout": 0.0,
7
- "auto_map": {
8
- "AutoConfig": "configuration_doge.DogeConfig",
9
- "AutoModelForCausalLM": "modeling_doge.DogeForCausalLM"
10
- },
11
- "bos_token_id": 1,
12
- "eos_token_id": 2,
13
- "expert_retrieval_size": 256,
14
- "hidden_act": "silu",
15
- "hidden_bias": false,
16
- "hidden_dropout": 0.0,
17
- "hidden_size": 256,
18
- "initializer_range": 0.02,
19
- "intermediate_size": 1024,
20
- "is_moe": false,
21
- "max_position_embeddings": 2048,
22
- "model_type": "doge",
23
- "num_attention_heads": 2,
24
- "num_cdmmoe_experts": 4096,
25
- "num_cdmmoe_experts_per_head": 8,
26
- "num_cdmmoe_heads": 4,
27
- "num_hidden_layers": 4,
28
- "pad_token_id": 0,
29
- "rms_norm_eps": 1e-06,
30
- "rope_scaling": null,
31
- "rope_theta": 10000.0,
32
- "tie_word_embeddings": false,
33
- "torch_dtype": "float32",
34
- "transformers_version": "4.46.1",
35
- "use_cache": true,
36
- "vocab_size": 32768
37
- }
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "./results/Doge-20M-Instruct-DPO",
3
+ "architectures": [
4
+ "DogeForCausalLM"
5
+ ],
6
+ "attention_dropout": 0.0,
7
+ "auto_map": {
8
+ "AutoConfig": "configuration_doge.DogeConfig",
9
+ "AutoModelForCausalLM": "modeling_doge.DogeForCausalLM"
10
+ },
11
+ "bos_token_id": 0,
12
+ "dynamic_mask_ratio": 0.0,
13
+ "eos_token_id": 1,
14
+ "expert_retrieval_size": 256,
15
+ "hidden_act": "silu",
16
+ "hidden_bias": false,
17
+ "hidden_dropout": 0.0,
18
+ "hidden_size": 256,
19
+ "initializer_range": 0.02,
20
+ "intermediate_size": 512,
21
+ "is_moe": false,
22
+ "max_position_embeddings": 2048,
23
+ "model_type": "doge",
24
+ "num_attention_heads": 2,
25
+ "num_cdmmoe_experts": 2048,
26
+ "num_cdmmoe_experts_per_head": 8,
27
+ "num_cdmmoe_heads": 4,
28
+ "num_channels": 3,
29
+ "num_hidden_layers": 8,
30
+ "num_key_value_heads": 1,
31
+ "pad_token_id": 2,
32
+ "patch_size": 16,
33
+ "rms_norm_eps": 1e-06,
34
+ "rope_scaling": {
35
+ "factor": 4.0,
36
+ "original_max_position_embeddings": 2048,
37
+ "rope_type": "dynamic"
38
+ },
39
+ "rope_theta": 10000.0,
40
+ "torch_dtype": "float32",
41
+ "transformers_version": "4.49.0.dev0",
42
+ "use_cache": true,
43
+ "vocab_size": 32768
44
+ }
configuration_doge.py CHANGED
@@ -25,20 +25,23 @@ from transformers.modeling_rope_utils import rope_config_validation
25
  class DogeConfig(PretrainedConfig):
26
  r"""
27
  This is the configuration class to store the configuration of a [`DogeModel`]. It is used to instantiate an Doge
28
- model according to the specified arguments, defining the model architecture like [LoserCheems/doge-tiny-test](https://huggingface.co/LoserCheems/doge-tiny-test)
29
 
30
  Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
31
  documentation from [`PretrainedConfig`] for more information.
32
 
33
  Args:
34
  vocab_size (`int`, *optional*, defaults to 32768):
35
- Vocabulary size of the Doge model. Defines the number of different tokens that can be represented by the
36
- `inputs_ids` passed when calling [`DogeModel`]
 
 
 
37
  hidden_size (`int`, *optional*, defaults to 1024):
38
  Dimension of the hidden representations.
39
- intermediate_size (`int`, *optional*, defaults to 4096):
40
  Dimension of the CDMoE representations.
41
- num_hidden_layers (`int`, *optional*, defaults to 16):
42
  Number of hidden layers in the Transformer decoder.
43
  hidden_bias (`bool`, *optional*, defaults to `False`):
44
  Whether to use bias in the hidden layers.
@@ -51,24 +54,21 @@ class DogeConfig(PretrainedConfig):
51
  rope_theta (`float`, *optional*, defaults to 10000.0):
52
  The base period of the RoPE embeddings.
53
  rope_scaling (`Dict`, *optional*):
54
- Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type
55
- and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value
56
- accordingly.
57
  Expected contents:
58
  `rope_type` (`str`):
59
- The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope',
60
- 'llama3'], with 'default' being the original RoPE implementation.
61
  `factor` (`float`, *optional*):
62
- Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In
63
- most scaling types, a `factor` of x will enable the model to handle sequences of length x *
64
- original maximum pre-trained length.
65
  `original_max_position_embeddings` (`int`, *optional*):
66
- Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during
67
- pretraining.
68
  `attention_factor` (`float`, *optional*):
69
  Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention
70
- computation. If unspecified, it defaults to value recommended by the implementation, using the
71
- `factor` field to infer the suggested value.
72
  `beta_fast` (`float`, *optional*):
73
  Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear
74
  ramp function. If unspecified, it defaults to 32.
@@ -76,13 +76,11 @@ class DogeConfig(PretrainedConfig):
76
  Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear
77
  ramp function. If unspecified, it defaults to 1.
78
  `short_factor` (`List[float]`, *optional*):
79
- Only used with 'longrope'. The scaling factor to be applied to short contexts (<
80
- `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
81
- size divided by the number of attention heads divided by 2
82
  `long_factor` (`List[float]`, *optional*):
83
- Only used with 'longrope'. The scaling factor to be applied to long contexts (<
84
- `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
85
- size divided by the number of attention heads divided by 2
86
  `low_freq_factor` (`float`, *optional*):
87
  Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE
88
  `high_freq_factor` (`float`, *optional*):
@@ -100,15 +98,24 @@ class DogeConfig(PretrainedConfig):
100
  Beginning of stream token id.
101
  eos_token_id (`int`, *optional*, defaults to 2):
102
  End of stream token id.
103
- tie_word_embeddings (`bool`, *optional*, defaults to `False`):
104
  Whether to tie weight embeddings
105
  num_attention_heads (`int`, *optional*, defaults to 8):
106
  Number of attention heads for each attention layer in the Transformer decoder.
 
 
 
 
 
 
 
107
  attention_dropout (`float`, *optional*, defaults to 0.0):
108
  The dropout ratio for the attention probabilities.
 
 
109
  is_moe (`bool`, *optional*, defaults to `False`):
110
  Whether to use the Cross Domain Mixture of Experts, if `True`, the MoE will inherit the MLP to initialize
111
- num_cdmmoe_experts (`int`, *optional*, defaults to 4096):
112
  Number of Private Experts for the Cross Domain Mixture of Experts.
113
  num_cdmmoe_heads (`int`, *optional*, defaults to 4):
114
  Number of heads of Private Experts for the Cross Domain Mixture of Experts.
@@ -124,32 +131,42 @@ class DogeConfig(PretrainedConfig):
124
  def __init__(
125
  self,
126
  vocab_size=32768,
 
 
127
  hidden_size=1024,
128
- intermediate_size=4096,
129
- num_hidden_layers=16,
130
  hidden_bias=False,
131
  hidden_dropout=0.0,
132
  hidden_act="silu",
133
  max_position_embeddings=2048,
134
  rope_theta=10000.0,
135
- rope_scaling=None,
 
 
 
 
136
  initializer_range=0.02,
137
  rms_norm_eps=1e-06,
138
  use_cache=True,
139
- pad_token_id=0,
140
- bos_token_id=1,
141
- eos_token_id=2,
142
- tie_word_embeddings=False,
143
  num_attention_heads=8,
 
144
  attention_dropout=0.0,
 
145
  is_moe=False,
146
- num_cdmmoe_experts=4096,
147
  num_cdmmoe_heads=4,
148
  num_cdmmoe_experts_per_head=8,
149
  expert_retrieval_size=256,
150
  **kwargs,
151
  ):
152
  self.vocab_size = vocab_size
 
 
153
  self.hidden_size = hidden_size
154
  self.intermediate_size = intermediate_size
155
  self.num_hidden_layers = num_hidden_layers
@@ -162,12 +179,14 @@ class DogeConfig(PretrainedConfig):
162
  self.initializer_range = initializer_range
163
  self.rms_norm_eps = rms_norm_eps
164
  self.use_cache = use_cache
165
- self.pad_token_id = pad_token_id
166
  self.bos_token_id = bos_token_id
167
  self.eos_token_id = eos_token_id
 
168
  self.tie_word_embeddings = tie_word_embeddings
169
  self.num_attention_heads = num_attention_heads
 
170
  self.attention_dropout = attention_dropout
 
171
  self.is_moe = is_moe
172
  self.num_cdmmoe_experts = num_cdmmoe_experts
173
  self.num_cdmmoe_heads = num_cdmmoe_heads
@@ -180,10 +199,14 @@ class DogeConfig(PretrainedConfig):
180
  self.rope_scaling["rope_type"] = self.rope_scaling["type"]
181
  rope_config_validation(self)
182
 
 
 
 
 
183
  super().__init__(
184
- pad_token_id=pad_token_id,
185
  bos_token_id=bos_token_id,
186
  eos_token_id=eos_token_id,
 
187
  tie_word_embeddings=tie_word_embeddings,
188
  **kwargs,
189
  )
 
25
  class DogeConfig(PretrainedConfig):
26
  r"""
27
  This is the configuration class to store the configuration of a [`DogeModel`]. It is used to instantiate an Doge
28
+ model according to the specified arguments, defining the model architecture like [JingzeShi/Doge-20M](https://huggingface.co/JingzeShi/Doge-20M).
29
 
30
  Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
31
  documentation from [`PretrainedConfig`] for more information.
32
 
33
  Args:
34
  vocab_size (`int`, *optional*, defaults to 32768):
35
+ Vocabulary size of the Doge model. Defines the number of different tokens that can be represented by the `inputs_ids` passed when calling [`DogeModel`]
36
+ num_channels (`int`, *optional*, defaults to 3):
37
+ Number of channels in the input image.
38
+ patch_size (`int`, *optional*, defaults to 16):
39
+ Patch size of Vision Transformer Embeddings.
40
  hidden_size (`int`, *optional*, defaults to 1024):
41
  Dimension of the hidden representations.
42
+ intermediate_size (`int`, *optional*, defaults to 2048):
43
  Dimension of the CDMoE representations.
44
+ num_hidden_layers (`int`, *optional*, defaults to 32):
45
  Number of hidden layers in the Transformer decoder.
46
  hidden_bias (`bool`, *optional*, defaults to `False`):
47
  Whether to use bias in the hidden layers.
 
54
  rope_theta (`float`, *optional*, defaults to 10000.0):
55
  The base period of the RoPE embeddings.
56
  rope_scaling (`Dict`, *optional*):
57
+ Dictionary containing the scaling configuration for the RoPE embeddings.
58
+ NOTE: if you apply new rope type and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value accordingly.
 
59
  Expected contents:
60
  `rope_type` (`str`):
61
+ The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope', 'llama3'], with 'default' being the original RoPE implementation.
 
62
  `factor` (`float`, *optional*):
63
+ Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings.
64
+ In most scaling types, a `factor` of x will enable the model to handle sequences of length x * original maximum pre-trained length.
 
65
  `original_max_position_embeddings` (`int`, *optional*):
66
+ Used with 'dynamic', 'longrope' and 'llama3'.
67
+ The original max position embeddings used during pretraining.
68
  `attention_factor` (`float`, *optional*):
69
  Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention
70
+ computation.
71
+ If unspecified, it defaults to value recommended by the implementation, using the `factor` field to infer the suggested value.
72
  `beta_fast` (`float`, *optional*):
73
  Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear
74
  ramp function. If unspecified, it defaults to 32.
 
76
  Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear
77
  ramp function. If unspecified, it defaults to 1.
78
  `short_factor` (`List[float]`, *optional*):
79
+ Only used with 'longrope'. The scaling factor to be applied to short contexts (<`original_max_position_embeddings`).
80
+ Must be a list of numbers with the same length as the hidden size divided by the number of attention heads divided by 2
 
81
  `long_factor` (`List[float]`, *optional*):
82
+ Only used with 'longrope'. The scaling factor to be applied to long contexts (<`original_max_position_embeddings`).
83
+ Must be a list of numbers with the same length as the hidden size divided by the number of attention heads divided by 2
 
84
  `low_freq_factor` (`float`, *optional*):
85
  Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE
86
  `high_freq_factor` (`float`, *optional*):
 
98
  Beginning of stream token id.
99
  eos_token_id (`int`, *optional*, defaults to 2):
100
  End of stream token id.
101
+ tie_word_embeddings (`bool`, *optional*, defaults to `True`):
102
  Whether to tie weight embeddings
103
  num_attention_heads (`int`, *optional*, defaults to 8):
104
  Number of attention heads for each attention layer in the Transformer decoder.
105
+ num_key_value_heads (`int`, *optional*, defaults to `None`):
106
+ This is the number of key_value heads that should be used to implement Grouped Query Attention.
107
+ If `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
108
+ `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used.
109
+ When converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed by meanpooling all the original heads within that group.
110
+ For more details checkout [this paper](https://arxiv.org/pdf/2305.13245.pdf).
111
+ If it is not specified, will default to `num_attention_heads`.
112
  attention_dropout (`float`, *optional*, defaults to 0.0):
113
  The dropout ratio for the attention probabilities.
114
+ dynamic_mask_ratio (`float`, *optional*, defaults to 0.0, range [0, 1]):
115
+ The ratio to control the proportion of the dynamic mask filled with the minimum value.
116
  is_moe (`bool`, *optional*, defaults to `False`):
117
  Whether to use the Cross Domain Mixture of Experts, if `True`, the MoE will inherit the MLP to initialize
118
+ num_cdmmoe_experts (`int`, *optional*, defaults to 2048):
119
  Number of Private Experts for the Cross Domain Mixture of Experts.
120
  num_cdmmoe_heads (`int`, *optional*, defaults to 4):
121
  Number of heads of Private Experts for the Cross Domain Mixture of Experts.
 
131
  def __init__(
132
  self,
133
  vocab_size=32768,
134
+ num_channels=3,
135
+ patch_size=16,
136
  hidden_size=1024,
137
+ intermediate_size=2048,
138
+ num_hidden_layers=32,
139
  hidden_bias=False,
140
  hidden_dropout=0.0,
141
  hidden_act="silu",
142
  max_position_embeddings=2048,
143
  rope_theta=10000.0,
144
+ rope_scaling={
145
+ "rope_type": "dynamic",
146
+ "factor": 4.0,
147
+ "original_max_position_embeddings": 2048,
148
+ },
149
  initializer_range=0.02,
150
  rms_norm_eps=1e-06,
151
  use_cache=True,
152
+ bos_token_id=0,
153
+ eos_token_id=1,
154
+ pad_token_id=2,
155
+ tie_word_embeddings=True,
156
  num_attention_heads=8,
157
+ num_key_value_heads=None,
158
  attention_dropout=0.0,
159
+ dynamic_mask_ratio=0.0,
160
  is_moe=False,
161
+ num_cdmmoe_experts=2048,
162
  num_cdmmoe_heads=4,
163
  num_cdmmoe_experts_per_head=8,
164
  expert_retrieval_size=256,
165
  **kwargs,
166
  ):
167
  self.vocab_size = vocab_size
168
+ self.num_channels = num_channels
169
+ self.patch_size = patch_size
170
  self.hidden_size = hidden_size
171
  self.intermediate_size = intermediate_size
172
  self.num_hidden_layers = num_hidden_layers
 
179
  self.initializer_range = initializer_range
180
  self.rms_norm_eps = rms_norm_eps
181
  self.use_cache = use_cache
 
182
  self.bos_token_id = bos_token_id
183
  self.eos_token_id = eos_token_id
184
+ self.pad_token_id = pad_token_id
185
  self.tie_word_embeddings = tie_word_embeddings
186
  self.num_attention_heads = num_attention_heads
187
+ self.num_key_value_heads = num_key_value_heads
188
  self.attention_dropout = attention_dropout
189
+ self.dynamic_mask_ratio = dynamic_mask_ratio
190
  self.is_moe = is_moe
191
  self.num_cdmmoe_experts = num_cdmmoe_experts
192
  self.num_cdmmoe_heads = num_cdmmoe_heads
 
199
  self.rope_scaling["rope_type"] = self.rope_scaling["type"]
200
  rope_config_validation(self)
201
 
202
+ # for backward compatibility
203
+ if num_key_value_heads is None:
204
+ self.num_key_value_heads = num_attention_heads
205
+
206
  super().__init__(
 
207
  bos_token_id=bos_token_id,
208
  eos_token_id=eos_token_id,
209
+ pad_token_id=pad_token_id,
210
  tie_word_embeddings=tie_word_embeddings,
211
  **kwargs,
212
  )
generation_config.json CHANGED
@@ -1,7 +1,7 @@
1
- {
2
- "_from_model_config": true,
3
- "bos_token_id": 1,
4
- "eos_token_id": 2,
5
- "pad_token_id": 0,
6
- "transformers_version": "4.46.1"
7
- }
 
1
+ {
2
+ "_from_model_config": true,
3
+ "bos_token_id": 0,
4
+ "eos_token_id": 1,
5
+ "pad_token_id": 2,
6
+ "transformers_version": "4.49.0.dev0"
7
+ }
model.safetensors CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:e3e3a6abbaf8f67363291e553d97a4a937204c8f43c97abf14f7e8fa8f64ab54
3
- size 83917640
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ae49b37c117138c1880aa9d2f1c140436772eb3cc1f9d1c73a2e2a0643100b2b
3
+ size 52482152
modeling_doge.py CHANGED
@@ -39,6 +39,7 @@ from transformers.modeling_utils import PreTrainedModel
39
  from transformers.utils import (
40
  add_start_docstrings,
41
  add_start_docstrings_to_model_forward,
 
42
  logging,
43
  replace_return_docstrings,
44
  )
@@ -49,6 +50,9 @@ try:
49
  except ImportError:
50
  einx_add = None
51
 
 
 
 
52
 
53
  logger = logging.get_logger(__name__)
54
 
@@ -79,7 +83,7 @@ class Residual(nn.Module):
79
  def __init__(self, hidden_size):
80
  super().__init__()
81
  self.weight = nn.Parameter(torch.ones(hidden_size))
82
-
83
  def forward(self, residual_states, hidden_states):
84
  return self.weight * residual_states + hidden_states
85
 
@@ -92,10 +96,10 @@ class RotaryEmbedding(nn.Module):
92
  super().__init__()
93
  self.rope_kwargs = {}
94
 
95
- if config.rope_scaling is None:
96
- self.rope_type = "default"
97
  else:
98
- self.rope_type = config.rope_scaling
99
  self.max_seq_len_cached = config.max_position_embeddings
100
  self.original_max_seq_len = config.max_position_embeddings
101
  self.base = config.rope_theta
@@ -133,6 +137,7 @@ class RotaryEmbedding(nn.Module):
133
  # core RoPE block
134
  inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
135
  position_ids_expanded = position_ids[:, None, :].float()
 
136
  device_type = x.device.type
137
  device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
138
  with torch.autocast(device_type=device_type, enabled=False):
@@ -141,6 +146,7 @@ class RotaryEmbedding(nn.Module):
141
  cos = emb.cos()
142
  sin = emb.sin()
143
 
 
144
  cos = cos * self.attention_scaling
145
  sin = sin * self.attention_scaling
146
 
@@ -168,11 +174,10 @@ def apply_QK_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
168
  Deprecated and unused.
169
  unsqueeze_dim (`int`, *optional*, defaults to 1):
170
  The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
171
- sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
172
- that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
173
- k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
174
- cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
175
- the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
176
  Returns:
177
  `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
178
  """
@@ -183,6 +188,18 @@ def apply_QK_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
183
  return q_embed, k_embed
184
 
185
 
 
 
 
 
 
 
 
 
 
 
 
 
186
  class DogeDynamicMaskAttention(nn.Module):
187
  """Dynamic Mask Attention from 'Wonderful Matrices' paper."""
188
 
@@ -193,46 +210,26 @@ class DogeDynamicMaskAttention(nn.Module):
193
  self.layer_idx = layer_idx
194
  if layer_idx is None:
195
  logger.warning_once(
196
- f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will "
197
- "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` "
198
- "when creating this class."
199
  )
200
 
201
  self.hidden_dim = config.hidden_size
202
- self.num_attention_heads = config.num_attention_heads
 
 
 
203
  self.attention_dropout = config.attention_dropout
204
- self.attention_head_dim = self.hidden_dim // self.num_attention_heads
205
 
206
  # Q K V O projections
207
- self.q_proj = nn.Linear(
208
- self.hidden_dim,
209
- self.num_attention_heads * self.attention_head_dim,
210
- bias=config.hidden_bias,
211
- )
212
- self.k_proj = nn.Linear(
213
- self.hidden_dim,
214
- self.num_attention_heads * self.attention_head_dim,
215
- bias=config.hidden_bias,
216
- )
217
  # dynamic mask for the QK^T attention score matrix
218
- self.A = nn.Parameter(
219
- torch.ones(self.num_attention_heads)
220
- )
221
- self.dt_proj = nn.Linear(
222
- self.hidden_dim,
223
- self.num_attention_heads,
224
- bias=config.hidden_bias,
225
- )
226
- self.v_proj = nn.Linear(
227
- self.hidden_dim,
228
- self.num_attention_heads * self.attention_head_dim,
229
- bias=config.hidden_bias,
230
- )
231
- self.o_proj = nn.Linear(
232
- self.hidden_dim,
233
- self.hidden_dim,
234
- bias=config.hidden_bias,
235
- )
236
 
237
  def forward(
238
  self,
@@ -250,15 +247,9 @@ class DogeDynamicMaskAttention(nn.Module):
250
  key_states = self.k_proj(hidden_states)
251
  value_states = self.v_proj(hidden_states)
252
 
253
- query_states = query_states.view(bsz, q_len, self.num_attention_heads, self.attention_head_dim).transpose(
254
- 1, 2
255
- )
256
- key_states = key_states.view(bsz, q_len, self.num_attention_heads, self.attention_head_dim).transpose(
257
- 1, 2
258
- )
259
- value_states = value_states.view(bsz, q_len, self.num_attention_heads, self.attention_head_dim).transpose(
260
- 1, 2
261
- )
262
 
263
  cos, sin = position_embeddings
264
  query_states, key_states = apply_QK_rotary_pos_emb(query_states, key_states, cos, sin)
@@ -268,16 +259,25 @@ class DogeDynamicMaskAttention(nn.Module):
268
  cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
269
  key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
270
 
 
 
 
 
 
 
 
 
271
  # compute attention scores matrix
272
- attn_weights = torch.matmul(query_states, key_states.transpose(-1, -2)) / math.sqrt(self.attention_head_dim)
273
 
274
  # add mask to attention scores
275
- if attention_mask is not None:
276
- dt_states = self.dt_proj(value_states.transpose(1, 2).reshape(bsz, value_states.shape[-2], -1))
277
- dynamic_mask = torch.exp(self.A * F.softplus(dt_states)).transpose(-1, -2)
278
- dynamic_mask = dynamic_mask < 1.0
279
- causal_mask = attention_mask[:, :, :, : key_states.shape[-2]].masked_fill(dynamic_mask[:, :, None, :], torch.finfo(hidden_states.dtype).min)
280
- attn_weights = attn_weights + causal_mask
 
281
 
282
  # upcast attention scores to fp32
283
  attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
@@ -292,8 +292,35 @@ class DogeDynamicMaskAttention(nn.Module):
292
 
293
  return attn_output, past_key_value
294
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
295
 
296
- class DogeSdpaDynamicMaskAttn(DogeDynamicMaskAttention):
297
 
298
  def forward(
299
  self,
@@ -311,9 +338,9 @@ class DogeSdpaDynamicMaskAttn(DogeDynamicMaskAttention):
311
  key_states = self.k_proj(hidden_states)
312
  value_states = self.v_proj(hidden_states)
313
 
314
- query_states = query_states.view(bsz, q_len, self.num_attention_heads, self.attention_head_dim).transpose(1, 2)
315
- key_states = key_states.view(bsz, q_len, self.num_attention_heads, self.attention_head_dim).transpose(1, 2)
316
- value_states = value_states.view(bsz, q_len, self.num_attention_heads, self.attention_head_dim).transpose(1, 2)
317
 
318
  cos, sin = position_embeddings
319
  query_states, key_states = apply_QK_rotary_pos_emb(query_states, key_states, cos, sin)
@@ -322,23 +349,31 @@ class DogeSdpaDynamicMaskAttn(DogeDynamicMaskAttention):
322
  # sin and cos are specific to RoPE models; cache_position needed for the static cache
323
  cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
324
  key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
 
 
 
 
325
 
326
- if attention_mask is not None:
327
- dt_states = self.dt_proj(value_states.transpose(1, 2).reshape(bsz, value_states.shape[-2], -1))
328
- dynamic_mask = torch.exp(self.A * F.softplus(dt_states)).transpose(-1, -2)
329
- dynamic_mask = dynamic_mask < 1.0
330
- causal_mask = attention_mask[:, :, :, : key_states.shape[-2]].masked_fill(dynamic_mask[:, :, None, :], torch.finfo(hidden_states.dtype).min)
 
331
 
332
  query_states = query_states.contiguous()
333
  key_states = key_states.contiguous()
334
  value_states = value_states.contiguous()
335
 
 
 
336
  attn_output = F.scaled_dot_product_attention(
337
  query_states,
338
  key_states,
339
  value_states,
340
- attn_mask=causal_mask,
341
- dropout_p=self.attention_dropout,
 
342
  )
343
 
344
  attn_output = attn_output.transpose(1, 2).contiguous()
@@ -348,9 +383,70 @@ class DogeSdpaDynamicMaskAttn(DogeDynamicMaskAttention):
348
  return attn_output, past_key_value
349
 
350
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
351
  DOGE_ATTENTION_CLASSES = {
 
352
  "eager": DogeDynamicMaskAttention,
353
- "sdpa": DogeSdpaDynamicMaskAttn,
354
  }
355
 
356
 
@@ -362,21 +458,9 @@ class DogeMLP(nn.Module):
362
  self.intermediate_dim = config.intermediate_size
363
  self.act_fn = ACT2FN[config.hidden_act]
364
 
365
- self.gate_proj = nn.Linear(
366
- self.hidden_dim,
367
- self.intermediate_dim,
368
- bias=config.hidden_bias,
369
- )
370
- self.up_proj = nn.Linear(
371
- self.hidden_dim,
372
- self.intermediate_dim,
373
- bias=config.hidden_bias,
374
- )
375
- self.down_proj = nn.Linear(
376
- self.intermediate_dim,
377
- self.hidden_dim,
378
- bias=config.hidden_bias,
379
- )
380
 
381
  def forward(
382
  self,
@@ -402,30 +486,12 @@ class DogeCDMoE(DogeMLP):
402
  self.num_keys = int(math.sqrt(self.num_cdmmoe_experts))
403
 
404
  # queries and keys for retrieval experts
405
- self.queries = nn.Linear(
406
- self.hidden_dim,
407
- self.num_cdmmoe_heads * self.expert_retrieval_dim,
408
- bias=False,
409
- )
410
- self.keys = nn.Parameter(
411
- torch.zeros(
412
- self.num_cdmmoe_heads,
413
- self.num_keys,
414
- 2,
415
- self.expert_retrieval_dim // 2,
416
- )
417
- )
418
 
419
  # experts
420
- self.down_embed = nn.Embedding(
421
- self.num_cdmmoe_experts,
422
- self.hidden_dim,
423
- )
424
- self.up_embed = nn.Embedding(
425
- self.num_cdmmoe_experts,
426
- self.hidden_dim,
427
- )
428
-
429
 
430
  def forward(
431
  self,
@@ -468,13 +534,13 @@ class DogeDecoderLayer(nn.Module):
468
  super().__init__()
469
  self.hidden_dropout = config.hidden_dropout
470
 
471
- self.pre_sequence_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
472
- self.attn = DOGE_ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx)
473
- self.post_sequence_residual = Residual(config.hidden_size)
474
 
475
- self.pre_state_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
476
  self.feed_forward = DogeMLP(config) if config.is_moe == False else DogeCDMoE(config)
477
- self.post_state_residual = Residual(config.hidden_size)
478
 
479
  def forward(
480
  self,
@@ -492,29 +558,25 @@ class DogeDecoderLayer(nn.Module):
492
  Args:
493
  hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
494
  attention_mask (`torch.FloatTensor`, *optional*):
495
- attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1,
496
- query_sequence_length, key_sequence_length)` if default attention is used.
497
  output_attentions (`bool`, *optional*):
498
- Whether or not to return the attentions tensors of all attention layers. See `attentions` under
499
- returned tensors for more detail.
500
  use_cache (`bool`, *optional*):
501
- If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
502
- (see `past_key_values`).
503
  past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
504
  cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
505
  Indices depicting the position of the input sequence tokens in the sequence
506
  position_embeddings (`Tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*):
507
- Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`,
508
- with `head_dim` being the embedding dimension of each attention head.
509
  kwargs (`dict`, *optional*):
510
- Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code
511
- into the model
512
  """
513
 
514
  # sequence transformation
515
  residual = hidden_states
516
- hidden_states = self.pre_sequence_layernorm(hidden_states)
517
- hidden_states, present_key_value = self.attn(
518
  hidden_states=hidden_states,
519
  attention_mask=attention_mask,
520
  position_ids=position_ids,
@@ -525,14 +587,14 @@ class DogeDecoderLayer(nn.Module):
525
  )
526
  self_attn_weights = None
527
  hidden_states = F.dropout(hidden_states, p=self.hidden_dropout, training=self.training)
528
- hidden_states = self.post_sequence_residual(residual, hidden_states)
529
 
530
  # state transformation
531
  residual = hidden_states
532
- hidden_states = self.pre_state_layernorm(hidden_states)
533
  hidden_states = self.feed_forward(hidden_states)
534
  hidden_states = F.dropout(hidden_states, p=self.hidden_dropout, training=self.training)
535
- hidden_states = self.post_state_residual(residual, hidden_states)
536
 
537
  outputs = (hidden_states,)
538
 
@@ -552,6 +614,7 @@ class DogePreTrainedModel(PreTrainedModel):
552
  supports_gradient_checkpointing = True
553
  _no_split_modules = ["DogeDecoderLayer"]
554
  _skip_keys_device_placement = ["past_key_values"]
 
555
  _supports_sdpa = True
556
  _supports_cache_class = True
557
  _supports_quantized_cache = True
@@ -572,11 +635,10 @@ class DogePreTrainedModel(PreTrainedModel):
572
  DOGE_INPUTS_DOCSTRING = r"""
573
  Args:
574
  input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
575
- Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
576
- it.
577
 
578
- Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
579
- [`PreTrainedTokenizer.__call__`] for details.
580
 
581
  [What are input IDs?](../glossary#input-ids)
582
  attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
@@ -587,60 +649,48 @@ DOGE_INPUTS_DOCSTRING = r"""
587
 
588
  [What are attention masks?](../glossary#attention-mask)
589
 
590
- Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
591
- [`PreTrainedTokenizer.__call__`] for details.
592
 
593
- If `past_key_values` is used, optionally only the last `input_ids` have to be input (see
594
- `past_key_values`).
595
 
596
- If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
597
- and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
598
- information on the default strategy.
599
 
600
  - 1 indicates the head is **not masked**,
601
  - 0 indicates the head is **masked**.
602
  position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
603
- Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
604
- config.n_positions - 1]`.
605
 
606
  [What are position IDs?](../glossary#position-ids)
607
  past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*):
608
- Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
609
- blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values`
610
- returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`.
611
 
612
  Two formats are allowed:
613
- - a [`~cache_utils.Cache`] instance, see our
614
- [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache);
615
- - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
616
- shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy
617
- cache format.
618
-
619
- The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the
620
- legacy cache format will be returned.
621
-
622
- If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't
623
- have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids`
624
- of shape `(batch_size, sequence_length)`.
625
  inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
626
- Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
627
- is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
628
- model's internal embedding lookup matrix.
629
  use_cache (`bool`, *optional*):
630
- If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
631
- `past_key_values`).
632
  output_attentions (`bool`, *optional*):
633
- Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
634
- tensors for more detail.
635
  output_hidden_states (`bool`, *optional*):
636
- Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
637
- more detail.
638
  return_dict (`bool`, *optional*):
639
  Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
640
  cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
641
- Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`,
642
- this tensor is not affected by padding. It is used to update the cache in the correct position and to infer
643
- the complete sequence length.
644
  """
645
 
646
 
@@ -711,9 +761,9 @@ class DogeModel(DogePreTrainedModel):
711
  else:
712
  past_key_values = DynamicCache.from_legacy_cache(past_key_values)
713
  logger.warning_once(
714
- "We detected that you are passing `past_key_values` as a tuple of tuples. This is deprecated and "
715
- "will be removed in v4.47. Please convert your cache or use an appropriate `Cache` class "
716
- "(https://huggingface.co/docs/transformers/kv_cache#legacy-cache-format)"
717
  )
718
 
719
  if cache_position is None:
@@ -739,7 +789,7 @@ class DogeModel(DogePreTrainedModel):
739
  all_self_attns = () if output_attentions else None
740
  next_decoder_cache = None
741
 
742
- for decoder_layer in self.layers:
743
  if output_hidden_states:
744
  all_hidden_states += (hidden_states,)
745
 
@@ -842,18 +892,15 @@ class DogeModel(DogePreTrainedModel):
842
  **kwargs,
843
  ):
844
  """
845
- Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
846
- `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
847
 
848
  Args:
849
  attention_mask (`torch.Tensor`):
850
- A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape
851
- `(batch_size, 1, query_length, key_value_length)`.
852
  sequence_length (`int`):
853
  The sequence length being processed.
854
  target_length (`int`):
855
- The target length: when generating with static cache, the mask should be as long as the static cache,
856
- to account for the 0 padding, the part of the cache that is not filled yet.
857
  dtype (`torch.dtype`):
858
  The dtype to use for the 4D attention mask.
859
  device (`torch.device`):
@@ -912,13 +959,13 @@ class DogeForCausalLM(DogePreTrainedModel, GenerationMixin):
912
 
913
  def set_output_embeddings(self, new_embeddings):
914
  self.lm_head = new_embeddings
 
 
 
915
 
916
  def set_decoder(self, decoder):
917
  self.model = decoder
918
 
919
- def get_decoder(self):
920
- return self.model
921
-
922
  @add_start_docstrings_to_model_forward(DOGE_INPUTS_DOCSTRING)
923
  @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
924
  def forward(
@@ -926,7 +973,7 @@ class DogeForCausalLM(DogePreTrainedModel, GenerationMixin):
926
  input_ids: torch.LongTensor = None,
927
  attention_mask: Optional[torch.Tensor] = None,
928
  position_ids: Optional[torch.LongTensor] = None,
929
- past_key_values: Optional[torch.Tensor] = None,
930
  inputs_embeds: Optional[torch.FloatTensor] = None,
931
  labels: Optional[torch.LongTensor] = None,
932
  use_cache: Optional[bool] = None,
@@ -935,19 +982,19 @@ class DogeForCausalLM(DogePreTrainedModel, GenerationMixin):
935
  return_dict: Optional[bool] = None,
936
  cache_position: Optional[torch.LongTensor] = None,
937
  num_logits_to_keep: int = 0,
938
- **loss_kwargs,
939
  ) -> Union[Tuple, CausalLMOutputWithPast]:
940
  r"""
941
  Args:
942
  labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
943
- Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
944
- config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
945
- (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
946
 
947
  num_logits_to_keep (`int`, *optional*):
948
- Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all
949
- `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
950
- token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
951
 
952
  Returns:
953
  """
@@ -969,6 +1016,7 @@ class DogeForCausalLM(DogePreTrainedModel, GenerationMixin):
969
  output_hidden_states=output_hidden_states,
970
  return_dict=return_dict,
971
  cache_position=cache_position,
 
972
  )
973
 
974
  hidden_states = outputs[0]
@@ -978,7 +1026,7 @@ class DogeForCausalLM(DogePreTrainedModel, GenerationMixin):
978
 
979
  loss = None
980
  if labels is not None:
981
- loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.vocab_size, **loss_kwargs)
982
 
983
  if not return_dict:
984
  output = (logits,) + outputs[1:]
@@ -993,18 +1041,98 @@ class DogeForCausalLM(DogePreTrainedModel, GenerationMixin):
993
  )
994
 
995
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
996
  @add_start_docstrings(
997
  """
998
  The Doge Model transformer with a sequence classification head on top (linear layer).
999
 
1000
- [`DogeForSequenceClassification`] uses the last token in order to do the classification, as other causal models
1001
- (e.g. GPT-2) do.
1002
 
1003
- Since it does classification on the last token, it requires to know the position of the last token. If a
1004
- `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If
1005
- no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the
1006
- padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in
1007
- each row of the batch).
1008
  """
1009
  )
1010
  class DogeForSequenceClassification(DogePreTrainedModel):
@@ -1041,9 +1169,9 @@ class DogeForSequenceClassification(DogePreTrainedModel):
1041
  ) -> Union[Tuple, SequenceClassifierOutputWithPast]:
1042
  r"""
1043
  labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1044
- Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
1045
- config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
1046
- `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
1047
  """
1048
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1049
 
 
39
  from transformers.utils import (
40
  add_start_docstrings,
41
  add_start_docstrings_to_model_forward,
42
+ is_torch_greater_or_equal,
43
  logging,
44
  replace_return_docstrings,
45
  )
 
50
  except ImportError:
51
  einx_add = None
52
 
53
+ if is_torch_greater_or_equal("2.5"):
54
+ from torch.nn.attention.flex_attention import flex_attention
55
+
56
 
57
  logger = logging.get_logger(__name__)
58
 
 
83
  def __init__(self, hidden_size):
84
  super().__init__()
85
  self.weight = nn.Parameter(torch.ones(hidden_size))
86
+
87
  def forward(self, residual_states, hidden_states):
88
  return self.weight * residual_states + hidden_states
89
 
 
96
  super().__init__()
97
  self.rope_kwargs = {}
98
 
99
+ if config.rope_scaling is not None:
100
+ self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
101
  else:
102
+ self.rope_type = "default"
103
  self.max_seq_len_cached = config.max_position_embeddings
104
  self.original_max_seq_len = config.max_position_embeddings
105
  self.base = config.rope_theta
 
137
  # core RoPE block
138
  inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
139
  position_ids_expanded = position_ids[:, None, :].float()
140
+ # Force float32 (see https://github.com/huggingface/transformers/pull/29285)
141
  device_type = x.device.type
142
  device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
143
  with torch.autocast(device_type=device_type, enabled=False):
 
146
  cos = emb.cos()
147
  sin = emb.sin()
148
 
149
+ # Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention
150
  cos = cos * self.attention_scaling
151
  sin = sin * self.attention_scaling
152
 
 
174
  Deprecated and unused.
175
  unsqueeze_dim (`int`, *optional*, defaults to 1):
176
  The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
177
+ sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k.
178
+ For example, note that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim].
179
+ Then, if q and k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k.
180
+ Similarly, if q and k have the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
 
181
  Returns:
182
  `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
183
  """
 
188
  return q_embed, k_embed
189
 
190
 
191
+ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
192
+ """
193
+ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep).
194
+ The hidden states go from (batch, num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
195
+ """
196
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
197
+ if n_rep == 1:
198
+ return hidden_states
199
+ hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
200
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
201
+
202
+
203
  class DogeDynamicMaskAttention(nn.Module):
204
  """Dynamic Mask Attention from 'Wonderful Matrices' paper."""
205
 
 
210
  self.layer_idx = layer_idx
211
  if layer_idx is None:
212
  logger.warning_once(
213
+ f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will lead to errors during the forward call if caching is used. "
214
+ "Please make sure to provide a `layer_idx` when creating this class."
 
215
  )
216
 
217
  self.hidden_dim = config.hidden_size
218
+ self.num_heads = config.num_attention_heads
219
+ self.head_dim = self.hidden_dim // self.num_heads
220
+ self.num_key_value_heads = config.num_key_value_heads
221
+ self.num_key_value_groups = self.num_heads // self.num_key_value_heads
222
  self.attention_dropout = config.attention_dropout
223
+ self.dynamic_mask_ratio = config.dynamic_mask_ratio
224
 
225
  # Q K V O projections
226
+ self.q_proj = nn.Linear(self.hidden_dim, self.num_heads * self.head_dim, bias=config.hidden_bias)
227
+ self.k_proj = nn.Linear(self.hidden_dim, self.num_key_value_heads * self.head_dim, bias=config.hidden_bias)
228
+ self.v_proj = nn.Linear(self.hidden_dim, self.num_key_value_heads * self.head_dim, bias=config.hidden_bias)
 
 
 
 
 
 
 
229
  # dynamic mask for the QK^T attention score matrix
230
+ self.A = nn.Parameter(torch.ones(self.num_heads))
231
+ self.dt_proj = nn.Linear(self.num_key_value_heads * self.head_dim, self.num_heads, bias=config.hidden_bias)
232
+ self.o_proj = nn.Linear(self.hidden_dim, self.hidden_dim, bias=config.hidden_bias)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
233
 
234
  def forward(
235
  self,
 
247
  key_states = self.k_proj(hidden_states)
248
  value_states = self.v_proj(hidden_states)
249
 
250
+ query_states = query_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
251
+ key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
252
+ value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
 
 
 
 
 
 
253
 
254
  cos, sin = position_embeddings
255
  query_states, key_states = apply_QK_rotary_pos_emb(query_states, key_states, cos, sin)
 
259
  cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
260
  key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
261
 
262
+ # calculate dynamic mask from value_states
263
+ dt_states = self.dt_proj(value_states.transpose(1, 2).reshape(bsz, value_states.shape[-2], -1))
264
+ dynamic_mask = torch.exp(self.A * F.softplus(dt_states)).transpose(-1, -2)
265
+
266
+ # repeat key and value states
267
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
268
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
269
+
270
  # compute attention scores matrix
271
+ attn_weights = torch.matmul(query_states, key_states.transpose(-1, -2)) / math.sqrt(self.head_dim)
272
 
273
  # add mask to attention scores
274
+ attn_mask = self.prepare_dynamic_mask(
275
+ hidden_states=hidden_states,
276
+ dynamic_mask=dynamic_mask,
277
+ dynamic_mask_ratio=0.1,
278
+ attention_mask=attention_mask,
279
+ )
280
+ attn_weights = attn_weights + attn_mask
281
 
282
  # upcast attention scores to fp32
283
  attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
 
292
 
293
  return attn_output, past_key_value
294
 
295
+ def prepare_dynamic_mask(
296
+ self,
297
+ hidden_states: torch.Tensor,
298
+ dynamic_mask: torch.Tensor,
299
+ dynamic_mask_ratio: float = 0.0,
300
+ attention_mask: Optional[torch.Tensor] = None,
301
+ ):
302
+ """
303
+ Combine `dynamic_mask` with `attention_mask` to generate the final `attn_mask`.
304
+
305
+ Args:
306
+ hidden_states (`torch.Tensor`): The input hidden_states, used to determine the minimum value of the current input precision.
307
+ dynamic_mask (`torch.Tensor`): dynamic mask of shape `(batch_size, num_heads, key_sequence_length)`.
308
+ dynamic_mask_ratio (`float`, *optional*): Ratio from 0.0 to 1.0 used to control the proportion of the dynamic mask filled with the minimum value.
309
+ attention_mask (`torch.Tensor`, *optional*): attention mask of shape `(batch_size, 1, query_sequence_length, key_sequence_length)`.
310
+ """
311
+ min_type = torch.finfo(hidden_states.dtype).min
312
+ attn_mask = dynamic_mask[:, :, None, :]
313
+ if 0.0 < dynamic_mask_ratio < 1.0:
314
+ num_dynamic_mask = int(attn_mask.shape[-1] * dynamic_mask_ratio)
315
+ if num_dynamic_mask > 0:
316
+ rate_value = torch.kthvalue(attn_mask, num_dynamic_mask, dim=-1, keepdim=True).values
317
+ attn_mask = attn_mask.masked_fill(attn_mask < rate_value, min_type)
318
+ if attention_mask is not None:
319
+ attn_mask = attn_mask.masked_fill(attention_mask[:, :, :, : hidden_states.shape[-2]] == min_type, min_type)
320
+ return attn_mask
321
+
322
 
323
+ class DogeSdpaDynamicMaskAttention(DogeDynamicMaskAttention):
324
 
325
  def forward(
326
  self,
 
338
  key_states = self.k_proj(hidden_states)
339
  value_states = self.v_proj(hidden_states)
340
 
341
+ query_states = query_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
342
+ key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
343
+ value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
344
 
345
  cos, sin = position_embeddings
346
  query_states, key_states = apply_QK_rotary_pos_emb(query_states, key_states, cos, sin)
 
349
  # sin and cos are specific to RoPE models; cache_position needed for the static cache
350
  cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
351
  key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
352
+
353
+ # calculate dynamic mask from value_states
354
+ dt_states = self.dt_proj(value_states.transpose(1, 2).reshape(bsz, value_states.shape[-2], -1))
355
+ dynamic_mask = torch.exp(self.A * F.softplus(dt_states)).transpose(-1, -2)
356
 
357
+ attn_mask = self.prepare_dynamic_mask(
358
+ hidden_states=hidden_states,
359
+ dynamic_mask=dynamic_mask,
360
+ dynamic_mask_ratio=self.dynamic_mask_ratio,
361
+ attention_mask=attention_mask,
362
+ )
363
 
364
  query_states = query_states.contiguous()
365
  key_states = key_states.contiguous()
366
  value_states = value_states.contiguous()
367
 
368
+ # NOTE: As of pytorch 2.5.1, cuDNN's SDPA backward pass is still incorrect, so we disable cuDNN SDPA (see https://github.com/pytorch/pytorch/issues/138581)
369
+ torch.backends.cuda.enable_cudnn_sdp(False)
370
  attn_output = F.scaled_dot_product_attention(
371
  query_states,
372
  key_states,
373
  value_states,
374
+ attn_mask=attn_mask,
375
+ dropout_p=self.attention_dropout if self.training else 0.0,
376
+ enable_gqa=True,
377
  )
378
 
379
  attn_output = attn_output.transpose(1, 2).contiguous()
 
383
  return attn_output, past_key_value
384
 
385
 
386
+ class DogeFlexDynamicMaskAttention(DogeDynamicMaskAttention):
387
+
388
+ def forward(
389
+ self,
390
+ hidden_states: torch.Tensor,
391
+ attention_mask: Optional[torch.Tensor] = None,
392
+ position_ids: Optional[torch.LongTensor] = None,
393
+ past_key_value: Optional[Cache] = None,
394
+ cache_position: Optional[torch.LongTensor] = None,
395
+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
396
+ **kwargs,
397
+ ) -> Tuple[torch.Tensor, Optional[Cache]]:
398
+ bsz, q_len, _ = hidden_states.shape
399
+
400
+ query_states = self.q_proj(hidden_states)
401
+ key_states = self.k_proj(hidden_states)
402
+ value_states = self.v_proj(hidden_states)
403
+
404
+ query_states = query_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
405
+ key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
406
+ value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
407
+
408
+ cos, sin = position_embeddings
409
+ query_states, key_states = apply_QK_rotary_pos_emb(query_states, key_states, cos, sin)
410
+
411
+ if past_key_value is not None:
412
+ # sin and cos are specific to RoPE models; cache_position needed for the static cache
413
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
414
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
415
+
416
+ dt_states = self.dt_proj(value_states.transpose(1, 2).reshape(bsz, value_states.shape[-2], -1))
417
+ dynamic_mask = torch.exp(self.A * F.softplus(dt_states)).transpose(-1, -2)
418
+
419
+ attn_mask = self.prepare_dynamic_mask(
420
+ hidden_states=hidden_states,
421
+ dynamic_mask=dynamic_mask,
422
+ dynamic_mask_ratio=self.dynamic_mask_ratio,
423
+ attention_mask=attention_mask,
424
+ )
425
+ # TODO: flex_attention: Captured buffers that require grad are not yet supported.
426
+ # NOTE: So we only use flex_attention in inference mode.
427
+ def dynamic_mask_mod(score, batch, head, q_idx, kv_idx):
428
+ score = score + attn_mask[batch][head][q_idx][kv_idx]
429
+ return score
430
+
431
+ attn_output = flex_attention(
432
+ query_states,
433
+ key_states,
434
+ value_states,
435
+ score_mod=dynamic_mask_mod,
436
+ enable_gqa=True,
437
+ )
438
+
439
+ attn_output = attn_output.transpose(1, 2).contiguous()
440
+ attn_output = attn_output.view(bsz, q_len, -1)
441
+ attn_output = self.o_proj(attn_output)
442
+
443
+ return attn_output, past_key_value
444
+
445
+
446
  DOGE_ATTENTION_CLASSES = {
447
+ "flex_attention": DogeFlexDynamicMaskAttention,
448
  "eager": DogeDynamicMaskAttention,
449
+ "sdpa": DogeSdpaDynamicMaskAttention,
450
  }
451
 
452
 
 
458
  self.intermediate_dim = config.intermediate_size
459
  self.act_fn = ACT2FN[config.hidden_act]
460
 
461
+ self.gate_proj = nn.Linear(self.hidden_dim, self.intermediate_dim, bias=config.hidden_bias)
462
+ self.up_proj = nn.Linear(self.hidden_dim, self.intermediate_dim, bias=config.hidden_bias)
463
+ self.down_proj = nn.Linear(self.intermediate_dim, self.hidden_dim, bias=config.hidden_bias)
 
 
 
 
 
 
 
 
 
 
 
 
464
 
465
  def forward(
466
  self,
 
486
  self.num_keys = int(math.sqrt(self.num_cdmmoe_experts))
487
 
488
  # queries and keys for retrieval experts
489
+ self.queries = nn.Linear(self.hidden_dim, self.num_cdmmoe_heads * self.expert_retrieval_dim, bias=False)
490
+ self.keys = nn.Parameter(torch.zeros(self.num_cdmmoe_heads, self.num_keys, 2, self.expert_retrieval_dim // 2))
 
 
 
 
 
 
 
 
 
 
 
491
 
492
  # experts
493
+ self.down_embed = nn.Embedding(self.num_cdmmoe_experts, self.hidden_dim)
494
+ self.up_embed = nn.Embedding(self.num_cdmmoe_experts, self.hidden_dim)
 
 
 
 
 
 
 
495
 
496
  def forward(
497
  self,
 
534
  super().__init__()
535
  self.hidden_dropout = config.hidden_dropout
536
 
537
+ self.pre_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
538
+ self.self_attn = DOGE_ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx)
539
+ self.pre_residual = Residual(config.hidden_size)
540
 
541
+ self.post_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
542
  self.feed_forward = DogeMLP(config) if config.is_moe == False else DogeCDMoE(config)
543
+ self.post_residual = Residual(config.hidden_size)
544
 
545
  def forward(
546
  self,
 
558
  Args:
559
  hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
560
  attention_mask (`torch.FloatTensor`, *optional*):
561
+ attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1, query_sequence_length, key_sequence_length)` if default attention is used.
 
562
  output_attentions (`bool`, *optional*):
563
+ Whether or not to return the attentions tensors of all attention layers.
564
+ See `attentions` under returned tensors for more detail.
565
  use_cache (`bool`, *optional*):
566
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see `past_key_values`).
 
567
  past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
568
  cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
569
  Indices depicting the position of the input sequence tokens in the sequence
570
  position_embeddings (`Tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*):
571
+ Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`, with `head_dim` being the embedding dimension of each attention head.
 
572
  kwargs (`dict`, *optional*):
573
+ Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code into the model
 
574
  """
575
 
576
  # sequence transformation
577
  residual = hidden_states
578
+ hidden_states = self.pre_layernorm(hidden_states)
579
+ hidden_states, present_key_value = self.self_attn(
580
  hidden_states=hidden_states,
581
  attention_mask=attention_mask,
582
  position_ids=position_ids,
 
587
  )
588
  self_attn_weights = None
589
  hidden_states = F.dropout(hidden_states, p=self.hidden_dropout, training=self.training)
590
+ hidden_states = self.pre_residual(residual, hidden_states)
591
 
592
  # state transformation
593
  residual = hidden_states
594
+ hidden_states = self.post_layernorm(hidden_states)
595
  hidden_states = self.feed_forward(hidden_states)
596
  hidden_states = F.dropout(hidden_states, p=self.hidden_dropout, training=self.training)
597
+ hidden_states = self.post_residual(residual, hidden_states)
598
 
599
  outputs = (hidden_states,)
600
 
 
614
  supports_gradient_checkpointing = True
615
  _no_split_modules = ["DogeDecoderLayer"]
616
  _skip_keys_device_placement = ["past_key_values"]
617
+ _supports_flex_attn = True
618
  _supports_sdpa = True
619
  _supports_cache_class = True
620
  _supports_quantized_cache = True
 
635
  DOGE_INPUTS_DOCSTRING = r"""
636
  Args:
637
  input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
638
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide it.
 
639
 
640
+ Indices can be obtained using [`AutoTokenizer`].
641
+ See [`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] for details.
642
 
643
  [What are input IDs?](../glossary#input-ids)
644
  attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
 
649
 
650
  [What are attention masks?](../glossary#attention-mask)
651
 
652
+ Indices can be obtained using [`AutoTokenizer`].
653
+ See [`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] for details.
654
 
655
+ If `past_key_values` is used, optionally only the last `input_ids` have to be input (see `past_key_values`).
 
656
 
657
+ If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`] and modify to your needs.
658
+ See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more information on the default strategy.
 
659
 
660
  - 1 indicates the head is **not masked**,
661
  - 0 indicates the head is **masked**.
662
  position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
663
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, config.n_positions - 1]`.
 
664
 
665
  [What are position IDs?](../glossary#position-ids)
666
  past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*):
667
+ Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention blocks) that can be used to speed up sequential decoding.
668
+ This typically consists in the `past_key_values` returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`.
 
669
 
670
  Two formats are allowed:
671
+ - a [`~cache_utils.Cache`] instance, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache);
672
+ - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy cache format.
673
+
674
+ The model will output the same cache format that is fed as input.
675
+ If no `past_key_values` are passed, the legacy cache format will be returned.
676
+
677
+ If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids` of shape `(batch_size, sequence_length)`.
 
 
 
 
 
678
  inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
679
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
680
+ This is useful if you want more control over how to convert `input_ids` indices into associated vectors than the model's internal embedding lookup matrix.
 
681
  use_cache (`bool`, *optional*):
682
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see `past_key_values`).
 
683
  output_attentions (`bool`, *optional*):
684
+ Whether or not to return the attentions tensors of all attention layers.
685
+ See `attentions` under returned tensors for more detail.
686
  output_hidden_states (`bool`, *optional*):
687
+ Whether or not to return the hidden states of all layers.
688
+ See `hidden_states` under returned tensors for more detail.
689
  return_dict (`bool`, *optional*):
690
  Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
691
  cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
692
+ Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`, this tensor is not affected by padding.
693
+ It is used to update the cache in the correct position and to infer the complete sequence length.
 
694
  """
695
 
696
 
 
761
  else:
762
  past_key_values = DynamicCache.from_legacy_cache(past_key_values)
763
  logger.warning_once(
764
+ "We detected that you are passing `past_key_values` as a tuple of tuples."
765
+ "This is deprecated and will be removed in v4.47."
766
+ "Please convert your cache or use an appropriate `Cache` class (https://huggingface.co/docs/transformers/kv_cache#legacy-cache-format)"
767
  )
768
 
769
  if cache_position is None:
 
789
  all_self_attns = () if output_attentions else None
790
  next_decoder_cache = None
791
 
792
+ for decoder_layer in self.layers[: self.config.num_hidden_layers]:
793
  if output_hidden_states:
794
  all_hidden_states += (hidden_states,)
795
 
 
892
  **kwargs,
893
  ):
894
  """
895
+ Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
 
896
 
897
  Args:
898
  attention_mask (`torch.Tensor`):
899
+ A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`.
 
900
  sequence_length (`int`):
901
  The sequence length being processed.
902
  target_length (`int`):
903
+ The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet.
 
904
  dtype (`torch.dtype`):
905
  The dtype to use for the 4D attention mask.
906
  device (`torch.device`):
 
959
 
960
  def set_output_embeddings(self, new_embeddings):
961
  self.lm_head = new_embeddings
962
+
963
+ def get_decoder(self):
964
+ return self.model
965
 
966
  def set_decoder(self, decoder):
967
  self.model = decoder
968
 
 
 
 
969
  @add_start_docstrings_to_model_forward(DOGE_INPUTS_DOCSTRING)
970
  @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
971
  def forward(
 
973
  input_ids: torch.LongTensor = None,
974
  attention_mask: Optional[torch.Tensor] = None,
975
  position_ids: Optional[torch.LongTensor] = None,
976
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
977
  inputs_embeds: Optional[torch.FloatTensor] = None,
978
  labels: Optional[torch.LongTensor] = None,
979
  use_cache: Optional[bool] = None,
 
982
  return_dict: Optional[bool] = None,
983
  cache_position: Optional[torch.LongTensor] = None,
984
  num_logits_to_keep: int = 0,
985
+ **kwargs,
986
  ) -> Union[Tuple, CausalLMOutputWithPast]:
987
  r"""
988
  Args:
989
  labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
990
+ Labels for computing the masked language modeling loss.
991
+ Indices should either be in `[0, ..., config.vocab_size]` or -100 (see `input_ids` docstring).
992
+ Tokens with indices set to `-100` are ignored (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
993
 
994
  num_logits_to_keep (`int`, *optional*):
995
+ Calculate logits for the last `num_logits_to_keep` tokens.
996
+ If `0`, calculate logits for all `input_ids` (special case).
997
+ Only last token logits are needed for generation, and calculating them only for that token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
998
 
999
  Returns:
1000
  """
 
1016
  output_hidden_states=output_hidden_states,
1017
  return_dict=return_dict,
1018
  cache_position=cache_position,
1019
+ **kwargs,
1020
  )
1021
 
1022
  hidden_states = outputs[0]
 
1026
 
1027
  loss = None
1028
  if labels is not None:
1029
+ loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.vocab_size, **kwargs)
1030
 
1031
  if not return_dict:
1032
  output = (logits,) + outputs[1:]
 
1041
  )
1042
 
1043
 
1044
+ class DogePatchEmbedding(nn.Module):
1045
+ """
1046
+ This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial `hidden_states` of shape `(batch_size, seq_len, hidden_size)` to be consumed by a Transformer.
1047
+ """
1048
+
1049
+ def __init__(self, config: DogeConfig):
1050
+ super().__init__()
1051
+
1052
+ self.num_channels = config.num_channels
1053
+ self.patch_size = config.patch_size
1054
+ self.hidden_dim = config.hidden_size
1055
+
1056
+ self.sequence_proj = nn.Conv2d(self.num_channels, self.hidden_dim, kernel_size=self.patch_size, stride=self.patch_size)
1057
+ self.state_proj = nn.Linear(self.hidden_dim, self.hidden_dim, bias=config.hidden_bias)
1058
+
1059
+ def forward(
1060
+ self,
1061
+ pixel_values: torch.Tensor,
1062
+ ) -> torch.Tensor:
1063
+ image_embedding = self.sequence_proj(pixel_values).flatten(2).transpose(1, 2)
1064
+ image_embedding = self.state_proj(image_embedding)
1065
+ return image_embedding
1066
+
1067
+
1068
+ class DogeForCausalVLM(DogeForCausalLM):
1069
+ _tied_weights_keys = ["lm_head.weight"]
1070
+
1071
+ def __init__(self, config: DogeConfig):
1072
+ super().__init__(config)
1073
+ self.config = config
1074
+ self.pixel_embed = DogePatchEmbedding(config)
1075
+
1076
+ # Initialize weights and apply final processing
1077
+ self.post_init()
1078
+
1079
+ def forward(
1080
+ self,
1081
+ input_ids: torch.LongTensor = None,
1082
+ pixel_values: torch.FloatTensor = None,
1083
+ attention_mask: Optional[torch.Tensor] = None,
1084
+ position_ids: Optional[torch.LongTensor] = None,
1085
+ past_key_values: Optional[torch.Tensor] = None,
1086
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1087
+ labels: Optional[torch.LongTensor] = None,
1088
+ use_cache: Optional[bool] = None,
1089
+ output_attentions: Optional[bool] = None,
1090
+ output_hidden_states: Optional[bool] = None,
1091
+ return_dict: Optional[bool] = None,
1092
+ cache_position: Optional[torch.LongTensor] = None,
1093
+ num_logits_to_keep: int = 0,
1094
+ **loss_kwargs,
1095
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
1096
+ # TODO: @wubingheng111: refer to Llava for implementating the forward method
1097
+ ...
1098
+
1099
+ def prepare_inputs_for_generation(
1100
+ self,
1101
+ input_ids=None,
1102
+ pixel_values=None,
1103
+ past_key_values=None,
1104
+ input_embeds=None,
1105
+ attention_mask=None,
1106
+ cache_position=None,
1107
+ num_logits_to_keep=None,
1108
+ **kwargs,
1109
+ ):
1110
+ model_inputs = self.model.prepare_inputs_for_generation(
1111
+ input_ids,
1112
+ past_key_values=past_key_values,
1113
+ inputs_embeds=input_embeds,
1114
+ attention_mask=attention_mask,
1115
+ cache_position=cache_position,
1116
+ num_logits_to_keep=num_logits_to_keep,
1117
+ **kwargs,
1118
+ )
1119
+
1120
+ if cache_position[0] == 0:
1121
+ model_inputs["pixel_values"] = pixel_values
1122
+
1123
+ return model_inputs
1124
+
1125
+
1126
  @add_start_docstrings(
1127
  """
1128
  The Doge Model transformer with a sequence classification head on top (linear layer).
1129
 
1130
+ [`DogeForSequenceClassification`] uses the last token in order to do the classification, as other causal models (e.g. GPT-2) do.
 
1131
 
1132
+ Since it does classification on the last token, it requires to know the position of the last token.
1133
+ If a `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row.
1134
+ If no `pad_token_id` is defined, it simply takes the last value in each row of the batch.
1135
+ Since it cannot guess the padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in each row of the batch).
 
1136
  """
1137
  )
1138
  class DogeForSequenceClassification(DogePreTrainedModel):
 
1169
  ) -> Union[Tuple, SequenceClassifierOutputWithPast]:
1170
  r"""
1171
  labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1172
+ Labels for computing the sequence classification/regression loss.
1173
+ Indices should be in `[0, ..., config.num_labels - 1]`.
1174
+ If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
1175
  """
1176
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1177