File size: 831 Bytes
c433e44 b2a8724 c433e44 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 |
import torch
from transformers import AutoConfig, AutoModel
from configuration_davit import DaViTConfig
from modeling_davit import DaViTModel
# Register the configuration and model
AutoConfig.register("davit", DaViTConfig)
AutoModel.register(DaViTConfig, DaViTModel)
# Step 1: Create a configuration object
config = DaViTConfig()
# Step 2: Create a model object
model = AutoModel.from_config(config)
# Step 3: Run a forward pass
# Generate a random sample input tensor with shape (batch_size, channels, height, width)
batch_size = 2
channels = 3
height = 224
width = 224
sample_input = torch.randn(batch_size, channels, height, width)
# Pass the sample input through the model
output = model(sample_input)
# Print the output shape
print(f"Output shape: {output.shape}")
# Expected output shape: (batch_size, projection_dim)
|