HachiML commited on
Commit
352385f
·
verified ·
1 Parent(s): 2f63acb

Upload modeling_moment.py

Browse files
Files changed (1) hide show
  1. modeling_moment.py +4 -2
modeling_moment.py CHANGED
@@ -481,10 +481,12 @@ class MomentEmbeddingModel(MomentPreTrainedModel):
481
  hidden_states = input_mask_patch_view_for_hidden_states * hidden_states
482
  # [batch_size, n_channels x n_patches, d_model]
483
  hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, self.config.d_model)
484
- input_mask_patch_view_for_hidden_states = input_mask_patch_view_for_hidden_states.transpose(1, 2).reshape(batch_size, -1, self.config.d_model)
 
 
485
 
486
  return TimeseriesOutputs(
487
- embeddings=enc_out, input_mask=input_mask, metadata=reduction, hidden_states=hidden_states, input_mask_patch_view=input_mask_patch_view_for_hidden_states,
488
  )
489
 
490
  def forward(
 
481
  hidden_states = input_mask_patch_view_for_hidden_states * hidden_states
482
  # [batch_size, n_channels x n_patches, d_model]
483
  hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, self.config.d_model)
484
+
485
+ # [batch_size x n_patches]
486
+ input_mask_patch_view_for_mists = Masking.convert_seq_to_patch_view(input_mask, self.patch_len)
487
 
488
  return TimeseriesOutputs(
489
+ embeddings=enc_out, input_mask=input_mask, metadata=reduction, hidden_states=hidden_states, input_mask_patch_view=input_mask_patch_view_for_mists
490
  )
491
 
492
  def forward(