SG3D-Demo / leo /img_encoder.py
zfzhang-thu
non-LFS commit
9de012e
import timm
import torch
import numpy as np
import torch.nn as nn
from einops import rearrange
def disabled_train(self, mode=True):
"""
Overwrite model.train with this function to make sure train/eval mode does not change anymore
"""
return self
def simple_conv_and_linear_weights_init(m):
if type(m) in [
nn.Conv1d,
nn.Conv2d,
nn.Conv3d,
nn.ConvTranspose1d,
nn.ConvTranspose2d,
nn.ConvTranspose3d,
]:
weight_shape = list(m.weight.data.size())
fan_in = np.prod(weight_shape[1:4])
fan_out = np.prod(weight_shape[2:4]) * weight_shape[0]
w_bound = np.sqrt(6.0 / (fan_in + fan_out))
m.weight.data.uniform_(-w_bound, w_bound)
if m.bias is not None:
m.bias.data.fill_(0)
elif type(m) == nn.Linear:
simple_linear_weights_init(m)
def simple_linear_weights_init(m):
if type(m) == nn.Linear:
weight_shape = list(m.weight.data.size())
fan_in = weight_shape[1]
fan_out = weight_shape[0]
w_bound = np.sqrt(6.0 / (fan_in + fan_out))
m.weight.data.uniform_(-w_bound, w_bound)
if m.bias is not None:
m.bias.data.fill_(0)
class Backbone2DWrapper(nn.Module):
def __init__(self, model, tag, freeze=True):
super().__init__()
self.model = model
self.tag = tag
self.freeze = freeze
if 'convnext' in tag:
self.out_channels = 1024
elif 'swin' in tag:
self.out_channels = 1024
elif 'vit' in tag:
self.out_channels = 768
elif 'resnet' in tag:
self.out_channels = 2048
else:
raise NotImplementedError
if freeze:
for param in self.parameters():
param.requires_grad = False
self.eval()
self.train = disabled_train
def forward_normal(self, x, flat_output=False):
feat = self.model.forward_features(x)
if 'swin' in self.tag:
feat = rearrange(feat, 'b h w c -> b c h w')
if 'vit_base_32_timm_laion2b' in self.tag or 'vit_base_32_timm_openai' in self.tag:
# TODO: [CLS] is prepended to the patches.
feat = rearrange(feat[:, 1:], 'b (h w) c -> b c h w', h=7)
if flat_output:
feat = rearrange(feat, 'b c h w -> b (h w) c')
return feat
@torch.no_grad()
def forward_frozen(self, x, flat_output=False):
return self.forward_normal(x, flat_output)
def forward(self, x, flat_output=False):
if self.freeze:
return self.forward_frozen(x, flat_output)
else:
return self.forward_normal(x, flat_output)
def convnext_base_laion2b(pretrained=False, freeze=True, **kwargs):
m = timm.create_model(
'convnext_base.clip_laion2b',
pretrained=pretrained
)
if kwargs.get('reset_clip_s2b2'):
s = m.state_dict()
for i in s.keys():
if 'stages.3.blocks.2' in i and ('weight' in i or 'bias' in i):
s[i].normal_()
m.load_state_dict(s, strict=True)
return Backbone2DWrapper(m, 'convnext_base_laion2b', freeze=freeze)
class GridFeatureExtractor2D(nn.Module):
def __init__(self, backbone_name='convnext_base', backbone_pretrain_dataset='laion2b', use_pretrain=True, freeze=True, pooling='avg'):
super().__init__()
init_func_name = '_'.join([backbone_name, backbone_pretrain_dataset])
init_func = globals().get(init_func_name)
if init_func and callable(init_func):
self.backbone = init_func(pretrained=use_pretrain, freeze=freeze)
else:
raise NotImplementedError(f"Backbone2D does not support {init_func_name}")
self.pooling = pooling
if self.pooling:
if self.pooling == 'avg':
self.pooling_layers = nn.Sequential(
nn.AdaptiveAvgPool2d(output_size=(1,1)),
nn.Flatten()
)
self.out_channels = self.backbone.out_channels
elif self.pooling == 'conv':
self.pooling_layers = nn.Sequential(
nn.Conv2d(self.backbone.out_channels, 64, 1),
nn.ReLU(inplace=True),
nn.Conv2d(64, 32, 1),
nn.Flatten()
)
self.pooling_layers.apply(simple_conv_and_linear_weights_init)
self.out_channels = 32 * 7 * 7 # hardcode for 224x224
elif self.pooling in ['attn', 'attention']:
self.visual_attention = nn.Sequential(
nn.Conv2d(self.backbone.out_channels, self.backbone.out_channels, 1),
nn.ReLU(inplace=True),
nn.Conv2d(self.backbone.out_channels, self.backbone.out_channels, 1),
)
self.visual_attention.apply(simple_conv_and_linear_weights_init)
def _attention_pooling(x):
B, C, H, W = x.size()
attn = self.visual_attention(x)
attn = attn.view(B, C, -1)
x = x.view(B, C, -1)
attn = attn.softmax(dim=-1)
x = torch.einsum('b c n, b c n -> b c', x, x)
return x
self.pooling_layers = _attention_pooling
self.out_channels = self.backbone.out_channels
else:
raise NotImplementedError(f"Backbone2D does not support {self.pooling} pooling")
else:
self.out_channels = self.backbone.out_channels
def forward(self, x):
if self.pooling:
x = self.backbone(x, flat_output=False)
x = self.pooling_layers(x).unsqueeze(1)
return x
else:
return self.backbone(x, flat_output=True)