Upload modeling_moment.py
Browse files- modeling_moment.py +1 -1
modeling_moment.py
CHANGED
@@ -478,7 +478,7 @@ class MomentEmbeddingModel(MomentPreTrainedModel):
|
|
478 |
input_mask_patch_view_for_hidden_states = Masking.convert_seq_to_patch_view(input_mask, self.patch_len)
|
479 |
# [batch_size x n_channels x n_patches x d_model]
|
480 |
print("*input_mask_patch_view_for_hidden_states: ", input_mask_patch_view_for_hidden_states.shape)
|
481 |
-
input_mask_patch_view_for_hidden_states = input_mask_patch_view_for_hidden_states.unsqueeze(-1).repeat(
|
482 |
1, n_channels, 1, self.config.d_model
|
483 |
)
|
484 |
# [batch_size x n_channels x n_patches x d_model]
|
|
|
478 |
input_mask_patch_view_for_hidden_states = Masking.convert_seq_to_patch_view(input_mask, self.patch_len)
|
479 |
# [batch_size x n_channels x n_patches x d_model]
|
480 |
print("*input_mask_patch_view_for_hidden_states: ", input_mask_patch_view_for_hidden_states.shape)
|
481 |
+
input_mask_patch_view_for_hidden_states = input_mask_patch_view_for_hidden_states.unsqueeze(1).unsqueeze(-1).repeat(
|
482 |
1, n_channels, 1, self.config.d_model
|
483 |
)
|
484 |
# [batch_size x n_channels x n_patches x d_model]
|