Crystalcareai commited on
Commit
4eaaef0
·
verified ·
1 Parent(s): 7d4d670

Update modeling_quiet.py

Browse files
Files changed (1) hide show
  1. modeling_quiet.py +1 -1
modeling_quiet.py CHANGED
@@ -774,7 +774,7 @@ class QuietSdpaAttention(QuietAttention):
774
  raise ValueError(
775
  f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
776
  )
777
-
778
  # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
779
  # Reference: https://github.com/pytorch/pytorch/issues/112577.
780
  if query_states.device.type == "cuda" and attention_mask is not None:
 
774
  raise ValueError(
775
  f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
776
  )
777
+ attention_mask = attention_mask.to(query_states.dtype)
778
  # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
779
  # Reference: https://github.com/pytorch/pytorch/issues/112577.
780
  if query_states.device.type == "cuda" and attention_mask is not None: