Crystalcareai
commited on
Update modeling_quiet.py
Browse files- modeling_quiet.py +21 -23
modeling_quiet.py
CHANGED
@@ -1431,30 +1431,28 @@ class QuietForCausalLM(QuietPreTrainedModel, GenerationMixin):
|
|
1431 |
# Apply the language model head to get the final logits
|
1432 |
logits = self.lm_head(mixed_hidden_states)
|
1433 |
|
1434 |
-
if
|
1435 |
-
return
|
1436 |
-
|
1437 |
-
|
1438 |
-
|
1439 |
-
|
1440 |
-
|
1441 |
-
|
1442 |
-
)
|
1443 |
-
else:
|
1444 |
-
# Handle the case where next_token_id is not defined (e.g., continuation_length <= 0)
|
1445 |
-
# This part of the code needs to be adapted based on how you want to handle this scenario.
|
1446 |
-
# As a placeholder, returning the logits from the last state of the original input.
|
1447 |
-
logits = self.lm_head(hidden_states_before)
|
1448 |
-
|
1449 |
-
if not return_dict:
|
1450 |
return logits
|
1451 |
-
|
1452 |
-
|
1453 |
-
logits=
|
1454 |
-
|
1455 |
-
|
1456 |
-
|
1457 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
1458 |
|
1459 |
@torch.no_grad()
|
1460 |
def generate(
|
|
|
1431 |
# Apply the language model head to get the final logits
|
1432 |
logits = self.lm_head(mixed_hidden_states)
|
1433 |
|
1434 |
+
if return_dict:
|
1435 |
+
return BaseModelOutputWithPast(
|
1436 |
+
logits=logits,
|
1437 |
+
past_key_values=new_key_values,
|
1438 |
+
hidden_states=outputs_after.hidden_states if output_hidden_states else None,
|
1439 |
+
attentions=outputs_after.attentions if output_attentions else None,
|
1440 |
+
)
|
1441 |
+
else:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1442 |
return logits
|
1443 |
+
else:
|
1444 |
+
# Handle the case where next_token_id is not defined (e.g., continuation_length <= 0)
|
1445 |
+
logits = self.lm_head(hidden_states_before)
|
1446 |
+
|
1447 |
+
if return_dict:
|
1448 |
+
return BaseModelOutputWithPast(
|
1449 |
+
logits=logits,
|
1450 |
+
past_key_values=past_key_values,
|
1451 |
+
hidden_states=outputs_before.hidden_states if output_hidden_states else None,
|
1452 |
+
attentions=outputs_before.attentions if output_attentions else None,
|
1453 |
+
)
|
1454 |
+
else:
|
1455 |
+
return logits
|
1456 |
|
1457 |
@torch.no_grad()
|
1458 |
def generate(
|