Upload modeling_moment.py
Browse files- modeling_moment.py +1 -4
modeling_moment.py
CHANGED
@@ -449,15 +449,12 @@ class MomentEmbeddingModel(MomentPreTrainedModel):
|
|
449 |
outputs = self.encoder(inputs_embeds=enc_in, attention_mask=attention_mask)
|
450 |
enc_out = outputs.last_hidden_state
|
451 |
|
452 |
-
# For Mists model
|
453 |
-
hidden_states = outputs.last_hidden_state
|
454 |
-
|
455 |
enc_out = enc_out.reshape((-1, n_channels, n_patches, self.config.d_model))
|
456 |
# [batch_size x n_channels x n_patches x d_model]
|
457 |
|
458 |
# For Mists model
|
459 |
# [batch_size, n_channels x n_patches, d_model]
|
460 |
-
|
461 |
|
462 |
if reduction == "mean":
|
463 |
enc_out = enc_out.mean(dim=1, keepdim=False) # Mean across channels
|
|
|
449 |
outputs = self.encoder(inputs_embeds=enc_in, attention_mask=attention_mask)
|
450 |
enc_out = outputs.last_hidden_state
|
451 |
|
|
|
|
|
|
|
452 |
enc_out = enc_out.reshape((-1, n_channels, n_patches, self.config.d_model))
|
453 |
# [batch_size x n_channels x n_patches x d_model]
|
454 |
|
455 |
# For Mists model
|
456 |
# [batch_size, n_channels x n_patches, d_model]
|
457 |
+
hidden_states = enc_out.reshape(batch_size, n_channels * n_patches, self.config.d_model)
|
458 |
|
459 |
if reduction == "mean":
|
460 |
enc_out = enc_out.mean(dim=1, keepdim=False) # Mean across channels
|