Crystalcareai commited on
Commit
5049df3
·
verified ·
1 Parent(s): 11aeabd

Update modeling_quiet.py

Browse files
Files changed (1) hide show
  1. modeling_quiet.py +2 -2
modeling_quiet.py CHANGED
@@ -540,7 +540,7 @@ class QuietFlashAttention2(QuietAttention):
540
  value_states = value_states.to(target_dtype)
541
 
542
  # Compute the causal mask
543
- causal = self.config.causal
544
  if causal:
545
  if self._flash_attn_uses_top_left_mask:
546
  # Compute the causal mask
@@ -768,7 +768,7 @@ class QuietSdpaAttention(QuietAttention):
768
  attn_mask=attention_mask.to(query_states.device) if attention_mask is not None else None,
769
  dropout_p=self.attention_dropout if self.training else 0.0,
770
  # The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1.
771
- causal=self.is_causal and attention_mask is None and q_len > 1,
772
  )
773
 
774
  attn_output = attn_output.transpose(1, 2).contiguous()
 
540
  value_states = value_states.to(target_dtype)
541
 
542
  # Compute the causal mask
543
+ causal = self.config.is_causal
544
  if causal:
545
  if self._flash_attn_uses_top_left_mask:
546
  # Compute the causal mask
 
768
  attn_mask=attention_mask.to(query_states.device) if attention_mask is not None else None,
769
  dropout_p=self.attention_dropout if self.training else 0.0,
770
  # The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1.
771
+ is_causal=self.is_causal and attention_mask is None and q_len > 1,
772
  )
773
 
774
  attn_output = attn_output.transpose(1, 2).contiguous()