import torch | |
from transformers import AutoConfig, AutoModel | |
from .configuration_davit import DaViTConfig | |
from .modeling_davit import DaViTModel | |
# Register the configuration and model | |
AutoConfig.register("davit", DaViTConfig) | |
AutoModel.register(DaViTConfig, DaViTModel) | |
# Step 1: Create a configuration object | |
config = DaViTConfig() | |
# Step 2: Create a model object | |
model = AutoModel.from_config(config) | |
# Step 3: Run a forward pass | |
# Generate a random sample input tensor with shape (batch_size, channels, height, width) | |
batch_size = 2 | |
channels = 3 | |
height = 224 | |
width = 224 | |
sample_input = torch.randn(batch_size, channels, height, width) | |
# Pass the sample input through the model | |
output = model(sample_input) | |
# Print the output shape | |
print(f"Output shape: {output.shape}") | |
# Expected output shape: (batch_size, projection_dim) | |