Crystalcareai
commited on
Update modeling_quiet.py
Browse files- modeling_quiet.py +1 -1
modeling_quiet.py
CHANGED
@@ -774,7 +774,7 @@ class QuietSdpaAttention(QuietAttention):
|
|
774 |
raise ValueError(
|
775 |
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
|
776 |
)
|
777 |
-
|
778 |
# SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
|
779 |
# Reference: https://github.com/pytorch/pytorch/issues/112577.
|
780 |
if query_states.device.type == "cuda" and attention_mask is not None:
|
|
|
774 |
raise ValueError(
|
775 |
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
|
776 |
)
|
777 |
+
attention_mask = attention_mask.to(query_states.dtype)
|
778 |
# SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
|
779 |
# Reference: https://github.com/pytorch/pytorch/issues/112577.
|
780 |
if query_states.device.type == "cuda" and attention_mask is not None:
|