Crystalcareai commited on
Commit
b5101f8
·
verified ·
1 Parent(s): bcd8e9c

Update modeling_quiet.py

Browse files
Files changed (1) hide show
  1. modeling_quiet.py +2 -3
modeling_quiet.py CHANGED
@@ -1421,10 +1421,9 @@ class QuietForCausalLM(QuietPreTrainedModel, GenerationMixin):
1421
  logits = self.lm_head(mixed_hidden_states)
1422
  return logits
1423
 
1424
- def generate(self, input_ids, attention_mask=None, max_length=None, temperature=1.0, **kwargs):
1425
  from .generate import generate
1426
- return generate(self, input_ids=input_ids, attention_mask=attention_mask, max_length=max_length, temperature=temperature, **kwargs)
1427
-
1428
  @add_start_docstrings_to_model_forward(QUIET_INPUTS_DOCSTRING)
1429
  @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
1430
  def forward(
 
1421
  logits = self.lm_head(mixed_hidden_states)
1422
  return logits
1423
 
1424
+ def generate(self, input_ids, attention_mask=None, max_length=None, temperature=1.0, stopping_criteria=None, **kwargs):
1425
  from .generate import generate
1426
+ return generate(self, input_ids=input_ids, attention_mask=attention_mask, max_length=max_length, temperature=temperature, stopping_criteria=stopping_criteria, **kwargs)
 
1427
  @add_start_docstrings_to_model_forward(QUIET_INPUTS_DOCSTRING)
1428
  @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
1429
  def forward(