Crystalcareai commited on
Commit
14a93fc
·
verified ·
1 Parent(s): 952f2ad

Update generate.py

Browse files
Files changed (1) hide show
  1. generate.py +124 -1
generate.py CHANGED
@@ -104,4 +104,127 @@ def custom_generate(
104
  print("Generated Token IDs shape:", generated_token_ids.shape)
105
  print("Generated Token IDs:", generated_token_ids)
106
 
107
- return generated_token_ids
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
104
  print("Generated Token IDs shape:", generated_token_ids.shape)
105
  print("Generated Token IDs:", generated_token_ids)
106
 
107
+ return generated_token_ids
108
+
109
+ def generate(
110
+ self,
111
+ input_ids,
112
+ attention_mask=None,
113
+ max_new_tokens=None,
114
+ min_length=None,
115
+ do_sample=None,
116
+ early_stopping=None,
117
+ num_beams=None,
118
+ temperature=1.1,
119
+ streamer=None,
120
+ top_k=None,
121
+ top_p=None,
122
+ repetition_penalty=None,
123
+ bad_words_ids=None,
124
+ bos_token_id=None,
125
+ pad_token_id=None,
126
+ eos_token_id=None,
127
+ length_penalty=None,
128
+ no_repeat_ngram_size=None,
129
+ num_return_sequences=None,
130
+ decoder_start_token_id=None,
131
+ use_cache=None,
132
+ num_beam_groups=None,
133
+ diversity_penalty=None,
134
+ prefix_allowed_tokens_fn=None,
135
+ output_attentions=None,
136
+ output_hidden_states=None,
137
+ output_scores=None,
138
+ return_dict_in_generate=None,
139
+ forced_bos_token_id=None,
140
+ forced_eos_token_id=None,
141
+ remove_invalid_values=None,
142
+ synced_gpus=None,
143
+ n_ahead=4,
144
+ n_ahead_talk=4,
145
+ merged_talk_heads=True,
146
+ merged_lm_and_talk_heads=False,
147
+ merged_lm_and_think_heads=True,
148
+ use_concat_talk_head=True,
149
+ use_shallow_think=True,
150
+ use_shallow_talk=False,
151
+ use_complex_think_head=False,
152
+ use_complex_talk_head=True,
153
+ use_weighted_talk_head=True,
154
+ trust_remote_code=True,
155
+ torch_dtype=torch.bfloat16,
156
+ **model_kwargs,
157
+ ):
158
+
159
+ if max_new_tokens is None:
160
+ max_new_tokens = 128
161
+
162
+ # Set model attributes
163
+ self.max_thoughts = n_ahead + n_ahead_talk + 1
164
+ self.merged_talk_heads = merged_talk_heads
165
+ self.merged_lm_and_talk_heads = merged_lm_and_talk_heads
166
+ self.merged_lm_and_think_heads = merged_lm_and_think_heads
167
+ self.use_concat_talk_head = use_concat_talk_head
168
+ self.use_shallow_think = use_shallow_think
169
+ self.use_shallow_talk = use_shallow_talk
170
+ self.use_complex_think_head = use_complex_think_head
171
+ self.use_complex_talk_head = use_complex_talk_head
172
+ self.use_weighted_talk_head = use_weighted_talk_head
173
+
174
+ # Set model properties
175
+ self.use_end_thought_token = True
176
+ self.use_start_thought_token = True
177
+ self.n_ahead = n_ahead
178
+ self.n_passes = 1
179
+ self.eval_mode = True
180
+ self.first_run = False
181
+ self.rm_initialized = True
182
+ self.original_mode = False
183
+
184
+ # Check if the input is a string (for compatibility with text-generation-webui)
185
+ if isinstance(input_ids, str):
186
+ input_ids = self.tokenizer.encode(input_ids, return_tensors='pt')
187
+
188
+ # Move input_ids and attention_mask to the same device as the model
189
+ input_ids = input_ids.to(self.device)
190
+ if attention_mask is not None:
191
+ attention_mask = attention_mask.to(self.device)
192
+
193
+ generated_token_ids = custom_generate(
194
+ self,
195
+ input_ids=input_ids,
196
+ attention_mask=attention_mask,
197
+ max_new_tokens=max_new_tokens,
198
+ min_length=min_length,
199
+ do_sample=do_sample,
200
+ early_stopping=early_stopping,
201
+ num_beams=num_beams,
202
+ temperature=temperature,
203
+ top_k=top_k,
204
+ top_p=top_p,
205
+ repetition_penalty=repetition_penalty,
206
+ bad_words_ids=bad_words_ids,
207
+ bos_token_id=bos_token_id,
208
+ pad_token_id=pad_token_id,
209
+ eos_token_id=eos_token_id,
210
+ length_penalty=length_penalty,
211
+ no_repeat_ngram_size=no_repeat_ngram_size,
212
+ num_return_sequences=num_return_sequences,
213
+ decoder_start_token_id=decoder_start_token_id,
214
+ use_cache=use_cache,
215
+ num_beam_groups=num_beam_groups,
216
+ diversity_penalty=diversity_penalty,
217
+ prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
218
+ output_attentions=output_attentions,
219
+ output_hidden_states=output_hidden_states,
220
+ output_scores=output_scores,
221
+ return_dict_in_generate=return_dict_in_generate,
222
+ forced_bos_token_id=forced_bos_token_id,
223
+ forced_eos_token_id=forced_eos_token_id,
224
+ remove_invalid_values=remove_invalid_values,
225
+ synced_gpus=synced_gpus,
226
+ streamer=streamer,
227
+ **model_kwargs,
228
+ )
229
+
230
+ return generated_token_ids