fix(root): Updating to the almost-ready release candidate.
Browse files- configuration_phi3.py +4 -0
- modeling_phi3.py +132 -92
configuration_phi3.py
CHANGED
@@ -87,6 +87,8 @@ class Phi3Config(PretrainedConfig):
|
|
87 |
contain the following keys: `type`, `short_factor` and `long_factor`. The `type` must be either `su` or `yarn` and
|
88 |
the `short_factor` and `long_factor` must be lists of numbers with the same length as the hidden size
|
89 |
divided by the number of attention heads divided by 2.
|
|
|
|
|
90 |
eos_token_id (`int`, *optional*, defaults to 32000):
|
91 |
The id of the "end-of-sequence" token.
|
92 |
pad_token_id (`int`, *optional*, defaults to 32000):
|
@@ -132,6 +134,7 @@ class Phi3Config(PretrainedConfig):
|
|
132 |
tie_word_embeddings=False,
|
133 |
rope_theta=10000.0,
|
134 |
rope_scaling=None,
|
|
|
135 |
eos_token_id=32000,
|
136 |
pad_token_id=32000,
|
137 |
sliding_window=None,
|
@@ -162,6 +165,7 @@ class Phi3Config(PretrainedConfig):
|
|
162 |
self.sliding_window = sliding_window
|
163 |
|
164 |
super().__init__(
|
|
|
165 |
eos_token_id=eos_token_id,
|
166 |
pad_token_id=pad_token_id,
|
167 |
tie_word_embeddings=tie_word_embeddings,
|
|
|
87 |
contain the following keys: `type`, `short_factor` and `long_factor`. The `type` must be either `su` or `yarn` and
|
88 |
the `short_factor` and `long_factor` must be lists of numbers with the same length as the hidden size
|
89 |
divided by the number of attention heads divided by 2.
|
90 |
+
bos_token_id (`int`, *optional*, defaults to 1):
|
91 |
+
The id of the "beginning-of-sequence" token.
|
92 |
eos_token_id (`int`, *optional*, defaults to 32000):
|
93 |
The id of the "end-of-sequence" token.
|
94 |
pad_token_id (`int`, *optional*, defaults to 32000):
|
|
|
134 |
tie_word_embeddings=False,
|
135 |
rope_theta=10000.0,
|
136 |
rope_scaling=None,
|
137 |
+
bos_token_id=1,
|
138 |
eos_token_id=32000,
|
139 |
pad_token_id=32000,
|
140 |
sliding_window=None,
|
|
|
165 |
self.sliding_window = sliding_window
|
166 |
|
167 |
super().__init__(
|
168 |
+
bos_token_id=bos_token_id,
|
169 |
eos_token_id=eos_token_id,
|
170 |
pad_token_id=pad_token_id,
|
171 |
tie_word_embeddings=tie_word_embeddings,
|
modeling_phi3.py
CHANGED
@@ -40,6 +40,7 @@ from transformers.utils import (
|
|
40 |
add_code_sample_docstrings,
|
41 |
add_start_docstrings,
|
42 |
add_start_docstrings_to_model_forward,
|
|
|
43 |
is_flash_attn_greater_or_equal_2_10,
|
44 |
logging,
|
45 |
replace_return_docstrings,
|
@@ -107,7 +108,7 @@ def _get_unpad_data(attention_mask):
|
|
107 |
)
|
108 |
|
109 |
|
110 |
-
# Copied from transformers.models.
|
111 |
class Phi3RotaryEmbedding(nn.Module):
|
112 |
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
|
113 |
super().__init__()
|
@@ -115,93 +116,131 @@ class Phi3RotaryEmbedding(nn.Module):
|
|
115 |
self.dim = dim
|
116 |
self.max_position_embeddings = max_position_embeddings
|
117 |
self.base = base
|
118 |
-
|
119 |
-
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
120 |
|
121 |
-
|
122 |
-
|
123 |
-
seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype()
|
124 |
-
)
|
125 |
-
|
126 |
-
def _set_cos_sin_cache(self, seq_len, device, dtype):
|
127 |
-
self.max_seq_len_cached = seq_len
|
128 |
-
t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.int64).type_as(self.inv_freq)
|
129 |
-
|
130 |
-
freqs = torch.outer(t, self.inv_freq)
|
131 |
-
# Different from paper, but it uses a different permutation in order to obtain the same calculation
|
132 |
-
emb = torch.cat((freqs, freqs), dim=-1)
|
133 |
-
self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
|
134 |
-
self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
|
135 |
-
|
136 |
-
def forward(self, x, seq_len=None):
|
137 |
# x: [bs, num_attention_heads, seq_len, head_size]
|
138 |
-
if
|
139 |
-
self.
|
140 |
-
|
141 |
-
|
142 |
-
|
143 |
-
|
144 |
-
|
145 |
-
|
146 |
-
|
147 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
148 |
def __init__(
|
149 |
self,
|
150 |
dim,
|
151 |
short_factor,
|
152 |
long_factor,
|
153 |
-
max_position_embeddings=2048,
|
154 |
original_max_position_embeddings=2048,
|
|
|
155 |
base=10000,
|
|
|
156 |
):
|
157 |
-
super().__init__()
|
158 |
|
159 |
-
self.dim = dim
|
160 |
self.short_factor = short_factor
|
161 |
self.long_factor = long_factor
|
162 |
-
self.max_position_embeddings = max_position_embeddings
|
163 |
self.original_max_position_embeddings = original_max_position_embeddings
|
164 |
-
self.base = base
|
165 |
|
166 |
-
def
|
167 |
-
|
|
|
|
|
168 |
|
169 |
@torch.no_grad()
|
170 |
-
def forward(self, x, seq_len=None):
|
171 |
-
|
172 |
-
|
173 |
-
|
174 |
-
|
175 |
-
if seq_len > self.original_max_position_embeddings:
|
176 |
-
t = torch.arange(seq_len, device=x.device, dtype=torch.float32)
|
177 |
-
rescale_factors = torch.tensor(self.long_factor, dtype=torch.float32, device=x.device)
|
178 |
else:
|
179 |
-
|
180 |
-
rescale_factors = torch.tensor(self.short_factor, dtype=torch.float32, device=x.device)
|
181 |
|
182 |
-
|
183 |
-
|
184 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
185 |
|
186 |
-
freqs = torch.outer(t, inv_freq)
|
187 |
-
mscale = self._calc_mscale(self.max_position_embeddings / self.original_max_position_embeddings)
|
188 |
-
emb = torch.cat((freqs, freqs), dim=-1)
|
189 |
|
190 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
191 |
|
|
|
|
|
|
|
192 |
|
193 |
-
|
194 |
-
def _calc_mscale(self, scale):
|
195 |
if scale <= 1.0:
|
196 |
return 1.0
|
197 |
-
return
|
198 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
199 |
|
200 |
-
|
201 |
-
|
202 |
-
|
203 |
-
|
204 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
205 |
|
206 |
|
207 |
# Copied from transformers.models.llama.modeling_llama.rotate_half
|
@@ -212,7 +251,8 @@ def rotate_half(x):
|
|
212 |
return torch.cat((-x2, x1), dim=-1)
|
213 |
|
214 |
|
215 |
-
|
|
|
216 |
"""Applies Rotary Position Embedding to the query and key tensors.
|
217 |
|
218 |
Args:
|
@@ -220,9 +260,8 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
|
|
220 |
k (`torch.Tensor`): The key tensor.
|
221 |
cos (`torch.Tensor`): The cosine part of the rotary embedding.
|
222 |
sin (`torch.Tensor`): The sine part of the rotary embedding.
|
223 |
-
position_ids (`torch.Tensor
|
224 |
-
|
225 |
-
used to pass offsetted position ids when working with a KV-cache.
|
226 |
unsqueeze_dim (`int`, *optional*, defaults to 1):
|
227 |
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
|
228 |
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
|
@@ -233,12 +272,11 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
|
|
233 |
Returns:
|
234 |
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
|
235 |
"""
|
236 |
-
cos = cos
|
237 |
-
sin = sin
|
238 |
-
|
239 |
-
|
240 |
-
|
241 |
-
return q_embed.to(q.dtype), k_embed.to(k.dtype)
|
242 |
|
243 |
|
244 |
class Phi3MLP(nn.Module):
|
@@ -252,12 +290,12 @@ class Phi3MLP(nn.Module):
|
|
252 |
self.activation_fn = ACT2FN[config.hidden_act]
|
253 |
|
254 |
def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor:
|
255 |
-
|
256 |
|
257 |
-
gate,
|
258 |
-
|
259 |
|
260 |
-
return self.down_proj(
|
261 |
|
262 |
|
263 |
# Copied from transformers.models.llama.modeling_llama.repeat_kv with llama->phi
|
@@ -296,6 +334,7 @@ class Phi3Attention(nn.Module):
|
|
296 |
self.max_position_embeddings = config.max_position_embeddings
|
297 |
self.original_max_position_embeddings = config.original_max_position_embeddings
|
298 |
self.rope_theta = config.rope_theta
|
|
|
299 |
self.is_causal = True
|
300 |
|
301 |
if (self.head_dim * self.num_heads) != self.hidden_size:
|
@@ -310,7 +349,7 @@ class Phi3Attention(nn.Module):
|
|
310 |
self._init_rope()
|
311 |
|
312 |
def _init_rope(self):
|
313 |
-
if self.
|
314 |
self.rotary_emb = Phi3RotaryEmbedding(
|
315 |
self.head_dim,
|
316 |
max_position_embeddings=self.max_position_embeddings,
|
@@ -318,30 +357,30 @@ class Phi3Attention(nn.Module):
|
|
318 |
)
|
319 |
else:
|
320 |
scaling_type = self.config.rope_scaling["type"]
|
|
|
|
|
|
|
321 |
if scaling_type == "su":
|
322 |
self.rotary_emb = Phi3SuScaledRotaryEmbedding(
|
323 |
self.head_dim,
|
324 |
-
|
325 |
-
|
326 |
-
max_position_embeddings=self.
|
327 |
-
original_max_position_embeddings=self.
|
328 |
-
base=self.
|
329 |
)
|
330 |
elif scaling_type == "yarn":
|
331 |
self.rotary_emb = Phi3YarnScaledRotaryEmbedding(
|
332 |
self.head_dim,
|
333 |
-
|
334 |
-
|
335 |
-
max_position_embeddings=self.
|
336 |
-
original_max_position_embeddings=self.
|
337 |
-
base=self.
|
338 |
)
|
339 |
else:
|
340 |
raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
|
341 |
|
342 |
-
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
|
343 |
-
return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
|
344 |
-
|
345 |
def forward(
|
346 |
self,
|
347 |
hidden_states: torch.Tensor,
|
@@ -374,7 +413,8 @@ class Phi3Attention(nn.Module):
|
|
374 |
"with a layer index."
|
375 |
)
|
376 |
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
|
377 |
-
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
|
|
378 |
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
|
379 |
|
380 |
if past_key_value is not None:
|
@@ -494,7 +534,7 @@ class Phi3FlashAttention2(Phi3Attention):
|
|
494 |
|
495 |
# Because the input can be padded, the absolute sequence length depends on the max position id.
|
496 |
rotary_seq_len = max(kv_seq_len, position_ids[:, -1].max().item()) + 1
|
497 |
-
cos, sin = self.rotary_emb(value_states, seq_len=rotary_seq_len)
|
498 |
|
499 |
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
|
500 |
|
@@ -781,7 +821,7 @@ class Phi3SdpaAttention(Phi3Attention):
|
|
781 |
kv_seq_len = key_states.shape[-2]
|
782 |
if past_key_value is not None:
|
783 |
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
|
784 |
-
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
785 |
|
786 |
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
|
787 |
|
|
|
40 |
add_code_sample_docstrings,
|
41 |
add_start_docstrings,
|
42 |
add_start_docstrings_to_model_forward,
|
43 |
+
is_flash_attn_2_available,
|
44 |
is_flash_attn_greater_or_equal_2_10,
|
45 |
logging,
|
46 |
replace_return_docstrings,
|
|
|
108 |
)
|
109 |
|
110 |
|
111 |
+
# Copied from transformers.models.gemma.modeling_gemma.GemmaRotaryEmbedding with gemma->phi3, Gemma->Phi3
|
112 |
class Phi3RotaryEmbedding(nn.Module):
|
113 |
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
|
114 |
super().__init__()
|
|
|
116 |
self.dim = dim
|
117 |
self.max_position_embeddings = max_position_embeddings
|
118 |
self.base = base
|
119 |
+
self.register_buffer("inv_freq", None, persistent=False)
|
|
|
120 |
|
121 |
+
@torch.no_grad()
|
122 |
+
def forward(self, x, position_ids, seq_len=None):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
123 |
# x: [bs, num_attention_heads, seq_len, head_size]
|
124 |
+
if self.inv_freq is None:
|
125 |
+
self.inv_freq = 1.0 / (
|
126 |
+
self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64, device=x.device).float() / self.dim)
|
127 |
+
)
|
128 |
+
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
|
129 |
+
position_ids_expanded = position_ids[:, None, :].float()
|
130 |
+
# Force float32 since bfloat16 loses precision on long contexts
|
131 |
+
# See https://github.com/huggingface/transformers/pull/29285
|
132 |
+
device_type = x.device.type
|
133 |
+
device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
|
134 |
+
with torch.autocast(device_type=device_type, enabled=False):
|
135 |
+
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
|
136 |
+
emb = torch.cat((freqs, freqs), dim=-1)
|
137 |
+
cos = emb.cos()
|
138 |
+
sin = emb.sin()
|
139 |
+
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
|
140 |
+
|
141 |
+
|
142 |
+
class Phi3SuScaledRotaryEmbedding(Phi3RotaryEmbedding):
|
143 |
def __init__(
|
144 |
self,
|
145 |
dim,
|
146 |
short_factor,
|
147 |
long_factor,
|
|
|
148 |
original_max_position_embeddings=2048,
|
149 |
+
max_position_embeddings=2048,
|
150 |
base=10000,
|
151 |
+
device=None,
|
152 |
):
|
153 |
+
super().__init__(dim, max_position_embeddings, base, device)
|
154 |
|
|
|
155 |
self.short_factor = short_factor
|
156 |
self.long_factor = long_factor
|
|
|
157 |
self.original_max_position_embeddings = original_max_position_embeddings
|
|
|
158 |
|
159 |
+
def _calc_scaling_factor(self, scale):
|
160 |
+
if scale <= 1.0:
|
161 |
+
return 1.0
|
162 |
+
return math.sqrt(1 + math.log(scale) / math.log(self.original_max_position_embeddings))
|
163 |
|
164 |
@torch.no_grad()
|
165 |
+
def forward(self, x, position_ids, seq_len=None):
|
166 |
+
position_ids_expanded = position_ids[:, None, :].float()
|
167 |
+
if position_ids_expanded.shape[-1] > self.original_max_position_embeddings:
|
168 |
+
ext_factors = torch.tensor(self.long_factor, dtype=torch.float32, device=x.device)
|
|
|
|
|
|
|
|
|
169 |
else:
|
170 |
+
ext_factors = torch.tensor(self.short_factor, dtype=torch.float32, device=x.device)
|
|
|
171 |
|
172 |
+
if self.inv_freq is None:
|
173 |
+
self.inv_freq = 1.0 / (
|
174 |
+
ext_factors
|
175 |
+
* self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64, device=x.device).float() / self.dim)
|
176 |
+
)
|
177 |
+
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
|
178 |
+
|
179 |
+
# Force float32 since bfloat16 loses precision on long contexts
|
180 |
+
# See https://github.com/huggingface/transformers/pull/29285
|
181 |
+
device_type = x.device.type
|
182 |
+
device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
|
183 |
+
with torch.autocast(device_type=device_type, enabled=False):
|
184 |
+
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
|
185 |
+
scaling_factor = self._calc_scaling_factor(
|
186 |
+
self.max_position_embeddings / self.original_max_position_embeddings
|
187 |
+
)
|
188 |
+
emb = torch.cat((freqs, freqs), dim=-1)
|
189 |
+
cos = emb.cos() * scaling_factor
|
190 |
+
sin = emb.sin() * scaling_factor
|
191 |
+
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
|
192 |
|
|
|
|
|
|
|
193 |
|
194 |
+
class Phi3YarnScaledRotaryEmbedding(Phi3RotaryEmbedding):
|
195 |
+
def __init__(
|
196 |
+
self,
|
197 |
+
dim,
|
198 |
+
short_factor,
|
199 |
+
long_factor,
|
200 |
+
original_max_position_embeddings=2048,
|
201 |
+
max_position_embeddings=2048,
|
202 |
+
base=10000,
|
203 |
+
device=None,
|
204 |
+
):
|
205 |
+
super().__init__(dim, max_position_embeddings, base, device)
|
206 |
|
207 |
+
self.short_factor = short_factor
|
208 |
+
self.long_factor = long_factor
|
209 |
+
self.original_max_position_embeddings = original_max_position_embeddings
|
210 |
|
211 |
+
def _calc_scaling_factor(self, scale):
|
|
|
212 |
if scale <= 1.0:
|
213 |
return 1.0
|
214 |
+
return 0.1 * math.log(scale) + 1.0
|
215 |
|
216 |
+
@torch.no_grad()
|
217 |
+
def forward(self, x, position_ids, seq_len=None):
|
218 |
+
position_ids_expanded = position_ids[:, None, :].float()
|
219 |
+
if position_ids_expanded.shape[-1] > self.original_max_position_embeddings:
|
220 |
+
ext_factors = torch.tensor(self.long_factor, dtype=torch.float32, device=x.device)
|
221 |
+
else:
|
222 |
+
ext_factors = torch.tensor(self.short_factor, dtype=torch.float32, device=x.device)
|
223 |
|
224 |
+
if self.inv_freq is None:
|
225 |
+
self.inv_freq = 1.0 / (
|
226 |
+
ext_factors
|
227 |
+
* self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64, device=x.device).float() / self.dim)
|
228 |
+
)
|
229 |
+
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
|
230 |
+
|
231 |
+
# Force float32 since bfloat16 loses precision on long contexts
|
232 |
+
# See https://github.com/huggingface/transformers/pull/29285
|
233 |
+
device_type = x.device.type
|
234 |
+
device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
|
235 |
+
with torch.autocast(device_type=device_type, enabled=False):
|
236 |
+
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
|
237 |
+
scaling_factor = self._calc_scaling_factor(
|
238 |
+
self.max_position_embeddings / self.original_max_position_embeddings
|
239 |
+
)
|
240 |
+
emb = torch.cat((freqs, freqs), dim=-1)
|
241 |
+
cos = emb.cos() * scaling_factor
|
242 |
+
sin = emb.sin() * scaling_factor
|
243 |
+
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
|
244 |
|
245 |
|
246 |
# Copied from transformers.models.llama.modeling_llama.rotate_half
|
|
|
251 |
return torch.cat((-x2, x1), dim=-1)
|
252 |
|
253 |
|
254 |
+
# Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb
|
255 |
+
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
|
256 |
"""Applies Rotary Position Embedding to the query and key tensors.
|
257 |
|
258 |
Args:
|
|
|
260 |
k (`torch.Tensor`): The key tensor.
|
261 |
cos (`torch.Tensor`): The cosine part of the rotary embedding.
|
262 |
sin (`torch.Tensor`): The sine part of the rotary embedding.
|
263 |
+
position_ids (`torch.Tensor`, *optional*):
|
264 |
+
Deprecated and unused.
|
|
|
265 |
unsqueeze_dim (`int`, *optional*, defaults to 1):
|
266 |
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
|
267 |
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
|
|
|
272 |
Returns:
|
273 |
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
|
274 |
"""
|
275 |
+
cos = cos.unsqueeze(unsqueeze_dim)
|
276 |
+
sin = sin.unsqueeze(unsqueeze_dim)
|
277 |
+
q_embed = (q * cos) + (rotate_half(q) * sin)
|
278 |
+
k_embed = (k * cos) + (rotate_half(k) * sin)
|
279 |
+
return q_embed, k_embed
|
|
|
280 |
|
281 |
|
282 |
class Phi3MLP(nn.Module):
|
|
|
290 |
self.activation_fn = ACT2FN[config.hidden_act]
|
291 |
|
292 |
def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor:
|
293 |
+
up_states = self.gate_up_proj(hidden_states)
|
294 |
|
295 |
+
gate, up_states = up_states.chunk(2, dim=-1)
|
296 |
+
up_states = up_states * self.activation_fn(gate)
|
297 |
|
298 |
+
return self.down_proj(up_states)
|
299 |
|
300 |
|
301 |
# Copied from transformers.models.llama.modeling_llama.repeat_kv with llama->phi
|
|
|
334 |
self.max_position_embeddings = config.max_position_embeddings
|
335 |
self.original_max_position_embeddings = config.original_max_position_embeddings
|
336 |
self.rope_theta = config.rope_theta
|
337 |
+
self.rope_scaling = config.rope_scaling
|
338 |
self.is_causal = True
|
339 |
|
340 |
if (self.head_dim * self.num_heads) != self.hidden_size:
|
|
|
349 |
self._init_rope()
|
350 |
|
351 |
def _init_rope(self):
|
352 |
+
if self.rope_scaling is None:
|
353 |
self.rotary_emb = Phi3RotaryEmbedding(
|
354 |
self.head_dim,
|
355 |
max_position_embeddings=self.max_position_embeddings,
|
|
|
357 |
)
|
358 |
else:
|
359 |
scaling_type = self.config.rope_scaling["type"]
|
360 |
+
short_factor = self.config.rope_scaling["short_factor"]
|
361 |
+
long_factor = self.config.rope_scaling["long_factor"]
|
362 |
+
|
363 |
if scaling_type == "su":
|
364 |
self.rotary_emb = Phi3SuScaledRotaryEmbedding(
|
365 |
self.head_dim,
|
366 |
+
short_factor,
|
367 |
+
long_factor,
|
368 |
+
max_position_embeddings=self.max_position_embeddings,
|
369 |
+
original_max_position_embeddings=self.original_max_position_embeddings,
|
370 |
+
base=self.rope_theta,
|
371 |
)
|
372 |
elif scaling_type == "yarn":
|
373 |
self.rotary_emb = Phi3YarnScaledRotaryEmbedding(
|
374 |
self.head_dim,
|
375 |
+
short_factor,
|
376 |
+
long_factor,
|
377 |
+
max_position_embeddings=self.max_position_embeddings,
|
378 |
+
original_max_position_embeddings=self.original_max_position_embeddings,
|
379 |
+
base=self.rope_theta,
|
380 |
)
|
381 |
else:
|
382 |
raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
|
383 |
|
|
|
|
|
|
|
384 |
def forward(
|
385 |
self,
|
386 |
hidden_states: torch.Tensor,
|
|
|
413 |
"with a layer index."
|
414 |
)
|
415 |
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
|
416 |
+
cos, sin = self.rotary_emb(value_states, position_ids, seq_len=kv_seq_len)
|
417 |
+
|
418 |
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
|
419 |
|
420 |
if past_key_value is not None:
|
|
|
534 |
|
535 |
# Because the input can be padded, the absolute sequence length depends on the max position id.
|
536 |
rotary_seq_len = max(kv_seq_len, position_ids[:, -1].max().item()) + 1
|
537 |
+
cos, sin = self.rotary_emb(value_states, position_ids, seq_len=rotary_seq_len)
|
538 |
|
539 |
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
|
540 |
|
|
|
821 |
kv_seq_len = key_states.shape[-2]
|
822 |
if past_key_value is not None:
|
823 |
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
|
824 |
+
cos, sin = self.rotary_emb(value_states, position_ids, seq_len=kv_seq_len)
|
825 |
|
826 |
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
|
827 |
|