Crystalcareai
commited on
Update generate.py
Browse files- generate.py +12 -49
generate.py
CHANGED
@@ -1,12 +1,10 @@
|
|
1 |
import torch
|
2 |
-
from transformers.utils import logging
|
3 |
from transformers.generation.utils import (
|
4 |
GenerationMixin,
|
5 |
validate_stopping_criteria,
|
6 |
StoppingCriteriaList,
|
7 |
)
|
8 |
-
|
9 |
-
logger = logging.get_logger(__name__)
|
10 |
|
11 |
|
12 |
def custom_generate(
|
@@ -45,8 +43,9 @@ def custom_generate(
|
|
45 |
synced_gpus=None,
|
46 |
**kwargs,
|
47 |
):
|
|
|
48 |
with torch.no_grad():
|
49 |
-
finished_generating = torch.zeros(len(input_ids), dtype=torch.bool, device=
|
50 |
|
51 |
while not finished_generating.all() and input_ids.shape[1] < max_length:
|
52 |
# Sample the next token
|
@@ -72,7 +71,7 @@ def custom_generate(
|
|
72 |
if last_token_idx + 1 >= len(base_answer_ids):
|
73 |
# Add padding everywhere
|
74 |
new_padding = torch.full((len(input_ids), 1), self.tokenizer.pad_token_id, dtype=torch.long,
|
75 |
-
device=
|
76 |
input_ids = torch.cat([input_ids, new_padding], dim=-1)
|
77 |
if attention_mask is not None:
|
78 |
attention_mask = torch.cat([attention_mask, torch.zeros_like(new_padding)], dim=-1)
|
@@ -92,8 +91,7 @@ def custom_generate(
|
|
92 |
streamer.put(new_ids_sampled)
|
93 |
|
94 |
generated_token_ids = input_ids.tolist()
|
95 |
-
|
96 |
-
return generated_token_ids
|
97 |
|
98 |
|
99 |
def generate(
|
@@ -105,8 +103,7 @@ def generate(
|
|
105 |
do_sample=None,
|
106 |
early_stopping=None,
|
107 |
num_beams=None,
|
108 |
-
temperature=
|
109 |
-
streamer=None,
|
110 |
top_k=None,
|
111 |
top_p=None,
|
112 |
repetition_penalty=None,
|
@@ -126,49 +123,15 @@ def generate(
|
|
126 |
output_hidden_states=None,
|
127 |
output_scores=None,
|
128 |
return_dict_in_generate=None,
|
129 |
-
forced_bos_token_id=
|
130 |
-
forced_eos_token_id=
|
131 |
remove_invalid_values=None,
|
132 |
synced_gpus=None,
|
133 |
-
n_ahead=12,
|
134 |
-
n_ahead_talk=4,
|
135 |
-
merged_talk_heads=True,
|
136 |
-
merged_lm_and_talk_heads=False,
|
137 |
-
merged_lm_and_think_heads=True,
|
138 |
-
use_concat_talk_head=True,
|
139 |
-
use_shallow_think=True,
|
140 |
-
use_shallow_talk=False,
|
141 |
-
use_complex_think_head=False,
|
142 |
-
use_complex_talk_head=True,
|
143 |
-
use_weighted_talk_head=True,
|
144 |
-
trust_remote_code=True,
|
145 |
-
torch_dtype=None,
|
146 |
**model_kwargs,
|
147 |
):
|
148 |
-
|
149 |
-
|
150 |
-
|
151 |
-
self.merged_lm_and_talk_heads = merged_lm_and_talk_heads
|
152 |
-
self.merged_lm_and_think_heads = merged_lm_and_think_heads
|
153 |
-
self.use_concat_talk_head = use_concat_talk_head
|
154 |
-
self.use_shallow_think = use_shallow_think
|
155 |
-
self.use_shallow_talk = use_shallow_talk
|
156 |
-
self.use_complex_think_head = use_complex_think_head
|
157 |
-
self.use_complex_talk_head = use_complex_talk_head
|
158 |
-
self.use_weighted_talk_head = use_weighted_talk_head
|
159 |
-
|
160 |
-
# Set model properties
|
161 |
-
self.use_end_thought_token = True
|
162 |
-
self.use_start_thought_token = True
|
163 |
-
self.n_ahead = n_ahead
|
164 |
-
self.n_passes = 1
|
165 |
-
self.eval_mode = True
|
166 |
-
self.first_run = False
|
167 |
-
self.rm_initialized = True
|
168 |
-
self.original_mode = False
|
169 |
-
|
170 |
-
# Generate using the custom generate function
|
171 |
-
generated_token_ids = custom_generate(
|
172 |
self,
|
173 |
input_ids=input_ids,
|
174 |
attention_mask=attention_mask,
|
@@ -205,4 +168,4 @@ def generate(
|
|
205 |
**model_kwargs,
|
206 |
)
|
207 |
|
208 |
-
return generated_token_ids
|
|
|
1 |
import torch
|
|
|
2 |
from transformers.generation.utils import (
|
3 |
GenerationMixin,
|
4 |
validate_stopping_criteria,
|
5 |
StoppingCriteriaList,
|
6 |
)
|
7 |
+
from transformers import TextStreamer
|
|
|
8 |
|
9 |
|
10 |
def custom_generate(
|
|
|
43 |
synced_gpus=None,
|
44 |
**kwargs,
|
45 |
):
|
46 |
+
device = input_ids.device
|
47 |
with torch.no_grad():
|
48 |
+
finished_generating = torch.zeros(len(input_ids), dtype=torch.bool, device=device)
|
49 |
|
50 |
while not finished_generating.all() and input_ids.shape[1] < max_length:
|
51 |
# Sample the next token
|
|
|
71 |
if last_token_idx + 1 >= len(base_answer_ids):
|
72 |
# Add padding everywhere
|
73 |
new_padding = torch.full((len(input_ids), 1), self.tokenizer.pad_token_id, dtype=torch.long,
|
74 |
+
device=device)
|
75 |
input_ids = torch.cat([input_ids, new_padding], dim=-1)
|
76 |
if attention_mask is not None:
|
77 |
attention_mask = torch.cat([attention_mask, torch.zeros_like(new_padding)], dim=-1)
|
|
|
91 |
streamer.put(new_ids_sampled)
|
92 |
|
93 |
generated_token_ids = input_ids.tolist()
|
94 |
+
return generated_token_ids, attention_mask
|
|
|
95 |
|
96 |
|
97 |
def generate(
|
|
|
103 |
do_sample=None,
|
104 |
early_stopping=None,
|
105 |
num_beams=None,
|
106 |
+
temperature=1.1,
|
|
|
107 |
top_k=None,
|
108 |
top_p=None,
|
109 |
repetition_penalty=None,
|
|
|
123 |
output_hidden_states=None,
|
124 |
output_scores=None,
|
125 |
return_dict_in_generate=None,
|
126 |
+
forced_bos_token_id=None,
|
127 |
+
forced_eos_token_id=None,
|
128 |
remove_invalid_values=None,
|
129 |
synced_gpus=None,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
130 |
**model_kwargs,
|
131 |
):
|
132 |
+
streamer = TextStreamer(self.tokenizer, skip_prompt=False, skip_special_tokens=True)
|
133 |
+
|
134 |
+
generated_token_ids, attention_mask = custom_generate(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
135 |
self,
|
136 |
input_ids=input_ids,
|
137 |
attention_mask=attention_mask,
|
|
|
168 |
**model_kwargs,
|
169 |
)
|
170 |
|
171 |
+
return generated_token_ids, attention_mask
|