Upload modeling_moment.py
Browse files- modeling_moment.py +2 -0
modeling_moment.py
CHANGED
@@ -485,6 +485,8 @@ class MomentEmbeddingModel(MomentPreTrainedModel):
|
|
485 |
|
486 |
# [batch_size x n_patches]
|
487 |
input_mask_patch_view_for_mists = Masking.convert_seq_to_patch_view(input_mask, self.patch_len)
|
|
|
|
|
488 |
|
489 |
return TimeseriesOutputs(
|
490 |
embeddings=enc_out, input_mask=input_mask, metadata=reduction, hidden_states=hidden_states, input_mask_patch_view=input_mask_patch_view_for_mists
|
|
|
485 |
|
486 |
# [batch_size x n_patches]
|
487 |
input_mask_patch_view_for_mists = Masking.convert_seq_to_patch_view(input_mask, self.patch_len)
|
488 |
+
# [batch_size, n_channels x n_patches]
|
489 |
+
input_mask_patch_view_for_mists = input_mask_patch_view_for_mists.repeat_interleave(n_channels, dim=1)
|
490 |
|
491 |
return TimeseriesOutputs(
|
492 |
embeddings=enc_out, input_mask=input_mask, metadata=reduction, hidden_states=hidden_states, input_mask_patch_view=input_mask_patch_view_for_mists
|