Crystalcareai
commited on
Update modeling_quiet.py
Browse files- modeling_quiet.py +21 -7
modeling_quiet.py
CHANGED
@@ -1425,15 +1425,29 @@ class QuietForCausalLM(QuietPreTrainedModel, GenerationMixin):
|
|
1425 |
logits = self.lm_head(mixed_hidden_states)
|
1426 |
return logits
|
1427 |
|
1428 |
-
|
1429 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1430 |
|
1431 |
-
|
1432 |
-
|
1433 |
|
1434 |
-
|
1435 |
-
def generate(self, **kwargs):
|
1436 |
-
return self.infer(**kwargs)
|
1437 |
|
1438 |
@add_start_docstrings_to_model_forward(QUIET_INPUTS_DOCSTRING)
|
1439 |
@replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
|
|
|
1425 |
logits = self.lm_head(mixed_hidden_states)
|
1426 |
return logits
|
1427 |
|
1428 |
+
|
1429 |
+
def generate(self, *args, **kwargs):
|
1430 |
+
# Save the original input_ids and attention_mask
|
1431 |
+
original_input_ids = kwargs.pop("input_ids", None)
|
1432 |
+
original_attention_mask = kwargs.pop("attention_mask", None)
|
1433 |
+
|
1434 |
+
# Call the infer method to get the logits
|
1435 |
+
logits = self.infer(
|
1436 |
+
input_ids=original_input_ids,
|
1437 |
+
attention_mask=original_attention_mask,
|
1438 |
+
position_ids=kwargs.pop("position_ids", None),
|
1439 |
+
past_key_values=kwargs.pop("past_key_values", None),
|
1440 |
+
inputs_embeds=kwargs.pop("inputs_embeds", None),
|
1441 |
+
use_cache=kwargs.pop("use_cache", None),
|
1442 |
+
output_attentions=kwargs.pop("output_attentions", None),
|
1443 |
+
output_hidden_states=kwargs.pop("output_hidden_states", None),
|
1444 |
+
return_dict=kwargs.pop("return_dict", None),
|
1445 |
+
)
|
1446 |
|
1447 |
+
# Generate output using the logits
|
1448 |
+
output_ids = torch.argmax(logits, dim=-1)
|
1449 |
|
1450 |
+
return output_ids
|
|
|
|
|
1451 |
|
1452 |
@add_start_docstrings_to_model_forward(QUIET_INPUTS_DOCSTRING)
|
1453 |
@replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
|