Crystalcareai
commited on
Update modeling_quiet.py
Browse files- modeling_quiet.py +2 -2
modeling_quiet.py
CHANGED
@@ -540,7 +540,7 @@ class QuietFlashAttention2(QuietAttention):
|
|
540 |
value_states = value_states.to(target_dtype)
|
541 |
|
542 |
# Compute the causal mask
|
543 |
-
causal = self.config.
|
544 |
if causal:
|
545 |
if self._flash_attn_uses_top_left_mask:
|
546 |
# Compute the causal mask
|
@@ -768,7 +768,7 @@ class QuietSdpaAttention(QuietAttention):
|
|
768 |
attn_mask=attention_mask.to(query_states.device) if attention_mask is not None else None,
|
769 |
dropout_p=self.attention_dropout if self.training else 0.0,
|
770 |
# The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1.
|
771 |
-
|
772 |
)
|
773 |
|
774 |
attn_output = attn_output.transpose(1, 2).contiguous()
|
|
|
540 |
value_states = value_states.to(target_dtype)
|
541 |
|
542 |
# Compute the causal mask
|
543 |
+
causal = self.config.is_causal
|
544 |
if causal:
|
545 |
if self._flash_attn_uses_top_left_mask:
|
546 |
# Compute the causal mask
|
|
|
768 |
attn_mask=attention_mask.to(query_states.device) if attention_mask is not None else None,
|
769 |
dropout_p=self.attention_dropout if self.training else 0.0,
|
770 |
# The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1.
|
771 |
+
is_causal=self.is_causal and attention_mask is None and q_len > 1,
|
772 |
)
|
773 |
|
774 |
attn_output = attn_output.transpose(1, 2).contiguous()
|