Crystalcareai
commited on
Update modeling_quiet.py
Browse files- modeling_quiet.py +4 -5
modeling_quiet.py
CHANGED
@@ -1280,8 +1280,8 @@ class QuietForCausalLM(QuietPreTrainedModel):
|
|
1280 |
# Generate the continuation
|
1281 |
continuation_length = self.n_ahead - 2
|
1282 |
new_key_values = past_key_values
|
1283 |
-
|
1284 |
start_time = time.time()
|
|
|
1285 |
for continuation_idx in range(continuation_length):
|
1286 |
outputs = self.model(
|
1287 |
input_ids=input_ids if continuation_idx == 0 else next_token_id.unsqueeze(-1).to(input_ids.device),
|
@@ -1295,9 +1295,7 @@ class QuietForCausalLM(QuietPreTrainedModel):
|
|
1295 |
return_dict=return_dict,
|
1296 |
)
|
1297 |
new_key_values = outputs.past_key_values
|
1298 |
-
|
1299 |
hidden_states = outputs[0]
|
1300 |
-
|
1301 |
logits = self.lm_head(hidden_states)
|
1302 |
logits = logits[:, -1, :] # Only consider the last token
|
1303 |
|
@@ -1336,9 +1334,9 @@ class QuietForCausalLM(QuietPreTrainedModel):
|
|
1336 |
)
|
1337 |
hidden_states_before = outputs_before[0][:, -1:, :]
|
1338 |
|
1339 |
-
#
|
1340 |
outputs_after = self.model(
|
1341 |
-
input_ids=
|
1342 |
attention_mask=attention_mask,
|
1343 |
position_ids=position_ids,
|
1344 |
past_key_values=new_key_values,
|
@@ -1358,6 +1356,7 @@ class QuietForCausalLM(QuietPreTrainedModel):
|
|
1358 |
|
1359 |
# Apply the language model head to get the final logits
|
1360 |
logits = self.lm_head(mixed_hidden_states)
|
|
|
1361 |
return logits
|
1362 |
|
1363 |
@add_start_docstrings_to_model_forward(QUIET_INPUTS_DOCSTRING)
|
|
|
1280 |
# Generate the continuation
|
1281 |
continuation_length = self.n_ahead - 2
|
1282 |
new_key_values = past_key_values
|
|
|
1283 |
start_time = time.time()
|
1284 |
+
|
1285 |
for continuation_idx in range(continuation_length):
|
1286 |
outputs = self.model(
|
1287 |
input_ids=input_ids if continuation_idx == 0 else next_token_id.unsqueeze(-1).to(input_ids.device),
|
|
|
1295 |
return_dict=return_dict,
|
1296 |
)
|
1297 |
new_key_values = outputs.past_key_values
|
|
|
1298 |
hidden_states = outputs[0]
|
|
|
1299 |
logits = self.lm_head(hidden_states)
|
1300 |
logits = logits[:, -1, :] # Only consider the last token
|
1301 |
|
|
|
1334 |
)
|
1335 |
hidden_states_before = outputs_before[0][:, -1:, :]
|
1336 |
|
1337 |
+
# Get the hidden states after the thought
|
1338 |
outputs_after = self.model(
|
1339 |
+
input_ids=input_ids,
|
1340 |
attention_mask=attention_mask,
|
1341 |
position_ids=position_ids,
|
1342 |
past_key_values=new_key_values,
|
|
|
1356 |
|
1357 |
# Apply the language model head to get the final logits
|
1358 |
logits = self.lm_head(mixed_hidden_states)
|
1359 |
+
|
1360 |
return logits
|
1361 |
|
1362 |
@add_start_docstrings_to_model_forward(QUIET_INPUTS_DOCSTRING)
|