Crystalcareai
commited on
Update modeling_quiet.py
Browse files- modeling_quiet.py +5 -2
modeling_quiet.py
CHANGED
@@ -158,7 +158,7 @@ class QuietRMSNorm(nn.Module):
|
|
158 |
|
159 |
def forward(self, hidden_states):
|
160 |
input_dtype = hidden_states.dtype
|
161 |
-
hidden_states = hidden_states.to(torch.
|
162 |
variance = hidden_states.pow(2).mean(-1, keepdim=True)
|
163 |
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
|
164 |
return hidden_states.to(input_dtype) * self.weight.to(hidden_states.device)
|
@@ -327,9 +327,12 @@ class QuietAttention(nn.Module):
|
|
327 |
"Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
|
328 |
)
|
329 |
bsz, q_len, _ = hidden_states.size()
|
330 |
-
|
|
|
331 |
query_states = self.q_proj(hidden_states)
|
|
|
332 |
key_states = self.k_proj(hidden_states)
|
|
|
333 |
value_states = self.v_proj(hidden_states)
|
334 |
|
335 |
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
|
|
158 |
|
159 |
def forward(self, hidden_states):
|
160 |
input_dtype = hidden_states.dtype
|
161 |
+
hidden_states = hidden_states.to(torch.float32)
|
162 |
variance = hidden_states.pow(2).mean(-1, keepdim=True)
|
163 |
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
|
164 |
return hidden_states.to(input_dtype) * self.weight.to(hidden_states.device)
|
|
|
327 |
"Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
|
328 |
)
|
329 |
bsz, q_len, _ = hidden_states.size()
|
330 |
+
|
331 |
+
hidden_states = hidden_states.to(self.q_proj.weight.dtype)
|
332 |
query_states = self.q_proj(hidden_states)
|
333 |
+
hidden_states = hidden_states.to(self.k_proj.weight.dtype)
|
334 |
key_states = self.k_proj(hidden_states)
|
335 |
+
hidden_states = hidden_states.to(self.v_proj.weight.dtype)
|
336 |
value_states = self.v_proj(hidden_states)
|
337 |
|
338 |
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|