Crystalcareai commited on
Commit
1aa3706
·
verified ·
1 Parent(s): 42469fd

Update modeling_quiet.py

Browse files
Files changed (1) hide show
  1. modeling_quiet.py +21 -23
modeling_quiet.py CHANGED
@@ -1431,30 +1431,28 @@ class QuietForCausalLM(QuietPreTrainedModel, GenerationMixin):
1431
  # Apply the language model head to get the final logits
1432
  logits = self.lm_head(mixed_hidden_states)
1433
 
1434
- if not return_dict:
1435
- return logits
1436
-
1437
- return BaseModelOutputWithPast(
1438
- logits=logits,
1439
- past_key_values=new_key_values,
1440
- hidden_states=outputs_after.hidden_states if output_hidden_states else None,
1441
- attentions=outputs_after.attentions if output_attentions else None,
1442
- )
1443
- else:
1444
- # Handle the case where next_token_id is not defined (e.g., continuation_length <= 0)
1445
- # This part of the code needs to be adapted based on how you want to handle this scenario.
1446
- # As a placeholder, returning the logits from the last state of the original input.
1447
- logits = self.lm_head(hidden_states_before)
1448
-
1449
- if not return_dict:
1450
  return logits
1451
-
1452
- return BaseModelOutputWithPast(
1453
- logits=logits,
1454
- past_key_values=past_key_values,
1455
- hidden_states=outputs_before.hidden_states if output_hidden_states else None,
1456
- attentions=outputs_before.attentions if output_attentions else None,
1457
- )
 
 
 
 
 
 
1458
 
1459
  @torch.no_grad()
1460
  def generate(
 
1431
  # Apply the language model head to get the final logits
1432
  logits = self.lm_head(mixed_hidden_states)
1433
 
1434
+ if return_dict:
1435
+ return BaseModelOutputWithPast(
1436
+ logits=logits,
1437
+ past_key_values=new_key_values,
1438
+ hidden_states=outputs_after.hidden_states if output_hidden_states else None,
1439
+ attentions=outputs_after.attentions if output_attentions else None,
1440
+ )
1441
+ else:
 
 
 
 
 
 
 
 
1442
  return logits
1443
+ else:
1444
+ # Handle the case where next_token_id is not defined (e.g., continuation_length <= 0)
1445
+ logits = self.lm_head(hidden_states_before)
1446
+
1447
+ if return_dict:
1448
+ return BaseModelOutputWithPast(
1449
+ logits=logits,
1450
+ past_key_values=past_key_values,
1451
+ hidden_states=outputs_before.hidden_states if output_hidden_states else None,
1452
+ attentions=outputs_before.attentions if output_attentions else None,
1453
+ )
1454
+ else:
1455
+ return logits
1456
 
1457
  @torch.no_grad()
1458
  def generate(