Crystalcareai commited on
Commit
9fdcb7b
·
verified ·
1 Parent(s): fe225f7

Update modeling_quiet.py

Browse files
Files changed (1) hide show
  1. modeling_quiet.py +5 -2
modeling_quiet.py CHANGED
@@ -158,7 +158,7 @@ class QuietRMSNorm(nn.Module):
158
 
159
  def forward(self, hidden_states):
160
  input_dtype = hidden_states.dtype
161
- hidden_states = hidden_states.to(torch.bfloat16)
162
  variance = hidden_states.pow(2).mean(-1, keepdim=True)
163
  hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
164
  return hidden_states.to(input_dtype) * self.weight.to(hidden_states.device)
@@ -327,9 +327,12 @@ class QuietAttention(nn.Module):
327
  "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
328
  )
329
  bsz, q_len, _ = hidden_states.size()
330
-
 
331
  query_states = self.q_proj(hidden_states)
 
332
  key_states = self.k_proj(hidden_states)
 
333
  value_states = self.v_proj(hidden_states)
334
 
335
  query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
 
158
 
159
  def forward(self, hidden_states):
160
  input_dtype = hidden_states.dtype
161
+ hidden_states = hidden_states.to(torch.float32)
162
  variance = hidden_states.pow(2).mean(-1, keepdim=True)
163
  hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
164
  return hidden_states.to(input_dtype) * self.weight.to(hidden_states.device)
 
327
  "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
328
  )
329
  bsz, q_len, _ = hidden_states.size()
330
+
331
+ hidden_states = hidden_states.to(self.q_proj.weight.dtype)
332
  query_states = self.q_proj(hidden_states)
333
+ hidden_states = hidden_states.to(self.k_proj.weight.dtype)
334
  key_states = self.k_proj(hidden_states)
335
+ hidden_states = hidden_states.to(self.v_proj.weight.dtype)
336
  value_states = self.v_proj(hidden_states)
337
 
338
  query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)