Crystalcareai
commited on
Update generate.py
Browse files- generate.py +5 -3
generate.py
CHANGED
@@ -56,7 +56,7 @@ def custom_generate(
|
|
56 |
generated_token_ids = torch.full((batch_size, max_new_tokens), self.tokenizer.pad_token_id, dtype=torch.long, device=device)
|
57 |
|
58 |
cur_token_idx = 0
|
59 |
-
while cur_token_idx <
|
60 |
# Sample the next token
|
61 |
new_ids = self(
|
62 |
input_ids[~finished_generating],
|
@@ -88,9 +88,11 @@ def custom_generate(
|
|
88 |
if attention_mask is not None:
|
89 |
attention_mask[answer_idx, last_token_idx + 1] = 1
|
90 |
input_ids[answer_idx, last_token_idx + 1] = new_ids_sampled
|
91 |
-
generated_token_ids[answer_idx, cur_token_idx] = new_ids_sampled
|
92 |
|
93 |
-
if
|
|
|
|
|
|
|
94 |
finished_generating[answer_idx] = 1
|
95 |
|
96 |
cur_token_idx += 1
|
|
|
56 |
generated_token_ids = torch.full((batch_size, max_new_tokens), self.tokenizer.pad_token_id, dtype=torch.long, device=device)
|
57 |
|
58 |
cur_token_idx = 0
|
59 |
+
while cur_token_idx < max_new_tokens:
|
60 |
# Sample the next token
|
61 |
new_ids = self(
|
62 |
input_ids[~finished_generating],
|
|
|
88 |
if attention_mask is not None:
|
89 |
attention_mask[answer_idx, last_token_idx + 1] = 1
|
90 |
input_ids[answer_idx, last_token_idx + 1] = new_ids_sampled
|
|
|
91 |
|
92 |
+
if cur_token_idx < max_new_tokens:
|
93 |
+
generated_token_ids[answer_idx, cur_token_idx] = new_ids_sampled
|
94 |
+
|
95 |
+
if new_ids_sampled == self.tokenizer.eos_token_id or cur_token_idx + 1 == max_new_tokens:
|
96 |
finished_generating[answer_idx] = 1
|
97 |
|
98 |
cur_token_idx += 1
|