File size: 831 Bytes
c433e44
 
b2a8724
 
c433e44
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
import torch
from transformers import AutoConfig, AutoModel
from configuration_davit import DaViTConfig
from modeling_davit import DaViTModel

# Register the configuration and model
AutoConfig.register("davit", DaViTConfig)
AutoModel.register(DaViTConfig, DaViTModel)

# Step 1: Create a configuration object
config = DaViTConfig()

# Step 2: Create a model object
model = AutoModel.from_config(config)

# Step 3: Run a forward pass
# Generate a random sample input tensor with shape (batch_size, channels, height, width)
batch_size = 2
channels = 3
height = 224
width = 224
sample_input = torch.randn(batch_size, channels, height, width)

# Pass the sample input through the model
output = model(sample_input)

# Print the output shape
print(f"Output shape: {output.shape}")

# Expected output shape: (batch_size, projection_dim)