Crystalcareai commited on
Commit
351d904
·
verified ·
1 Parent(s): b3900b9

Update modeling_quiet.py

Browse files
Files changed (1) hide show
  1. modeling_quiet.py +15 -0
modeling_quiet.py CHANGED
@@ -1236,6 +1236,21 @@ class QuietForCausalLM(QuietPreTrainedModel, GenerationMixin):
1236
  if input_ids.dim() == 1:
1237
  input_ids = input_ids.unsqueeze(0)
1238
  attention_mask = attention_mask.unsqueeze(0) if attention_mask is not None else None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1239
 
1240
  output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1241
  output_hidden_states = (
 
1236
  if input_ids.dim() == 1:
1237
  input_ids = input_ids.unsqueeze(0)
1238
  attention_mask = attention_mask.unsqueeze(0) if attention_mask is not None else None
1239
+ position_ids = position_ids.unsqueeze(0) if position_ids is not None else None
1240
+
1241
+ seq_len = input_ids.shape[1]
1242
+
1243
+ if position_ids is None:
1244
+ position_ids = torch.arange(seq_len, dtype=torch.long, device=input_ids.device)
1245
+ position_ids = position_ids.unsqueeze(0).expand_as(input_ids)
1246
+ else:
1247
+ # Handle the case when position_ids is an empty tensor
1248
+ if position_ids.numel() == 0:
1249
+ position_ids = torch.arange(seq_len, dtype=torch.long, device=input_ids.device)
1250
+ position_ids = position_ids.unsqueeze(0).expand_as(input_ids)
1251
+ else:
1252
+ position_ids = position_ids.unsqueeze(0).view(-1, seq_len)
1253
+
1254
 
1255
  output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1256
  output_hidden_states = (