VictorSanh
commited on
fix discrepancy of speed in the case of full attention mask
Browse files- modeling_siglip.py +12 -5
modeling_siglip.py
CHANGED
@@ -1121,14 +1121,21 @@ class SiglipVisionTransformer(nn.Module):
|
|
1121 |
hidden_states = self.embeddings(pixel_values=pixel_values, patch_attention_mask=patch_attention_mask)
|
1122 |
|
1123 |
patch_attention_mask = patch_attention_mask.view(batch_size, -1)
|
1124 |
-
|
1125 |
-
|
1126 |
-
|
1127 |
-
|
|
|
|
|
|
|
1128 |
_prepare_4d_attention_mask(patch_attention_mask, hidden_states.dtype)
|
1129 |
if not self.config._flash_attn_2_enabled
|
1130 |
else patch_attention_mask
|
1131 |
-
)
|
|
|
|
|
|
|
|
|
1132 |
output_attentions=output_attentions,
|
1133 |
output_hidden_states=output_hidden_states,
|
1134 |
return_dict=return_dict,
|
|
|
1121 |
hidden_states = self.embeddings(pixel_values=pixel_values, patch_attention_mask=patch_attention_mask)
|
1122 |
|
1123 |
patch_attention_mask = patch_attention_mask.view(batch_size, -1)
|
1124 |
+
# The call to `_upad_input` in `_flash_attention_forward` is expensive
|
1125 |
+
# So when the `patch_attention_mask` is full of 1s (i.e. attending to the whole sequence),
|
1126 |
+
# avoiding passing the attention_mask, which is equivalent to attending to the full sequence
|
1127 |
+
if not torch.any(~patch_attention_mask):
|
1128 |
+
attention_mask=None
|
1129 |
+
else:
|
1130 |
+
attention_mask = (
|
1131 |
_prepare_4d_attention_mask(patch_attention_mask, hidden_states.dtype)
|
1132 |
if not self.config._flash_attn_2_enabled
|
1133 |
else patch_attention_mask
|
1134 |
+
)
|
1135 |
+
|
1136 |
+
encoder_outputs = self.encoder(
|
1137 |
+
inputs_embeds=hidden_states,
|
1138 |
+
attention_mask=attention_mask,
|
1139 |
output_attentions=output_attentions,
|
1140 |
output_hidden_states=output_hidden_states,
|
1141 |
return_dict=return_dict,
|