Crystalcareai
commited on
Update generate.py
Browse files- generate.py +6 -6
generate.py
CHANGED
@@ -103,10 +103,7 @@ def custom_generate(
|
|
103 |
if streamer is not None:
|
104 |
streamer.put(new_ids_sampled)
|
105 |
|
106 |
-
|
107 |
-
generated_text = self.tokenizer.decode(generated_token_ids[0], skip_special_tokens=False)
|
108 |
-
|
109 |
-
return generated_token_ids, generated_text
|
110 |
|
111 |
def generate(
|
112 |
self,
|
@@ -193,7 +190,7 @@ def generate(
|
|
193 |
if attention_mask is not None:
|
194 |
attention_mask = attention_mask.to(self.device)
|
195 |
|
196 |
-
generated_token_ids
|
197 |
self,
|
198 |
input_ids=input_ids,
|
199 |
attention_mask=attention_mask,
|
@@ -230,4 +227,7 @@ def generate(
|
|
230 |
**model_kwargs,
|
231 |
)
|
232 |
|
233 |
-
|
|
|
|
|
|
|
|
103 |
if streamer is not None:
|
104 |
streamer.put(new_ids_sampled)
|
105 |
|
106 |
+
return generated_token_ids
|
|
|
|
|
|
|
107 |
|
108 |
def generate(
|
109 |
self,
|
|
|
190 |
if attention_mask is not None:
|
191 |
attention_mask = attention_mask.to(self.device)
|
192 |
|
193 |
+
generated_token_ids = custom_generate(
|
194 |
self,
|
195 |
input_ids=input_ids,
|
196 |
attention_mask=attention_mask,
|
|
|
227 |
**model_kwargs,
|
228 |
)
|
229 |
|
230 |
+
# Convert the generated token IDs tensor to text
|
231 |
+
generated_text = self.tokenizer.decode(generated_token_ids[0], skip_special_tokens=False)
|
232 |
+
|
233 |
+
return generated_token_ids, generated_text
|