Crystalcareai
commited on
Update modeling_quiet.py
Browse files- modeling_quiet.py +11 -1
modeling_quiet.py
CHANGED
@@ -1412,7 +1412,17 @@ class QuietForCausalLM(QuietPreTrainedModel, GenerationMixin):
|
|
1412 |
past_key_values_length,
|
1413 |
sliding_window=self.config.sliding_window,
|
1414 |
)
|
1415 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1416 |
outputs = self.model(
|
1417 |
# input_ids=input_ids,
|
1418 |
attention_mask=attention_mask,
|
|
|
1412 |
past_key_values_length,
|
1413 |
sliding_window=self.config.sliding_window,
|
1414 |
)
|
1415 |
+
if attention_mask is not None:
|
1416 |
+
if attention_mask.dim() == 2:
|
1417 |
+
# Expand the attention mask to have dimensions (batch_size, 1, 1, seq_length)
|
1418 |
+
attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
|
1419 |
+
elif attention_mask.dim() == 3:
|
1420 |
+
# Expand the attention mask to have dimensions (batch_size, 1, seq_length, seq_length)
|
1421 |
+
attention_mask = attention_mask.unsqueeze(1)
|
1422 |
+
else:
|
1423 |
+
raise ValueError(
|
1424 |
+
f"Attention mask should have 2 or 3 dimensions, but got {attention_mask.dim()} dimensions."
|
1425 |
+
)
|
1426 |
outputs = self.model(
|
1427 |
# input_ids=input_ids,
|
1428 |
attention_mask=attention_mask,
|