Upload modeling_moment.py
Browse files- modeling_moment.py +1 -0
modeling_moment.py
CHANGED
@@ -456,6 +456,7 @@ class MomentEmbeddingModel(MomentPreTrainedModel):
|
|
456 |
# For Mists model
|
457 |
# [batch_size, n_channels x n_patches, d_model]
|
458 |
# hidden_states = enc_out.reshape(batch_size, n_channels * n_patches, self.config.d_model)
|
|
|
459 |
|
460 |
if reduction == "mean":
|
461 |
enc_out = enc_out.mean(dim=1, keepdim=False) # Mean across channels
|
|
|
456 |
# For Mists model
|
457 |
# [batch_size, n_channels x n_patches, d_model]
|
458 |
# hidden_states = enc_out.reshape(batch_size, n_channels * n_patches, self.config.d_model)
|
459 |
+
hidden_states = hidden_states.reshape(batch_size, n_channels, n_patches, self.config.d_model).transpose(1, 2).reshape(batch_size, -1, self.config.d_model)
|
460 |
|
461 |
if reduction == "mean":
|
462 |
enc_out = enc_out.mean(dim=1, keepdim=False) # Mean across channels
|