Crystalcareai commited on
Commit
16432e8
·
verified ·
1 Parent(s): 29d3cfe

Update modeling_quiet.py

Browse files
Files changed (1) hide show
  1. modeling_quiet.py +27 -8
modeling_quiet.py CHANGED
@@ -1127,14 +1127,32 @@ class QuietForCausalLM(QuietPreTrainedModel, GenerationMixin):
1127
  return generated_text
1128
 
1129
  @torch.no_grad()
1130
- def generate(self, input_ids: torch.LongTensor = torch.LongTensor(), attention_mask: Optional[torch.Tensor] = None, max_new_tokens: Optional[int] = None, temperature: float = 1.1, **kwargs):
1131
- return self.generate_with_callback(input_ids, attention_mask, max_new_tokens, temperature, callback=None, **kwargs)
1132
-
1133
- def generate_with_streaming(self, input_ids: torch.LongTensor = torch.LongTensor(), attention_mask: Optional[torch.Tensor] = None, max_new_tokens: Optional[int] = None, temperature: float = 1.1, **kwargs):
1134
- def callback(generated_text):
1135
- yield generated_text
 
 
 
 
 
 
 
 
1136
 
1137
- return self.generate_with_callback(input_ids, attention_mask, max_new_tokens, temperature, callback=callback, **kwargs)
 
 
 
 
 
 
 
 
 
 
1138
 
1139
  @add_start_docstrings_to_model_forward(QUIET_INPUTS_DOCSTRING)
1140
  @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
@@ -2066,4 +2084,5 @@ class QuietForSequenceClassification(QuietPreTrainedModel):
2066
  past_key_values=transformer_outputs.past_key_values,
2067
  hidden_states=transformer_outputs.hidden_states,
2068
  attentions=transformer_outputs.attentions,
2069
- )
 
 
1127
  return generated_text
1128
 
1129
  @torch.no_grad()
1130
+ def generate(
1131
+ self,
1132
+ input_ids: torch.LongTensor = torch.LongTensor(),
1133
+ attention_mask: Optional[torch.Tensor] = None,
1134
+ max_new_tokens: Optional[int] = None,
1135
+ temperature: float = 1.1,
1136
+ **kwargs,
1137
+ ):
1138
+ if isinstance(input_ids, str):
1139
+ input_ids = self.tokenizer(input_ids, return_tensors="pt").input_ids
1140
+
1141
+ if attention_mask is None:
1142
+ # Create a default attention mask if not provided
1143
+ attention_mask = torch.ones_like(input_ids)
1144
 
1145
+ from .generate import generate
1146
+ generated_token_ids, generated_text = generate(self, input_ids, attention_mask=attention_mask, max_new_tokens=max_new_tokens, temperature=temperature, **kwargs)
1147
+
1148
+ # Convert the generated token IDs to a tensor
1149
+ generated_token_ids = torch.tensor(generated_token_ids)
1150
+
1151
+ # Return the generated text if it's a string, otherwise return the token IDs
1152
+ if isinstance(generated_text, str):
1153
+ return generated_text
1154
+ else:
1155
+ return generated_token_ids
1156
 
1157
  @add_start_docstrings_to_model_forward(QUIET_INPUTS_DOCSTRING)
1158
  @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
 
2084
  past_key_values=transformer_outputs.past_key_values,
2085
  hidden_states=transformer_outputs.hidden_states,
2086
  attentions=transformer_outputs.attentions,
2087
+ )
2088
+