jupyterjazz commited on
Commit
90873c4
·
verified ·
1 Parent(s): d8cbc92

Update rotary.py

Browse files
Files changed (1) hide show
  1. rotary.py +17 -4
rotary.py CHANGED
@@ -493,9 +493,15 @@ class RotaryEmbedding(torch.nn.Module):
493
 
494
  @base.setter
495
  def base(self, new_base):
496
- if new_base > 0:
497
- self._base = float(new_base)
498
- self.inv_freq = self._compute_inv_freq(device=self.inv_freq.device)
 
 
 
 
 
 
499
  else:
500
  raise ValueError("Rotary base value must be positive")
501
 
@@ -508,21 +514,27 @@ class RotaryEmbedding(torch.nn.Module):
508
  )
509
  )
510
 
511
- def _update_cos_sin_cache(self, seqlen, device=None, dtype=None):
 
 
512
  # Reset the tables if the sequence length has changed,
513
  # if we're on a new device (possibly due to tracing for instance),
514
  # or if we're switching from inference mode to training
 
515
  if (
516
  seqlen > self._seq_len_cached
517
  or self._cos_cached is None
518
  or self._cos_cached.device != device
519
  or self._cos_cached.dtype != dtype
520
  or (self.training and self._cos_cached.is_inference())
 
521
  ):
522
  self._seq_len_cached = seqlen
523
  # We want fp32 here, not self.inv_freq.dtype, since the model could be loaded in bf16
524
  # And the output of arange can be quite large, so bf16 would lose a lot of precision.
525
  # However, for compatibility reason, we add an option to use the dtype of self.inv_freq.
 
 
526
  if self.pos_idx_in_fp32:
527
  t = torch.arange(seqlen, device=device, dtype=torch.float32)
528
  # We want fp32 here as well since inv_freq will be multiplied with t, and the output
@@ -536,6 +548,7 @@ class RotaryEmbedding(torch.nn.Module):
536
  else:
537
  t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype)
538
  inv_freq = self.inv_freq
 
539
  # Don't do einsum, it converts fp32 to fp16 under AMP
540
  # freqs = torch.einsum("i,j->ij", t, self.inv_freq)
541
  freqs = torch.outer(t, inv_freq)
 
493
 
494
  @base.setter
495
  def base(self, new_base):
496
+ new_base = float(new_base)
497
+ if new_base > 0 and new_base != self._base:
498
+ self._base = new_base
499
+ self._update_cos_sin_cache(
500
+ self._seq_len_cached,
501
+ device=self.inv_freq.device,
502
+ dtype=self._cos_cached.dtype if self._cos_cached is not None else None,
503
+ rotary_base_changed=True,
504
+ )
505
  else:
506
  raise ValueError("Rotary base value must be positive")
507
 
 
514
  )
515
  )
516
 
517
+ def _update_cos_sin_cache(
518
+ self, seqlen, device=None, dtype=None, rotary_base_changed=False
519
+ ):
520
  # Reset the tables if the sequence length has changed,
521
  # if we're on a new device (possibly due to tracing for instance),
522
  # or if we're switching from inference mode to training
523
+ # or if the rotary base value was changed
524
  if (
525
  seqlen > self._seq_len_cached
526
  or self._cos_cached is None
527
  or self._cos_cached.device != device
528
  or self._cos_cached.dtype != dtype
529
  or (self.training and self._cos_cached.is_inference())
530
+ or rotary_base_changed
531
  ):
532
  self._seq_len_cached = seqlen
533
  # We want fp32 here, not self.inv_freq.dtype, since the model could be loaded in bf16
534
  # And the output of arange can be quite large, so bf16 would lose a lot of precision.
535
  # However, for compatibility reason, we add an option to use the dtype of self.inv_freq.
536
+ if rotary_base_changed:
537
+ self.inv_freq = self._compute_inv_freq(device=self.inv_freq.device)
538
  if self.pos_idx_in_fp32:
539
  t = torch.arange(seqlen, device=device, dtype=torch.float32)
540
  # We want fp32 here as well since inv_freq will be multiplied with t, and the output
 
548
  else:
549
  t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype)
550
  inv_freq = self.inv_freq
551
+
552
  # Don't do einsum, it converts fp32 to fp16 under AMP
553
  # freqs = torch.einsum("i,j->ij", t, self.inv_freq)
554
  freqs = torch.outer(t, inv_freq)