Crystalcareai commited on
Commit
94d0604
·
verified ·
1 Parent(s): 8b70e64

Update generate.py

Browse files
Files changed (1) hide show
  1. generate.py +15 -5
generate.py CHANGED
@@ -101,11 +101,10 @@ def generate(
101
  max_length=None,
102
  min_length=None,
103
  do_sample=None,
104
- n_ahead=12,
105
- n_ahead_talk=4,
106
  early_stopping=None,
107
  num_beams=None,
108
- temperature=0.9,
 
109
  top_k=None,
110
  top_p=None,
111
  repetition_penalty=None,
@@ -129,9 +128,21 @@ def generate(
129
  forced_eos_token_id=None,
130
  remove_invalid_values=None,
131
  synced_gpus=None,
 
 
 
 
 
 
 
 
 
 
 
 
 
132
  **model_kwargs,
133
  ):
134
-
135
  # Set model attributes
136
  self.max_thoughts = n_ahead + n_ahead_talk + 1
137
  self.merged_talk_heads = merged_talk_heads
@@ -154,7 +165,6 @@ def generate(
154
  self.rm_initialized = True
155
  self.original_mode = False
156
 
157
-
158
  streamer = TextStreamer(self.tokenizer, skip_prompt=False, skip_special_tokens=True)
159
 
160
  generated_token_ids, attention_mask = custom_generate(
 
101
  max_length=None,
102
  min_length=None,
103
  do_sample=None,
 
 
104
  early_stopping=None,
105
  num_beams=None,
106
+ temperature=1.1,
107
+ streamer=None,
108
  top_k=None,
109
  top_p=None,
110
  repetition_penalty=None,
 
128
  forced_eos_token_id=None,
129
  remove_invalid_values=None,
130
  synced_gpus=None,
131
+ n_ahead=12,
132
+ n_ahead_talk=4,
133
+ merged_talk_heads=True,
134
+ merged_lm_and_talk_heads=False,
135
+ merged_lm_and_think_heads=True,
136
+ use_concat_talk_head=True,
137
+ use_shallow_think=True,
138
+ use_shallow_talk=False,
139
+ use_complex_think_head=False,
140
+ use_complex_talk_head=True,
141
+ use_weighted_talk_head=True,
142
+ trust_remote_code=True,
143
+ torch_dtype=torch.bfloat16,
144
  **model_kwargs,
145
  ):
 
146
  # Set model attributes
147
  self.max_thoughts = n_ahead + n_ahead_talk + 1
148
  self.merged_talk_heads = merged_talk_heads
 
165
  self.rm_initialized = True
166
  self.original_mode = False
167
 
 
168
  streamer = TextStreamer(self.tokenizer, skip_prompt=False, skip_special_tokens=True)
169
 
170
  generated_token_ids, attention_mask = custom_generate(