Crystalcareai commited on
Commit
5af661b
·
verified ·
1 Parent(s): 5bc3c10

Update generate.py

Browse files
Files changed (1) hide show
  1. generate.py +6 -1
generate.py CHANGED
@@ -43,6 +43,11 @@ def custom_generate(
43
  synced_gpus=None,
44
  **kwargs,
45
  ):
 
 
 
 
 
46
  device = input_ids.device
47
  with torch.no_grad():
48
  batch_size = input_ids.shape[0]
@@ -211,4 +216,4 @@ def generate(
211
  **model_kwargs,
212
  )
213
 
214
- return generated_token_ids
 
43
  synced_gpus=None,
44
  **kwargs,
45
  ):
46
+ if input_ids is None or input_ids.nelement() == 0:
47
+ # If input_ids is None or an empty tensor, create a default input tensor
48
+ input_ids = torch.LongTensor([[self.tokenizer.bos_token_id]])
49
+ attention_mask = torch.ones_like(input_ids)
50
+
51
  device = input_ids.device
52
  with torch.no_grad():
53
  batch_size = input_ids.shape[0]
 
216
  **model_kwargs,
217
  )
218
 
219
+ return generated_token_ids