Crystalcareai commited on
Commit
e8b350b
·
verified ·
1 Parent(s): fc5b9b3

Update modeling_quiet.py

Browse files
Files changed (1) hide show
  1. modeling_quiet.py +6 -6
modeling_quiet.py CHANGED
@@ -1427,18 +1427,18 @@ class QuietForCausalLM(QuietPreTrainedModel, GenerationMixin):
1427
  @torch.no_grad()
1428
  def generate(
1429
  self,
1430
- input_ids: torch.LongTensor,
1431
- attention_mask: Optional[torch.Tensor] = None,
1432
- max_new_tokens: Optional[int] = None,
1433
- temperature: float = 1.1,
1434
  **kwargs,
1435
- ) -> torch.LongTensor:
1436
  if attention_mask is None:
1437
  # Create a default attention mask if not provided
1438
  attention_mask = torch.ones_like(input_ids)
1439
 
1440
  from .generate import generate
1441
- return generate(self, input_ids=input_ids, attention_mask=attention_mask, max_new_tokens=max_new_tokens, temperature=temperature, **kwargs)
1442
 
1443
  @add_start_docstrings_to_model_forward(QUIET_INPUTS_DOCSTRING)
1444
  @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
 
1427
  @torch.no_grad()
1428
  def generate(
1429
  self,
1430
+ input_ids,
1431
+ attention_mask=None,
1432
+ max_new_tokens=None,
1433
+ temperature=1.1,
1434
  **kwargs,
1435
+ ):
1436
  if attention_mask is None:
1437
  # Create a default attention mask if not provided
1438
  attention_mask = torch.ones_like(input_ids)
1439
 
1440
  from .generate import generate
1441
+ return generate(self, input_ids, attention_mask=attention_mask, max_new_tokens=max_new_tokens, temperature=temperature, **kwargs)
1442
 
1443
  @add_start_docstrings_to_model_forward(QUIET_INPUTS_DOCSTRING)
1444
  @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)