Crystalcareai commited on
Commit
d3a9a29
·
verified ·
1 Parent(s): 6df17e6

Update modeling_quiet.py

Browse files
Files changed (1) hide show
  1. modeling_quiet.py +4 -5
modeling_quiet.py CHANGED
@@ -1280,8 +1280,8 @@ class QuietForCausalLM(QuietPreTrainedModel):
1280
  # Generate the continuation
1281
  continuation_length = self.n_ahead - 2
1282
  new_key_values = past_key_values
1283
-
1284
  start_time = time.time()
 
1285
  for continuation_idx in range(continuation_length):
1286
  outputs = self.model(
1287
  input_ids=input_ids if continuation_idx == 0 else next_token_id.unsqueeze(-1).to(input_ids.device),
@@ -1295,9 +1295,7 @@ class QuietForCausalLM(QuietPreTrainedModel):
1295
  return_dict=return_dict,
1296
  )
1297
  new_key_values = outputs.past_key_values
1298
-
1299
  hidden_states = outputs[0]
1300
-
1301
  logits = self.lm_head(hidden_states)
1302
  logits = logits[:, -1, :] # Only consider the last token
1303
 
@@ -1336,9 +1334,9 @@ class QuietForCausalLM(QuietPreTrainedModel):
1336
  )
1337
  hidden_states_before = outputs_before[0][:, -1:, :]
1338
 
1339
- # two new tokens: last continuation token and end thought token
1340
  outputs_after = self.model(
1341
- input_ids=torch.cat([next_token_id.unsqueeze(-1).to(input_ids.device), torch.tensor(end_thought_token_id).unsqueeze(-1).unsqueeze(-1).to(input_ids.device)], dim=-1),
1342
  attention_mask=attention_mask,
1343
  position_ids=position_ids,
1344
  past_key_values=new_key_values,
@@ -1358,6 +1356,7 @@ class QuietForCausalLM(QuietPreTrainedModel):
1358
 
1359
  # Apply the language model head to get the final logits
1360
  logits = self.lm_head(mixed_hidden_states)
 
1361
  return logits
1362
 
1363
  @add_start_docstrings_to_model_forward(QUIET_INPUTS_DOCSTRING)
 
1280
  # Generate the continuation
1281
  continuation_length = self.n_ahead - 2
1282
  new_key_values = past_key_values
 
1283
  start_time = time.time()
1284
+
1285
  for continuation_idx in range(continuation_length):
1286
  outputs = self.model(
1287
  input_ids=input_ids if continuation_idx == 0 else next_token_id.unsqueeze(-1).to(input_ids.device),
 
1295
  return_dict=return_dict,
1296
  )
1297
  new_key_values = outputs.past_key_values
 
1298
  hidden_states = outputs[0]
 
1299
  logits = self.lm_head(hidden_states)
1300
  logits = logits[:, -1, :] # Only consider the last token
1301
 
 
1334
  )
1335
  hidden_states_before = outputs_before[0][:, -1:, :]
1336
 
1337
+ # Get the hidden states after the thought
1338
  outputs_after = self.model(
1339
+ input_ids=input_ids,
1340
  attention_mask=attention_mask,
1341
  position_ids=position_ids,
1342
  past_key_values=new_key_values,
 
1356
 
1357
  # Apply the language model head to get the final logits
1358
  logits = self.lm_head(mixed_hidden_states)
1359
+
1360
  return logits
1361
 
1362
  @add_start_docstrings_to_model_forward(QUIET_INPUTS_DOCSTRING)