Crystalcareai commited on
Commit
38421a3
·
verified ·
1 Parent(s): e5bb001

Update modeling_quiet.py

Browse files
Files changed (1) hide show
  1. modeling_quiet.py +41 -19
modeling_quiet.py CHANGED
@@ -1100,32 +1100,54 @@ class QuietForCausalLM(QuietPreTrainedModel, GenerationMixin):
1100
  @torch.no_grad()
1101
  def generate(
1102
  self,
1103
- input_ids: torch.LongTensor = torch.LongTensor(),
1104
- attention_mask: Optional[torch.Tensor] = None,
1105
- max_new_tokens: Optional[int] = None,
1106
- temperature: float = 1.1,
1107
- **kwargs,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1108
  ):
1109
- if isinstance(input_ids, str):
1110
- input_ids = self.tokenizer(input_ids, return_tensors="pt").input_ids
1111
-
1112
- if attention_mask is None:
1113
- # Create a default attention mask if not provided
1114
- attention_mask = torch.ones_like(input_ids)
1115
-
1116
- from .generate import generate
1117
-
1118
- output = generate(
1119
  self,
1120
- input_ids,
1121
  attention_mask=attention_mask,
1122
  max_new_tokens=max_new_tokens,
1123
  temperature=temperature,
1124
- **kwargs,
1125
  )
1126
 
1127
- return output.sequences
1128
-
1129
  @add_start_docstrings_to_model_forward(QUIET_INPUTS_DOCSTRING)
1130
  @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
1131
  def forward(
 
1100
  @torch.no_grad()
1101
  def generate(
1102
  self,
1103
+ input_ids=None,
1104
+ attention_mask=None,
1105
+ max_new_tokens=None,
1106
+ min_length=None,
1107
+ do_sample=None,
1108
+ early_stopping=None,
1109
+ num_beams=None,
1110
+ temperature=1.0,
1111
+ top_k=None,
1112
+ top_p=None,
1113
+ repetition_penalty=None,
1114
+ bad_words_ids=None,
1115
+ bos_token_id=None,
1116
+ pad_token_id=None,
1117
+ eos_token_id=None,
1118
+ length_penalty=None,
1119
+ no_repeat_ngram_size=None,
1120
+ num_return_sequences=None,
1121
+ decoder_start_token_id=None,
1122
+ use_cache=None,
1123
+ num_beam_groups=None,
1124
+ diversity_penalty=None,
1125
+ prefix_allowed_tokens_fn=None,
1126
+ output_attentions=None,
1127
+ output_hidden_states=None,
1128
+ output_scores=None,
1129
+ return_dict_in_generate=None,
1130
+ forced_bos_token_id=None,
1131
+ forced_eos_token_id=None,
1132
+ remove_invalid_values=None,
1133
+ synced_gpus=None,
1134
+ **model_kwargs,
1135
  ):
1136
+ # Prepare the generation process with customized settings
1137
+ model_inputs = self.prepare_inputs_for_generation(
1138
+ input_ids, past_key_values=None, attention_mask=attention_mask, use_cache=use_cache, **model_kwargs
1139
+ )
1140
+
1141
+ # Call the external custom generation function, ensuring it's integrated properly
1142
+ return custom_generate(
 
 
 
1143
  self,
1144
+ input_ids=input_ids,
1145
  attention_mask=attention_mask,
1146
  max_new_tokens=max_new_tokens,
1147
  temperature=temperature,
1148
+ **model_kwargs
1149
  )
1150
 
 
 
1151
  @add_start_docstrings_to_model_forward(QUIET_INPUTS_DOCSTRING)
1152
  @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
1153
  def forward(