gugarosa commited on
Commit
534cce7
·
verified ·
1 Parent(s): 94d2ad2

fix(root): Updating to the almost-ready release candidate.

Browse files
Files changed (2) hide show
  1. configuration_phi3.py +4 -0
  2. modeling_phi3.py +132 -92
configuration_phi3.py CHANGED
@@ -87,6 +87,8 @@ class Phi3Config(PretrainedConfig):
87
  contain the following keys: `type`, `short_factor` and `long_factor`. The `type` must be either `su` or `yarn` and
88
  the `short_factor` and `long_factor` must be lists of numbers with the same length as the hidden size
89
  divided by the number of attention heads divided by 2.
 
 
90
  eos_token_id (`int`, *optional*, defaults to 32000):
91
  The id of the "end-of-sequence" token.
92
  pad_token_id (`int`, *optional*, defaults to 32000):
@@ -132,6 +134,7 @@ class Phi3Config(PretrainedConfig):
132
  tie_word_embeddings=False,
133
  rope_theta=10000.0,
134
  rope_scaling=None,
 
135
  eos_token_id=32000,
136
  pad_token_id=32000,
137
  sliding_window=None,
@@ -162,6 +165,7 @@ class Phi3Config(PretrainedConfig):
162
  self.sliding_window = sliding_window
163
 
164
  super().__init__(
 
165
  eos_token_id=eos_token_id,
166
  pad_token_id=pad_token_id,
167
  tie_word_embeddings=tie_word_embeddings,
 
87
  contain the following keys: `type`, `short_factor` and `long_factor`. The `type` must be either `su` or `yarn` and
88
  the `short_factor` and `long_factor` must be lists of numbers with the same length as the hidden size
89
  divided by the number of attention heads divided by 2.
90
+ bos_token_id (`int`, *optional*, defaults to 1):
91
+ The id of the "beginning-of-sequence" token.
92
  eos_token_id (`int`, *optional*, defaults to 32000):
93
  The id of the "end-of-sequence" token.
94
  pad_token_id (`int`, *optional*, defaults to 32000):
 
134
  tie_word_embeddings=False,
135
  rope_theta=10000.0,
136
  rope_scaling=None,
137
+ bos_token_id=1,
138
  eos_token_id=32000,
139
  pad_token_id=32000,
140
  sliding_window=None,
 
165
  self.sliding_window = sliding_window
166
 
167
  super().__init__(
168
+ bos_token_id=bos_token_id,
169
  eos_token_id=eos_token_id,
170
  pad_token_id=pad_token_id,
171
  tie_word_embeddings=tie_word_embeddings,
modeling_phi3.py CHANGED
@@ -40,6 +40,7 @@ from transformers.utils import (
40
  add_code_sample_docstrings,
41
  add_start_docstrings,
42
  add_start_docstrings_to_model_forward,
 
43
  is_flash_attn_greater_or_equal_2_10,
44
  logging,
45
  replace_return_docstrings,
@@ -107,7 +108,7 @@ def _get_unpad_data(attention_mask):
107
  )
108
 
109
 
110
- # Copied from transformers.models.mistral.modeling_mistral.MistralRotaryEmbedding with Mistral->Phi3
111
  class Phi3RotaryEmbedding(nn.Module):
112
  def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
113
  super().__init__()
@@ -115,93 +116,131 @@ class Phi3RotaryEmbedding(nn.Module):
115
  self.dim = dim
116
  self.max_position_embeddings = max_position_embeddings
117
  self.base = base
118
- inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim))
119
- self.register_buffer("inv_freq", inv_freq, persistent=False)
120
 
121
- # Build here to make `torch.jit.trace` work.
122
- self._set_cos_sin_cache(
123
- seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype()
124
- )
125
-
126
- def _set_cos_sin_cache(self, seq_len, device, dtype):
127
- self.max_seq_len_cached = seq_len
128
- t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.int64).type_as(self.inv_freq)
129
-
130
- freqs = torch.outer(t, self.inv_freq)
131
- # Different from paper, but it uses a different permutation in order to obtain the same calculation
132
- emb = torch.cat((freqs, freqs), dim=-1)
133
- self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
134
- self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
135
-
136
- def forward(self, x, seq_len=None):
137
  # x: [bs, num_attention_heads, seq_len, head_size]
138
- if seq_len > self.max_seq_len_cached:
139
- self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
140
-
141
- return (
142
- self.cos_cached[:seq_len].to(dtype=x.dtype),
143
- self.sin_cached[:seq_len].to(dtype=x.dtype),
144
- )
145
-
146
-
147
- class _Phi3ScaledRotaryEmbedding(nn.Module):
 
 
 
 
 
 
 
 
 
148
  def __init__(
149
  self,
150
  dim,
151
  short_factor,
152
  long_factor,
153
- max_position_embeddings=2048,
154
  original_max_position_embeddings=2048,
 
155
  base=10000,
 
156
  ):
157
- super().__init__()
158
 
159
- self.dim = dim
160
  self.short_factor = short_factor
161
  self.long_factor = long_factor
162
- self.max_position_embeddings = max_position_embeddings
163
  self.original_max_position_embeddings = original_max_position_embeddings
164
- self.base = base
165
 
166
- def _calc_mscale(self, scale):
167
- raise NotImplementedError("`_calc_mscale` should be implemented in subclasses")
 
 
168
 
169
  @torch.no_grad()
170
- def forward(self, x, seq_len=None):
171
- if seq_len is None:
172
- seq_len = x.shape[-2]
173
- t = torch.arange(seq_len, device=x.device, dtype=torch.float32)
174
-
175
- if seq_len > self.original_max_position_embeddings:
176
- t = torch.arange(seq_len, device=x.device, dtype=torch.float32)
177
- rescale_factors = torch.tensor(self.long_factor, dtype=torch.float32, device=x.device)
178
  else:
179
- t = torch.arange(self.original_max_position_embeddings, device=x.device, dtype=torch.float32)
180
- rescale_factors = torch.tensor(self.short_factor, dtype=torch.float32, device=x.device)
181
 
182
- inv_freq = 1.0 / (
183
- rescale_factors * (self.base ** (torch.arange(0, self.dim, 2).float().to(x.device) / self.dim))
184
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
185
 
186
- freqs = torch.outer(t, inv_freq)
187
- mscale = self._calc_mscale(self.max_position_embeddings / self.original_max_position_embeddings)
188
- emb = torch.cat((freqs, freqs), dim=-1)
189
 
190
- return (emb.cos() * mscale).to(x.dtype), (emb.sin() * mscale).to(x.dtype)
 
 
 
 
 
 
 
 
 
 
 
191
 
 
 
 
192
 
193
- class Phi3SuScaledRotaryEmbedding(_Phi3ScaledRotaryEmbedding):
194
- def _calc_mscale(self, scale):
195
  if scale <= 1.0:
196
  return 1.0
197
- return math.sqrt(1 + math.log(scale) / math.log(self.original_max_position_embeddings))
198
 
 
 
 
 
 
 
 
199
 
200
- class Phi3YarnScaledRotaryEmbedding(_Phi3ScaledRotaryEmbedding):
201
- def _calc_mscale(self, scale):
202
- if scale <= 1.0:
203
- return 1.0
204
- return 0.1 * math.log(scale) + 1.0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
205
 
206
 
207
  # Copied from transformers.models.llama.modeling_llama.rotate_half
@@ -212,7 +251,8 @@ def rotate_half(x):
212
  return torch.cat((-x2, x1), dim=-1)
213
 
214
 
215
- def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
 
216
  """Applies Rotary Position Embedding to the query and key tensors.
217
 
218
  Args:
@@ -220,9 +260,8 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
220
  k (`torch.Tensor`): The key tensor.
221
  cos (`torch.Tensor`): The cosine part of the rotary embedding.
222
  sin (`torch.Tensor`): The sine part of the rotary embedding.
223
- position_ids (`torch.Tensor`):
224
- The position indices of the tokens corresponding to the query and key tensors. For example, this can be
225
- used to pass offsetted position ids when working with a KV-cache.
226
  unsqueeze_dim (`int`, *optional*, defaults to 1):
227
  The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
228
  sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
@@ -233,12 +272,11 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
233
  Returns:
234
  `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
235
  """
236
- cos = cos[position_ids].unsqueeze(unsqueeze_dim)
237
- sin = sin[position_ids].unsqueeze(unsqueeze_dim)
238
- # Need fp32 here to match logits
239
- q_embed = (q.float() * cos.float()) + (rotate_half(q).float() * sin.float())
240
- k_embed = (k.float() * cos.float()) + (rotate_half(k).float() * sin.float())
241
- return q_embed.to(q.dtype), k_embed.to(k.dtype)
242
 
243
 
244
  class Phi3MLP(nn.Module):
@@ -252,12 +290,12 @@ class Phi3MLP(nn.Module):
252
  self.activation_fn = ACT2FN[config.hidden_act]
253
 
254
  def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor:
255
- y = self.gate_up_proj(hidden_states)
256
 
257
- gate, y = y.chunk(2, dim=-1)
258
- y = y * self.activation_fn(gate)
259
 
260
- return self.down_proj(y)
261
 
262
 
263
  # Copied from transformers.models.llama.modeling_llama.repeat_kv with llama->phi
@@ -296,6 +334,7 @@ class Phi3Attention(nn.Module):
296
  self.max_position_embeddings = config.max_position_embeddings
297
  self.original_max_position_embeddings = config.original_max_position_embeddings
298
  self.rope_theta = config.rope_theta
 
299
  self.is_causal = True
300
 
301
  if (self.head_dim * self.num_heads) != self.hidden_size:
@@ -310,7 +349,7 @@ class Phi3Attention(nn.Module):
310
  self._init_rope()
311
 
312
  def _init_rope(self):
313
- if self.config.rope_scaling is None:
314
  self.rotary_emb = Phi3RotaryEmbedding(
315
  self.head_dim,
316
  max_position_embeddings=self.max_position_embeddings,
@@ -318,30 +357,30 @@ class Phi3Attention(nn.Module):
318
  )
319
  else:
320
  scaling_type = self.config.rope_scaling["type"]
 
 
 
321
  if scaling_type == "su":
322
  self.rotary_emb = Phi3SuScaledRotaryEmbedding(
323
  self.head_dim,
324
- self.config.rope_scaling["short_factor"],
325
- self.config.rope_scaling["long_factor"],
326
- max_position_embeddings=self.config.max_position_embeddings,
327
- original_max_position_embeddings=self.config.original_max_position_embeddings,
328
- base=self.config.rope_theta,
329
  )
330
  elif scaling_type == "yarn":
331
  self.rotary_emb = Phi3YarnScaledRotaryEmbedding(
332
  self.head_dim,
333
- self.config.rope_scaling["short_factor"],
334
- self.config.rope_scaling["long_factor"],
335
- max_position_embeddings=self.config.max_position_embeddings,
336
- original_max_position_embeddings=self.config.original_max_position_embeddings,
337
- base=self.config.rope_theta,
338
  )
339
  else:
340
  raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
341
 
342
- def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
343
- return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
344
-
345
  def forward(
346
  self,
347
  hidden_states: torch.Tensor,
@@ -374,7 +413,8 @@ class Phi3Attention(nn.Module):
374
  "with a layer index."
375
  )
376
  kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
377
- cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
 
378
  query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
379
 
380
  if past_key_value is not None:
@@ -494,7 +534,7 @@ class Phi3FlashAttention2(Phi3Attention):
494
 
495
  # Because the input can be padded, the absolute sequence length depends on the max position id.
496
  rotary_seq_len = max(kv_seq_len, position_ids[:, -1].max().item()) + 1
497
- cos, sin = self.rotary_emb(value_states, seq_len=rotary_seq_len)
498
 
499
  query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
500
 
@@ -781,7 +821,7 @@ class Phi3SdpaAttention(Phi3Attention):
781
  kv_seq_len = key_states.shape[-2]
782
  if past_key_value is not None:
783
  kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
784
- cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
785
 
786
  query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
787
 
 
40
  add_code_sample_docstrings,
41
  add_start_docstrings,
42
  add_start_docstrings_to_model_forward,
43
+ is_flash_attn_2_available,
44
  is_flash_attn_greater_or_equal_2_10,
45
  logging,
46
  replace_return_docstrings,
 
108
  )
109
 
110
 
111
+ # Copied from transformers.models.gemma.modeling_gemma.GemmaRotaryEmbedding with gemma->phi3, Gemma->Phi3
112
  class Phi3RotaryEmbedding(nn.Module):
113
  def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
114
  super().__init__()
 
116
  self.dim = dim
117
  self.max_position_embeddings = max_position_embeddings
118
  self.base = base
119
+ self.register_buffer("inv_freq", None, persistent=False)
 
120
 
121
+ @torch.no_grad()
122
+ def forward(self, x, position_ids, seq_len=None):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
123
  # x: [bs, num_attention_heads, seq_len, head_size]
124
+ if self.inv_freq is None:
125
+ self.inv_freq = 1.0 / (
126
+ self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64, device=x.device).float() / self.dim)
127
+ )
128
+ inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
129
+ position_ids_expanded = position_ids[:, None, :].float()
130
+ # Force float32 since bfloat16 loses precision on long contexts
131
+ # See https://github.com/huggingface/transformers/pull/29285
132
+ device_type = x.device.type
133
+ device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
134
+ with torch.autocast(device_type=device_type, enabled=False):
135
+ freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
136
+ emb = torch.cat((freqs, freqs), dim=-1)
137
+ cos = emb.cos()
138
+ sin = emb.sin()
139
+ return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
140
+
141
+
142
+ class Phi3SuScaledRotaryEmbedding(Phi3RotaryEmbedding):
143
  def __init__(
144
  self,
145
  dim,
146
  short_factor,
147
  long_factor,
 
148
  original_max_position_embeddings=2048,
149
+ max_position_embeddings=2048,
150
  base=10000,
151
+ device=None,
152
  ):
153
+ super().__init__(dim, max_position_embeddings, base, device)
154
 
 
155
  self.short_factor = short_factor
156
  self.long_factor = long_factor
 
157
  self.original_max_position_embeddings = original_max_position_embeddings
 
158
 
159
+ def _calc_scaling_factor(self, scale):
160
+ if scale <= 1.0:
161
+ return 1.0
162
+ return math.sqrt(1 + math.log(scale) / math.log(self.original_max_position_embeddings))
163
 
164
  @torch.no_grad()
165
+ def forward(self, x, position_ids, seq_len=None):
166
+ position_ids_expanded = position_ids[:, None, :].float()
167
+ if position_ids_expanded.shape[-1] > self.original_max_position_embeddings:
168
+ ext_factors = torch.tensor(self.long_factor, dtype=torch.float32, device=x.device)
 
 
 
 
169
  else:
170
+ ext_factors = torch.tensor(self.short_factor, dtype=torch.float32, device=x.device)
 
171
 
172
+ if self.inv_freq is None:
173
+ self.inv_freq = 1.0 / (
174
+ ext_factors
175
+ * self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64, device=x.device).float() / self.dim)
176
+ )
177
+ inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
178
+
179
+ # Force float32 since bfloat16 loses precision on long contexts
180
+ # See https://github.com/huggingface/transformers/pull/29285
181
+ device_type = x.device.type
182
+ device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
183
+ with torch.autocast(device_type=device_type, enabled=False):
184
+ freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
185
+ scaling_factor = self._calc_scaling_factor(
186
+ self.max_position_embeddings / self.original_max_position_embeddings
187
+ )
188
+ emb = torch.cat((freqs, freqs), dim=-1)
189
+ cos = emb.cos() * scaling_factor
190
+ sin = emb.sin() * scaling_factor
191
+ return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
192
 
 
 
 
193
 
194
+ class Phi3YarnScaledRotaryEmbedding(Phi3RotaryEmbedding):
195
+ def __init__(
196
+ self,
197
+ dim,
198
+ short_factor,
199
+ long_factor,
200
+ original_max_position_embeddings=2048,
201
+ max_position_embeddings=2048,
202
+ base=10000,
203
+ device=None,
204
+ ):
205
+ super().__init__(dim, max_position_embeddings, base, device)
206
 
207
+ self.short_factor = short_factor
208
+ self.long_factor = long_factor
209
+ self.original_max_position_embeddings = original_max_position_embeddings
210
 
211
+ def _calc_scaling_factor(self, scale):
 
212
  if scale <= 1.0:
213
  return 1.0
214
+ return 0.1 * math.log(scale) + 1.0
215
 
216
+ @torch.no_grad()
217
+ def forward(self, x, position_ids, seq_len=None):
218
+ position_ids_expanded = position_ids[:, None, :].float()
219
+ if position_ids_expanded.shape[-1] > self.original_max_position_embeddings:
220
+ ext_factors = torch.tensor(self.long_factor, dtype=torch.float32, device=x.device)
221
+ else:
222
+ ext_factors = torch.tensor(self.short_factor, dtype=torch.float32, device=x.device)
223
 
224
+ if self.inv_freq is None:
225
+ self.inv_freq = 1.0 / (
226
+ ext_factors
227
+ * self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64, device=x.device).float() / self.dim)
228
+ )
229
+ inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
230
+
231
+ # Force float32 since bfloat16 loses precision on long contexts
232
+ # See https://github.com/huggingface/transformers/pull/29285
233
+ device_type = x.device.type
234
+ device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
235
+ with torch.autocast(device_type=device_type, enabled=False):
236
+ freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
237
+ scaling_factor = self._calc_scaling_factor(
238
+ self.max_position_embeddings / self.original_max_position_embeddings
239
+ )
240
+ emb = torch.cat((freqs, freqs), dim=-1)
241
+ cos = emb.cos() * scaling_factor
242
+ sin = emb.sin() * scaling_factor
243
+ return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
244
 
245
 
246
  # Copied from transformers.models.llama.modeling_llama.rotate_half
 
251
  return torch.cat((-x2, x1), dim=-1)
252
 
253
 
254
+ # Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb
255
+ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
256
  """Applies Rotary Position Embedding to the query and key tensors.
257
 
258
  Args:
 
260
  k (`torch.Tensor`): The key tensor.
261
  cos (`torch.Tensor`): The cosine part of the rotary embedding.
262
  sin (`torch.Tensor`): The sine part of the rotary embedding.
263
+ position_ids (`torch.Tensor`, *optional*):
264
+ Deprecated and unused.
 
265
  unsqueeze_dim (`int`, *optional*, defaults to 1):
266
  The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
267
  sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
 
272
  Returns:
273
  `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
274
  """
275
+ cos = cos.unsqueeze(unsqueeze_dim)
276
+ sin = sin.unsqueeze(unsqueeze_dim)
277
+ q_embed = (q * cos) + (rotate_half(q) * sin)
278
+ k_embed = (k * cos) + (rotate_half(k) * sin)
279
+ return q_embed, k_embed
 
280
 
281
 
282
  class Phi3MLP(nn.Module):
 
290
  self.activation_fn = ACT2FN[config.hidden_act]
291
 
292
  def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor:
293
+ up_states = self.gate_up_proj(hidden_states)
294
 
295
+ gate, up_states = up_states.chunk(2, dim=-1)
296
+ up_states = up_states * self.activation_fn(gate)
297
 
298
+ return self.down_proj(up_states)
299
 
300
 
301
  # Copied from transformers.models.llama.modeling_llama.repeat_kv with llama->phi
 
334
  self.max_position_embeddings = config.max_position_embeddings
335
  self.original_max_position_embeddings = config.original_max_position_embeddings
336
  self.rope_theta = config.rope_theta
337
+ self.rope_scaling = config.rope_scaling
338
  self.is_causal = True
339
 
340
  if (self.head_dim * self.num_heads) != self.hidden_size:
 
349
  self._init_rope()
350
 
351
  def _init_rope(self):
352
+ if self.rope_scaling is None:
353
  self.rotary_emb = Phi3RotaryEmbedding(
354
  self.head_dim,
355
  max_position_embeddings=self.max_position_embeddings,
 
357
  )
358
  else:
359
  scaling_type = self.config.rope_scaling["type"]
360
+ short_factor = self.config.rope_scaling["short_factor"]
361
+ long_factor = self.config.rope_scaling["long_factor"]
362
+
363
  if scaling_type == "su":
364
  self.rotary_emb = Phi3SuScaledRotaryEmbedding(
365
  self.head_dim,
366
+ short_factor,
367
+ long_factor,
368
+ max_position_embeddings=self.max_position_embeddings,
369
+ original_max_position_embeddings=self.original_max_position_embeddings,
370
+ base=self.rope_theta,
371
  )
372
  elif scaling_type == "yarn":
373
  self.rotary_emb = Phi3YarnScaledRotaryEmbedding(
374
  self.head_dim,
375
+ short_factor,
376
+ long_factor,
377
+ max_position_embeddings=self.max_position_embeddings,
378
+ original_max_position_embeddings=self.original_max_position_embeddings,
379
+ base=self.rope_theta,
380
  )
381
  else:
382
  raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
383
 
 
 
 
384
  def forward(
385
  self,
386
  hidden_states: torch.Tensor,
 
413
  "with a layer index."
414
  )
415
  kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
416
+ cos, sin = self.rotary_emb(value_states, position_ids, seq_len=kv_seq_len)
417
+
418
  query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
419
 
420
  if past_key_value is not None:
 
534
 
535
  # Because the input can be padded, the absolute sequence length depends on the max position id.
536
  rotary_seq_len = max(kv_seq_len, position_ids[:, -1].max().item()) + 1
537
+ cos, sin = self.rotary_emb(value_states, position_ids, seq_len=rotary_seq_len)
538
 
539
  query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
540
 
 
821
  kv_seq_len = key_states.shape[-2]
822
  if past_key_value is not None:
823
  kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
824
+ cos, sin = self.rotary_emb(value_states, position_ids, seq_len=kv_seq_len)
825
 
826
  query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
827