Crystalcareai
commited on
Update modeling_quiet.py
Browse files- modeling_quiet.py +27 -8
modeling_quiet.py
CHANGED
@@ -1127,14 +1127,32 @@ class QuietForCausalLM(QuietPreTrainedModel, GenerationMixin):
|
|
1127 |
return generated_text
|
1128 |
|
1129 |
@torch.no_grad()
|
1130 |
-
def generate(
|
1131 |
-
|
1132 |
-
|
1133 |
-
|
1134 |
-
|
1135 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1136 |
|
1137 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1138 |
|
1139 |
@add_start_docstrings_to_model_forward(QUIET_INPUTS_DOCSTRING)
|
1140 |
@replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
|
@@ -2066,4 +2084,5 @@ class QuietForSequenceClassification(QuietPreTrainedModel):
|
|
2066 |
past_key_values=transformer_outputs.past_key_values,
|
2067 |
hidden_states=transformer_outputs.hidden_states,
|
2068 |
attentions=transformer_outputs.attentions,
|
2069 |
-
)
|
|
|
|
1127 |
return generated_text
|
1128 |
|
1129 |
@torch.no_grad()
|
1130 |
+
def generate(
|
1131 |
+
self,
|
1132 |
+
input_ids: torch.LongTensor = torch.LongTensor(),
|
1133 |
+
attention_mask: Optional[torch.Tensor] = None,
|
1134 |
+
max_new_tokens: Optional[int] = None,
|
1135 |
+
temperature: float = 1.1,
|
1136 |
+
**kwargs,
|
1137 |
+
):
|
1138 |
+
if isinstance(input_ids, str):
|
1139 |
+
input_ids = self.tokenizer(input_ids, return_tensors="pt").input_ids
|
1140 |
+
|
1141 |
+
if attention_mask is None:
|
1142 |
+
# Create a default attention mask if not provided
|
1143 |
+
attention_mask = torch.ones_like(input_ids)
|
1144 |
|
1145 |
+
from .generate import generate
|
1146 |
+
generated_token_ids, generated_text = generate(self, input_ids, attention_mask=attention_mask, max_new_tokens=max_new_tokens, temperature=temperature, **kwargs)
|
1147 |
+
|
1148 |
+
# Convert the generated token IDs to a tensor
|
1149 |
+
generated_token_ids = torch.tensor(generated_token_ids)
|
1150 |
+
|
1151 |
+
# Return the generated text if it's a string, otherwise return the token IDs
|
1152 |
+
if isinstance(generated_text, str):
|
1153 |
+
return generated_text
|
1154 |
+
else:
|
1155 |
+
return generated_token_ids
|
1156 |
|
1157 |
@add_start_docstrings_to_model_forward(QUIET_INPUTS_DOCSTRING)
|
1158 |
@replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
|
|
|
2084 |
past_key_values=transformer_outputs.past_key_values,
|
2085 |
hidden_states=transformer_outputs.hidden_states,
|
2086 |
attentions=transformer_outputs.attentions,
|
2087 |
+
)
|
2088 |
+
|