Crystalcareai
commited on
Update modeling_quiet.py
Browse files- modeling_quiet.py +6 -6
modeling_quiet.py
CHANGED
@@ -1427,18 +1427,18 @@ class QuietForCausalLM(QuietPreTrainedModel, GenerationMixin):
|
|
1427 |
@torch.no_grad()
|
1428 |
def generate(
|
1429 |
self,
|
1430 |
-
input_ids
|
1431 |
-
attention_mask
|
1432 |
-
max_new_tokens
|
1433 |
-
temperature
|
1434 |
**kwargs,
|
1435 |
-
)
|
1436 |
if attention_mask is None:
|
1437 |
# Create a default attention mask if not provided
|
1438 |
attention_mask = torch.ones_like(input_ids)
|
1439 |
|
1440 |
from .generate import generate
|
1441 |
-
return generate(self, input_ids
|
1442 |
|
1443 |
@add_start_docstrings_to_model_forward(QUIET_INPUTS_DOCSTRING)
|
1444 |
@replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
|
|
|
1427 |
@torch.no_grad()
|
1428 |
def generate(
|
1429 |
self,
|
1430 |
+
input_ids,
|
1431 |
+
attention_mask=None,
|
1432 |
+
max_new_tokens=None,
|
1433 |
+
temperature=1.1,
|
1434 |
**kwargs,
|
1435 |
+
):
|
1436 |
if attention_mask is None:
|
1437 |
# Create a default attention mask if not provided
|
1438 |
attention_mask = torch.ones_like(input_ids)
|
1439 |
|
1440 |
from .generate import generate
|
1441 |
+
return generate(self, input_ids, attention_mask=attention_mask, max_new_tokens=max_new_tokens, temperature=temperature, **kwargs)
|
1442 |
|
1443 |
@add_start_docstrings_to_model_forward(QUIET_INPUTS_DOCSTRING)
|
1444 |
@replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
|