Crystalcareai
commited on
Update generate.py
Browse files- generate.py +6 -1
generate.py
CHANGED
@@ -43,6 +43,11 @@ def custom_generate(
|
|
43 |
synced_gpus=None,
|
44 |
**kwargs,
|
45 |
):
|
|
|
|
|
|
|
|
|
|
|
46 |
device = input_ids.device
|
47 |
with torch.no_grad():
|
48 |
batch_size = input_ids.shape[0]
|
@@ -211,4 +216,4 @@ def generate(
|
|
211 |
**model_kwargs,
|
212 |
)
|
213 |
|
214 |
-
return generated_token_ids
|
|
|
43 |
synced_gpus=None,
|
44 |
**kwargs,
|
45 |
):
|
46 |
+
if input_ids is None or input_ids.nelement() == 0:
|
47 |
+
# If input_ids is None or an empty tensor, create a default input tensor
|
48 |
+
input_ids = torch.LongTensor([[self.tokenizer.bos_token_id]])
|
49 |
+
attention_mask = torch.ones_like(input_ids)
|
50 |
+
|
51 |
device = input_ids.device
|
52 |
with torch.no_grad():
|
53 |
batch_size = input_ids.shape[0]
|
|
|
216 |
**model_kwargs,
|
217 |
)
|
218 |
|
219 |
+
return generated_token_ids
|