HachiML commited on
Commit
0d02263
·
verified ·
1 Parent(s): 055918c

Upload modeling_moment.py

Browse files
Files changed (1) hide show
  1. modeling_moment.py +0 -6
modeling_moment.py CHANGED
@@ -430,8 +430,6 @@ class MomentEmbeddingModel(MomentPreTrainedModel):
430
  if input_mask is None:
431
  input_mask = torch.ones((batch_size, seq_len)).to(x_enc.device)
432
 
433
- print("*input_mask: ", input_mask.shape)
434
-
435
  x_enc = self.normalizer(x=x_enc, mask=input_mask, mode="norm")
436
  x_enc = torch.nan_to_num(x_enc, nan=0, posinf=0, neginf=0)
437
 
@@ -474,17 +472,13 @@ class MomentEmbeddingModel(MomentPreTrainedModel):
474
  # Ensure hidden_states are consistent for both short and long inputs with input_mask specified
475
  # hidden_states = hidden_states.reshape(batch_size, n_channels, n_patches, self.config.d_model).transpose(1, 2).reshape(batch_size, -1, self.config.d_model)
476
  # [batch_size x n_patches]
477
- print("*input_mask: ", input_mask.shape)
478
  input_mask_patch_view_for_hidden_states = Masking.convert_seq_to_patch_view(input_mask, self.patch_len)
479
  # [batch_size x n_channels x n_patches x d_model]
480
- print("*input_mask_patch_view_for_hidden_states: ", input_mask_patch_view_for_hidden_states.shape)
481
  input_mask_patch_view_for_hidden_states = input_mask_patch_view_for_hidden_states.unsqueeze(1).unsqueeze(-1).repeat(
482
  1, n_channels, 1, self.config.d_model
483
  )
484
  # [batch_size x n_channels x n_patches x d_model]
485
  hidden_states = hidden_states.reshape(batch_size, n_channels, n_patches, self.config.d_model)
486
- print("*input_mask_patch_view_for_hidden_states: ", input_mask_patch_view_for_hidden_states.shape)
487
- print("*hidden_states: ", hidden_states.shape)
488
  hidden_states = input_mask_patch_view_for_hidden_states * hidden_states
489
  # [batch_size, n_channels x n_patches, d_model]
490
  hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, self.config.d_model)
 
430
  if input_mask is None:
431
  input_mask = torch.ones((batch_size, seq_len)).to(x_enc.device)
432
 
 
 
433
  x_enc = self.normalizer(x=x_enc, mask=input_mask, mode="norm")
434
  x_enc = torch.nan_to_num(x_enc, nan=0, posinf=0, neginf=0)
435
 
 
472
  # Ensure hidden_states are consistent for both short and long inputs with input_mask specified
473
  # hidden_states = hidden_states.reshape(batch_size, n_channels, n_patches, self.config.d_model).transpose(1, 2).reshape(batch_size, -1, self.config.d_model)
474
  # [batch_size x n_patches]
 
475
  input_mask_patch_view_for_hidden_states = Masking.convert_seq_to_patch_view(input_mask, self.patch_len)
476
  # [batch_size x n_channels x n_patches x d_model]
 
477
  input_mask_patch_view_for_hidden_states = input_mask_patch_view_for_hidden_states.unsqueeze(1).unsqueeze(-1).repeat(
478
  1, n_channels, 1, self.config.d_model
479
  )
480
  # [batch_size x n_channels x n_patches x d_model]
481
  hidden_states = hidden_states.reshape(batch_size, n_channels, n_patches, self.config.d_model)
 
 
482
  hidden_states = input_mask_patch_view_for_hidden_states * hidden_states
483
  # [batch_size, n_channels x n_patches, d_model]
484
  hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, self.config.d_model)