Crystalcareai
commited on
Update modeling_quiet.py
Browse files- modeling_quiet.py +11 -11
modeling_quiet.py
CHANGED
@@ -1402,17 +1402,17 @@ class QuietForCausalLM(QuietPreTrainedModel, GenerationMixin):
|
|
1402 |
hidden_states_before = outputs_before[0][:, -1:, :]
|
1403 |
|
1404 |
# two new tokens: last continuation token and end thought token
|
1405 |
-
|
1406 |
-
|
1407 |
-
|
1408 |
-
|
1409 |
-
|
1410 |
-
|
1411 |
-
|
1412 |
-
|
1413 |
-
|
1414 |
-
|
1415 |
-
|
1416 |
hidden_states_after = outputs_after[0][:, -1:, :]
|
1417 |
|
1418 |
# Apply the talk head to get the mixing weight
|
|
|
1402 |
hidden_states_before = outputs_before[0][:, -1:, :]
|
1403 |
|
1404 |
# two new tokens: last continuation token and end thought token
|
1405 |
+
outputs_after = self.model(
|
1406 |
+
input_ids=torch.cat([next_token_id.unsqueeze(-1).to(input_ids.device), torch.tensor([[end_thought_token_id]] * batch_size).to(input_ids.device)], dim=-1),
|
1407 |
+
attention_mask=torch.cat([attention_mask[:, -1:], torch.ones((batch_size, 1)).to(attention_mask.device)], dim=-1),
|
1408 |
+
position_ids=position_ids,
|
1409 |
+
past_key_values=new_key_values,
|
1410 |
+
inputs_embeds=inputs_embeds,
|
1411 |
+
use_cache=use_cache,
|
1412 |
+
output_attentions=output_attentions,
|
1413 |
+
output_hidden_states=output_hidden_states,
|
1414 |
+
return_dict=return_dict,
|
1415 |
+
)
|
1416 |
hidden_states_after = outputs_after[0][:, -1:, :]
|
1417 |
|
1418 |
# Apply the talk head to get the mixing weight
|