Crystalcareai commited on
Commit
b52f8ef
·
verified ·
1 Parent(s): 5ce32b8

Update modeling_quiet.py

Browse files
Files changed (1) hide show
  1. modeling_quiet.py +3 -6
modeling_quiet.py CHANGED
@@ -1426,15 +1426,12 @@ class QuietForCausalLM(QuietPreTrainedModel, GenerationMixin):
1426
  return logits
1427
 
1428
 
 
1429
  def generate(self, *args, **kwargs):
1430
- # Save the original input_ids and attention_mask
1431
- original_input_ids = kwargs.pop("input_ids", None)
1432
- original_attention_mask = kwargs.pop("attention_mask", None)
1433
-
1434
  # Call the infer method to get the logits
1435
  logits = self.infer(
1436
- input_ids=original_input_ids,
1437
- attention_mask=original_attention_mask,
1438
  position_ids=kwargs.pop("position_ids", None),
1439
  past_key_values=kwargs.pop("past_key_values", None),
1440
  inputs_embeds=kwargs.pop("inputs_embeds", None),
 
1426
  return logits
1427
 
1428
 
1429
+ @torch.no_grad()
1430
  def generate(self, *args, **kwargs):
 
 
 
 
1431
  # Call the infer method to get the logits
1432
  logits = self.infer(
1433
+ input_ids=kwargs.pop("input_ids", None),
1434
+ attention_mask=kwargs.pop("attention_mask", None),
1435
  position_ids=kwargs.pop("position_ids", None),
1436
  past_key_values=kwargs.pop("past_key_values", None),
1437
  inputs_embeds=kwargs.pop("inputs_embeds", None),