Crystalcareai
commited on
Update generate.py
Browse files- generate.py +20 -9
generate.py
CHANGED
@@ -56,9 +56,12 @@ def custom_generate(
|
|
56 |
with torch.no_grad():
|
57 |
batch_size = input_ids.shape[0]
|
58 |
finished_generating = torch.zeros(batch_size, dtype=torch.bool, device=device)
|
|
|
|
|
59 |
generated_token_ids = torch.full((batch_size, max_new_tokens), self.tokenizer.pad_token_id, dtype=torch.long, device=device)
|
60 |
|
61 |
-
|
|
|
62 |
# Sample the next token
|
63 |
new_ids = self(
|
64 |
input_ids[~finished_generating],
|
@@ -96,13 +99,11 @@ def custom_generate(
|
|
96 |
input_ids[answer_idx, last_token_idx + 1] = new_ids_sampled
|
97 |
generated_token_ids[answer_idx, cur_token_idx] = new_ids_sampled
|
98 |
|
99 |
-
if new_ids_sampled == self.tokenizer.eos_token_id or
|
100 |
finished_generating[answer_idx] = 1
|
101 |
|
102 |
-
|
103 |
-
|
104 |
-
finished_generating[answer_idx] = 1
|
105 |
-
|
106 |
if finished_generating.all():
|
107 |
break
|
108 |
|
@@ -112,7 +113,13 @@ def custom_generate(
|
|
112 |
print("Generated Token IDs shape:", generated_token_ids.shape)
|
113 |
print("Generated Token IDs:", generated_token_ids)
|
114 |
|
115 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
116 |
|
117 |
def generate(
|
118 |
self,
|
@@ -161,6 +168,7 @@ def generate(
|
|
161 |
use_weighted_talk_head=True,
|
162 |
trust_remote_code=True,
|
163 |
torch_dtype=torch.bfloat16,
|
|
|
164 |
**model_kwargs,
|
165 |
):
|
166 |
|
@@ -198,7 +206,7 @@ def generate(
|
|
198 |
if attention_mask is not None:
|
199 |
attention_mask = attention_mask.to(self.device)
|
200 |
|
201 |
-
generated_token_ids = custom_generate(
|
202 |
self,
|
203 |
input_ids=input_ids,
|
204 |
attention_mask=attention_mask,
|
@@ -235,4 +243,7 @@ def generate(
|
|
235 |
**model_kwargs,
|
236 |
)
|
237 |
|
238 |
-
|
|
|
|
|
|
|
|
56 |
with torch.no_grad():
|
57 |
batch_size = input_ids.shape[0]
|
58 |
finished_generating = torch.zeros(batch_size, dtype=torch.bool, device=device)
|
59 |
+
|
60 |
+
max_length = input_ids.shape[1] + max_new_tokens
|
61 |
generated_token_ids = torch.full((batch_size, max_new_tokens), self.tokenizer.pad_token_id, dtype=torch.long, device=device)
|
62 |
|
63 |
+
cur_token_idx = 0
|
64 |
+
while cur_token_idx < max_length:
|
65 |
# Sample the next token
|
66 |
new_ids = self(
|
67 |
input_ids[~finished_generating],
|
|
|
99 |
input_ids[answer_idx, last_token_idx + 1] = new_ids_sampled
|
100 |
generated_token_ids[answer_idx, cur_token_idx] = new_ids_sampled
|
101 |
|
102 |
+
if new_ids_sampled == self.tokenizer.eos_token_id or cur_token_idx + 1 == max_length:
|
103 |
finished_generating[answer_idx] = 1
|
104 |
|
105 |
+
cur_token_idx += 1
|
106 |
+
|
|
|
|
|
107 |
if finished_generating.all():
|
108 |
break
|
109 |
|
|
|
113 |
print("Generated Token IDs shape:", generated_token_ids.shape)
|
114 |
print("Generated Token IDs:", generated_token_ids)
|
115 |
|
116 |
+
# Decode the generated token IDs into text
|
117 |
+
generated_text = self.tokenizer.decode(generated_token_ids[0], skip_special_tokens=True)
|
118 |
+
print("Generated Text:")
|
119 |
+
print(generated_text)
|
120 |
+
|
121 |
+
return generated_token_ids, generated_text
|
122 |
+
|
123 |
|
124 |
def generate(
|
125 |
self,
|
|
|
168 |
use_weighted_talk_head=True,
|
169 |
trust_remote_code=True,
|
170 |
torch_dtype=torch.bfloat16,
|
171 |
+
dynamic_temperature=None,
|
172 |
**model_kwargs,
|
173 |
):
|
174 |
|
|
|
206 |
if attention_mask is not None:
|
207 |
attention_mask = attention_mask.to(self.device)
|
208 |
|
209 |
+
generated_token_ids, generated_text = custom_generate(
|
210 |
self,
|
211 |
input_ids=input_ids,
|
212 |
attention_mask=attention_mask,
|
|
|
243 |
**model_kwargs,
|
244 |
)
|
245 |
|
246 |
+
if dynamic_temperature is not None:
|
247 |
+
return generated_text
|
248 |
+
else:
|
249 |
+
return generated_token_ids
|