Crystalcareai
commited on
Update generate.py
Browse files- generate.py +4 -8
generate.py
CHANGED
@@ -52,7 +52,6 @@ def custom_generate(
|
|
52 |
new_ids = self(
|
53 |
input_ids[~finished_generating],
|
54 |
attention_mask=attention_mask[~finished_generating] if attention_mask is not None else None,
|
55 |
-
streamer=streamer,
|
56 |
**kwargs
|
57 |
)['logits']
|
58 |
|
@@ -88,6 +87,9 @@ def custom_generate(
|
|
88 |
if new_ids_sampled == self.tokenizer.convert_tokens_to_ids("<|/assistant|>"):
|
89 |
finished_generating[answer_idx] = 1
|
90 |
|
|
|
|
|
|
|
91 |
generated_token_ids = input_ids.tolist()
|
92 |
|
93 |
return generated_token_ids
|
@@ -165,12 +167,6 @@ def generate(
|
|
165 |
self.rm_initialized = True
|
166 |
self.original_mode = False
|
167 |
|
168 |
-
# # Validate stopping criteria
|
169 |
-
# stopping_criteria = self._get_stopping_criteria(generation_config=self.config, stopping_criteria=StoppingCriteriaList())
|
170 |
-
# if stopping_criteria is None:
|
171 |
-
# stopping_criteria = StoppingCriteriaList()
|
172 |
-
# if max_length is not None:
|
173 |
-
# stopping_criteria = validate_stopping_criteria(max_length, stopping_criteria=stopping_criteria)
|
174 |
# Generate using the custom generate function
|
175 |
generated_token_ids = custom_generate(
|
176 |
self,
|
@@ -209,4 +205,4 @@ def generate(
|
|
209 |
**model_kwargs,
|
210 |
)
|
211 |
|
212 |
-
return generated_token_ids
|
|
|
52 |
new_ids = self(
|
53 |
input_ids[~finished_generating],
|
54 |
attention_mask=attention_mask[~finished_generating] if attention_mask is not None else None,
|
|
|
55 |
**kwargs
|
56 |
)['logits']
|
57 |
|
|
|
87 |
if new_ids_sampled == self.tokenizer.convert_tokens_to_ids("<|/assistant|>"):
|
88 |
finished_generating[answer_idx] = 1
|
89 |
|
90 |
+
if streamer is not None:
|
91 |
+
streamer.put(new_ids_sampled)
|
92 |
+
|
93 |
generated_token_ids = input_ids.tolist()
|
94 |
|
95 |
return generated_token_ids
|
|
|
167 |
self.rm_initialized = True
|
168 |
self.original_mode = False
|
169 |
|
|
|
|
|
|
|
|
|
|
|
|
|
170 |
# Generate using the custom generate function
|
171 |
generated_token_ids = custom_generate(
|
172 |
self,
|
|
|
205 |
**model_kwargs,
|
206 |
)
|
207 |
|
208 |
+
return generated_token_ids
|