Crystalcareai
commited on
Update generate.py
Browse files- generate.py +19 -26
generate.py
CHANGED
@@ -6,6 +6,7 @@ from transformers.generation.utils import (
|
|
6 |
)
|
7 |
from transformers import TextStreamer
|
8 |
|
|
|
9 |
def custom_generate(
|
10 |
self,
|
11 |
input_ids,
|
@@ -44,19 +45,19 @@ def custom_generate(
|
|
44 |
):
|
45 |
if input_ids is None or input_ids.nelement() == 0:
|
46 |
# If input_ids is None or an empty tensor, create a default input tensor
|
47 |
-
input_ids = torch.LongTensor([[self.tokenizer.bos_token_id]])
|
48 |
-
attention_mask = torch.ones_like(input_ids)
|
49 |
|
50 |
device = input_ids.device
|
51 |
with torch.no_grad():
|
52 |
batch_size = input_ids.shape[0]
|
|
|
|
|
|
|
53 |
finished_generating = torch.zeros(batch_size, dtype=torch.bool, device=device)
|
54 |
-
|
55 |
-
max_length = input_ids.shape[1] + max_new_tokens
|
56 |
generated_token_ids = torch.full((batch_size, max_new_tokens), self.tokenizer.pad_token_id, dtype=torch.long, device=device)
|
57 |
|
58 |
-
cur_token_idx
|
59 |
-
while cur_token_idx < max_new_tokens:
|
60 |
# Sample the next token
|
61 |
new_ids = self(
|
62 |
input_ids[~finished_generating],
|
@@ -80,7 +81,7 @@ def custom_generate(
|
|
80 |
if last_token_idx + 1 >= len(base_answer_ids):
|
81 |
# Add padding everywhere
|
82 |
new_padding = torch.full((batch_size, 1), self.tokenizer.pad_token_id, dtype=torch.long,
|
83 |
-
|
84 |
input_ids = torch.cat([input_ids, new_padding], dim=-1)
|
85 |
if attention_mask is not None:
|
86 |
attention_mask = torch.cat([attention_mask, torch.zeros_like(new_padding)], dim=-1)
|
@@ -88,15 +89,15 @@ def custom_generate(
|
|
88 |
if attention_mask is not None:
|
89 |
attention_mask[answer_idx, last_token_idx + 1] = 1
|
90 |
input_ids[answer_idx, last_token_idx + 1] = new_ids_sampled
|
|
|
91 |
|
92 |
-
if
|
93 |
-
generated_token_ids[answer_idx, cur_token_idx] = new_ids_sampled
|
94 |
-
|
95 |
-
if new_ids_sampled == self.tokenizer.eos_token_id or cur_token_idx + 1 == max_new_tokens:
|
96 |
finished_generating[answer_idx] = 1
|
97 |
|
98 |
-
|
99 |
-
|
|
|
|
|
100 |
if finished_generating.all():
|
101 |
break
|
102 |
|
@@ -105,6 +106,7 @@ def custom_generate(
|
|
105 |
|
106 |
return generated_token_ids
|
107 |
|
|
|
108 |
def generate(
|
109 |
self,
|
110 |
input_ids,
|
@@ -152,12 +154,11 @@ def generate(
|
|
152 |
use_weighted_talk_head=True,
|
153 |
trust_remote_code=True,
|
154 |
torch_dtype=torch.bfloat16,
|
155 |
-
dynamic_temperature=None,
|
156 |
**model_kwargs,
|
157 |
):
|
158 |
-
|
159 |
if max_new_tokens is None:
|
160 |
-
max_new_tokens =
|
161 |
|
162 |
# Set model attributes
|
163 |
self.max_thoughts = n_ahead + n_ahead_talk + 1
|
@@ -185,16 +186,11 @@ def generate(
|
|
185 |
if isinstance(input_ids, str):
|
186 |
input_ids = self.tokenizer.encode(input_ids, return_tensors='pt')
|
187 |
|
188 |
-
# Move input_ids and attention_mask to the same device as the model
|
189 |
-
input_ids = input_ids.to(self.device)
|
190 |
-
if attention_mask is not None:
|
191 |
-
attention_mask = attention_mask.to(self.device)
|
192 |
-
|
193 |
generated_token_ids = custom_generate(
|
194 |
self,
|
195 |
-
input_ids=input_ids,
|
196 |
attention_mask=attention_mask,
|
197 |
-
max_new_tokens=max_new_tokens,
|
198 |
min_length=min_length,
|
199 |
do_sample=do_sample,
|
200 |
early_stopping=early_stopping,
|
@@ -227,7 +223,4 @@ def generate(
|
|
227 |
**model_kwargs,
|
228 |
)
|
229 |
|
230 |
-
# Convert the generated token IDs tensor to text
|
231 |
-
generated_text = self.tokenizer.decode(generated_token_ids[0], skip_special_tokens=False)
|
232 |
-
|
233 |
return generated_token_ids
|
|
|
6 |
)
|
7 |
from transformers import TextStreamer
|
8 |
|
9 |
+
|
10 |
def custom_generate(
|
11 |
self,
|
12 |
input_ids,
|
|
|
45 |
):
|
46 |
if input_ids is None or input_ids.nelement() == 0:
|
47 |
# If input_ids is None or an empty tensor, create a default input tensor
|
48 |
+
input_ids = torch.LongTensor([[self.tokenizer.bos_token_id]])
|
49 |
+
attention_mask = torch.ones_like(input_ids)
|
50 |
|
51 |
device = input_ids.device
|
52 |
with torch.no_grad():
|
53 |
batch_size = input_ids.shape[0]
|
54 |
+
if max_new_tokens is None:
|
55 |
+
raise ValueError("max_new_tokens must be provided.")
|
56 |
+
|
57 |
finished_generating = torch.zeros(batch_size, dtype=torch.bool, device=device)
|
|
|
|
|
58 |
generated_token_ids = torch.full((batch_size, max_new_tokens), self.tokenizer.pad_token_id, dtype=torch.long, device=device)
|
59 |
|
60 |
+
for cur_token_idx in range(max_new_tokens):
|
|
|
61 |
# Sample the next token
|
62 |
new_ids = self(
|
63 |
input_ids[~finished_generating],
|
|
|
81 |
if last_token_idx + 1 >= len(base_answer_ids):
|
82 |
# Add padding everywhere
|
83 |
new_padding = torch.full((batch_size, 1), self.tokenizer.pad_token_id, dtype=torch.long,
|
84 |
+
device=device)
|
85 |
input_ids = torch.cat([input_ids, new_padding], dim=-1)
|
86 |
if attention_mask is not None:
|
87 |
attention_mask = torch.cat([attention_mask, torch.zeros_like(new_padding)], dim=-1)
|
|
|
89 |
if attention_mask is not None:
|
90 |
attention_mask[answer_idx, last_token_idx + 1] = 1
|
91 |
input_ids[answer_idx, last_token_idx + 1] = new_ids_sampled
|
92 |
+
generated_token_ids[answer_idx, cur_token_idx] = new_ids_sampled
|
93 |
|
94 |
+
if new_ids_sampled == self.tokenizer.eos_token_id or new_ids_sampled == self.tokenizer.bos_token_id or new_ids_sampled == self.tokenizer.pad_token_id:
|
|
|
|
|
|
|
95 |
finished_generating[answer_idx] = 1
|
96 |
|
97 |
+
# Check if the end token is generated
|
98 |
+
if new_ids_sampled == self.tokenizer.convert_tokens_to_ids("</s>"):
|
99 |
+
finished_generating[answer_idx] = 1
|
100 |
+
|
101 |
if finished_generating.all():
|
102 |
break
|
103 |
|
|
|
106 |
|
107 |
return generated_token_ids
|
108 |
|
109 |
+
|
110 |
def generate(
|
111 |
self,
|
112 |
input_ids,
|
|
|
154 |
use_weighted_talk_head=True,
|
155 |
trust_remote_code=True,
|
156 |
torch_dtype=torch.bfloat16,
|
|
|
157 |
**model_kwargs,
|
158 |
):
|
159 |
+
# Set default value for max_new_tokens if not provided
|
160 |
if max_new_tokens is None:
|
161 |
+
max_new_tokens = 20 # Set a reasonable default value
|
162 |
|
163 |
# Set model attributes
|
164 |
self.max_thoughts = n_ahead + n_ahead_talk + 1
|
|
|
186 |
if isinstance(input_ids, str):
|
187 |
input_ids = self.tokenizer.encode(input_ids, return_tensors='pt')
|
188 |
|
|
|
|
|
|
|
|
|
|
|
189 |
generated_token_ids = custom_generate(
|
190 |
self,
|
191 |
+
input_ids=input_ids, # Pass input_ids explicitly
|
192 |
attention_mask=attention_mask,
|
193 |
+
max_new_tokens=max_new_tokens, # Pass max_new_tokens explicitly
|
194 |
min_length=min_length,
|
195 |
do_sample=do_sample,
|
196 |
early_stopping=early_stopping,
|
|
|
223 |
**model_kwargs,
|
224 |
)
|
225 |
|
|
|
|
|
|
|
226 |
return generated_token_ids
|