Crystalcareai
commited on
Update modeling_quiet.py
Browse files- modeling_quiet.py +15 -0
modeling_quiet.py
CHANGED
@@ -1236,6 +1236,21 @@ class QuietForCausalLM(QuietPreTrainedModel, GenerationMixin):
|
|
1236 |
if input_ids.dim() == 1:
|
1237 |
input_ids = input_ids.unsqueeze(0)
|
1238 |
attention_mask = attention_mask.unsqueeze(0) if attention_mask is not None else None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1239 |
|
1240 |
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
1241 |
output_hidden_states = (
|
|
|
1236 |
if input_ids.dim() == 1:
|
1237 |
input_ids = input_ids.unsqueeze(0)
|
1238 |
attention_mask = attention_mask.unsqueeze(0) if attention_mask is not None else None
|
1239 |
+
position_ids = position_ids.unsqueeze(0) if position_ids is not None else None
|
1240 |
+
|
1241 |
+
seq_len = input_ids.shape[1]
|
1242 |
+
|
1243 |
+
if position_ids is None:
|
1244 |
+
position_ids = torch.arange(seq_len, dtype=torch.long, device=input_ids.device)
|
1245 |
+
position_ids = position_ids.unsqueeze(0).expand_as(input_ids)
|
1246 |
+
else:
|
1247 |
+
# Handle the case when position_ids is an empty tensor
|
1248 |
+
if position_ids.numel() == 0:
|
1249 |
+
position_ids = torch.arange(seq_len, dtype=torch.long, device=input_ids.device)
|
1250 |
+
position_ids = position_ids.unsqueeze(0).expand_as(input_ids)
|
1251 |
+
else:
|
1252 |
+
position_ids = position_ids.unsqueeze(0).view(-1, seq_len)
|
1253 |
+
|
1254 |
|
1255 |
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
1256 |
output_hidden_states = (
|