Upload modeling_moment.py
Browse files- modeling_moment.py +15 -5
modeling_moment.py
CHANGED
@@ -432,6 +432,7 @@ class MomentEmbeddingModel(MomentPreTrainedModel):
|
|
432 |
x_enc = self.normalizer(x=x_enc, mask=input_mask, mode="norm")
|
433 |
x_enc = torch.nan_to_num(x_enc, nan=0, posinf=0, neginf=0)
|
434 |
|
|
|
435 |
input_mask_patch_view = Masking.convert_seq_to_patch_view(
|
436 |
input_mask, self.patch_len
|
437 |
)
|
@@ -453,11 +454,6 @@ class MomentEmbeddingModel(MomentPreTrainedModel):
|
|
453 |
enc_out = enc_out.reshape((-1, n_channels, n_patches, self.config.d_model))
|
454 |
# [batch_size x n_channels x n_patches x d_model]
|
455 |
|
456 |
-
# For Mists model
|
457 |
-
# [batch_size, n_channels x n_patches, d_model]
|
458 |
-
# Ensure hidden_states are consistent for both short and long inputs with input_mask specified
|
459 |
-
hidden_states = hidden_states.reshape(batch_size, n_channels, n_patches, self.config.d_model).transpose(1, 2).reshape(batch_size, -1, self.config.d_model)
|
460 |
-
|
461 |
if reduction == "mean":
|
462 |
enc_out = enc_out.mean(dim=1, keepdim=False) # Mean across channels
|
463 |
# [batch_size x n_patches x d_model]
|
@@ -469,6 +465,20 @@ class MomentEmbeddingModel(MomentPreTrainedModel):
|
|
469 |
) / input_mask_patch_view.sum(dim=1)
|
470 |
else:
|
471 |
raise NotImplementedError(f"Reduction method {reduction} not implemented.")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
472 |
|
473 |
return TimeseriesOutputs(
|
474 |
embeddings=enc_out, input_mask=input_mask, metadata=reduction, hidden_states=hidden_states
|
|
|
432 |
x_enc = self.normalizer(x=x_enc, mask=input_mask, mode="norm")
|
433 |
x_enc = torch.nan_to_num(x_enc, nan=0, posinf=0, neginf=0)
|
434 |
|
435 |
+
# [batch_size x n_patches]
|
436 |
input_mask_patch_view = Masking.convert_seq_to_patch_view(
|
437 |
input_mask, self.patch_len
|
438 |
)
|
|
|
454 |
enc_out = enc_out.reshape((-1, n_channels, n_patches, self.config.d_model))
|
455 |
# [batch_size x n_channels x n_patches x d_model]
|
456 |
|
|
|
|
|
|
|
|
|
|
|
457 |
if reduction == "mean":
|
458 |
enc_out = enc_out.mean(dim=1, keepdim=False) # Mean across channels
|
459 |
# [batch_size x n_patches x d_model]
|
|
|
465 |
) / input_mask_patch_view.sum(dim=1)
|
466 |
else:
|
467 |
raise NotImplementedError(f"Reduction method {reduction} not implemented.")
|
468 |
+
|
469 |
+
# For Mists model
|
470 |
+
# [batch_size, n_channels x n_patches, d_model]
|
471 |
+
# Ensure hidden_states are consistent for both short and long inputs with input_mask specified
|
472 |
+
# hidden_states = hidden_states.reshape(batch_size, n_channels, n_patches, self.config.d_model).transpose(1, 2).reshape(batch_size, -1, self.config.d_model)
|
473 |
+
# [batch_size x n_channels x n_patches x d_model]
|
474 |
+
hidden_states = hidden_states.reshape(batch_size, n_channels, n_patches, self.config.d_model)
|
475 |
+
# [batch_size x n_patches]
|
476 |
+
input_mask_patch_view_for_hidden_states = Masking.convert_seq_to_patch_view(input_mask, self.patch_len)
|
477 |
+
# [batch_size x n_channels x n_patches x d_model]
|
478 |
+
input_mask_patch_view_for_hidden_states = input_mask_patch_view_for_hidden_states.unsqueeze(-1).repeat(
|
479 |
+
1, n_channels, 1, self.config.d_model
|
480 |
+
)
|
481 |
+
hidden_states = input_mask_patch_view_for_hidden_states * hidden_states
|
482 |
|
483 |
return TimeseriesOutputs(
|
484 |
embeddings=enc_out, input_mask=input_mask, metadata=reduction, hidden_states=hidden_states
|