Crystalcareai
commited on
Update modeling_quiet.py
Browse files- modeling_quiet.py +41 -19
modeling_quiet.py
CHANGED
@@ -1100,32 +1100,54 @@ class QuietForCausalLM(QuietPreTrainedModel, GenerationMixin):
|
|
1100 |
@torch.no_grad()
|
1101 |
def generate(
|
1102 |
self,
|
1103 |
-
input_ids
|
1104 |
-
attention_mask
|
1105 |
-
max_new_tokens
|
1106 |
-
|
1107 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1108 |
):
|
1109 |
-
|
1110 |
-
|
1111 |
-
|
1112 |
-
|
1113 |
-
|
1114 |
-
|
1115 |
-
|
1116 |
-
from .generate import generate
|
1117 |
-
|
1118 |
-
output = generate(
|
1119 |
self,
|
1120 |
-
input_ids,
|
1121 |
attention_mask=attention_mask,
|
1122 |
max_new_tokens=max_new_tokens,
|
1123 |
temperature=temperature,
|
1124 |
-
**
|
1125 |
)
|
1126 |
|
1127 |
-
return output.sequences
|
1128 |
-
|
1129 |
@add_start_docstrings_to_model_forward(QUIET_INPUTS_DOCSTRING)
|
1130 |
@replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
|
1131 |
def forward(
|
|
|
1100 |
@torch.no_grad()
|
1101 |
def generate(
|
1102 |
self,
|
1103 |
+
input_ids=None,
|
1104 |
+
attention_mask=None,
|
1105 |
+
max_new_tokens=None,
|
1106 |
+
min_length=None,
|
1107 |
+
do_sample=None,
|
1108 |
+
early_stopping=None,
|
1109 |
+
num_beams=None,
|
1110 |
+
temperature=1.0,
|
1111 |
+
top_k=None,
|
1112 |
+
top_p=None,
|
1113 |
+
repetition_penalty=None,
|
1114 |
+
bad_words_ids=None,
|
1115 |
+
bos_token_id=None,
|
1116 |
+
pad_token_id=None,
|
1117 |
+
eos_token_id=None,
|
1118 |
+
length_penalty=None,
|
1119 |
+
no_repeat_ngram_size=None,
|
1120 |
+
num_return_sequences=None,
|
1121 |
+
decoder_start_token_id=None,
|
1122 |
+
use_cache=None,
|
1123 |
+
num_beam_groups=None,
|
1124 |
+
diversity_penalty=None,
|
1125 |
+
prefix_allowed_tokens_fn=None,
|
1126 |
+
output_attentions=None,
|
1127 |
+
output_hidden_states=None,
|
1128 |
+
output_scores=None,
|
1129 |
+
return_dict_in_generate=None,
|
1130 |
+
forced_bos_token_id=None,
|
1131 |
+
forced_eos_token_id=None,
|
1132 |
+
remove_invalid_values=None,
|
1133 |
+
synced_gpus=None,
|
1134 |
+
**model_kwargs,
|
1135 |
):
|
1136 |
+
# Prepare the generation process with customized settings
|
1137 |
+
model_inputs = self.prepare_inputs_for_generation(
|
1138 |
+
input_ids, past_key_values=None, attention_mask=attention_mask, use_cache=use_cache, **model_kwargs
|
1139 |
+
)
|
1140 |
+
|
1141 |
+
# Call the external custom generation function, ensuring it's integrated properly
|
1142 |
+
return custom_generate(
|
|
|
|
|
|
|
1143 |
self,
|
1144 |
+
input_ids=input_ids,
|
1145 |
attention_mask=attention_mask,
|
1146 |
max_new_tokens=max_new_tokens,
|
1147 |
temperature=temperature,
|
1148 |
+
**model_kwargs
|
1149 |
)
|
1150 |
|
|
|
|
|
1151 |
@add_start_docstrings_to_model_forward(QUIET_INPUTS_DOCSTRING)
|
1152 |
@replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
|
1153 |
def forward(
|