Crystalcareai
commited on
Update generate.py
Browse files- generate.py +6 -3
generate.py
CHANGED
@@ -45,7 +45,9 @@ def custom_generate(
|
|
45 |
):
|
46 |
device = input_ids.device
|
47 |
with torch.no_grad():
|
48 |
-
|
|
|
|
|
49 |
|
50 |
for cur_token_idx in range(max_new_tokens):
|
51 |
# Sample the next token
|
@@ -70,7 +72,7 @@ def custom_generate(
|
|
70 |
# Assign the new id to the last token
|
71 |
if last_token_idx + 1 >= len(base_answer_ids):
|
72 |
# Add padding everywhere
|
73 |
-
new_padding = torch.full((
|
74 |
device=device)
|
75 |
input_ids = torch.cat([input_ids, new_padding], dim=-1)
|
76 |
if attention_mask is not None:
|
@@ -79,6 +81,7 @@ def custom_generate(
|
|
79 |
if attention_mask is not None:
|
80 |
attention_mask[answer_idx, last_token_idx + 1] = 1
|
81 |
input_ids[answer_idx, last_token_idx + 1] = new_ids_sampled
|
|
|
82 |
|
83 |
if new_ids_sampled == self.tokenizer.eos_token_id or new_ids_sampled == self.tokenizer.bos_token_id or new_ids_sampled == self.tokenizer.pad_token_id:
|
84 |
finished_generating[answer_idx] = 1
|
@@ -93,7 +96,7 @@ def custom_generate(
|
|
93 |
if streamer is not None:
|
94 |
streamer.put(new_ids_sampled)
|
95 |
|
96 |
-
return
|
97 |
|
98 |
|
99 |
def generate(
|
|
|
45 |
):
|
46 |
device = input_ids.device
|
47 |
with torch.no_grad():
|
48 |
+
batch_size = input_ids.shape[0]
|
49 |
+
finished_generating = torch.zeros(batch_size, dtype=torch.bool, device=device)
|
50 |
+
generated_token_ids = torch.full((batch_size, max_new_tokens), self.tokenizer.pad_token_id, dtype=torch.long, device=device)
|
51 |
|
52 |
for cur_token_idx in range(max_new_tokens):
|
53 |
# Sample the next token
|
|
|
72 |
# Assign the new id to the last token
|
73 |
if last_token_idx + 1 >= len(base_answer_ids):
|
74 |
# Add padding everywhere
|
75 |
+
new_padding = torch.full((batch_size, 1), self.tokenizer.pad_token_id, dtype=torch.long,
|
76 |
device=device)
|
77 |
input_ids = torch.cat([input_ids, new_padding], dim=-1)
|
78 |
if attention_mask is not None:
|
|
|
81 |
if attention_mask is not None:
|
82 |
attention_mask[answer_idx, last_token_idx + 1] = 1
|
83 |
input_ids[answer_idx, last_token_idx + 1] = new_ids_sampled
|
84 |
+
generated_token_ids[answer_idx, cur_token_idx] = new_ids_sampled
|
85 |
|
86 |
if new_ids_sampled == self.tokenizer.eos_token_id or new_ids_sampled == self.tokenizer.bos_token_id or new_ids_sampled == self.tokenizer.pad_token_id:
|
87 |
finished_generating[answer_idx] = 1
|
|
|
96 |
if streamer is not None:
|
97 |
streamer.put(new_ids_sampled)
|
98 |
|
99 |
+
return generated_token_ids
|
100 |
|
101 |
|
102 |
def generate(
|