Crystalcareai commited on
Commit
3af5c08
·
verified ·
1 Parent(s): 620b59e

Update modeling_quiet.py

Browse files
Files changed (1) hide show
  1. modeling_quiet.py +2 -3
modeling_quiet.py CHANGED
@@ -1704,11 +1704,10 @@ class QuietForCausalLM(QuietPreTrainedModel):
1704
  hidden_states = outputs[0]
1705
  prev_rm_logits = rm_logits # for policy gradient
1706
  prev_rm_tokens = cur_rm_tokens # for policy gradient
1707
-
1708
  if ahead_idx == 0:
1709
  hidden_states_lm = hidden_states
1710
- logits = self.lm_head(hidden_states_lm)
1711
- logits = self.bn_lm_head(logits.transpose(1, 2)).transpose(1, 2)
1712
  base_hidden_states = hidden_states.clone()
1713
  initial_loss_logits = logits.clone()
1714
  if self.optimize_lm_head_only_at_start or self.optimize_model_only_at_start:
 
1704
  hidden_states = outputs[0]
1705
  prev_rm_logits = rm_logits # for policy gradient
1706
  prev_rm_tokens = cur_rm_tokens # for policy gradient
 
1707
  if ahead_idx == 0:
1708
  hidden_states_lm = hidden_states
1709
+ logits = self.lm_head(hidden_states_lm)
1710
+ logits = self.bn_lm_head(logits.transpose(1, 2)).transpose(1, 2)
1711
  base_hidden_states = hidden_states.clone()
1712
  initial_loss_logits = logits.clone()
1713
  if self.optimize_lm_head_only_at_start or self.optimize_model_only_at_start: