Crystalcareai
commited on
Update modeling_quiet.py
Browse files- modeling_quiet.py +2 -3
modeling_quiet.py
CHANGED
@@ -1704,11 +1704,10 @@ class QuietForCausalLM(QuietPreTrainedModel):
|
|
1704 |
hidden_states = outputs[0]
|
1705 |
prev_rm_logits = rm_logits # for policy gradient
|
1706 |
prev_rm_tokens = cur_rm_tokens # for policy gradient
|
1707 |
-
|
1708 |
if ahead_idx == 0:
|
1709 |
hidden_states_lm = hidden_states
|
1710 |
-
|
1711 |
-
|
1712 |
base_hidden_states = hidden_states.clone()
|
1713 |
initial_loss_logits = logits.clone()
|
1714 |
if self.optimize_lm_head_only_at_start or self.optimize_model_only_at_start:
|
|
|
1704 |
hidden_states = outputs[0]
|
1705 |
prev_rm_logits = rm_logits # for policy gradient
|
1706 |
prev_rm_tokens = cur_rm_tokens # for policy gradient
|
|
|
1707 |
if ahead_idx == 0:
|
1708 |
hidden_states_lm = hidden_states
|
1709 |
+
logits = self.lm_head(hidden_states_lm)
|
1710 |
+
logits = self.bn_lm_head(logits.transpose(1, 2)).transpose(1, 2)
|
1711 |
base_hidden_states = hidden_states.clone()
|
1712 |
initial_loss_logits = logits.clone()
|
1713 |
if self.optimize_lm_head_only_at_start or self.optimize_model_only_at_start:
|