Crystalcareai
commited on
Update modeling_quiet.py
Browse files- modeling_quiet.py +6 -14
modeling_quiet.py
CHANGED
@@ -194,20 +194,13 @@ class QuietRotaryEmbedding(nn.Module):
|
|
194 |
self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
|
195 |
|
196 |
def forward(self, x, seq_len=None):
|
197 |
-
|
198 |
-
|
199 |
-
|
200 |
-
|
201 |
-
# extend the rotary embeddings by repeating them
|
202 |
-
num_repeats = (seq_len - 1) // self.max_position_embeddings + 1
|
203 |
-
cos_cached = self.cos_cached.repeat(num_repeats, 1)[:seq_len]
|
204 |
-
sin_cached = self.sin_cached.repeat(num_repeats, 1)[:seq_len]
|
205 |
-
else:
|
206 |
-
cos_cached = self.cos_cached[:seq_len]
|
207 |
-
sin_cached = self.sin_cached[:seq_len]
|
208 |
return (
|
209 |
-
cos_cached.to(dtype=x.dtype),
|
210 |
-
sin_cached.to(dtype=x.dtype),
|
211 |
)
|
212 |
|
213 |
|
@@ -355,7 +348,6 @@ class QuietAttention(nn.Module):
|
|
355 |
"with a layer index."
|
356 |
)
|
357 |
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
|
358 |
-
|
359 |
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
360 |
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
|
361 |
|
|
|
194 |
self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
|
195 |
|
196 |
def forward(self, x, seq_len=None):
|
197 |
+
# x: [bs, num_attention_heads, seq_len, head_size]
|
198 |
+
if seq_len > self.max_seq_len_cached:
|
199 |
+
self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
|
200 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
201 |
return (
|
202 |
+
self.cos_cached[:seq_len].to(dtype=x.dtype),
|
203 |
+
self.sin_cached[:seq_len].to(dtype=x.dtype),
|
204 |
)
|
205 |
|
206 |
|
|
|
348 |
"with a layer index."
|
349 |
)
|
350 |
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
|
|
|
351 |
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
352 |
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
|
353 |
|