Crystalcareai
commited on
Update generate.py
Browse files- generate.py +12 -10
generate.py
CHANGED
@@ -45,15 +45,12 @@ def custom_generate(
|
|
45 |
):
|
46 |
if input_ids is None or input_ids.nelement() == 0:
|
47 |
# If input_ids is None or an empty tensor, create a default input tensor
|
48 |
-
input_ids = torch.LongTensor([[self.tokenizer.bos_token_id]])
|
49 |
-
attention_mask = torch.ones_like(input_ids)
|
50 |
|
51 |
device = input_ids.device
|
52 |
with torch.no_grad():
|
53 |
batch_size = input_ids.shape[0]
|
54 |
-
if max_new_tokens is None:
|
55 |
-
raise ValueError("max_new_tokens must be provided.")
|
56 |
-
|
57 |
finished_generating = torch.zeros(batch_size, dtype=torch.bool, device=device)
|
58 |
generated_token_ids = torch.full((batch_size, max_new_tokens), self.tokenizer.pad_token_id, dtype=torch.long, device=device)
|
59 |
|
@@ -156,10 +153,10 @@ def generate(
|
|
156 |
torch_dtype=torch.bfloat16,
|
157 |
**model_kwargs,
|
158 |
):
|
159 |
-
# Set default value for max_new_tokens if not provided
|
160 |
-
if max_new_tokens is None:
|
161 |
-
max_new_tokens = 128 # Set a reasonable default value
|
162 |
|
|
|
|
|
|
|
163 |
# Set model attributes
|
164 |
self.max_thoughts = n_ahead + n_ahead_talk + 1
|
165 |
self.merged_talk_heads = merged_talk_heads
|
@@ -186,11 +183,16 @@ def generate(
|
|
186 |
if isinstance(input_ids, str):
|
187 |
input_ids = self.tokenizer.encode(input_ids, return_tensors='pt')
|
188 |
|
|
|
|
|
|
|
|
|
|
|
189 |
generated_token_ids = custom_generate(
|
190 |
self,
|
191 |
-
input_ids=input_ids,
|
192 |
attention_mask=attention_mask,
|
193 |
-
max_new_tokens=max_new_tokens,
|
194 |
min_length=min_length,
|
195 |
do_sample=do_sample,
|
196 |
early_stopping=early_stopping,
|
|
|
45 |
):
|
46 |
if input_ids is None or input_ids.nelement() == 0:
|
47 |
# If input_ids is None or an empty tensor, create a default input tensor
|
48 |
+
input_ids = torch.LongTensor([[self.tokenizer.bos_token_id]]).to(self.device)
|
49 |
+
attention_mask = torch.ones_like(input_ids).to(self.device)
|
50 |
|
51 |
device = input_ids.device
|
52 |
with torch.no_grad():
|
53 |
batch_size = input_ids.shape[0]
|
|
|
|
|
|
|
54 |
finished_generating = torch.zeros(batch_size, dtype=torch.bool, device=device)
|
55 |
generated_token_ids = torch.full((batch_size, max_new_tokens), self.tokenizer.pad_token_id, dtype=torch.long, device=device)
|
56 |
|
|
|
153 |
torch_dtype=torch.bfloat16,
|
154 |
**model_kwargs,
|
155 |
):
|
|
|
|
|
|
|
156 |
|
157 |
+
if max_new_tokens is None:
|
158 |
+
max_new_tokens = 128
|
159 |
+
|
160 |
# Set model attributes
|
161 |
self.max_thoughts = n_ahead + n_ahead_talk + 1
|
162 |
self.merged_talk_heads = merged_talk_heads
|
|
|
183 |
if isinstance(input_ids, str):
|
184 |
input_ids = self.tokenizer.encode(input_ids, return_tensors='pt')
|
185 |
|
186 |
+
# Move input_ids and attention_mask to the same device as the model
|
187 |
+
input_ids = input_ids.to(self.device)
|
188 |
+
if attention_mask is not None:
|
189 |
+
attention_mask = attention_mask.to(self.device)
|
190 |
+
|
191 |
generated_token_ids = custom_generate(
|
192 |
self,
|
193 |
+
input_ids=input_ids,
|
194 |
attention_mask=attention_mask,
|
195 |
+
max_new_tokens=max_new_tokens,
|
196 |
min_length=min_length,
|
197 |
do_sample=do_sample,
|
198 |
early_stopping=early_stopping,
|