Crystalcareai commited on
Commit
723bd20
·
verified ·
1 Parent(s): 2f5cda6

Update generate.py

Browse files
Files changed (1) hide show
  1. generate.py +24 -0
generate.py CHANGED
@@ -129,6 +129,30 @@ def generate(
129
  synced_gpus=None,
130
  **model_kwargs,
131
  ):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
132
  streamer = TextStreamer(self.tokenizer, skip_prompt=False, skip_special_tokens=True)
133
 
134
  generated_token_ids, attention_mask = custom_generate(
 
129
  synced_gpus=None,
130
  **model_kwargs,
131
  ):
132
+
133
+ # Set model attributes
134
+ self.max_thoughts = n_ahead + n_ahead_talk + 1
135
+ self.merged_talk_heads = merged_talk_heads
136
+ self.merged_lm_and_talk_heads = merged_lm_and_talk_heads
137
+ self.merged_lm_and_think_heads = merged_lm_and_think_heads
138
+ self.use_concat_talk_head = use_concat_talk_head
139
+ self.use_shallow_think = use_shallow_think
140
+ self.use_shallow_talk = use_shallow_talk
141
+ self.use_complex_think_head = use_complex_think_head
142
+ self.use_complex_talk_head = use_complex_talk_head
143
+ self.use_weighted_talk_head = use_weighted_talk_head
144
+
145
+ # Set model properties
146
+ self.use_end_thought_token = True
147
+ self.use_start_thought_token = True
148
+ self.n_ahead = n_ahead
149
+ self.n_passes = 1
150
+ self.eval_mode = True
151
+ self.first_run = False
152
+ self.rm_initialized = True
153
+ self.original_mode = False
154
+
155
+
156
  streamer = TextStreamer(self.tokenizer, skip_prompt=False, skip_special_tokens=True)
157
 
158
  generated_token_ids, attention_mask = custom_generate(