Crystalcareai commited on
Commit
979e88c
·
verified ·
1 Parent(s): 16432e8

Update generate.py

Browse files
Files changed (1) hide show
  1. generate.py +6 -6
generate.py CHANGED
@@ -103,10 +103,7 @@ def custom_generate(
103
  if streamer is not None:
104
  streamer.put(new_ids_sampled)
105
 
106
- # Convert generated token IDs to text
107
- generated_text = self.tokenizer.decode(generated_token_ids[0], skip_special_tokens=False)
108
-
109
- return generated_token_ids, generated_text
110
 
111
  def generate(
112
  self,
@@ -193,7 +190,7 @@ def generate(
193
  if attention_mask is not None:
194
  attention_mask = attention_mask.to(self.device)
195
 
196
- generated_token_ids, generated_text = custom_generate(
197
  self,
198
  input_ids=input_ids,
199
  attention_mask=attention_mask,
@@ -230,4 +227,7 @@ def generate(
230
  **model_kwargs,
231
  )
232
 
233
- return generated_token_ids, generated_text
 
 
 
 
103
  if streamer is not None:
104
  streamer.put(new_ids_sampled)
105
 
106
+ return generated_token_ids
 
 
 
107
 
108
  def generate(
109
  self,
 
190
  if attention_mask is not None:
191
  attention_mask = attention_mask.to(self.device)
192
 
193
+ generated_token_ids = custom_generate(
194
  self,
195
  input_ids=input_ids,
196
  attention_mask=attention_mask,
 
227
  **model_kwargs,
228
  )
229
 
230
+ # Convert the generated token IDs tensor to text
231
+ generated_text = self.tokenizer.decode(generated_token_ids[0], skip_special_tokens=False)
232
+
233
+ return generated_token_ids, generated_text