Crystalcareai commited on
Commit
581b060
·
verified ·
1 Parent(s): 6fc8cf8

Update modeling_quiet.py

Browse files
Files changed (1) hide show
  1. modeling_quiet.py +6 -14
modeling_quiet.py CHANGED
@@ -194,20 +194,13 @@ class QuietRotaryEmbedding(nn.Module):
194
  self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
195
 
196
  def forward(self, x, seq_len=None):
197
- if seq_len is None:
198
- seq_len = x.shape[-2]
199
- if seq_len > self.max_position_embeddings:
200
- # If the sequence length is greater than the maximum position embeddings,
201
- # extend the rotary embeddings by repeating them
202
- num_repeats = (seq_len - 1) // self.max_position_embeddings + 1
203
- cos_cached = self.cos_cached.repeat(num_repeats, 1)[:seq_len]
204
- sin_cached = self.sin_cached.repeat(num_repeats, 1)[:seq_len]
205
- else:
206
- cos_cached = self.cos_cached[:seq_len]
207
- sin_cached = self.sin_cached[:seq_len]
208
  return (
209
- cos_cached.to(dtype=x.dtype),
210
- sin_cached.to(dtype=x.dtype),
211
  )
212
 
213
 
@@ -355,7 +348,6 @@ class QuietAttention(nn.Module):
355
  "with a layer index."
356
  )
357
  kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
358
-
359
  cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
360
  query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
361
 
 
194
  self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
195
 
196
  def forward(self, x, seq_len=None):
197
+ # x: [bs, num_attention_heads, seq_len, head_size]
198
+ if seq_len > self.max_seq_len_cached:
199
+ self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
200
+
 
 
 
 
 
 
 
201
  return (
202
+ self.cos_cached[:seq_len].to(dtype=x.dtype),
203
+ self.sin_cached[:seq_len].to(dtype=x.dtype),
204
  )
205
 
206
 
 
348
  "with a layer index."
349
  )
350
  kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
 
351
  cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
352
  query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
353