Update modeling_quiet.py
Browse files- modeling_quiet.py +1 -1
modeling_quiet.py
CHANGED
@@ -158,7 +158,7 @@ class QuietRMSNorm(nn.Module):
|
|
158 |
|
159 |
def forward(self, hidden_states):
|
160 |
input_dtype = hidden_states.dtype
|
161 |
-
hidden_states = hidden_states.to(torch.
|
162 |
variance = hidden_states.pow(2).mean(-1, keepdim=True)
|
163 |
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
|
164 |
return hidden_states.to(input_dtype) * self.weight.to(hidden_states.device)
|
|
|
158 |
|
159 |
def forward(self, hidden_states):
|
160 |
input_dtype = hidden_states.dtype
|
161 |
+
hidden_states = hidden_states.to(torch.bfloat16)
|
162 |
variance = hidden_states.pow(2).mean(-1, keepdim=True)
|
163 |
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
|
164 |
return hidden_states.to(input_dtype) * self.weight.to(hidden_states.device)
|