Crystalcareai
commited on
Update generate.py
Browse files- generate.py +7 -4
generate.py
CHANGED
@@ -11,7 +11,7 @@ def custom_generate(
|
|
11 |
self,
|
12 |
input_ids,
|
13 |
attention_mask=None,
|
14 |
-
|
15 |
min_length=None,
|
16 |
do_sample=None,
|
17 |
early_stopping=None,
|
@@ -47,7 +47,7 @@ def custom_generate(
|
|
47 |
with torch.no_grad():
|
48 |
finished_generating = torch.zeros(len(input_ids), dtype=torch.bool, device=device)
|
49 |
|
50 |
-
|
51 |
# Sample the next token
|
52 |
new_ids = self(
|
53 |
input_ids[~finished_generating],
|
@@ -86,6 +86,9 @@ def custom_generate(
|
|
86 |
# Check if the end token is generated
|
87 |
if new_ids_sampled == self.tokenizer.convert_tokens_to_ids("</s>"):
|
88 |
finished_generating[answer_idx] = 1
|
|
|
|
|
|
|
89 |
|
90 |
if streamer is not None:
|
91 |
streamer.put(new_ids_sampled)
|
@@ -98,7 +101,7 @@ def generate(
|
|
98 |
self,
|
99 |
input_ids,
|
100 |
attention_mask=None,
|
101 |
-
|
102 |
min_length=None,
|
103 |
do_sample=None,
|
104 |
early_stopping=None,
|
@@ -169,7 +172,7 @@ def generate(
|
|
169 |
self,
|
170 |
input_ids=input_ids,
|
171 |
attention_mask=attention_mask,
|
172 |
-
|
173 |
min_length=min_length,
|
174 |
do_sample=do_sample,
|
175 |
early_stopping=early_stopping,
|
|
|
11 |
self,
|
12 |
input_ids,
|
13 |
attention_mask=None,
|
14 |
+
max_new_tokens=None,
|
15 |
min_length=None,
|
16 |
do_sample=None,
|
17 |
early_stopping=None,
|
|
|
47 |
with torch.no_grad():
|
48 |
finished_generating = torch.zeros(len(input_ids), dtype=torch.bool, device=device)
|
49 |
|
50 |
+
for cur_token_idx in range(max_new_tokens):
|
51 |
# Sample the next token
|
52 |
new_ids = self(
|
53 |
input_ids[~finished_generating],
|
|
|
86 |
# Check if the end token is generated
|
87 |
if new_ids_sampled == self.tokenizer.convert_tokens_to_ids("</s>"):
|
88 |
finished_generating[answer_idx] = 1
|
89 |
+
|
90 |
+
if finished_generating.all():
|
91 |
+
break
|
92 |
|
93 |
if streamer is not None:
|
94 |
streamer.put(new_ids_sampled)
|
|
|
101 |
self,
|
102 |
input_ids,
|
103 |
attention_mask=None,
|
104 |
+
max_new_tokens=None,
|
105 |
min_length=None,
|
106 |
do_sample=None,
|
107 |
early_stopping=None,
|
|
|
172 |
self,
|
173 |
input_ids=input_ids,
|
174 |
attention_mask=attention_mask,
|
175 |
+
max_new_tokens=max_new_tokens,
|
176 |
min_length=min_length,
|
177 |
do_sample=do_sample,
|
178 |
early_stopping=early_stopping,
|