Crystalcareai
commited on
Update generate.py
Browse files- generate.py +3 -6
generate.py
CHANGED
@@ -47,8 +47,6 @@ def custom_generate(
|
|
47 |
with torch.no_grad():
|
48 |
finished_generating = torch.zeros(len(input_ids), dtype=torch.bool, device=device)
|
49 |
|
50 |
-
if max_new_tokens is None:
|
51 |
-
max_new_tokens = 50 # Default value if not specified
|
52 |
for cur_token_idx in range(max_new_tokens):
|
53 |
# Sample the next token
|
54 |
new_ids = self(
|
@@ -95,8 +93,7 @@ def custom_generate(
|
|
95 |
if streamer is not None:
|
96 |
streamer.put(new_ids_sampled)
|
97 |
|
98 |
-
|
99 |
-
return generated_token_ids, attention_mask
|
100 |
|
101 |
|
102 |
def generate(
|
@@ -170,7 +167,7 @@ def generate(
|
|
170 |
self.rm_initialized = True
|
171 |
self.original_mode = False
|
172 |
|
173 |
-
generated_token_ids
|
174 |
self,
|
175 |
input_ids=input_ids,
|
176 |
attention_mask=attention_mask,
|
@@ -207,4 +204,4 @@ def generate(
|
|
207 |
**model_kwargs,
|
208 |
)
|
209 |
|
210 |
-
return generated_token_ids
|
|
|
47 |
with torch.no_grad():
|
48 |
finished_generating = torch.zeros(len(input_ids), dtype=torch.bool, device=device)
|
49 |
|
|
|
|
|
50 |
for cur_token_idx in range(max_new_tokens):
|
51 |
# Sample the next token
|
52 |
new_ids = self(
|
|
|
93 |
if streamer is not None:
|
94 |
streamer.put(new_ids_sampled)
|
95 |
|
96 |
+
return input_ids
|
|
|
97 |
|
98 |
|
99 |
def generate(
|
|
|
167 |
self.rm_initialized = True
|
168 |
self.original_mode = False
|
169 |
|
170 |
+
generated_token_ids = custom_generate(
|
171 |
self,
|
172 |
input_ids=input_ids,
|
173 |
attention_mask=attention_mask,
|
|
|
204 |
**model_kwargs,
|
205 |
)
|
206 |
|
207 |
+
return generated_token_ids
|