Crystalcareai commited on
Commit
fe225f7
·
verified ·
1 Parent(s): 5e41833

Update modeling_quiet.py

Browse files
Files changed (1) hide show
  1. modeling_quiet.py +1 -1
modeling_quiet.py CHANGED
@@ -158,7 +158,7 @@ class QuietRMSNorm(nn.Module):
158
 
159
  def forward(self, hidden_states):
160
  input_dtype = hidden_states.dtype
161
- hidden_states = hidden_states.to(torch.float32)
162
  variance = hidden_states.pow(2).mean(-1, keepdim=True)
163
  hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
164
  return hidden_states.to(input_dtype) * self.weight.to(hidden_states.device)
 
158
 
159
  def forward(self, hidden_states):
160
  input_dtype = hidden_states.dtype
161
+ hidden_states = hidden_states.to(torch.bfloat16)
162
  variance = hidden_states.pow(2).mean(-1, keepdim=True)
163
  hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
164
  return hidden_states.to(input_dtype) * self.weight.to(hidden_states.device)