Eden-Multimodal / models /projection_layer.py
Himank Jain
initial commit
c6a12ae
raw
history blame contribute delete
452 Bytes
import torch.nn as nn
class ProjectionBlock(nn.Module):
def __init__(self, input_dim_CLIP, input_dim_phi2):
super().__init__()
self.pre_norm = nn.LayerNorm(input_dim_CLIP)
self.proj = nn.Sequential(
nn.Linear(input_dim_CLIP, input_dim_phi2),
nn.GELU(),
nn.Linear(input_dim_phi2, input_dim_phi2)
)
def forward(self, x):
x = self.pre_norm(x)
return self.proj(x)