|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
from timm.models.layers import trunc_normal_, DropPath, to_2tuple |
|
from timm.models.registry import register_model |
|
from timm.models.layers import trunc_normal_, SelectAdaptivePool2d, DropPath, ConvMlp, Mlp, LayerNorm2d, LayerNorm, \ |
|
create_conv2d, get_act_layer, make_divisible, to_ntuple |
|
from functools import partial |
|
import torch.utils.checkpoint as checkpoint |
|
try: |
|
from huggingface_hub import hf_hub_download |
|
except: |
|
hf_hub_download = None |
|
|
|
has_mmdet = False |
|
has_mmseg = False |
|
|
|
|
|
|
|
try: |
|
from mmseg.models.builder import BACKBONES as seg_BACKBONES |
|
from mmseg.utils import get_root_logger |
|
from mmcv.runner import _load_checkpoint |
|
has_mmseg = True |
|
except ImportError: |
|
get_root_logger = None |
|
_load_checkpoint = None |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class GRNwithNHWC(nn.Module): |
|
""" GRN (Global Response Normalization) layer |
|
Originally proposed in ConvNeXt V2 (https://arxiv.org/abs/2301.00808) |
|
This implementation is more efficient than the original (https://github.com/facebookresearch/ConvNeXt-V2) |
|
We assume the inputs to this layer are (N, H, W, C) |
|
""" |
|
def __init__(self, dim, use_bias=True): |
|
super().__init__() |
|
self.use_bias = use_bias |
|
self.gamma = nn.Parameter(torch.zeros(1, 1, 1, dim)) |
|
if self.use_bias: |
|
self.beta = nn.Parameter(torch.zeros(1, 1, 1, dim)) |
|
|
|
def forward(self, x): |
|
Gx = torch.norm(x, p=2, dim=(1, 2), keepdim=True) |
|
Nx = Gx / (Gx.mean(dim=-1, keepdim=True) + 1e-6) |
|
if self.use_bias: |
|
return (self.gamma * Nx + 1) * x + self.beta |
|
else: |
|
return (self.gamma * Nx + 1) * x |
|
|
|
|
|
class NCHWtoNHWC(nn.Module): |
|
def __init__(self): |
|
super().__init__() |
|
|
|
def forward(self, x): |
|
return x.permute(0, 2, 3, 1) |
|
|
|
|
|
class NHWCtoNCHW(nn.Module): |
|
def __init__(self): |
|
super().__init__() |
|
|
|
def forward(self, x): |
|
return x.permute(0, 3, 1, 2) |
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_conv2d(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias, |
|
attempt_use_lk_impl=True): |
|
kernel_size = to_2tuple(kernel_size) |
|
if padding is None: |
|
padding = (kernel_size[0] // 2, kernel_size[1] // 2) |
|
else: |
|
padding = to_2tuple(padding) |
|
need_large_impl = kernel_size[0] == kernel_size[1] and kernel_size[0] > 5 and padding == (kernel_size[0] // 2, kernel_size[1] // 2) |
|
|
|
if attempt_use_lk_impl and need_large_impl: |
|
print('---------------- trying to import iGEMM implementation for large-kernel conv') |
|
try: |
|
from depthwise_conv2d_implicit_gemm import DepthWiseConv2dImplicitGEMM |
|
print('---------------- found iGEMM implementation ') |
|
except: |
|
DepthWiseConv2dImplicitGEMM = None |
|
print('---------------- found no iGEMM. use original conv. follow https://github.com/AILab-CVC/UniRepLKNet to install it.') |
|
if DepthWiseConv2dImplicitGEMM is not None and need_large_impl and in_channels == out_channels \ |
|
and out_channels == groups and stride == 1 and dilation == 1: |
|
print(f'===== iGEMM Efficient Conv Impl, channels {in_channels}, kernel size {kernel_size} =====') |
|
return DepthWiseConv2dImplicitGEMM(in_channels, kernel_size, bias=bias) |
|
return nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, |
|
padding=padding, dilation=dilation, groups=groups, bias=bias) |
|
|
|
|
|
def get_bn(dim, use_sync_bn=False): |
|
if use_sync_bn: |
|
return nn.SyncBatchNorm(dim) |
|
else: |
|
return nn.BatchNorm2d(dim) |
|
|
|
class SEBlock(nn.Module): |
|
""" |
|
Squeeze-and-Excitation Block proposed in SENet (https://arxiv.org/abs/1709.01507) |
|
We assume the inputs to this layer are (N, C, H, W) |
|
""" |
|
def __init__(self, input_channels, internal_neurons): |
|
super(SEBlock, self).__init__() |
|
self.down = nn.Conv2d(in_channels=input_channels, out_channels=internal_neurons, |
|
kernel_size=1, stride=1, bias=True) |
|
self.up = nn.Conv2d(in_channels=internal_neurons, out_channels=input_channels, |
|
kernel_size=1, stride=1, bias=True) |
|
self.input_channels = input_channels |
|
self.nonlinear = nn.ReLU(inplace=True) |
|
|
|
def forward(self, inputs): |
|
x = F.adaptive_avg_pool2d(inputs, output_size=(1, 1)) |
|
x = self.down(x) |
|
x = self.nonlinear(x) |
|
x = self.up(x) |
|
x = F.sigmoid(x) |
|
return inputs * x.view(-1, self.input_channels, 1, 1) |
|
|
|
def fuse_bn(conv, bn): |
|
conv_bias = 0 if conv.bias is None else conv.bias |
|
std = (bn.running_var + bn.eps).sqrt() |
|
return conv.weight * (bn.weight / std).reshape(-1, 1, 1, 1), bn.bias + (conv_bias - bn.running_mean) * bn.weight / std |
|
|
|
def convert_dilated_to_nondilated(kernel, dilate_rate): |
|
identity_kernel = torch.ones((1, 1, 1, 1)).to(kernel.device) |
|
if kernel.size(1) == 1: |
|
|
|
dilated = F.conv_transpose2d(kernel, identity_kernel, stride=dilate_rate) |
|
return dilated |
|
else: |
|
|
|
slices = [] |
|
for i in range(kernel.size(1)): |
|
dilated = F.conv_transpose2d(kernel[:,i:i+1,:,:], identity_kernel, stride=dilate_rate) |
|
slices.append(dilated) |
|
return torch.cat(slices, dim=1) |
|
|
|
def merge_dilated_into_large_kernel(large_kernel, dilated_kernel, dilated_r): |
|
large_k = large_kernel.size(2) |
|
dilated_k = dilated_kernel.size(2) |
|
equivalent_kernel_size = dilated_r * (dilated_k - 1) + 1 |
|
equivalent_kernel = convert_dilated_to_nondilated(dilated_kernel, dilated_r) |
|
rows_to_pad = large_k // 2 - equivalent_kernel_size // 2 |
|
merged_kernel = large_kernel + F.pad(equivalent_kernel, [rows_to_pad] * 4) |
|
return merged_kernel |
|
|
|
|
|
class DilatedReparamBlock(nn.Module): |
|
""" |
|
Dilated Reparam Block proposed in UniRepLKNet (https://github.com/AILab-CVC/UniRepLKNet) |
|
We assume the inputs to this block are (N, C, H, W) |
|
""" |
|
def __init__(self, channels, kernel_size, deploy, use_sync_bn=False, attempt_use_lk_impl=True): |
|
super().__init__() |
|
self.lk_origin = get_conv2d(channels, channels, kernel_size, stride=1, |
|
padding=kernel_size//2, dilation=1, groups=channels, bias=deploy, |
|
attempt_use_lk_impl=attempt_use_lk_impl) |
|
self.attempt_use_lk_impl = attempt_use_lk_impl |
|
|
|
|
|
if kernel_size == 17: |
|
self.kernel_sizes = [5, 9, 3, 3, 3] |
|
self.dilates = [1, 2, 4, 5, 7] |
|
elif kernel_size == 15: |
|
self.kernel_sizes = [5, 7, 3, 3, 3] |
|
self.dilates = [1, 2, 3, 5, 7] |
|
elif kernel_size == 13: |
|
self.kernel_sizes = [5, 7, 3, 3, 3] |
|
self.dilates = [1, 2, 3, 4, 5] |
|
elif kernel_size == 11: |
|
self.kernel_sizes = [5, 5, 3, 3, 3] |
|
self.dilates = [1, 2, 3, 4, 5] |
|
elif kernel_size == 9: |
|
self.kernel_sizes = [5, 5, 3, 3] |
|
self.dilates = [1, 2, 3, 4] |
|
elif kernel_size == 7: |
|
self.kernel_sizes = [5, 3, 3] |
|
self.dilates = [1, 2, 3] |
|
elif kernel_size == 5: |
|
self.kernel_sizes = [3, 3] |
|
self.dilates = [1, 2] |
|
else: |
|
raise ValueError('Dilated Reparam Block requires kernel_size >= 5') |
|
|
|
if not deploy: |
|
self.origin_bn = get_bn(channels, use_sync_bn) |
|
for k, r in zip(self.kernel_sizes, self.dilates): |
|
self.__setattr__('dil_conv_k{}_{}'.format(k, r), |
|
nn.Conv2d(in_channels=channels, out_channels=channels, kernel_size=k, stride=1, |
|
padding=(r * (k - 1) + 1) // 2, dilation=r, groups=channels, |
|
bias=False)) |
|
self.__setattr__('dil_bn_k{}_{}'.format(k, r), get_bn(channels, use_sync_bn=use_sync_bn)) |
|
|
|
def forward(self, x): |
|
if not hasattr(self, 'origin_bn'): |
|
return self.lk_origin(x) |
|
out = self.origin_bn(self.lk_origin(x)) |
|
for k, r in zip(self.kernel_sizes, self.dilates): |
|
conv = self.__getattr__('dil_conv_k{}_{}'.format(k, r)) |
|
bn = self.__getattr__('dil_bn_k{}_{}'.format(k, r)) |
|
out = out + bn(conv(x)) |
|
return out |
|
|
|
def merge_dilated_branches(self): |
|
if hasattr(self, 'origin_bn'): |
|
origin_k, origin_b = fuse_bn(self.lk_origin, self.origin_bn) |
|
for k, r in zip(self.kernel_sizes, self.dilates): |
|
conv = self.__getattr__('dil_conv_k{}_{}'.format(k, r)) |
|
bn = self.__getattr__('dil_bn_k{}_{}'.format(k, r)) |
|
branch_k, branch_b = fuse_bn(conv, bn) |
|
origin_k = merge_dilated_into_large_kernel(origin_k, branch_k, r) |
|
origin_b += branch_b |
|
merged_conv = get_conv2d(origin_k.size(0), origin_k.size(0), origin_k.size(2), stride=1, |
|
padding=origin_k.size(2)//2, dilation=1, groups=origin_k.size(0), bias=True, |
|
attempt_use_lk_impl=self.attempt_use_lk_impl) |
|
merged_conv.weight.data = origin_k |
|
merged_conv.bias.data = origin_b |
|
self.lk_origin = merged_conv |
|
self.__delattr__('origin_bn') |
|
for k, r in zip(self.kernel_sizes, self.dilates): |
|
self.__delattr__('dil_conv_k{}_{}'.format(k, r)) |
|
self.__delattr__('dil_bn_k{}_{}'.format(k, r)) |
|
|
|
|
|
class UniRepLKNetBlock(nn.Module): |
|
|
|
def __init__(self, |
|
dim, |
|
kernel_size, |
|
drop_path=0., |
|
layer_scale_init_value=1e-6, |
|
deploy=False, |
|
attempt_use_lk_impl=True, |
|
with_cp=False, |
|
use_sync_bn=False, |
|
ffn_factor=4): |
|
super().__init__() |
|
self.with_cp = with_cp |
|
if deploy: |
|
print('------------------------------- Note: deploy mode') |
|
if self.with_cp: |
|
print('****** note with_cp = True, reduce memory consumption but may slow down training ******') |
|
|
|
if kernel_size == 0: |
|
self.dwconv = nn.Identity() |
|
elif kernel_size >= 7: |
|
self.dwconv = DilatedReparamBlock(dim, kernel_size, deploy=deploy, |
|
use_sync_bn=use_sync_bn, |
|
attempt_use_lk_impl=attempt_use_lk_impl) |
|
|
|
else: |
|
assert kernel_size in [3, 5] |
|
self.dwconv = get_conv2d(dim, dim, kernel_size=kernel_size, stride=1, padding=kernel_size // 2, |
|
dilation=1, groups=dim, bias=deploy, |
|
attempt_use_lk_impl=attempt_use_lk_impl) |
|
|
|
if deploy or kernel_size == 0: |
|
self.norm = nn.Identity() |
|
else: |
|
self.norm = get_bn(dim, use_sync_bn=use_sync_bn) |
|
|
|
self.se = SEBlock(dim, dim // 4) |
|
|
|
ffn_dim = int(ffn_factor * dim) |
|
self.pwconv1 = nn.Sequential( |
|
NCHWtoNHWC(), |
|
nn.Linear(dim, ffn_dim)) |
|
self.act = nn.Sequential( |
|
nn.GELU(), |
|
GRNwithNHWC(ffn_dim, use_bias=not deploy)) |
|
if deploy: |
|
self.pwconv2 = nn.Sequential( |
|
nn.Linear(ffn_dim, dim), |
|
NHWCtoNCHW()) |
|
else: |
|
self.pwconv2 = nn.Sequential( |
|
nn.Linear(ffn_dim, dim, bias=False), |
|
NHWCtoNCHW(), |
|
get_bn(dim, use_sync_bn=use_sync_bn)) |
|
|
|
self.gamma = nn.Parameter(layer_scale_init_value * torch.ones(dim), |
|
requires_grad=True) if (not deploy) and layer_scale_init_value is not None \ |
|
and layer_scale_init_value > 0 else None |
|
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() |
|
|
|
def compute_residual(self, x): |
|
y = self.se(self.norm(self.dwconv(x))) |
|
y = self.pwconv2(self.act(self.pwconv1(y))) |
|
if self.gamma is not None: |
|
y = self.gamma.view(1, -1, 1, 1) * y |
|
return self.drop_path(y) |
|
|
|
def forward(self, inputs): |
|
|
|
def _f(x): |
|
return x + self.compute_residual(x) |
|
|
|
if self.with_cp and inputs.requires_grad: |
|
out = checkpoint.checkpoint(_f, inputs) |
|
else: |
|
out = _f(inputs) |
|
return out |
|
|
|
def reparameterize(self): |
|
if hasattr(self.dwconv, 'merge_dilated_branches'): |
|
self.dwconv.merge_dilated_branches() |
|
if hasattr(self.norm, 'running_var'): |
|
std = (self.norm.running_var + self.norm.eps).sqrt() |
|
if hasattr(self.dwconv, 'lk_origin'): |
|
self.dwconv.lk_origin.weight.data *= (self.norm.weight / std).view(-1, 1, 1, 1) |
|
self.dwconv.lk_origin.bias.data = self.norm.bias + ( |
|
self.dwconv.lk_origin.bias - self.norm.running_mean) * self.norm.weight / std |
|
else: |
|
conv = nn.Conv2d(self.dwconv.in_channels, self.dwconv.out_channels, self.dwconv.kernel_size, |
|
padding=self.dwconv.padding, groups=self.dwconv.groups, bias=True) |
|
conv.weight.data = self.dwconv.weight * (self.norm.weight / std).view(-1, 1, 1, 1) |
|
conv.bias.data = self.norm.bias - self.norm.running_mean * self.norm.weight / std |
|
self.dwconv = conv |
|
self.norm = nn.Identity() |
|
if self.gamma is not None: |
|
final_scale = self.gamma.data |
|
self.gamma = None |
|
else: |
|
final_scale = 1 |
|
if self.act[1].use_bias and len(self.pwconv2) == 3: |
|
grn_bias = self.act[1].beta.data |
|
self.act[1].__delattr__('beta') |
|
self.act[1].use_bias = False |
|
linear = self.pwconv2[0] |
|
grn_bias_projected_bias = (linear.weight.data @ grn_bias.view(-1, 1)).squeeze() |
|
bn = self.pwconv2[2] |
|
std = (bn.running_var + bn.eps).sqrt() |
|
new_linear = nn.Linear(linear.in_features, linear.out_features, bias=True) |
|
new_linear.weight.data = linear.weight * (bn.weight / std * final_scale).view(-1, 1) |
|
linear_bias = 0 if linear.bias is None else linear.bias.data |
|
linear_bias += grn_bias_projected_bias |
|
new_linear.bias.data = (bn.bias + (linear_bias - bn.running_mean) * bn.weight / std) * final_scale |
|
self.pwconv2 = nn.Sequential(new_linear, self.pwconv2[1]) |
|
|
|
|
|
|
|
default_UniRepLKNet_A_F_P_kernel_sizes = ((3, 3), |
|
(13, 13), |
|
(13, 13, 13, 13, 13, 13), |
|
(13, 13)) |
|
default_UniRepLKNet_N_kernel_sizes = ((3, 3), |
|
(13, 13), |
|
(13, 13, 13, 13, 13, 13, 13, 13), |
|
(13, 13)) |
|
default_UniRepLKNet_T_kernel_sizes = ((3, 3, 3), |
|
(13, 13, 13), |
|
(13, 3, 13, 3, 13, 3, 13, 3, 13, 3, 13, 3, 13, 3, 13, 3, 13, 3), |
|
(13, 13, 13)) |
|
default_UniRepLKNet_S_B_L_XL_kernel_sizes = ((3, 3, 3), |
|
(13, 13, 13), |
|
(13, 3, 3, 13, 3, 3, 13, 3, 3, 13, 3, 3, 13, 3, 3, 13, 3, 3, 13, 3, 3, 13, 3, 3, 13, 3, 3), |
|
(13, 13, 13)) |
|
UniRepLKNet_A_F_P_depths = (2, 2, 6, 2) |
|
UniRepLKNet_N_depths = (2, 2, 8, 2) |
|
UniRepLKNet_T_depths = (3, 3, 18, 3) |
|
UniRepLKNet_S_B_L_XL_depths = (3, 3, 27, 3) |
|
|
|
default_depths_to_kernel_sizes = { |
|
UniRepLKNet_A_F_P_depths: default_UniRepLKNet_A_F_P_kernel_sizes, |
|
UniRepLKNet_N_depths: default_UniRepLKNet_N_kernel_sizes, |
|
UniRepLKNet_T_depths: default_UniRepLKNet_T_kernel_sizes, |
|
UniRepLKNet_S_B_L_XL_depths: default_UniRepLKNet_S_B_L_XL_kernel_sizes |
|
} |
|
|
|
class UniRepLKNet(nn.Module): |
|
r""" UniRepLKNet |
|
A PyTorch impl of UniRepLKNet |
|
|
|
Args: |
|
in_chans (int): Number of input image channels. Default: 3 |
|
num_classes (int): Number of classes for classification head. Default: 1000 |
|
depths (tuple(int)): Number of blocks at each stage. Default: (3, 3, 27, 3) |
|
dims (int): Feature dimension at each stage. Default: (96, 192, 384, 768) |
|
drop_path_rate (float): Stochastic depth rate. Default: 0. |
|
layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6. |
|
head_init_scale (float): Init scaling value for classifier weights and biases. Default: 1. |
|
kernel_sizes (tuple(tuple(int))): Kernel size for each block. None means using the default settings. Default: None. |
|
deploy (bool): deploy = True means using the inference structure. Default: False |
|
with_cp (bool): with_cp = True means using torch.utils.checkpoint to save GPU memory. Default: False |
|
init_cfg (dict): weights to load. The easiest way to use UniRepLKNet with for OpenMMLab family. Default: None |
|
attempt_use_lk_impl (bool): try to load the efficient iGEMM large-kernel impl. Setting it to False disabling the iGEMM impl. Default: True |
|
use_sync_bn (bool): use_sync_bn = True means using sync BN. Use it if your batch size is small. Default: False |
|
""" |
|
def __init__(self, |
|
in_chans=3, |
|
num_classes=1000, |
|
depths=(3, 3, 27, 3), |
|
dims=(96, 192, 384, 768), |
|
drop_path_rate=0., |
|
layer_scale_init_value=1e-6, |
|
head_init_scale=1., |
|
kernel_sizes=None, |
|
deploy=False, |
|
with_cp=True, |
|
init_cfg=None, |
|
attempt_use_lk_impl=True, |
|
use_sync_bn=False, |
|
**kwargs |
|
): |
|
super().__init__() |
|
|
|
depths = tuple(depths) |
|
if kernel_sizes is None: |
|
if depths in default_depths_to_kernel_sizes: |
|
print('=========== use default kernel size ') |
|
kernel_sizes = default_depths_to_kernel_sizes[depths] |
|
else: |
|
raise ValueError('no default kernel size settings for the given depths, ' |
|
'please specify kernel sizes for each block, e.g., ' |
|
'((3, 3), (13, 13), (13, 13, 13, 13, 13, 13), (13, 13))') |
|
print(kernel_sizes) |
|
for i in range(4): |
|
assert len(kernel_sizes[i]) == depths[i], 'kernel sizes do not match the depths' |
|
|
|
self.with_cp = with_cp |
|
|
|
dp_rates = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] |
|
print('=========== drop path rates: ', dp_rates) |
|
|
|
self.downsample_layers = nn.ModuleList() |
|
self.downsample_layers.append(nn.Sequential( |
|
nn.Conv2d(in_chans, dims[0] // 2, kernel_size=3, stride=2, padding=1), |
|
LayerNorm(dims[0] // 2, eps=1e-6, data_format="channels_first"), |
|
nn.GELU(), |
|
nn.Conv2d(dims[0] // 2, dims[0], kernel_size=3, stride=2, padding=1), |
|
LayerNorm(dims[0], eps=1e-6, data_format="channels_first"))) |
|
|
|
for i in range(3): |
|
self.downsample_layers.append(nn.Sequential( |
|
nn.Conv2d(dims[i], dims[i + 1], kernel_size=3, stride=2, padding=1), |
|
LayerNorm(dims[i + 1], eps=1e-6, data_format="channels_first"))) |
|
|
|
self.stages = nn.ModuleList() |
|
|
|
cur = 0 |
|
for i in range(4): |
|
main_stage = nn.Sequential( |
|
*[UniRepLKNetBlock(dim=dims[i], kernel_size=kernel_sizes[i][j], drop_path=dp_rates[cur + j], |
|
layer_scale_init_value=layer_scale_init_value, deploy=deploy, |
|
attempt_use_lk_impl=attempt_use_lk_impl, |
|
with_cp=with_cp, use_sync_bn=use_sync_bn) for j in |
|
range(depths[i])]) |
|
self.stages.append(main_stage) |
|
cur += depths[i] |
|
|
|
self.last_channels = dims[-1] |
|
|
|
self.for_pretrain = init_cfg is None |
|
self.for_downstream = not self.for_pretrain |
|
if self.for_downstream: |
|
assert num_classes is None |
|
|
|
if self.for_pretrain: |
|
self.init_cfg = None |
|
self.norm = nn.LayerNorm(self.last_channels, eps=1e-6) |
|
|
|
self.head = nn.Linear(self.last_channels, self.last_channels) |
|
self.apply(self._init_weights) |
|
self.head.weight.data.mul_(head_init_scale) |
|
self.head.bias.data.mul_(head_init_scale) |
|
self.output_mode = 'logits' |
|
else: |
|
self.init_cfg = init_cfg |
|
self.init_weights() |
|
self.output_mode = 'features' |
|
norm_layer = partial(LayerNorm, eps=1e-6, data_format="channels_first") |
|
for i_layer in range(4): |
|
layer = norm_layer(dims[i_layer]) |
|
layer_name = f'norm{i_layer}' |
|
self.add_module(layer_name, layer) |
|
|
|
|
|
|
|
def init_weights(self): |
|
|
|
def load_state_dict(module, state_dict, strict=False, logger=None): |
|
unexpected_keys = [] |
|
own_state = module.state_dict() |
|
for name, param in state_dict.items(): |
|
if name not in own_state: |
|
unexpected_keys.append(name) |
|
continue |
|
if isinstance(param, torch.nn.Parameter): |
|
|
|
param = param.data |
|
try: |
|
own_state[name].copy_(param) |
|
except Exception: |
|
raise RuntimeError( |
|
'While copying the parameter named {}, ' |
|
'whose dimensions in the model are {} and ' |
|
'whose dimensions in the checkpoint are {}.'.format( |
|
name, own_state[name].size(), param.size())) |
|
missing_keys = set(own_state.keys()) - set(state_dict.keys()) |
|
|
|
err_msg = [] |
|
if unexpected_keys: |
|
err_msg.append('unexpected key in source state_dict: {}\n'.format(', '.join(unexpected_keys))) |
|
if missing_keys: |
|
err_msg.append('missing keys in source state_dict: {}\n'.format(', '.join(missing_keys))) |
|
err_msg = '\n'.join(err_msg) |
|
if err_msg: |
|
if strict: |
|
raise RuntimeError(err_msg) |
|
elif logger is not None: |
|
logger.warn(err_msg) |
|
else: |
|
print(err_msg) |
|
|
|
logger = get_root_logger() |
|
assert self.init_cfg is not None |
|
ckpt_path = self.init_cfg['checkpoint'] |
|
if ckpt_path is None: |
|
print('================ Note: init_cfg is provided but I got no init ckpt path, so skip initialization') |
|
else: |
|
ckpt = _load_checkpoint(ckpt_path, logger=logger, map_location='cpu') |
|
if 'state_dict' in ckpt: |
|
_state_dict = ckpt['state_dict'] |
|
elif 'model' in ckpt: |
|
_state_dict = ckpt['model'] |
|
else: |
|
_state_dict = ckpt |
|
|
|
load_state_dict(self, _state_dict, strict=False, logger=logger) |
|
|
|
|
|
def _init_weights(self, m): |
|
if isinstance(m, (nn.Conv2d, nn.Linear)): |
|
trunc_normal_(m.weight, std=.02) |
|
if hasattr(m, 'bias') and m.bias is not None: |
|
nn.init.constant_(m.bias, 0) |
|
|
|
def forward(self, x): |
|
if self.output_mode == 'logits': |
|
for stage_idx in range(4): |
|
x = self.downsample_layers[stage_idx](x) |
|
x = self.stages[stage_idx](x) |
|
x = self.norm(x.mean([-2, -1])) |
|
x = self.head(x) |
|
return x |
|
elif self.output_mode == 'features': |
|
outs = [] |
|
for stage_idx in range(4): |
|
x = self.downsample_layers[stage_idx](x) |
|
x = self.stages[stage_idx](x) |
|
outs.append(self.__getattr__(f'norm{stage_idx}')(x)) |
|
return outs |
|
else: |
|
raise ValueError('Defined new output mode?') |
|
|
|
def reparameterize_unireplknet(self): |
|
for m in self.modules(): |
|
if hasattr(m, 'reparameterize'): |
|
m.reparameterize() |
|
|
|
@torch.jit.ignore |
|
def get_classifier(self): |
|
return self.head.fc |
|
|
|
def reset_classifier(self, num_classes=0, global_pool=None): |
|
if global_pool is not None: |
|
self.head.global_pool = SelectAdaptivePool2d(pool_type=global_pool) |
|
self.head.flatten = nn.Flatten(1) if global_pool else nn.Identity() |
|
|
|
self.head.fc = nn.Linear(self.num_features, self.num_features) if num_classes > 0 else nn.Identity() |
|
|
|
|
|
|
|
|
|
|
|
class LayerNorm(nn.Module): |
|
r""" LayerNorm implementation used in ConvNeXt |
|
LayerNorm that supports two data formats: channels_last (default) or channels_first. |
|
The ordering of the dimensions in the inputs. channels_last corresponds to inputs with |
|
shape (batch_size, height, width, channels) while channels_first corresponds to inputs |
|
with shape (batch_size, channels, height, width). |
|
""" |
|
|
|
def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last", reshape_last_to_first=False): |
|
super().__init__() |
|
self.weight = nn.Parameter(torch.ones(normalized_shape)) |
|
self.bias = nn.Parameter(torch.zeros(normalized_shape)) |
|
self.eps = eps |
|
self.data_format = data_format |
|
if self.data_format not in ["channels_last", "channels_first"]: |
|
raise NotImplementedError |
|
self.normalized_shape = (normalized_shape,) |
|
self.reshape_last_to_first = reshape_last_to_first |
|
|
|
def forward(self, x): |
|
if self.data_format == "channels_last": |
|
return F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps) |
|
elif self.data_format == "channels_first": |
|
u = x.mean(1, keepdim=True) |
|
s = (x - u).pow(2).mean(1, keepdim=True) |
|
x = (x - u) / torch.sqrt(s + self.eps) |
|
x = self.weight[:, None, None] * x + self.bias[:, None, None] |
|
return x |
|
|
|
|
|
|
|
if has_mmdet: |
|
@det_BACKBONES.register_module() |
|
class UniRepLKNetBackbone(UniRepLKNet): |
|
def __init__(self, |
|
depths=(3, 3, 27, 3), |
|
dims=(96, 192, 384, 768), |
|
drop_path_rate=0., |
|
layer_scale_init_value=1e-6, |
|
kernel_sizes=None, |
|
deploy=False, |
|
with_cp=False, |
|
init_cfg=None, |
|
attempt_use_lk_impl=False): |
|
assert init_cfg is not None |
|
super().__init__(in_chans=3, num_classes=None, depths=depths, dims=dims, |
|
drop_path_rate=drop_path_rate, layer_scale_init_value=layer_scale_init_value, |
|
kernel_sizes=kernel_sizes, deploy=deploy, with_cp=with_cp, |
|
init_cfg=init_cfg, attempt_use_lk_impl=attempt_use_lk_impl, use_sync_bn=True) |
|
|
|
|
|
if has_mmseg: |
|
@seg_BACKBONES.register_module() |
|
class UniRepLKNetBackbone(UniRepLKNet): |
|
def __init__(self, |
|
depths=(3, 3, 27, 3), |
|
dims=(96, 192, 384, 768), |
|
drop_path_rate=0., |
|
layer_scale_init_value=1e-6, |
|
kernel_sizes=None, |
|
deploy=False, |
|
with_cp=False, |
|
init_cfg=None, |
|
attempt_use_lk_impl=False): |
|
assert init_cfg is not None |
|
super().__init__(in_chans=3, num_classes=None, depths=depths, dims=dims, |
|
drop_path_rate=drop_path_rate, layer_scale_init_value=layer_scale_init_value, |
|
kernel_sizes=kernel_sizes, deploy=deploy, with_cp=with_cp, |
|
init_cfg=init_cfg, attempt_use_lk_impl=attempt_use_lk_impl, use_sync_bn=True) |
|
|
|
|
|
model_urls = { |
|
|
|
} |
|
|
|
huggingface_file_names = { |
|
"unireplknet_a_1k": "unireplknet_a_in1k_224_acc77.03.pth", |
|
"unireplknet_f_1k": "unireplknet_f_in1k_224_acc78.58.pth", |
|
"unireplknet_p_1k": "unireplknet_p_in1k_224_acc80.23.pth", |
|
"unireplknet_n_1k": "unireplknet_n_in1k_224_acc81.64.pth", |
|
"unireplknet_t_1k": "unireplknet_t_in1k_224_acc83.21.pth", |
|
"unireplknet_s_1k": "unireplknet_s_in1k_224_acc83.91.pth", |
|
"unireplknet_s_22k": "unireplknet_s_in22k_pretrain.pth", |
|
"unireplknet_s_22k_to_1k": "unireplknet_s_in22k_to_in1k_384_acc86.44.pth", |
|
"unireplknet_b_22k": "unireplknet_b_in22k_pretrain.pth", |
|
"unireplknet_b_22k_to_1k": "unireplknet_b_in22k_to_in1k_384_acc87.40.pth", |
|
"unireplknet_l_22k": "unireplknet_l_in22k_pretrain.pth", |
|
"unireplknet_l_22k_to_1k": "unireplknet_l_in22k_to_in1k_384_acc87.88.pth", |
|
"unireplknet_xl_22k": "unireplknet_xl_in22k_pretrain.pth", |
|
"unireplknet_xl_22k_to_1k": "unireplknet_xl_in22k_to_in1k_384_acc87.96.pth" |
|
} |
|
|
|
def load_with_key(model, key): |
|
|
|
if hf_hub_download is not None: |
|
repo_id = 'DingXiaoH/UniRepLKNet' |
|
cache_file = hf_hub_download(repo_id=repo_id, filename=huggingface_file_names[key]) |
|
checkpoint = torch.load(cache_file, map_location='cpu') |
|
else: |
|
checkpoint = torch.hub.load_state_dict_from_url(url=model_urls[key], map_location="cpu", check_hash=True) |
|
if 'model' in checkpoint: |
|
checkpoint = checkpoint['model'] |
|
model.load_state_dict(checkpoint) |
|
|
|
def initialize_with_pretrained(model, model_name, in_1k_pretrained, in_22k_pretrained, in_22k_to_1k): |
|
if in_1k_pretrained: |
|
key = model_name + '_1k' |
|
elif in_22k_pretrained: |
|
key = model_name + '_22k' |
|
elif in_22k_to_1k: |
|
key = model_name + '_22k_to_1k' |
|
else: |
|
key = None |
|
if key: |
|
load_with_key(model, key) |
|
|
|
@register_model |
|
def unireplknet_a(in_1k_pretrained=False, **kwargs): |
|
model = UniRepLKNet(depths=UniRepLKNet_A_F_P_depths, dims=(40, 80, 160, 320), **kwargs) |
|
initialize_with_pretrained(model, 'unireplknet_a', in_1k_pretrained, False, False) |
|
return model |
|
|
|
@register_model |
|
def unireplknet_f(in_1k_pretrained=False, **kwargs): |
|
model = UniRepLKNet(depths=UniRepLKNet_A_F_P_depths, dims=(48, 96, 192, 384), **kwargs) |
|
initialize_with_pretrained(model, 'unireplknet_f', in_1k_pretrained, False, False) |
|
return model |
|
|
|
@register_model |
|
def unireplknet_p(in_1k_pretrained=False, **kwargs): |
|
model = UniRepLKNet(depths=UniRepLKNet_A_F_P_depths, dims=(64, 128, 256, 512), **kwargs) |
|
initialize_with_pretrained(model, 'unireplknet_p', in_1k_pretrained, False, False) |
|
return model |
|
|
|
@register_model |
|
def unireplknet_n(in_1k_pretrained=False, **kwargs): |
|
model = UniRepLKNet(depths=UniRepLKNet_N_depths, dims=(80, 160, 320, 640), **kwargs) |
|
initialize_with_pretrained(model, 'unireplknet_n', in_1k_pretrained, False, False) |
|
return model |
|
|
|
@register_model |
|
def unireplknet_t(in_1k_pretrained=False, **kwargs): |
|
model = UniRepLKNet(depths=UniRepLKNet_T_depths, dims=(80, 160, 320, 640), **kwargs) |
|
initialize_with_pretrained(model, 'unireplknet_t', in_1k_pretrained, False, False) |
|
return model |
|
|
|
@register_model |
|
def unireplknet_s(in_1k_pretrained=False, in_22k_pretrained=False, in_22k_to_1k=False, **kwargs): |
|
model = UniRepLKNet(depths=UniRepLKNet_S_B_L_XL_depths, dims=(96, 192, 384, 768), **kwargs) |
|
initialize_with_pretrained(model, 'unireplknet_s', in_1k_pretrained, in_22k_pretrained, in_22k_to_1k) |
|
return model |
|
|
|
@register_model |
|
def unireplknet_b(in_22k_pretrained=False, in_22k_to_1k=False, **kwargs): |
|
model = UniRepLKNet(depths=UniRepLKNet_S_B_L_XL_depths, dims=(128, 256, 512, 1024), **kwargs) |
|
initialize_with_pretrained(model, 'unireplknet_b', False, in_22k_pretrained, in_22k_to_1k) |
|
return model |
|
|
|
@register_model |
|
def unireplknet_l(in_22k_pretrained=False, in_22k_to_1k=False, **kwargs): |
|
model = UniRepLKNet(depths=UniRepLKNet_S_B_L_XL_depths, dims=(192, 384, 768, 1536), **kwargs) |
|
initialize_with_pretrained(model, 'unireplknet_l', False, in_22k_pretrained, in_22k_to_1k) |
|
return model |
|
|
|
@register_model |
|
def unireplknet_xl(in_22k_pretrained=False, in_22k_to_1k=False, **kwargs): |
|
model = UniRepLKNet(depths=UniRepLKNet_S_B_L_XL_depths, dims=(256, 512, 1024, 2048), **kwargs) |
|
initialize_with_pretrained(model, 'unireplknet_xl', False, in_22k_pretrained, in_22k_to_1k) |
|
return model |
|
|
|
@register_model |
|
def unireplknet_h(in_22k_pretrained=False, in_22k_to_1k=False, **kwargs): |
|
model = UniRepLKNet(depths=UniRepLKNet_S_B_L_XL_depths, dims=(480, 960, 1920, 3840), **kwargs) |
|
initialize_with_pretrained(model, 'unireplknet_h', False, in_22k_pretrained, in_22k_to_1k) |
|
return model |
|
|
|
|
|
if __name__ == '__main__': |
|
model_large = unireplknet_l() |
|
print(model_large) |
|
ckpt = torch.load("UniRepLKNet-L-b75k_s10B_CLIP-in1k_75.72.pt") |
|
model_large.load_state_dict(ckpt,strict=False) |
|
print("Loaded CLIP Pretrained Models") |