Crystalcareai
commited on
Update modeling_quiet.py
Browse files- modeling_quiet.py +1 -1
modeling_quiet.py
CHANGED
@@ -100,7 +100,7 @@ def _prepare_4d_causal_attention_mask_for_sdpa(attention_mask, input_shape, inpu
|
|
100 |
# - if the model is a decoder, apply a causal mask in addition to the padding mask
|
101 |
# - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length]
|
102 |
if past_key_values_length > 0:
|
103 |
-
attention_mask = attention_mask.to(dtype=torch.
|
104 |
attention_mask = attention_mask[:, past_key_values_length:]
|
105 |
expanded_attn_mask = attention_mask[:, None, None, :]
|
106 |
combined_attention_mask = expanded_attn_mask
|
|
|
100 |
# - if the model is a decoder, apply a causal mask in addition to the padding mask
|
101 |
# - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length]
|
102 |
if past_key_values_length > 0:
|
103 |
+
attention_mask = attention_mask.to(dtype=torch.bfloat16)
|
104 |
attention_mask = attention_mask[:, past_key_values_length:]
|
105 |
expanded_attn_mask = attention_mask[:, None, None, :]
|
106 |
combined_attention_mask = expanded_attn_mask
|