Crystalcareai commited on
Commit
1d25390
·
verified ·
1 Parent(s): 974e6b8

Update modeling_quiet.py

Browse files
Files changed (1) hide show
  1. modeling_quiet.py +1 -1
modeling_quiet.py CHANGED
@@ -100,7 +100,7 @@ def _prepare_4d_causal_attention_mask_for_sdpa(attention_mask, input_shape, inpu
100
  # - if the model is a decoder, apply a causal mask in addition to the padding mask
101
  # - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length]
102
  if past_key_values_length > 0:
103
- attention_mask = attention_mask.to(dtype=torch.long)
104
  attention_mask = attention_mask[:, past_key_values_length:]
105
  expanded_attn_mask = attention_mask[:, None, None, :]
106
  combined_attention_mask = expanded_attn_mask
 
100
  # - if the model is a decoder, apply a causal mask in addition to the padding mask
101
  # - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length]
102
  if past_key_values_length > 0:
103
+ attention_mask = attention_mask.to(dtype=torch.bfloat16)
104
  attention_mask = attention_mask[:, past_key_values_length:]
105
  expanded_attn_mask = attention_mask[:, None, None, :]
106
  combined_attention_mask = expanded_attn_mask