Crystalcareai commited on
Commit
88ca699
·
verified ·
1 Parent(s): 6caf34c

Update modeling_quiet.py

Browse files
Files changed (1) hide show
  1. modeling_quiet.py +3 -25
modeling_quiet.py CHANGED
@@ -1262,6 +1262,7 @@ class QuietForCausalLM(QuietPreTrainedModel):
1262
 
1263
  # For visualization
1264
  self.eval_mode = False
 
1265
  num_talk = 1
1266
  talk_input_dim = config.hidden_size if not self.use_concat_talk_head else config.hidden_size * 2
1267
  if self.use_weighted_talk_head:
@@ -1283,16 +1284,6 @@ class QuietForCausalLM(QuietPreTrainedModel):
1283
  nn.Linear(talk_input_dim, talk_output_dim, bias=False)
1284
  )])
1285
 
1286
- # Add batch normalization to the model
1287
- self.bn_lm_head = nn.BatchNorm1d(config.vocab_size)
1288
- self.bn_talk_head = nn.BatchNorm1d(talk_output_dim)
1289
-
1290
- # Initialize weights using Xavier initialization
1291
- self.apply(self._init_weights)
1292
-
1293
- # Add dropout regularization
1294
- self.dropout = nn.Dropout(config.hidden_dropout_prob)
1295
-
1296
  # Initialize weights and apply final processing
1297
  self.post_init()
1298
 
@@ -1313,14 +1304,6 @@ class QuietForCausalLM(QuietPreTrainedModel):
1313
 
1314
  def get_decoder(self):
1315
  return self.model
1316
-
1317
- def _init_weights(self, module):
1318
- if isinstance(module, nn.Linear):
1319
- nn.init.xavier_uniform_(module.weight)
1320
- if module.bias is not None:
1321
- nn.init.constant_(module.bias, 0)
1322
- elif isinstance(module, nn.Embedding):
1323
- nn.init.xavier_uniform_(module.weight)
1324
 
1325
  @torch.no_grad()
1326
  def infer(
@@ -1719,11 +1702,10 @@ class QuietForCausalLM(QuietPreTrainedModel):
1719
  hidden_states = outputs[0]
1720
  prev_rm_logits = rm_logits # for policy gradient
1721
  prev_rm_tokens = cur_rm_tokens # for policy gradient
 
1722
  if ahead_idx == 0:
1723
  hidden_states_lm = hidden_states
1724
  logits = self.lm_head(hidden_states_lm)
1725
- logits = self.bn_lm_head(logits.transpose(1, 2)).transpose(1, 2)
1726
- logits = self.dropout(logits)
1727
  base_hidden_states = hidden_states.clone()
1728
  initial_loss_logits = logits.clone()
1729
  if self.optimize_lm_head_only_at_start or self.optimize_model_only_at_start:
@@ -1754,12 +1736,8 @@ class QuietForCausalLM(QuietPreTrainedModel):
1754
  head_input_hidden_states = talk_hidden_states
1755
 
1756
  residual_logits = self.talk_head[0](head_input_hidden_states)
1757
- residual_logits = self.bn_talk_head(residual_logits.transpose(1, 2)).transpose(1, 2)
1758
- residual_logits = self.dropout(residual_logits)
1759
  if self.use_shallow_talk:
1760
- residual_logits = apply_head(self.lm_head, residual_logits, detach=self.optimize_lm_head_only_at_start)
1761
- residual_logits = self.bn_lm_head(residual_logits.transpose(1, 2)).transpose(1, 2)
1762
- residual_logits = self.dropout(residual_logits)
1763
  residual_logits = residual_logits.to(logits.device)
1764
  if self.use_weighted_talk_head:
1765
  # combine the cur_base_hidden with the talk_hidden_states according to the weighted head
 
1262
 
1263
  # For visualization
1264
  self.eval_mode = False
1265
+
1266
  num_talk = 1
1267
  talk_input_dim = config.hidden_size if not self.use_concat_talk_head else config.hidden_size * 2
1268
  if self.use_weighted_talk_head:
 
1284
  nn.Linear(talk_input_dim, talk_output_dim, bias=False)
1285
  )])
1286
 
 
 
 
 
 
 
 
 
 
 
1287
  # Initialize weights and apply final processing
1288
  self.post_init()
1289
 
 
1304
 
1305
  def get_decoder(self):
1306
  return self.model
 
 
 
 
 
 
 
 
1307
 
1308
  @torch.no_grad()
1309
  def infer(
 
1702
  hidden_states = outputs[0]
1703
  prev_rm_logits = rm_logits # for policy gradient
1704
  prev_rm_tokens = cur_rm_tokens # for policy gradient
1705
+
1706
  if ahead_idx == 0:
1707
  hidden_states_lm = hidden_states
1708
  logits = self.lm_head(hidden_states_lm)
 
 
1709
  base_hidden_states = hidden_states.clone()
1710
  initial_loss_logits = logits.clone()
1711
  if self.optimize_lm_head_only_at_start or self.optimize_model_only_at_start:
 
1736
  head_input_hidden_states = talk_hidden_states
1737
 
1738
  residual_logits = self.talk_head[0](head_input_hidden_states)
 
 
1739
  if self.use_shallow_talk:
1740
+ residual_logits = apply_head(self.lm_head, residual_logits, detach=self.optimize_lm_head_only_at_start)
 
 
1741
  residual_logits = residual_logits.to(logits.device)
1742
  if self.use_weighted_talk_head:
1743
  # combine the cur_base_hidden with the talk_hidden_states according to the weighted head