Crystalcareai commited on
Commit
b47da4b
·
verified ·
1 Parent(s): 8c48db2

Update modeling_quiet.py

Browse files
Files changed (1) hide show
  1. modeling_quiet.py +11 -11
modeling_quiet.py CHANGED
@@ -1402,17 +1402,17 @@ class QuietForCausalLM(QuietPreTrainedModel, GenerationMixin):
1402
  hidden_states_before = outputs_before[0][:, -1:, :]
1403
 
1404
  # two new tokens: last continuation token and end thought token
1405
- outputs_after = self.model(
1406
- input_ids=torch.cat([next_token_id.unsqueeze(-1).to(input_ids.device), torch.tensor([[end_thought_token_id]] * batch_size).to(input_ids.device)], dim=-1),
1407
- attention_mask=torch.cat([attention_mask[:, -1:], torch.ones((batch_size, 1)).to(attention_mask.device)], dim=-1),
1408
- position_ids=position_ids,
1409
- past_key_values=new_key_values,
1410
- inputs_embeds=inputs_embeds,
1411
- use_cache=use_cache,
1412
- output_attentions=output_attentions,
1413
- output_hidden_states=output_hidden_states,
1414
- return_dict=return_dict,
1415
- )
1416
  hidden_states_after = outputs_after[0][:, -1:, :]
1417
 
1418
  # Apply the talk head to get the mixing weight
 
1402
  hidden_states_before = outputs_before[0][:, -1:, :]
1403
 
1404
  # two new tokens: last continuation token and end thought token
1405
+ outputs_after = self.model(
1406
+ input_ids=torch.cat([next_token_id.unsqueeze(-1).to(input_ids.device), torch.tensor([[end_thought_token_id]] * batch_size).to(input_ids.device)], dim=-1),
1407
+ attention_mask=torch.cat([attention_mask[:, -1:], torch.ones((batch_size, 1)).to(attention_mask.device)], dim=-1),
1408
+ position_ids=position_ids,
1409
+ past_key_values=new_key_values,
1410
+ inputs_embeds=inputs_embeds,
1411
+ use_cache=use_cache,
1412
+ output_attentions=output_attentions,
1413
+ output_hidden_states=output_hidden_states,
1414
+ return_dict=return_dict,
1415
+ )
1416
  hidden_states_after = outputs_after[0][:, -1:, :]
1417
 
1418
  # Apply the talk head to get the mixing weight