jupyterjazz commited on
Commit
071760a
·
verified ·
1 Parent(s): 90873c4

Update rotary.py

Browse files
Files changed (1) hide show
  1. rotary.py +1 -1
rotary.py CHANGED
@@ -534,7 +534,7 @@ class RotaryEmbedding(torch.nn.Module):
534
  # And the output of arange can be quite large, so bf16 would lose a lot of precision.
535
  # However, for compatibility reason, we add an option to use the dtype of self.inv_freq.
536
  if rotary_base_changed:
537
- self.inv_freq = self._compute_inv_freq(device=self.inv_freq.device)
538
  if self.pos_idx_in_fp32:
539
  t = torch.arange(seqlen, device=device, dtype=torch.float32)
540
  # We want fp32 here as well since inv_freq will be multiplied with t, and the output
 
534
  # And the output of arange can be quite large, so bf16 would lose a lot of precision.
535
  # However, for compatibility reason, we add an option to use the dtype of self.inv_freq.
536
  if rotary_base_changed:
537
+ self.inv_freq = self._compute_inv_freq(device=device)
538
  if self.pos_idx_in_fp32:
539
  t = torch.arange(seqlen, device=device, dtype=torch.float32)
540
  # We want fp32 here as well since inv_freq will be multiplied with t, and the output