Crystalcareai
commited on
Update modeling_quiet.py
Browse files- modeling_quiet.py +3 -25
modeling_quiet.py
CHANGED
@@ -1262,6 +1262,7 @@ class QuietForCausalLM(QuietPreTrainedModel):
|
|
1262 |
|
1263 |
# For visualization
|
1264 |
self.eval_mode = False
|
|
|
1265 |
num_talk = 1
|
1266 |
talk_input_dim = config.hidden_size if not self.use_concat_talk_head else config.hidden_size * 2
|
1267 |
if self.use_weighted_talk_head:
|
@@ -1283,16 +1284,6 @@ class QuietForCausalLM(QuietPreTrainedModel):
|
|
1283 |
nn.Linear(talk_input_dim, talk_output_dim, bias=False)
|
1284 |
)])
|
1285 |
|
1286 |
-
# Add batch normalization to the model
|
1287 |
-
self.bn_lm_head = nn.BatchNorm1d(config.vocab_size)
|
1288 |
-
self.bn_talk_head = nn.BatchNorm1d(talk_output_dim)
|
1289 |
-
|
1290 |
-
# Initialize weights using Xavier initialization
|
1291 |
-
self.apply(self._init_weights)
|
1292 |
-
|
1293 |
-
# Add dropout regularization
|
1294 |
-
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
1295 |
-
|
1296 |
# Initialize weights and apply final processing
|
1297 |
self.post_init()
|
1298 |
|
@@ -1313,14 +1304,6 @@ class QuietForCausalLM(QuietPreTrainedModel):
|
|
1313 |
|
1314 |
def get_decoder(self):
|
1315 |
return self.model
|
1316 |
-
|
1317 |
-
def _init_weights(self, module):
|
1318 |
-
if isinstance(module, nn.Linear):
|
1319 |
-
nn.init.xavier_uniform_(module.weight)
|
1320 |
-
if module.bias is not None:
|
1321 |
-
nn.init.constant_(module.bias, 0)
|
1322 |
-
elif isinstance(module, nn.Embedding):
|
1323 |
-
nn.init.xavier_uniform_(module.weight)
|
1324 |
|
1325 |
@torch.no_grad()
|
1326 |
def infer(
|
@@ -1719,11 +1702,10 @@ class QuietForCausalLM(QuietPreTrainedModel):
|
|
1719 |
hidden_states = outputs[0]
|
1720 |
prev_rm_logits = rm_logits # for policy gradient
|
1721 |
prev_rm_tokens = cur_rm_tokens # for policy gradient
|
|
|
1722 |
if ahead_idx == 0:
|
1723 |
hidden_states_lm = hidden_states
|
1724 |
logits = self.lm_head(hidden_states_lm)
|
1725 |
-
logits = self.bn_lm_head(logits.transpose(1, 2)).transpose(1, 2)
|
1726 |
-
logits = self.dropout(logits)
|
1727 |
base_hidden_states = hidden_states.clone()
|
1728 |
initial_loss_logits = logits.clone()
|
1729 |
if self.optimize_lm_head_only_at_start or self.optimize_model_only_at_start:
|
@@ -1754,12 +1736,8 @@ class QuietForCausalLM(QuietPreTrainedModel):
|
|
1754 |
head_input_hidden_states = talk_hidden_states
|
1755 |
|
1756 |
residual_logits = self.talk_head[0](head_input_hidden_states)
|
1757 |
-
residual_logits = self.bn_talk_head(residual_logits.transpose(1, 2)).transpose(1, 2)
|
1758 |
-
residual_logits = self.dropout(residual_logits)
|
1759 |
if self.use_shallow_talk:
|
1760 |
-
residual_logits = apply_head(self.lm_head, residual_logits, detach=self.optimize_lm_head_only_at_start)
|
1761 |
-
residual_logits = self.bn_lm_head(residual_logits.transpose(1, 2)).transpose(1, 2)
|
1762 |
-
residual_logits = self.dropout(residual_logits)
|
1763 |
residual_logits = residual_logits.to(logits.device)
|
1764 |
if self.use_weighted_talk_head:
|
1765 |
# combine the cur_base_hidden with the talk_hidden_states according to the weighted head
|
|
|
1262 |
|
1263 |
# For visualization
|
1264 |
self.eval_mode = False
|
1265 |
+
|
1266 |
num_talk = 1
|
1267 |
talk_input_dim = config.hidden_size if not self.use_concat_talk_head else config.hidden_size * 2
|
1268 |
if self.use_weighted_talk_head:
|
|
|
1284 |
nn.Linear(talk_input_dim, talk_output_dim, bias=False)
|
1285 |
)])
|
1286 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1287 |
# Initialize weights and apply final processing
|
1288 |
self.post_init()
|
1289 |
|
|
|
1304 |
|
1305 |
def get_decoder(self):
|
1306 |
return self.model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1307 |
|
1308 |
@torch.no_grad()
|
1309 |
def infer(
|
|
|
1702 |
hidden_states = outputs[0]
|
1703 |
prev_rm_logits = rm_logits # for policy gradient
|
1704 |
prev_rm_tokens = cur_rm_tokens # for policy gradient
|
1705 |
+
|
1706 |
if ahead_idx == 0:
|
1707 |
hidden_states_lm = hidden_states
|
1708 |
logits = self.lm_head(hidden_states_lm)
|
|
|
|
|
1709 |
base_hidden_states = hidden_states.clone()
|
1710 |
initial_loss_logits = logits.clone()
|
1711 |
if self.optimize_lm_head_only_at_start or self.optimize_model_only_at_start:
|
|
|
1736 |
head_input_hidden_states = talk_hidden_states
|
1737 |
|
1738 |
residual_logits = self.talk_head[0](head_input_hidden_states)
|
|
|
|
|
1739 |
if self.use_shallow_talk:
|
1740 |
+
residual_logits = apply_head(self.lm_head, residual_logits, detach=self.optimize_lm_head_only_at_start)
|
|
|
|
|
1741 |
residual_logits = residual_logits.to(logits.device)
|
1742 |
if self.use_weighted_talk_head:
|
1743 |
# combine the cur_base_hidden with the talk_hidden_states according to the weighted head
|