DaViT / test_davit_model.py
amaye15's picture
Upload folder using huggingface_hub
c433e44 verified
raw
history blame
833 Bytes
import torch
from transformers import AutoConfig, AutoModel
from .configuration_davit import DaViTConfig
from .modeling_davit import DaViTModel
# Register the configuration and model
AutoConfig.register("davit", DaViTConfig)
AutoModel.register(DaViTConfig, DaViTModel)
# Step 1: Create a configuration object
config = DaViTConfig()
# Step 2: Create a model object
model = AutoModel.from_config(config)
# Step 3: Run a forward pass
# Generate a random sample input tensor with shape (batch_size, channels, height, width)
batch_size = 2
channels = 3
height = 224
width = 224
sample_input = torch.randn(batch_size, channels, height, width)
# Pass the sample input through the model
output = model(sample_input)
# Print the output shape
print(f"Output shape: {output.shape}")
# Expected output shape: (batch_size, projection_dim)