Crystalcareai commited on
Commit
169ec2e
·
verified ·
1 Parent(s): 7874fb0

Update modeling_quiet.py

Browse files
Files changed (1) hide show
  1. modeling_quiet.py +12 -4
modeling_quiet.py CHANGED
@@ -1424,10 +1424,18 @@ class QuietForCausalLM(QuietPreTrainedModel, GenerationMixin):
1424
  logits = self.lm_head(mixed_hidden_states)
1425
  return logits
1426
 
1427
- # def generate(self, input_ids, attention_mask=None, max_length=None, temperature=1.0, **kwargs):
1428
- # from .generate import generate
1429
- # return generate(self, input_ids=input_ids, attention_mask=attention_mask, max_length=max_length, temperature=temperature, **kwargs)
1430
-
 
 
 
 
 
 
 
 
1431
  @add_start_docstrings_to_model_forward(QUIET_INPUTS_DOCSTRING)
1432
  @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
1433
  def forward(
 
1424
  logits = self.lm_head(mixed_hidden_states)
1425
  return logits
1426
 
1427
+ @torch.no_grad()
1428
+ def generate(
1429
+ self,
1430
+ input_ids: torch.LongTensor,
1431
+ attention_mask: Optional[torch.Tensor] = None,
1432
+ max_length: Optional[int] = None,
1433
+ temperature: float = 1.1,
1434
+ **kwargs,
1435
+ ):
1436
+ from .generate import generate
1437
+ return generate(self, input_ids=input_ids, attention_mask=attention_mask, max_length=max_length, temperature=temperature, **kwargs)
1438
+
1439
  @add_start_docstrings_to_model_forward(QUIET_INPUTS_DOCSTRING)
1440
  @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
1441
  def forward(