Crystalcareai
commited on
Update modeling_quiet.py
Browse files- modeling_quiet.py +3 -6
modeling_quiet.py
CHANGED
@@ -1426,15 +1426,12 @@ class QuietForCausalLM(QuietPreTrainedModel, GenerationMixin):
|
|
1426 |
return logits
|
1427 |
|
1428 |
|
|
|
1429 |
def generate(self, *args, **kwargs):
|
1430 |
-
# Save the original input_ids and attention_mask
|
1431 |
-
original_input_ids = kwargs.pop("input_ids", None)
|
1432 |
-
original_attention_mask = kwargs.pop("attention_mask", None)
|
1433 |
-
|
1434 |
# Call the infer method to get the logits
|
1435 |
logits = self.infer(
|
1436 |
-
input_ids=
|
1437 |
-
attention_mask=
|
1438 |
position_ids=kwargs.pop("position_ids", None),
|
1439 |
past_key_values=kwargs.pop("past_key_values", None),
|
1440 |
inputs_embeds=kwargs.pop("inputs_embeds", None),
|
|
|
1426 |
return logits
|
1427 |
|
1428 |
|
1429 |
+
@torch.no_grad()
|
1430 |
def generate(self, *args, **kwargs):
|
|
|
|
|
|
|
|
|
1431 |
# Call the infer method to get the logits
|
1432 |
logits = self.infer(
|
1433 |
+
input_ids=kwargs.pop("input_ids", None),
|
1434 |
+
attention_mask=kwargs.pop("attention_mask", None),
|
1435 |
position_ids=kwargs.pop("position_ids", None),
|
1436 |
past_key_values=kwargs.pop("past_key_values", None),
|
1437 |
inputs_embeds=kwargs.pop("inputs_embeds", None),
|