init
Browse files
UniRepLKNet-L-b75k_s10B_CLIP-in1k_75.72.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:6f1b36fef7c9a0174b113db0093df8e827b523193e0ed10df788f8def3e64c84
|
3 |
+
size 875658963
|
modeling_UniRepLKNet.py
ADDED
@@ -0,0 +1,784 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# UniRepLKNet: A Universal Perception Large-Kernel ConvNet for Audio, Video, Point Cloud, Time-Series and Image Recognition
|
2 |
+
# Github source: https://github.com/AILab-CVC/UniRepLKNet
|
3 |
+
# Licensed under The Apache License 2.0 License [see LICENSE for details]
|
4 |
+
# Based on RepLKNet, ConvNeXt, timm, DINO and DeiT code bases
|
5 |
+
# https://github.com/DingXiaoH/RepLKNet-pytorch
|
6 |
+
# https://github.com/facebookresearch/ConvNeXt
|
7 |
+
# https://github.com/rwightman/pytorch-image-models/tree/master/timm
|
8 |
+
# https://github.com/facebookresearch/deit/
|
9 |
+
# https://github.com/facebookresearch/dino
|
10 |
+
# --------------------------------------------------------'
|
11 |
+
import torch
|
12 |
+
import torch.nn as nn
|
13 |
+
import torch.nn.functional as F
|
14 |
+
from timm.models.layers import trunc_normal_, DropPath, to_2tuple
|
15 |
+
from timm.models.registry import register_model
|
16 |
+
from timm.models.layers import trunc_normal_, SelectAdaptivePool2d, DropPath, ConvMlp, Mlp, LayerNorm2d, LayerNorm, \
|
17 |
+
create_conv2d, get_act_layer, make_divisible, to_ntuple
|
18 |
+
from functools import partial
|
19 |
+
import torch.utils.checkpoint as checkpoint
|
20 |
+
try:
|
21 |
+
from huggingface_hub import hf_hub_download
|
22 |
+
except:
|
23 |
+
hf_hub_download = None # install huggingface_hub if you would like to download models conveniently from huggingface
|
24 |
+
|
25 |
+
has_mmdet = False
|
26 |
+
has_mmseg = False
|
27 |
+
# =============== for the ease of directly using this file in MMSegmentation and MMDetection.
|
28 |
+
# =============== ignore the following two segments of code if you do not plan to do so
|
29 |
+
# =============== delete one of the following two segments if you get a confliction
|
30 |
+
try:
|
31 |
+
from mmseg.models.builder import BACKBONES as seg_BACKBONES
|
32 |
+
from mmseg.utils import get_root_logger
|
33 |
+
from mmcv.runner import _load_checkpoint
|
34 |
+
has_mmseg = True
|
35 |
+
except ImportError:
|
36 |
+
get_root_logger = None
|
37 |
+
_load_checkpoint = None
|
38 |
+
|
39 |
+
# try:
|
40 |
+
# from mmdet.models.builder import BACKBONES as det_BACKBONES
|
41 |
+
# from mmdet.utils import get_root_logger
|
42 |
+
# from mmcv.runner import _load_checkpoint
|
43 |
+
# has_mmdet = True
|
44 |
+
# except ImportError:
|
45 |
+
# get_root_logger = None
|
46 |
+
# _load_checkpoint = None
|
47 |
+
# ===========================================================================================
|
48 |
+
|
49 |
+
|
50 |
+
class GRNwithNHWC(nn.Module):
|
51 |
+
""" GRN (Global Response Normalization) layer
|
52 |
+
Originally proposed in ConvNeXt V2 (https://arxiv.org/abs/2301.00808)
|
53 |
+
This implementation is more efficient than the original (https://github.com/facebookresearch/ConvNeXt-V2)
|
54 |
+
We assume the inputs to this layer are (N, H, W, C)
|
55 |
+
"""
|
56 |
+
def __init__(self, dim, use_bias=True):
|
57 |
+
super().__init__()
|
58 |
+
self.use_bias = use_bias
|
59 |
+
self.gamma = nn.Parameter(torch.zeros(1, 1, 1, dim))
|
60 |
+
if self.use_bias:
|
61 |
+
self.beta = nn.Parameter(torch.zeros(1, 1, 1, dim))
|
62 |
+
|
63 |
+
def forward(self, x):
|
64 |
+
Gx = torch.norm(x, p=2, dim=(1, 2), keepdim=True)
|
65 |
+
Nx = Gx / (Gx.mean(dim=-1, keepdim=True) + 1e-6)
|
66 |
+
if self.use_bias:
|
67 |
+
return (self.gamma * Nx + 1) * x + self.beta
|
68 |
+
else:
|
69 |
+
return (self.gamma * Nx + 1) * x
|
70 |
+
|
71 |
+
|
72 |
+
class NCHWtoNHWC(nn.Module):
|
73 |
+
def __init__(self):
|
74 |
+
super().__init__()
|
75 |
+
|
76 |
+
def forward(self, x):
|
77 |
+
return x.permute(0, 2, 3, 1)
|
78 |
+
|
79 |
+
|
80 |
+
class NHWCtoNCHW(nn.Module):
|
81 |
+
def __init__(self):
|
82 |
+
super().__init__()
|
83 |
+
|
84 |
+
def forward(self, x):
|
85 |
+
return x.permute(0, 3, 1, 2)
|
86 |
+
|
87 |
+
#================== This function decides which conv implementation (the native or iGEMM) to use
|
88 |
+
# Note that iGEMM large-kernel conv impl will be used if
|
89 |
+
# - you attempt to do so (attempt_to_use_large_impl=True), and
|
90 |
+
# - it has been installed (follow https://github.com/AILab-CVC/UniRepLKNet), and
|
91 |
+
# - the conv layer is depth-wise, stride = 1, non-dilated, kernel_size > 5, and padding == kernel_size // 2
|
92 |
+
def get_conv2d(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias,
|
93 |
+
attempt_use_lk_impl=True):
|
94 |
+
kernel_size = to_2tuple(kernel_size)
|
95 |
+
if padding is None:
|
96 |
+
padding = (kernel_size[0] // 2, kernel_size[1] // 2)
|
97 |
+
else:
|
98 |
+
padding = to_2tuple(padding)
|
99 |
+
need_large_impl = kernel_size[0] == kernel_size[1] and kernel_size[0] > 5 and padding == (kernel_size[0] // 2, kernel_size[1] // 2)
|
100 |
+
|
101 |
+
if attempt_use_lk_impl and need_large_impl:
|
102 |
+
print('---------------- trying to import iGEMM implementation for large-kernel conv')
|
103 |
+
try:
|
104 |
+
from depthwise_conv2d_implicit_gemm import DepthWiseConv2dImplicitGEMM
|
105 |
+
print('---------------- found iGEMM implementation ')
|
106 |
+
except:
|
107 |
+
DepthWiseConv2dImplicitGEMM = None
|
108 |
+
print('---------------- found no iGEMM. use original conv. follow https://github.com/AILab-CVC/UniRepLKNet to install it.')
|
109 |
+
if DepthWiseConv2dImplicitGEMM is not None and need_large_impl and in_channels == out_channels \
|
110 |
+
and out_channels == groups and stride == 1 and dilation == 1:
|
111 |
+
print(f'===== iGEMM Efficient Conv Impl, channels {in_channels}, kernel size {kernel_size} =====')
|
112 |
+
return DepthWiseConv2dImplicitGEMM(in_channels, kernel_size, bias=bias)
|
113 |
+
return nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride,
|
114 |
+
padding=padding, dilation=dilation, groups=groups, bias=bias)
|
115 |
+
|
116 |
+
|
117 |
+
def get_bn(dim, use_sync_bn=False):
|
118 |
+
if use_sync_bn:
|
119 |
+
return nn.SyncBatchNorm(dim)
|
120 |
+
else:
|
121 |
+
return nn.BatchNorm2d(dim)
|
122 |
+
|
123 |
+
class SEBlock(nn.Module):
|
124 |
+
"""
|
125 |
+
Squeeze-and-Excitation Block proposed in SENet (https://arxiv.org/abs/1709.01507)
|
126 |
+
We assume the inputs to this layer are (N, C, H, W)
|
127 |
+
"""
|
128 |
+
def __init__(self, input_channels, internal_neurons):
|
129 |
+
super(SEBlock, self).__init__()
|
130 |
+
self.down = nn.Conv2d(in_channels=input_channels, out_channels=internal_neurons,
|
131 |
+
kernel_size=1, stride=1, bias=True)
|
132 |
+
self.up = nn.Conv2d(in_channels=internal_neurons, out_channels=input_channels,
|
133 |
+
kernel_size=1, stride=1, bias=True)
|
134 |
+
self.input_channels = input_channels
|
135 |
+
self.nonlinear = nn.ReLU(inplace=True)
|
136 |
+
|
137 |
+
def forward(self, inputs):
|
138 |
+
x = F.adaptive_avg_pool2d(inputs, output_size=(1, 1))
|
139 |
+
x = self.down(x)
|
140 |
+
x = self.nonlinear(x)
|
141 |
+
x = self.up(x)
|
142 |
+
x = F.sigmoid(x)
|
143 |
+
return inputs * x.view(-1, self.input_channels, 1, 1)
|
144 |
+
|
145 |
+
def fuse_bn(conv, bn):
|
146 |
+
conv_bias = 0 if conv.bias is None else conv.bias
|
147 |
+
std = (bn.running_var + bn.eps).sqrt()
|
148 |
+
return conv.weight * (bn.weight / std).reshape(-1, 1, 1, 1), bn.bias + (conv_bias - bn.running_mean) * bn.weight / std
|
149 |
+
|
150 |
+
def convert_dilated_to_nondilated(kernel, dilate_rate):
|
151 |
+
identity_kernel = torch.ones((1, 1, 1, 1)).to(kernel.device)
|
152 |
+
if kernel.size(1) == 1:
|
153 |
+
# This is a DW kernel
|
154 |
+
dilated = F.conv_transpose2d(kernel, identity_kernel, stride=dilate_rate)
|
155 |
+
return dilated
|
156 |
+
else:
|
157 |
+
# This is a dense or group-wise (but not DW) kernel
|
158 |
+
slices = []
|
159 |
+
for i in range(kernel.size(1)):
|
160 |
+
dilated = F.conv_transpose2d(kernel[:,i:i+1,:,:], identity_kernel, stride=dilate_rate)
|
161 |
+
slices.append(dilated)
|
162 |
+
return torch.cat(slices, dim=1)
|
163 |
+
|
164 |
+
def merge_dilated_into_large_kernel(large_kernel, dilated_kernel, dilated_r):
|
165 |
+
large_k = large_kernel.size(2)
|
166 |
+
dilated_k = dilated_kernel.size(2)
|
167 |
+
equivalent_kernel_size = dilated_r * (dilated_k - 1) + 1
|
168 |
+
equivalent_kernel = convert_dilated_to_nondilated(dilated_kernel, dilated_r)
|
169 |
+
rows_to_pad = large_k // 2 - equivalent_kernel_size // 2
|
170 |
+
merged_kernel = large_kernel + F.pad(equivalent_kernel, [rows_to_pad] * 4)
|
171 |
+
return merged_kernel
|
172 |
+
|
173 |
+
|
174 |
+
class DilatedReparamBlock(nn.Module):
|
175 |
+
"""
|
176 |
+
Dilated Reparam Block proposed in UniRepLKNet (https://github.com/AILab-CVC/UniRepLKNet)
|
177 |
+
We assume the inputs to this block are (N, C, H, W)
|
178 |
+
"""
|
179 |
+
def __init__(self, channels, kernel_size, deploy, use_sync_bn=False, attempt_use_lk_impl=True):
|
180 |
+
super().__init__()
|
181 |
+
self.lk_origin = get_conv2d(channels, channels, kernel_size, stride=1,
|
182 |
+
padding=kernel_size//2, dilation=1, groups=channels, bias=deploy,
|
183 |
+
attempt_use_lk_impl=attempt_use_lk_impl)
|
184 |
+
self.attempt_use_lk_impl = attempt_use_lk_impl
|
185 |
+
|
186 |
+
# Default settings. We did not tune them carefully. Different settings may work better.
|
187 |
+
if kernel_size == 17:
|
188 |
+
self.kernel_sizes = [5, 9, 3, 3, 3]
|
189 |
+
self.dilates = [1, 2, 4, 5, 7]
|
190 |
+
elif kernel_size == 15:
|
191 |
+
self.kernel_sizes = [5, 7, 3, 3, 3]
|
192 |
+
self.dilates = [1, 2, 3, 5, 7]
|
193 |
+
elif kernel_size == 13:
|
194 |
+
self.kernel_sizes = [5, 7, 3, 3, 3]
|
195 |
+
self.dilates = [1, 2, 3, 4, 5]
|
196 |
+
elif kernel_size == 11:
|
197 |
+
self.kernel_sizes = [5, 5, 3, 3, 3]
|
198 |
+
self.dilates = [1, 2, 3, 4, 5]
|
199 |
+
elif kernel_size == 9:
|
200 |
+
self.kernel_sizes = [5, 5, 3, 3]
|
201 |
+
self.dilates = [1, 2, 3, 4]
|
202 |
+
elif kernel_size == 7:
|
203 |
+
self.kernel_sizes = [5, 3, 3]
|
204 |
+
self.dilates = [1, 2, 3]
|
205 |
+
elif kernel_size == 5:
|
206 |
+
self.kernel_sizes = [3, 3]
|
207 |
+
self.dilates = [1, 2]
|
208 |
+
else:
|
209 |
+
raise ValueError('Dilated Reparam Block requires kernel_size >= 5')
|
210 |
+
|
211 |
+
if not deploy:
|
212 |
+
self.origin_bn = get_bn(channels, use_sync_bn)
|
213 |
+
for k, r in zip(self.kernel_sizes, self.dilates):
|
214 |
+
self.__setattr__('dil_conv_k{}_{}'.format(k, r),
|
215 |
+
nn.Conv2d(in_channels=channels, out_channels=channels, kernel_size=k, stride=1,
|
216 |
+
padding=(r * (k - 1) + 1) // 2, dilation=r, groups=channels,
|
217 |
+
bias=False))
|
218 |
+
self.__setattr__('dil_bn_k{}_{}'.format(k, r), get_bn(channels, use_sync_bn=use_sync_bn))
|
219 |
+
|
220 |
+
def forward(self, x):
|
221 |
+
if not hasattr(self, 'origin_bn'): # deploy mode
|
222 |
+
return self.lk_origin(x)
|
223 |
+
out = self.origin_bn(self.lk_origin(x))
|
224 |
+
for k, r in zip(self.kernel_sizes, self.dilates):
|
225 |
+
conv = self.__getattr__('dil_conv_k{}_{}'.format(k, r))
|
226 |
+
bn = self.__getattr__('dil_bn_k{}_{}'.format(k, r))
|
227 |
+
out = out + bn(conv(x))
|
228 |
+
return out
|
229 |
+
|
230 |
+
def merge_dilated_branches(self):
|
231 |
+
if hasattr(self, 'origin_bn'):
|
232 |
+
origin_k, origin_b = fuse_bn(self.lk_origin, self.origin_bn)
|
233 |
+
for k, r in zip(self.kernel_sizes, self.dilates):
|
234 |
+
conv = self.__getattr__('dil_conv_k{}_{}'.format(k, r))
|
235 |
+
bn = self.__getattr__('dil_bn_k{}_{}'.format(k, r))
|
236 |
+
branch_k, branch_b = fuse_bn(conv, bn)
|
237 |
+
origin_k = merge_dilated_into_large_kernel(origin_k, branch_k, r)
|
238 |
+
origin_b += branch_b
|
239 |
+
merged_conv = get_conv2d(origin_k.size(0), origin_k.size(0), origin_k.size(2), stride=1,
|
240 |
+
padding=origin_k.size(2)//2, dilation=1, groups=origin_k.size(0), bias=True,
|
241 |
+
attempt_use_lk_impl=self.attempt_use_lk_impl)
|
242 |
+
merged_conv.weight.data = origin_k
|
243 |
+
merged_conv.bias.data = origin_b
|
244 |
+
self.lk_origin = merged_conv
|
245 |
+
self.__delattr__('origin_bn')
|
246 |
+
for k, r in zip(self.kernel_sizes, self.dilates):
|
247 |
+
self.__delattr__('dil_conv_k{}_{}'.format(k, r))
|
248 |
+
self.__delattr__('dil_bn_k{}_{}'.format(k, r))
|
249 |
+
|
250 |
+
|
251 |
+
class UniRepLKNetBlock(nn.Module):
|
252 |
+
|
253 |
+
def __init__(self,
|
254 |
+
dim,
|
255 |
+
kernel_size,
|
256 |
+
drop_path=0.,
|
257 |
+
layer_scale_init_value=1e-6,
|
258 |
+
deploy=False,
|
259 |
+
attempt_use_lk_impl=True,
|
260 |
+
with_cp=False,
|
261 |
+
use_sync_bn=False,
|
262 |
+
ffn_factor=4):
|
263 |
+
super().__init__()
|
264 |
+
self.with_cp = with_cp
|
265 |
+
if deploy:
|
266 |
+
print('------------------------------- Note: deploy mode')
|
267 |
+
if self.with_cp:
|
268 |
+
print('****** note with_cp = True, reduce memory consumption but may slow down training ******')
|
269 |
+
|
270 |
+
if kernel_size == 0:
|
271 |
+
self.dwconv = nn.Identity()
|
272 |
+
elif kernel_size >= 7:
|
273 |
+
self.dwconv = DilatedReparamBlock(dim, kernel_size, deploy=deploy,
|
274 |
+
use_sync_bn=use_sync_bn,
|
275 |
+
attempt_use_lk_impl=attempt_use_lk_impl)
|
276 |
+
|
277 |
+
else:
|
278 |
+
assert kernel_size in [3, 5]
|
279 |
+
self.dwconv = get_conv2d(dim, dim, kernel_size=kernel_size, stride=1, padding=kernel_size // 2,
|
280 |
+
dilation=1, groups=dim, bias=deploy,
|
281 |
+
attempt_use_lk_impl=attempt_use_lk_impl)
|
282 |
+
|
283 |
+
if deploy or kernel_size == 0:
|
284 |
+
self.norm = nn.Identity()
|
285 |
+
else:
|
286 |
+
self.norm = get_bn(dim, use_sync_bn=use_sync_bn)
|
287 |
+
|
288 |
+
self.se = SEBlock(dim, dim // 4)
|
289 |
+
|
290 |
+
ffn_dim = int(ffn_factor * dim)
|
291 |
+
self.pwconv1 = nn.Sequential(
|
292 |
+
NCHWtoNHWC(),
|
293 |
+
nn.Linear(dim, ffn_dim))
|
294 |
+
self.act = nn.Sequential(
|
295 |
+
nn.GELU(),
|
296 |
+
GRNwithNHWC(ffn_dim, use_bias=not deploy))
|
297 |
+
if deploy:
|
298 |
+
self.pwconv2 = nn.Sequential(
|
299 |
+
nn.Linear(ffn_dim, dim),
|
300 |
+
NHWCtoNCHW())
|
301 |
+
else:
|
302 |
+
self.pwconv2 = nn.Sequential(
|
303 |
+
nn.Linear(ffn_dim, dim, bias=False),
|
304 |
+
NHWCtoNCHW(),
|
305 |
+
get_bn(dim, use_sync_bn=use_sync_bn))
|
306 |
+
|
307 |
+
self.gamma = nn.Parameter(layer_scale_init_value * torch.ones(dim),
|
308 |
+
requires_grad=True) if (not deploy) and layer_scale_init_value is not None \
|
309 |
+
and layer_scale_init_value > 0 else None
|
310 |
+
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
311 |
+
|
312 |
+
def compute_residual(self, x):
|
313 |
+
y = self.se(self.norm(self.dwconv(x)))
|
314 |
+
y = self.pwconv2(self.act(self.pwconv1(y)))
|
315 |
+
if self.gamma is not None:
|
316 |
+
y = self.gamma.view(1, -1, 1, 1) * y
|
317 |
+
return self.drop_path(y)
|
318 |
+
|
319 |
+
def forward(self, inputs):
|
320 |
+
|
321 |
+
def _f(x):
|
322 |
+
return x + self.compute_residual(x)
|
323 |
+
|
324 |
+
if self.with_cp and inputs.requires_grad:
|
325 |
+
out = checkpoint.checkpoint(_f, inputs)
|
326 |
+
else:
|
327 |
+
out = _f(inputs)
|
328 |
+
return out
|
329 |
+
|
330 |
+
def reparameterize(self):
|
331 |
+
if hasattr(self.dwconv, 'merge_dilated_branches'):
|
332 |
+
self.dwconv.merge_dilated_branches()
|
333 |
+
if hasattr(self.norm, 'running_var'):
|
334 |
+
std = (self.norm.running_var + self.norm.eps).sqrt()
|
335 |
+
if hasattr(self.dwconv, 'lk_origin'):
|
336 |
+
self.dwconv.lk_origin.weight.data *= (self.norm.weight / std).view(-1, 1, 1, 1)
|
337 |
+
self.dwconv.lk_origin.bias.data = self.norm.bias + (
|
338 |
+
self.dwconv.lk_origin.bias - self.norm.running_mean) * self.norm.weight / std
|
339 |
+
else:
|
340 |
+
conv = nn.Conv2d(self.dwconv.in_channels, self.dwconv.out_channels, self.dwconv.kernel_size,
|
341 |
+
padding=self.dwconv.padding, groups=self.dwconv.groups, bias=True)
|
342 |
+
conv.weight.data = self.dwconv.weight * (self.norm.weight / std).view(-1, 1, 1, 1)
|
343 |
+
conv.bias.data = self.norm.bias - self.norm.running_mean * self.norm.weight / std
|
344 |
+
self.dwconv = conv
|
345 |
+
self.norm = nn.Identity()
|
346 |
+
if self.gamma is not None:
|
347 |
+
final_scale = self.gamma.data
|
348 |
+
self.gamma = None
|
349 |
+
else:
|
350 |
+
final_scale = 1
|
351 |
+
if self.act[1].use_bias and len(self.pwconv2) == 3:
|
352 |
+
grn_bias = self.act[1].beta.data
|
353 |
+
self.act[1].__delattr__('beta')
|
354 |
+
self.act[1].use_bias = False
|
355 |
+
linear = self.pwconv2[0]
|
356 |
+
grn_bias_projected_bias = (linear.weight.data @ grn_bias.view(-1, 1)).squeeze()
|
357 |
+
bn = self.pwconv2[2]
|
358 |
+
std = (bn.running_var + bn.eps).sqrt()
|
359 |
+
new_linear = nn.Linear(linear.in_features, linear.out_features, bias=True)
|
360 |
+
new_linear.weight.data = linear.weight * (bn.weight / std * final_scale).view(-1, 1)
|
361 |
+
linear_bias = 0 if linear.bias is None else linear.bias.data
|
362 |
+
linear_bias += grn_bias_projected_bias
|
363 |
+
new_linear.bias.data = (bn.bias + (linear_bias - bn.running_mean) * bn.weight / std) * final_scale
|
364 |
+
self.pwconv2 = nn.Sequential(new_linear, self.pwconv2[1])
|
365 |
+
|
366 |
+
|
367 |
+
|
368 |
+
default_UniRepLKNet_A_F_P_kernel_sizes = ((3, 3),
|
369 |
+
(13, 13),
|
370 |
+
(13, 13, 13, 13, 13, 13),
|
371 |
+
(13, 13))
|
372 |
+
default_UniRepLKNet_N_kernel_sizes = ((3, 3),
|
373 |
+
(13, 13),
|
374 |
+
(13, 13, 13, 13, 13, 13, 13, 13),
|
375 |
+
(13, 13))
|
376 |
+
default_UniRepLKNet_T_kernel_sizes = ((3, 3, 3),
|
377 |
+
(13, 13, 13),
|
378 |
+
(13, 3, 13, 3, 13, 3, 13, 3, 13, 3, 13, 3, 13, 3, 13, 3, 13, 3),
|
379 |
+
(13, 13, 13))
|
380 |
+
default_UniRepLKNet_S_B_L_XL_kernel_sizes = ((3, 3, 3),
|
381 |
+
(13, 13, 13),
|
382 |
+
(13, 3, 3, 13, 3, 3, 13, 3, 3, 13, 3, 3, 13, 3, 3, 13, 3, 3, 13, 3, 3, 13, 3, 3, 13, 3, 3),
|
383 |
+
(13, 13, 13))
|
384 |
+
UniRepLKNet_A_F_P_depths = (2, 2, 6, 2)
|
385 |
+
UniRepLKNet_N_depths = (2, 2, 8, 2)
|
386 |
+
UniRepLKNet_T_depths = (3, 3, 18, 3)
|
387 |
+
UniRepLKNet_S_B_L_XL_depths = (3, 3, 27, 3)
|
388 |
+
|
389 |
+
default_depths_to_kernel_sizes = {
|
390 |
+
UniRepLKNet_A_F_P_depths: default_UniRepLKNet_A_F_P_kernel_sizes,
|
391 |
+
UniRepLKNet_N_depths: default_UniRepLKNet_N_kernel_sizes,
|
392 |
+
UniRepLKNet_T_depths: default_UniRepLKNet_T_kernel_sizes,
|
393 |
+
UniRepLKNet_S_B_L_XL_depths: default_UniRepLKNet_S_B_L_XL_kernel_sizes
|
394 |
+
}
|
395 |
+
|
396 |
+
class UniRepLKNet(nn.Module):
|
397 |
+
r""" UniRepLKNet
|
398 |
+
A PyTorch impl of UniRepLKNet
|
399 |
+
|
400 |
+
Args:
|
401 |
+
in_chans (int): Number of input image channels. Default: 3
|
402 |
+
num_classes (int): Number of classes for classification head. Default: 1000
|
403 |
+
depths (tuple(int)): Number of blocks at each stage. Default: (3, 3, 27, 3)
|
404 |
+
dims (int): Feature dimension at each stage. Default: (96, 192, 384, 768)
|
405 |
+
drop_path_rate (float): Stochastic depth rate. Default: 0.
|
406 |
+
layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6.
|
407 |
+
head_init_scale (float): Init scaling value for classifier weights and biases. Default: 1.
|
408 |
+
kernel_sizes (tuple(tuple(int))): Kernel size for each block. None means using the default settings. Default: None.
|
409 |
+
deploy (bool): deploy = True means using the inference structure. Default: False
|
410 |
+
with_cp (bool): with_cp = True means using torch.utils.checkpoint to save GPU memory. Default: False
|
411 |
+
init_cfg (dict): weights to load. The easiest way to use UniRepLKNet with for OpenMMLab family. Default: None
|
412 |
+
attempt_use_lk_impl (bool): try to load the efficient iGEMM large-kernel impl. Setting it to False disabling the iGEMM impl. Default: True
|
413 |
+
use_sync_bn (bool): use_sync_bn = True means using sync BN. Use it if your batch size is small. Default: False
|
414 |
+
"""
|
415 |
+
def __init__(self,
|
416 |
+
in_chans=3,
|
417 |
+
num_classes=1000,
|
418 |
+
depths=(3, 3, 27, 3),
|
419 |
+
dims=(96, 192, 384, 768),
|
420 |
+
drop_path_rate=0.,
|
421 |
+
layer_scale_init_value=1e-6,
|
422 |
+
head_init_scale=1.,
|
423 |
+
kernel_sizes=None,
|
424 |
+
deploy=False,
|
425 |
+
with_cp=True,
|
426 |
+
init_cfg=None,
|
427 |
+
attempt_use_lk_impl=True,
|
428 |
+
use_sync_bn=False,
|
429 |
+
**kwargs
|
430 |
+
):
|
431 |
+
super().__init__()
|
432 |
+
|
433 |
+
depths = tuple(depths)
|
434 |
+
if kernel_sizes is None:
|
435 |
+
if depths in default_depths_to_kernel_sizes:
|
436 |
+
print('=========== use default kernel size ')
|
437 |
+
kernel_sizes = default_depths_to_kernel_sizes[depths]
|
438 |
+
else:
|
439 |
+
raise ValueError('no default kernel size settings for the given depths, '
|
440 |
+
'please specify kernel sizes for each block, e.g., '
|
441 |
+
'((3, 3), (13, 13), (13, 13, 13, 13, 13, 13), (13, 13))')
|
442 |
+
print(kernel_sizes)
|
443 |
+
for i in range(4):
|
444 |
+
assert len(kernel_sizes[i]) == depths[i], 'kernel sizes do not match the depths'
|
445 |
+
|
446 |
+
self.with_cp = with_cp
|
447 |
+
|
448 |
+
dp_rates = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]
|
449 |
+
print('=========== drop path rates: ', dp_rates)
|
450 |
+
|
451 |
+
self.downsample_layers = nn.ModuleList()
|
452 |
+
self.downsample_layers.append(nn.Sequential(
|
453 |
+
nn.Conv2d(in_chans, dims[0] // 2, kernel_size=3, stride=2, padding=1),
|
454 |
+
LayerNorm(dims[0] // 2, eps=1e-6, data_format="channels_first"),
|
455 |
+
nn.GELU(),
|
456 |
+
nn.Conv2d(dims[0] // 2, dims[0], kernel_size=3, stride=2, padding=1),
|
457 |
+
LayerNorm(dims[0], eps=1e-6, data_format="channels_first")))
|
458 |
+
|
459 |
+
for i in range(3):
|
460 |
+
self.downsample_layers.append(nn.Sequential(
|
461 |
+
nn.Conv2d(dims[i], dims[i + 1], kernel_size=3, stride=2, padding=1),
|
462 |
+
LayerNorm(dims[i + 1], eps=1e-6, data_format="channels_first")))
|
463 |
+
|
464 |
+
self.stages = nn.ModuleList()
|
465 |
+
|
466 |
+
cur = 0
|
467 |
+
for i in range(4):
|
468 |
+
main_stage = nn.Sequential(
|
469 |
+
*[UniRepLKNetBlock(dim=dims[i], kernel_size=kernel_sizes[i][j], drop_path=dp_rates[cur + j],
|
470 |
+
layer_scale_init_value=layer_scale_init_value, deploy=deploy,
|
471 |
+
attempt_use_lk_impl=attempt_use_lk_impl,
|
472 |
+
with_cp=with_cp, use_sync_bn=use_sync_bn) for j in
|
473 |
+
range(depths[i])])
|
474 |
+
self.stages.append(main_stage)
|
475 |
+
cur += depths[i]
|
476 |
+
|
477 |
+
self.last_channels = dims[-1]
|
478 |
+
|
479 |
+
self.for_pretrain = init_cfg is None
|
480 |
+
self.for_downstream = not self.for_pretrain # there may be some other scenarios
|
481 |
+
if self.for_downstream:
|
482 |
+
assert num_classes is None
|
483 |
+
|
484 |
+
if self.for_pretrain:
|
485 |
+
self.init_cfg = None
|
486 |
+
self.norm = nn.LayerNorm(self.last_channels, eps=1e-6) # final norm layer
|
487 |
+
# self.head = nn.Linear(self.last_channels, num_classes)
|
488 |
+
self.head = nn.Linear(self.last_channels, self.last_channels)
|
489 |
+
self.apply(self._init_weights)
|
490 |
+
self.head.weight.data.mul_(head_init_scale)
|
491 |
+
self.head.bias.data.mul_(head_init_scale)
|
492 |
+
self.output_mode = 'logits'
|
493 |
+
else:
|
494 |
+
self.init_cfg = init_cfg # OpenMMLab style init
|
495 |
+
self.init_weights()
|
496 |
+
self.output_mode = 'features'
|
497 |
+
norm_layer = partial(LayerNorm, eps=1e-6, data_format="channels_first")
|
498 |
+
for i_layer in range(4):
|
499 |
+
layer = norm_layer(dims[i_layer])
|
500 |
+
layer_name = f'norm{i_layer}'
|
501 |
+
self.add_module(layer_name, layer)
|
502 |
+
|
503 |
+
|
504 |
+
# load pretrained backbone weights in the OpenMMLab style
|
505 |
+
def init_weights(self):
|
506 |
+
|
507 |
+
def load_state_dict(module, state_dict, strict=False, logger=None):
|
508 |
+
unexpected_keys = []
|
509 |
+
own_state = module.state_dict()
|
510 |
+
for name, param in state_dict.items():
|
511 |
+
if name not in own_state:
|
512 |
+
unexpected_keys.append(name)
|
513 |
+
continue
|
514 |
+
if isinstance(param, torch.nn.Parameter):
|
515 |
+
# backwards compatibility for serialized parameters
|
516 |
+
param = param.data
|
517 |
+
try:
|
518 |
+
own_state[name].copy_(param)
|
519 |
+
except Exception:
|
520 |
+
raise RuntimeError(
|
521 |
+
'While copying the parameter named {}, '
|
522 |
+
'whose dimensions in the model are {} and '
|
523 |
+
'whose dimensions in the checkpoint are {}.'.format(
|
524 |
+
name, own_state[name].size(), param.size()))
|
525 |
+
missing_keys = set(own_state.keys()) - set(state_dict.keys())
|
526 |
+
|
527 |
+
err_msg = []
|
528 |
+
if unexpected_keys:
|
529 |
+
err_msg.append('unexpected key in source state_dict: {}\n'.format(', '.join(unexpected_keys)))
|
530 |
+
if missing_keys:
|
531 |
+
err_msg.append('missing keys in source state_dict: {}\n'.format(', '.join(missing_keys)))
|
532 |
+
err_msg = '\n'.join(err_msg)
|
533 |
+
if err_msg:
|
534 |
+
if strict:
|
535 |
+
raise RuntimeError(err_msg)
|
536 |
+
elif logger is not None:
|
537 |
+
logger.warn(err_msg)
|
538 |
+
else:
|
539 |
+
print(err_msg)
|
540 |
+
|
541 |
+
logger = get_root_logger()
|
542 |
+
assert self.init_cfg is not None
|
543 |
+
ckpt_path = self.init_cfg['checkpoint']
|
544 |
+
if ckpt_path is None:
|
545 |
+
print('================ Note: init_cfg is provided but I got no init ckpt path, so skip initialization')
|
546 |
+
else:
|
547 |
+
ckpt = _load_checkpoint(ckpt_path, logger=logger, map_location='cpu')
|
548 |
+
if 'state_dict' in ckpt:
|
549 |
+
_state_dict = ckpt['state_dict']
|
550 |
+
elif 'model' in ckpt:
|
551 |
+
_state_dict = ckpt['model']
|
552 |
+
else:
|
553 |
+
_state_dict = ckpt
|
554 |
+
|
555 |
+
load_state_dict(self, _state_dict, strict=False, logger=logger)
|
556 |
+
|
557 |
+
|
558 |
+
def _init_weights(self, m):
|
559 |
+
if isinstance(m, (nn.Conv2d, nn.Linear)):
|
560 |
+
trunc_normal_(m.weight, std=.02)
|
561 |
+
if hasattr(m, 'bias') and m.bias is not None:
|
562 |
+
nn.init.constant_(m.bias, 0)
|
563 |
+
|
564 |
+
def forward(self, x):
|
565 |
+
if self.output_mode == 'logits':
|
566 |
+
for stage_idx in range(4):
|
567 |
+
x = self.downsample_layers[stage_idx](x)
|
568 |
+
x = self.stages[stage_idx](x)
|
569 |
+
x = self.norm(x.mean([-2, -1]))
|
570 |
+
x = self.head(x)
|
571 |
+
return x
|
572 |
+
elif self.output_mode == 'features':
|
573 |
+
outs = []
|
574 |
+
for stage_idx in range(4):
|
575 |
+
x = self.downsample_layers[stage_idx](x)
|
576 |
+
x = self.stages[stage_idx](x)
|
577 |
+
outs.append(self.__getattr__(f'norm{stage_idx}')(x))
|
578 |
+
return outs
|
579 |
+
else:
|
580 |
+
raise ValueError('Defined new output mode?')
|
581 |
+
|
582 |
+
def reparameterize_unireplknet(self):
|
583 |
+
for m in self.modules():
|
584 |
+
if hasattr(m, 'reparameterize'):
|
585 |
+
m.reparameterize()
|
586 |
+
|
587 |
+
@torch.jit.ignore
|
588 |
+
def get_classifier(self):
|
589 |
+
return self.head.fc
|
590 |
+
|
591 |
+
def reset_classifier(self, num_classes=0, global_pool=None):
|
592 |
+
if global_pool is not None:
|
593 |
+
self.head.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
|
594 |
+
self.head.flatten = nn.Flatten(1) if global_pool else nn.Identity()
|
595 |
+
# self.head.fc = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
|
596 |
+
self.head.fc = nn.Linear(self.num_features, self.num_features) if num_classes > 0 else nn.Identity()
|
597 |
+
|
598 |
+
|
599 |
+
|
600 |
+
|
601 |
+
|
602 |
+
class LayerNorm(nn.Module):
|
603 |
+
r""" LayerNorm implementation used in ConvNeXt
|
604 |
+
LayerNorm that supports two data formats: channels_last (default) or channels_first.
|
605 |
+
The ordering of the dimensions in the inputs. channels_last corresponds to inputs with
|
606 |
+
shape (batch_size, height, width, channels) while channels_first corresponds to inputs
|
607 |
+
with shape (batch_size, channels, height, width).
|
608 |
+
"""
|
609 |
+
|
610 |
+
def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last", reshape_last_to_first=False):
|
611 |
+
super().__init__()
|
612 |
+
self.weight = nn.Parameter(torch.ones(normalized_shape))
|
613 |
+
self.bias = nn.Parameter(torch.zeros(normalized_shape))
|
614 |
+
self.eps = eps
|
615 |
+
self.data_format = data_format
|
616 |
+
if self.data_format not in ["channels_last", "channels_first"]:
|
617 |
+
raise NotImplementedError
|
618 |
+
self.normalized_shape = (normalized_shape,)
|
619 |
+
self.reshape_last_to_first = reshape_last_to_first
|
620 |
+
|
621 |
+
def forward(self, x):
|
622 |
+
if self.data_format == "channels_last":
|
623 |
+
return F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
|
624 |
+
elif self.data_format == "channels_first":
|
625 |
+
u = x.mean(1, keepdim=True)
|
626 |
+
s = (x - u).pow(2).mean(1, keepdim=True)
|
627 |
+
x = (x - u) / torch.sqrt(s + self.eps)
|
628 |
+
x = self.weight[:, None, None] * x + self.bias[:, None, None]
|
629 |
+
return x
|
630 |
+
|
631 |
+
|
632 |
+
# For easy use as backbone in MMDetection framework. Ignore these lines if you do not use MMDetection
|
633 |
+
if has_mmdet:
|
634 |
+
@det_BACKBONES.register_module()
|
635 |
+
class UniRepLKNetBackbone(UniRepLKNet):
|
636 |
+
def __init__(self,
|
637 |
+
depths=(3, 3, 27, 3),
|
638 |
+
dims=(96, 192, 384, 768),
|
639 |
+
drop_path_rate=0.,
|
640 |
+
layer_scale_init_value=1e-6,
|
641 |
+
kernel_sizes=None,
|
642 |
+
deploy=False,
|
643 |
+
with_cp=False,
|
644 |
+
init_cfg=None,
|
645 |
+
attempt_use_lk_impl=False):
|
646 |
+
assert init_cfg is not None
|
647 |
+
super().__init__(in_chans=3, num_classes=None, depths=depths, dims=dims,
|
648 |
+
drop_path_rate=drop_path_rate, layer_scale_init_value=layer_scale_init_value,
|
649 |
+
kernel_sizes=kernel_sizes, deploy=deploy, with_cp=with_cp,
|
650 |
+
init_cfg=init_cfg, attempt_use_lk_impl=attempt_use_lk_impl, use_sync_bn=True)
|
651 |
+
|
652 |
+
# For easy use as backbone in MMSegmentation framework. Ignore these lines if you do not use MMSegmentation
|
653 |
+
if has_mmseg:
|
654 |
+
@seg_BACKBONES.register_module()
|
655 |
+
class UniRepLKNetBackbone(UniRepLKNet):
|
656 |
+
def __init__(self,
|
657 |
+
depths=(3, 3, 27, 3),
|
658 |
+
dims=(96, 192, 384, 768),
|
659 |
+
drop_path_rate=0.,
|
660 |
+
layer_scale_init_value=1e-6,
|
661 |
+
kernel_sizes=None,
|
662 |
+
deploy=False,
|
663 |
+
with_cp=False,
|
664 |
+
init_cfg=None,
|
665 |
+
attempt_use_lk_impl=False):
|
666 |
+
assert init_cfg is not None
|
667 |
+
super().__init__(in_chans=3, num_classes=None, depths=depths, dims=dims,
|
668 |
+
drop_path_rate=drop_path_rate, layer_scale_init_value=layer_scale_init_value,
|
669 |
+
kernel_sizes=kernel_sizes, deploy=deploy, with_cp=with_cp,
|
670 |
+
init_cfg=init_cfg, attempt_use_lk_impl=attempt_use_lk_impl, use_sync_bn=True)
|
671 |
+
|
672 |
+
|
673 |
+
model_urls = {
|
674 |
+
#TODO: it seems that google drive does not support direct downloading with url? so where to upload the checkpoints other than huggingface? any suggestions?
|
675 |
+
}
|
676 |
+
|
677 |
+
huggingface_file_names = {
|
678 |
+
"unireplknet_a_1k": "unireplknet_a_in1k_224_acc77.03.pth",
|
679 |
+
"unireplknet_f_1k": "unireplknet_f_in1k_224_acc78.58.pth",
|
680 |
+
"unireplknet_p_1k": "unireplknet_p_in1k_224_acc80.23.pth",
|
681 |
+
"unireplknet_n_1k": "unireplknet_n_in1k_224_acc81.64.pth",
|
682 |
+
"unireplknet_t_1k": "unireplknet_t_in1k_224_acc83.21.pth",
|
683 |
+
"unireplknet_s_1k": "unireplknet_s_in1k_224_acc83.91.pth",
|
684 |
+
"unireplknet_s_22k": "unireplknet_s_in22k_pretrain.pth",
|
685 |
+
"unireplknet_s_22k_to_1k": "unireplknet_s_in22k_to_in1k_384_acc86.44.pth",
|
686 |
+
"unireplknet_b_22k": "unireplknet_b_in22k_pretrain.pth",
|
687 |
+
"unireplknet_b_22k_to_1k": "unireplknet_b_in22k_to_in1k_384_acc87.40.pth",
|
688 |
+
"unireplknet_l_22k": "unireplknet_l_in22k_pretrain.pth",
|
689 |
+
"unireplknet_l_22k_to_1k": "unireplknet_l_in22k_to_in1k_384_acc87.88.pth",
|
690 |
+
"unireplknet_xl_22k": "unireplknet_xl_in22k_pretrain.pth",
|
691 |
+
"unireplknet_xl_22k_to_1k": "unireplknet_xl_in22k_to_in1k_384_acc87.96.pth"
|
692 |
+
}
|
693 |
+
|
694 |
+
def load_with_key(model, key):
|
695 |
+
# if huggingface hub is found, download from our huggingface repo
|
696 |
+
if hf_hub_download is not None:
|
697 |
+
repo_id = 'DingXiaoH/UniRepLKNet'
|
698 |
+
cache_file = hf_hub_download(repo_id=repo_id, filename=huggingface_file_names[key])
|
699 |
+
checkpoint = torch.load(cache_file, map_location='cpu')
|
700 |
+
else:
|
701 |
+
checkpoint = torch.hub.load_state_dict_from_url(url=model_urls[key], map_location="cpu", check_hash=True)
|
702 |
+
if 'model' in checkpoint:
|
703 |
+
checkpoint = checkpoint['model']
|
704 |
+
model.load_state_dict(checkpoint)
|
705 |
+
|
706 |
+
def initialize_with_pretrained(model, model_name, in_1k_pretrained, in_22k_pretrained, in_22k_to_1k):
|
707 |
+
if in_1k_pretrained:
|
708 |
+
key = model_name + '_1k'
|
709 |
+
elif in_22k_pretrained:
|
710 |
+
key = model_name + '_22k'
|
711 |
+
elif in_22k_to_1k:
|
712 |
+
key = model_name + '_22k_to_1k'
|
713 |
+
else:
|
714 |
+
key = None
|
715 |
+
if key:
|
716 |
+
load_with_key(model, key)
|
717 |
+
|
718 |
+
@register_model
|
719 |
+
def unireplknet_a(in_1k_pretrained=False, **kwargs):
|
720 |
+
model = UniRepLKNet(depths=UniRepLKNet_A_F_P_depths, dims=(40, 80, 160, 320), **kwargs)
|
721 |
+
initialize_with_pretrained(model, 'unireplknet_a', in_1k_pretrained, False, False)
|
722 |
+
return model
|
723 |
+
|
724 |
+
@register_model
|
725 |
+
def unireplknet_f(in_1k_pretrained=False, **kwargs):
|
726 |
+
model = UniRepLKNet(depths=UniRepLKNet_A_F_P_depths, dims=(48, 96, 192, 384), **kwargs)
|
727 |
+
initialize_with_pretrained(model, 'unireplknet_f', in_1k_pretrained, False, False)
|
728 |
+
return model
|
729 |
+
|
730 |
+
@register_model
|
731 |
+
def unireplknet_p(in_1k_pretrained=False, **kwargs):
|
732 |
+
model = UniRepLKNet(depths=UniRepLKNet_A_F_P_depths, dims=(64, 128, 256, 512), **kwargs)
|
733 |
+
initialize_with_pretrained(model, 'unireplknet_p', in_1k_pretrained, False, False)
|
734 |
+
return model
|
735 |
+
|
736 |
+
@register_model
|
737 |
+
def unireplknet_n(in_1k_pretrained=False, **kwargs):
|
738 |
+
model = UniRepLKNet(depths=UniRepLKNet_N_depths, dims=(80, 160, 320, 640), **kwargs)
|
739 |
+
initialize_with_pretrained(model, 'unireplknet_n', in_1k_pretrained, False, False)
|
740 |
+
return model
|
741 |
+
|
742 |
+
@register_model
|
743 |
+
def unireplknet_t(in_1k_pretrained=False, **kwargs):
|
744 |
+
model = UniRepLKNet(depths=UniRepLKNet_T_depths, dims=(80, 160, 320, 640), **kwargs)
|
745 |
+
initialize_with_pretrained(model, 'unireplknet_t', in_1k_pretrained, False, False)
|
746 |
+
return model
|
747 |
+
|
748 |
+
@register_model
|
749 |
+
def unireplknet_s(in_1k_pretrained=False, in_22k_pretrained=False, in_22k_to_1k=False, **kwargs):
|
750 |
+
model = UniRepLKNet(depths=UniRepLKNet_S_B_L_XL_depths, dims=(96, 192, 384, 768), **kwargs)
|
751 |
+
initialize_with_pretrained(model, 'unireplknet_s', in_1k_pretrained, in_22k_pretrained, in_22k_to_1k)
|
752 |
+
return model
|
753 |
+
|
754 |
+
@register_model
|
755 |
+
def unireplknet_b(in_22k_pretrained=False, in_22k_to_1k=False, **kwargs):
|
756 |
+
model = UniRepLKNet(depths=UniRepLKNet_S_B_L_XL_depths, dims=(128, 256, 512, 1024), **kwargs)
|
757 |
+
initialize_with_pretrained(model, 'unireplknet_b', False, in_22k_pretrained, in_22k_to_1k)
|
758 |
+
return model
|
759 |
+
|
760 |
+
@register_model
|
761 |
+
def unireplknet_l(in_22k_pretrained=False, in_22k_to_1k=False, **kwargs):
|
762 |
+
model = UniRepLKNet(depths=UniRepLKNet_S_B_L_XL_depths, dims=(192, 384, 768, 1536), **kwargs)
|
763 |
+
initialize_with_pretrained(model, 'unireplknet_l', False, in_22k_pretrained, in_22k_to_1k)
|
764 |
+
return model
|
765 |
+
|
766 |
+
@register_model
|
767 |
+
def unireplknet_xl(in_22k_pretrained=False, in_22k_to_1k=False, **kwargs):
|
768 |
+
model = UniRepLKNet(depths=UniRepLKNet_S_B_L_XL_depths, dims=(256, 512, 1024, 2048), **kwargs)
|
769 |
+
initialize_with_pretrained(model, 'unireplknet_xl', False, in_22k_pretrained, in_22k_to_1k)
|
770 |
+
return model
|
771 |
+
|
772 |
+
@register_model
|
773 |
+
def unireplknet_h(in_22k_pretrained=False, in_22k_to_1k=False, **kwargs):
|
774 |
+
model = UniRepLKNet(depths=UniRepLKNet_S_B_L_XL_depths, dims=(480, 960, 1920, 3840), **kwargs)
|
775 |
+
initialize_with_pretrained(model, 'unireplknet_h', False, in_22k_pretrained, in_22k_to_1k)
|
776 |
+
return model
|
777 |
+
|
778 |
+
|
779 |
+
if __name__ == '__main__':
|
780 |
+
model_large = unireplknet_l()
|
781 |
+
print(model_large)
|
782 |
+
ckpt = torch.load("UniRepLKNet-L-b75k_s10B_CLIP-in1k_75.72.pt")
|
783 |
+
model_large.load_state_dict(ckpt,strict=False) # Since we do not need heads in CLIP pretraining.
|
784 |
+
print("Loaded CLIP Pretrained Models")
|