Crystalcareai commited on
Commit
6a35495
·
verified ·
1 Parent(s): beb979f

Update modeling_quiet.py

Browse files
Files changed (1) hide show
  1. modeling_quiet.py +2 -27
modeling_quiet.py CHANGED
@@ -1111,21 +1111,6 @@ class QuietForCausalLM(QuietPreTrainedModel, GenerationMixin):
1111
  logits = self.lm_head(mixed_hidden_states)
1112
  return logits
1113
 
1114
- def generate_with_callback(self, input_ids: torch.LongTensor = torch.LongTensor(), attention_mask: Optional[torch.Tensor] = None, max_new_tokens: Optional[int] = None, temperature: float = 1.1, callback=None, **kwargs):
1115
- if isinstance(input_ids, str):
1116
- input_ids = self.tokenizer(input_ids, return_tensors="pt").input_ids
1117
-
1118
- if attention_mask is None:
1119
- attention_mask = torch.ones_like(input_ids)
1120
-
1121
- from .generate import generate
1122
- generated_token_ids, generated_text = generate(self, input_ids, attention_mask=attention_mask, max_new_tokens=max_new_tokens, temperature=temperature, **kwargs)
1123
-
1124
- if callback is not None:
1125
- callback(generated_text)
1126
-
1127
- return generated_text
1128
-
1129
  @torch.no_grad()
1130
  def generate(
1131
  self,
@@ -1143,16 +1128,7 @@ class QuietForCausalLM(QuietPreTrainedModel, GenerationMixin):
1143
  attention_mask = torch.ones_like(input_ids)
1144
 
1145
  from .generate import generate
1146
- generated_token_ids, generated_text = generate(self, input_ids, attention_mask=attention_mask, max_new_tokens=max_new_tokens, temperature=temperature, **kwargs)
1147
-
1148
- # Convert the generated token IDs to a tensor
1149
- generated_token_ids = torch.tensor(generated_token_ids)
1150
-
1151
- # Return the generated text if it's a string, otherwise return the token IDs
1152
- if isinstance(generated_text, str):
1153
- return generated_text
1154
- else:
1155
- return generated_token_ids
1156
 
1157
  @add_start_docstrings_to_model_forward(QUIET_INPUTS_DOCSTRING)
1158
  @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
@@ -2084,5 +2060,4 @@ class QuietForSequenceClassification(QuietPreTrainedModel):
2084
  past_key_values=transformer_outputs.past_key_values,
2085
  hidden_states=transformer_outputs.hidden_states,
2086
  attentions=transformer_outputs.attentions,
2087
- )
2088
-
 
1111
  logits = self.lm_head(mixed_hidden_states)
1112
  return logits
1113
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1114
  @torch.no_grad()
1115
  def generate(
1116
  self,
 
1128
  attention_mask = torch.ones_like(input_ids)
1129
 
1130
  from .generate import generate
1131
+ return generate(self, input_ids, attention_mask=attention_mask, max_new_tokens=max_new_tokens, temperature=temperature, **kwargs)
 
 
 
 
 
 
 
 
 
1132
 
1133
  @add_start_docstrings_to_model_forward(QUIET_INPUTS_DOCSTRING)
1134
  @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
 
2060
  past_key_values=transformer_outputs.past_key_values,
2061
  hidden_states=transformer_outputs.hidden_states,
2062
  attentions=transformer_outputs.attentions,
2063
+ )