Crystalcareai
commited on
Update modeling_quiet.py
Browse files- modeling_quiet.py +12 -4
modeling_quiet.py
CHANGED
@@ -1424,10 +1424,18 @@ class QuietForCausalLM(QuietPreTrainedModel, GenerationMixin):
|
|
1424 |
logits = self.lm_head(mixed_hidden_states)
|
1425 |
return logits
|
1426 |
|
1427 |
-
|
1428 |
-
|
1429 |
-
|
1430 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1431 |
@add_start_docstrings_to_model_forward(QUIET_INPUTS_DOCSTRING)
|
1432 |
@replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
|
1433 |
def forward(
|
|
|
1424 |
logits = self.lm_head(mixed_hidden_states)
|
1425 |
return logits
|
1426 |
|
1427 |
+
@torch.no_grad()
|
1428 |
+
def generate(
|
1429 |
+
self,
|
1430 |
+
input_ids: torch.LongTensor,
|
1431 |
+
attention_mask: Optional[torch.Tensor] = None,
|
1432 |
+
max_length: Optional[int] = None,
|
1433 |
+
temperature: float = 1.1,
|
1434 |
+
**kwargs,
|
1435 |
+
):
|
1436 |
+
from .generate import generate
|
1437 |
+
return generate(self, input_ids=input_ids, attention_mask=attention_mask, max_length=max_length, temperature=temperature, **kwargs)
|
1438 |
+
|
1439 |
@add_start_docstrings_to_model_forward(QUIET_INPUTS_DOCSTRING)
|
1440 |
@replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
|
1441 |
def forward(
|