Crystalcareai commited on
Commit
0d0f81b
·
verified ·
1 Parent(s): 7e816a3

Update generate.py

Browse files
Files changed (1) hide show
  1. generate.py +4 -8
generate.py CHANGED
@@ -52,7 +52,6 @@ def custom_generate(
52
  new_ids = self(
53
  input_ids[~finished_generating],
54
  attention_mask=attention_mask[~finished_generating] if attention_mask is not None else None,
55
- streamer=streamer,
56
  **kwargs
57
  )['logits']
58
 
@@ -88,6 +87,9 @@ def custom_generate(
88
  if new_ids_sampled == self.tokenizer.convert_tokens_to_ids("<|/assistant|>"):
89
  finished_generating[answer_idx] = 1
90
 
 
 
 
91
  generated_token_ids = input_ids.tolist()
92
 
93
  return generated_token_ids
@@ -165,12 +167,6 @@ def generate(
165
  self.rm_initialized = True
166
  self.original_mode = False
167
 
168
- # # Validate stopping criteria
169
- # stopping_criteria = self._get_stopping_criteria(generation_config=self.config, stopping_criteria=StoppingCriteriaList())
170
- # if stopping_criteria is None:
171
- # stopping_criteria = StoppingCriteriaList()
172
- # if max_length is not None:
173
- # stopping_criteria = validate_stopping_criteria(max_length, stopping_criteria=stopping_criteria)
174
  # Generate using the custom generate function
175
  generated_token_ids = custom_generate(
176
  self,
@@ -209,4 +205,4 @@ def generate(
209
  **model_kwargs,
210
  )
211
 
212
- return generated_token_ids
 
52
  new_ids = self(
53
  input_ids[~finished_generating],
54
  attention_mask=attention_mask[~finished_generating] if attention_mask is not None else None,
 
55
  **kwargs
56
  )['logits']
57
 
 
87
  if new_ids_sampled == self.tokenizer.convert_tokens_to_ids("<|/assistant|>"):
88
  finished_generating[answer_idx] = 1
89
 
90
+ if streamer is not None:
91
+ streamer.put(new_ids_sampled)
92
+
93
  generated_token_ids = input_ids.tolist()
94
 
95
  return generated_token_ids
 
167
  self.rm_initialized = True
168
  self.original_mode = False
169
 
 
 
 
 
 
 
170
  # Generate using the custom generate function
171
  generated_token_ids = custom_generate(
172
  self,
 
205
  **model_kwargs,
206
  )
207
 
208
+ return generated_token_ids