Yiyuan commited on
Commit
d6d798b
·
1 Parent(s): 76668e1
UniRepLKNet-L-b75k_s10B_CLIP-in1k_75.72.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6f1b36fef7c9a0174b113db0093df8e827b523193e0ed10df788f8def3e64c84
3
+ size 875658963
modeling_UniRepLKNet.py ADDED
@@ -0,0 +1,784 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # UniRepLKNet: A Universal Perception Large-Kernel ConvNet for Audio, Video, Point Cloud, Time-Series and Image Recognition
2
+ # Github source: https://github.com/AILab-CVC/UniRepLKNet
3
+ # Licensed under The Apache License 2.0 License [see LICENSE for details]
4
+ # Based on RepLKNet, ConvNeXt, timm, DINO and DeiT code bases
5
+ # https://github.com/DingXiaoH/RepLKNet-pytorch
6
+ # https://github.com/facebookresearch/ConvNeXt
7
+ # https://github.com/rwightman/pytorch-image-models/tree/master/timm
8
+ # https://github.com/facebookresearch/deit/
9
+ # https://github.com/facebookresearch/dino
10
+ # --------------------------------------------------------'
11
+ import torch
12
+ import torch.nn as nn
13
+ import torch.nn.functional as F
14
+ from timm.models.layers import trunc_normal_, DropPath, to_2tuple
15
+ from timm.models.registry import register_model
16
+ from timm.models.layers import trunc_normal_, SelectAdaptivePool2d, DropPath, ConvMlp, Mlp, LayerNorm2d, LayerNorm, \
17
+ create_conv2d, get_act_layer, make_divisible, to_ntuple
18
+ from functools import partial
19
+ import torch.utils.checkpoint as checkpoint
20
+ try:
21
+ from huggingface_hub import hf_hub_download
22
+ except:
23
+ hf_hub_download = None # install huggingface_hub if you would like to download models conveniently from huggingface
24
+
25
+ has_mmdet = False
26
+ has_mmseg = False
27
+ # =============== for the ease of directly using this file in MMSegmentation and MMDetection.
28
+ # =============== ignore the following two segments of code if you do not plan to do so
29
+ # =============== delete one of the following two segments if you get a confliction
30
+ try:
31
+ from mmseg.models.builder import BACKBONES as seg_BACKBONES
32
+ from mmseg.utils import get_root_logger
33
+ from mmcv.runner import _load_checkpoint
34
+ has_mmseg = True
35
+ except ImportError:
36
+ get_root_logger = None
37
+ _load_checkpoint = None
38
+
39
+ # try:
40
+ # from mmdet.models.builder import BACKBONES as det_BACKBONES
41
+ # from mmdet.utils import get_root_logger
42
+ # from mmcv.runner import _load_checkpoint
43
+ # has_mmdet = True
44
+ # except ImportError:
45
+ # get_root_logger = None
46
+ # _load_checkpoint = None
47
+ # ===========================================================================================
48
+
49
+
50
+ class GRNwithNHWC(nn.Module):
51
+ """ GRN (Global Response Normalization) layer
52
+ Originally proposed in ConvNeXt V2 (https://arxiv.org/abs/2301.00808)
53
+ This implementation is more efficient than the original (https://github.com/facebookresearch/ConvNeXt-V2)
54
+ We assume the inputs to this layer are (N, H, W, C)
55
+ """
56
+ def __init__(self, dim, use_bias=True):
57
+ super().__init__()
58
+ self.use_bias = use_bias
59
+ self.gamma = nn.Parameter(torch.zeros(1, 1, 1, dim))
60
+ if self.use_bias:
61
+ self.beta = nn.Parameter(torch.zeros(1, 1, 1, dim))
62
+
63
+ def forward(self, x):
64
+ Gx = torch.norm(x, p=2, dim=(1, 2), keepdim=True)
65
+ Nx = Gx / (Gx.mean(dim=-1, keepdim=True) + 1e-6)
66
+ if self.use_bias:
67
+ return (self.gamma * Nx + 1) * x + self.beta
68
+ else:
69
+ return (self.gamma * Nx + 1) * x
70
+
71
+
72
+ class NCHWtoNHWC(nn.Module):
73
+ def __init__(self):
74
+ super().__init__()
75
+
76
+ def forward(self, x):
77
+ return x.permute(0, 2, 3, 1)
78
+
79
+
80
+ class NHWCtoNCHW(nn.Module):
81
+ def __init__(self):
82
+ super().__init__()
83
+
84
+ def forward(self, x):
85
+ return x.permute(0, 3, 1, 2)
86
+
87
+ #================== This function decides which conv implementation (the native or iGEMM) to use
88
+ # Note that iGEMM large-kernel conv impl will be used if
89
+ # - you attempt to do so (attempt_to_use_large_impl=True), and
90
+ # - it has been installed (follow https://github.com/AILab-CVC/UniRepLKNet), and
91
+ # - the conv layer is depth-wise, stride = 1, non-dilated, kernel_size > 5, and padding == kernel_size // 2
92
+ def get_conv2d(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias,
93
+ attempt_use_lk_impl=True):
94
+ kernel_size = to_2tuple(kernel_size)
95
+ if padding is None:
96
+ padding = (kernel_size[0] // 2, kernel_size[1] // 2)
97
+ else:
98
+ padding = to_2tuple(padding)
99
+ need_large_impl = kernel_size[0] == kernel_size[1] and kernel_size[0] > 5 and padding == (kernel_size[0] // 2, kernel_size[1] // 2)
100
+
101
+ if attempt_use_lk_impl and need_large_impl:
102
+ print('---------------- trying to import iGEMM implementation for large-kernel conv')
103
+ try:
104
+ from depthwise_conv2d_implicit_gemm import DepthWiseConv2dImplicitGEMM
105
+ print('---------------- found iGEMM implementation ')
106
+ except:
107
+ DepthWiseConv2dImplicitGEMM = None
108
+ print('---------------- found no iGEMM. use original conv. follow https://github.com/AILab-CVC/UniRepLKNet to install it.')
109
+ if DepthWiseConv2dImplicitGEMM is not None and need_large_impl and in_channels == out_channels \
110
+ and out_channels == groups and stride == 1 and dilation == 1:
111
+ print(f'===== iGEMM Efficient Conv Impl, channels {in_channels}, kernel size {kernel_size} =====')
112
+ return DepthWiseConv2dImplicitGEMM(in_channels, kernel_size, bias=bias)
113
+ return nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride,
114
+ padding=padding, dilation=dilation, groups=groups, bias=bias)
115
+
116
+
117
+ def get_bn(dim, use_sync_bn=False):
118
+ if use_sync_bn:
119
+ return nn.SyncBatchNorm(dim)
120
+ else:
121
+ return nn.BatchNorm2d(dim)
122
+
123
+ class SEBlock(nn.Module):
124
+ """
125
+ Squeeze-and-Excitation Block proposed in SENet (https://arxiv.org/abs/1709.01507)
126
+ We assume the inputs to this layer are (N, C, H, W)
127
+ """
128
+ def __init__(self, input_channels, internal_neurons):
129
+ super(SEBlock, self).__init__()
130
+ self.down = nn.Conv2d(in_channels=input_channels, out_channels=internal_neurons,
131
+ kernel_size=1, stride=1, bias=True)
132
+ self.up = nn.Conv2d(in_channels=internal_neurons, out_channels=input_channels,
133
+ kernel_size=1, stride=1, bias=True)
134
+ self.input_channels = input_channels
135
+ self.nonlinear = nn.ReLU(inplace=True)
136
+
137
+ def forward(self, inputs):
138
+ x = F.adaptive_avg_pool2d(inputs, output_size=(1, 1))
139
+ x = self.down(x)
140
+ x = self.nonlinear(x)
141
+ x = self.up(x)
142
+ x = F.sigmoid(x)
143
+ return inputs * x.view(-1, self.input_channels, 1, 1)
144
+
145
+ def fuse_bn(conv, bn):
146
+ conv_bias = 0 if conv.bias is None else conv.bias
147
+ std = (bn.running_var + bn.eps).sqrt()
148
+ return conv.weight * (bn.weight / std).reshape(-1, 1, 1, 1), bn.bias + (conv_bias - bn.running_mean) * bn.weight / std
149
+
150
+ def convert_dilated_to_nondilated(kernel, dilate_rate):
151
+ identity_kernel = torch.ones((1, 1, 1, 1)).to(kernel.device)
152
+ if kernel.size(1) == 1:
153
+ # This is a DW kernel
154
+ dilated = F.conv_transpose2d(kernel, identity_kernel, stride=dilate_rate)
155
+ return dilated
156
+ else:
157
+ # This is a dense or group-wise (but not DW) kernel
158
+ slices = []
159
+ for i in range(kernel.size(1)):
160
+ dilated = F.conv_transpose2d(kernel[:,i:i+1,:,:], identity_kernel, stride=dilate_rate)
161
+ slices.append(dilated)
162
+ return torch.cat(slices, dim=1)
163
+
164
+ def merge_dilated_into_large_kernel(large_kernel, dilated_kernel, dilated_r):
165
+ large_k = large_kernel.size(2)
166
+ dilated_k = dilated_kernel.size(2)
167
+ equivalent_kernel_size = dilated_r * (dilated_k - 1) + 1
168
+ equivalent_kernel = convert_dilated_to_nondilated(dilated_kernel, dilated_r)
169
+ rows_to_pad = large_k // 2 - equivalent_kernel_size // 2
170
+ merged_kernel = large_kernel + F.pad(equivalent_kernel, [rows_to_pad] * 4)
171
+ return merged_kernel
172
+
173
+
174
+ class DilatedReparamBlock(nn.Module):
175
+ """
176
+ Dilated Reparam Block proposed in UniRepLKNet (https://github.com/AILab-CVC/UniRepLKNet)
177
+ We assume the inputs to this block are (N, C, H, W)
178
+ """
179
+ def __init__(self, channels, kernel_size, deploy, use_sync_bn=False, attempt_use_lk_impl=True):
180
+ super().__init__()
181
+ self.lk_origin = get_conv2d(channels, channels, kernel_size, stride=1,
182
+ padding=kernel_size//2, dilation=1, groups=channels, bias=deploy,
183
+ attempt_use_lk_impl=attempt_use_lk_impl)
184
+ self.attempt_use_lk_impl = attempt_use_lk_impl
185
+
186
+ # Default settings. We did not tune them carefully. Different settings may work better.
187
+ if kernel_size == 17:
188
+ self.kernel_sizes = [5, 9, 3, 3, 3]
189
+ self.dilates = [1, 2, 4, 5, 7]
190
+ elif kernel_size == 15:
191
+ self.kernel_sizes = [5, 7, 3, 3, 3]
192
+ self.dilates = [1, 2, 3, 5, 7]
193
+ elif kernel_size == 13:
194
+ self.kernel_sizes = [5, 7, 3, 3, 3]
195
+ self.dilates = [1, 2, 3, 4, 5]
196
+ elif kernel_size == 11:
197
+ self.kernel_sizes = [5, 5, 3, 3, 3]
198
+ self.dilates = [1, 2, 3, 4, 5]
199
+ elif kernel_size == 9:
200
+ self.kernel_sizes = [5, 5, 3, 3]
201
+ self.dilates = [1, 2, 3, 4]
202
+ elif kernel_size == 7:
203
+ self.kernel_sizes = [5, 3, 3]
204
+ self.dilates = [1, 2, 3]
205
+ elif kernel_size == 5:
206
+ self.kernel_sizes = [3, 3]
207
+ self.dilates = [1, 2]
208
+ else:
209
+ raise ValueError('Dilated Reparam Block requires kernel_size >= 5')
210
+
211
+ if not deploy:
212
+ self.origin_bn = get_bn(channels, use_sync_bn)
213
+ for k, r in zip(self.kernel_sizes, self.dilates):
214
+ self.__setattr__('dil_conv_k{}_{}'.format(k, r),
215
+ nn.Conv2d(in_channels=channels, out_channels=channels, kernel_size=k, stride=1,
216
+ padding=(r * (k - 1) + 1) // 2, dilation=r, groups=channels,
217
+ bias=False))
218
+ self.__setattr__('dil_bn_k{}_{}'.format(k, r), get_bn(channels, use_sync_bn=use_sync_bn))
219
+
220
+ def forward(self, x):
221
+ if not hasattr(self, 'origin_bn'): # deploy mode
222
+ return self.lk_origin(x)
223
+ out = self.origin_bn(self.lk_origin(x))
224
+ for k, r in zip(self.kernel_sizes, self.dilates):
225
+ conv = self.__getattr__('dil_conv_k{}_{}'.format(k, r))
226
+ bn = self.__getattr__('dil_bn_k{}_{}'.format(k, r))
227
+ out = out + bn(conv(x))
228
+ return out
229
+
230
+ def merge_dilated_branches(self):
231
+ if hasattr(self, 'origin_bn'):
232
+ origin_k, origin_b = fuse_bn(self.lk_origin, self.origin_bn)
233
+ for k, r in zip(self.kernel_sizes, self.dilates):
234
+ conv = self.__getattr__('dil_conv_k{}_{}'.format(k, r))
235
+ bn = self.__getattr__('dil_bn_k{}_{}'.format(k, r))
236
+ branch_k, branch_b = fuse_bn(conv, bn)
237
+ origin_k = merge_dilated_into_large_kernel(origin_k, branch_k, r)
238
+ origin_b += branch_b
239
+ merged_conv = get_conv2d(origin_k.size(0), origin_k.size(0), origin_k.size(2), stride=1,
240
+ padding=origin_k.size(2)//2, dilation=1, groups=origin_k.size(0), bias=True,
241
+ attempt_use_lk_impl=self.attempt_use_lk_impl)
242
+ merged_conv.weight.data = origin_k
243
+ merged_conv.bias.data = origin_b
244
+ self.lk_origin = merged_conv
245
+ self.__delattr__('origin_bn')
246
+ for k, r in zip(self.kernel_sizes, self.dilates):
247
+ self.__delattr__('dil_conv_k{}_{}'.format(k, r))
248
+ self.__delattr__('dil_bn_k{}_{}'.format(k, r))
249
+
250
+
251
+ class UniRepLKNetBlock(nn.Module):
252
+
253
+ def __init__(self,
254
+ dim,
255
+ kernel_size,
256
+ drop_path=0.,
257
+ layer_scale_init_value=1e-6,
258
+ deploy=False,
259
+ attempt_use_lk_impl=True,
260
+ with_cp=False,
261
+ use_sync_bn=False,
262
+ ffn_factor=4):
263
+ super().__init__()
264
+ self.with_cp = with_cp
265
+ if deploy:
266
+ print('------------------------------- Note: deploy mode')
267
+ if self.with_cp:
268
+ print('****** note with_cp = True, reduce memory consumption but may slow down training ******')
269
+
270
+ if kernel_size == 0:
271
+ self.dwconv = nn.Identity()
272
+ elif kernel_size >= 7:
273
+ self.dwconv = DilatedReparamBlock(dim, kernel_size, deploy=deploy,
274
+ use_sync_bn=use_sync_bn,
275
+ attempt_use_lk_impl=attempt_use_lk_impl)
276
+
277
+ else:
278
+ assert kernel_size in [3, 5]
279
+ self.dwconv = get_conv2d(dim, dim, kernel_size=kernel_size, stride=1, padding=kernel_size // 2,
280
+ dilation=1, groups=dim, bias=deploy,
281
+ attempt_use_lk_impl=attempt_use_lk_impl)
282
+
283
+ if deploy or kernel_size == 0:
284
+ self.norm = nn.Identity()
285
+ else:
286
+ self.norm = get_bn(dim, use_sync_bn=use_sync_bn)
287
+
288
+ self.se = SEBlock(dim, dim // 4)
289
+
290
+ ffn_dim = int(ffn_factor * dim)
291
+ self.pwconv1 = nn.Sequential(
292
+ NCHWtoNHWC(),
293
+ nn.Linear(dim, ffn_dim))
294
+ self.act = nn.Sequential(
295
+ nn.GELU(),
296
+ GRNwithNHWC(ffn_dim, use_bias=not deploy))
297
+ if deploy:
298
+ self.pwconv2 = nn.Sequential(
299
+ nn.Linear(ffn_dim, dim),
300
+ NHWCtoNCHW())
301
+ else:
302
+ self.pwconv2 = nn.Sequential(
303
+ nn.Linear(ffn_dim, dim, bias=False),
304
+ NHWCtoNCHW(),
305
+ get_bn(dim, use_sync_bn=use_sync_bn))
306
+
307
+ self.gamma = nn.Parameter(layer_scale_init_value * torch.ones(dim),
308
+ requires_grad=True) if (not deploy) and layer_scale_init_value is not None \
309
+ and layer_scale_init_value > 0 else None
310
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
311
+
312
+ def compute_residual(self, x):
313
+ y = self.se(self.norm(self.dwconv(x)))
314
+ y = self.pwconv2(self.act(self.pwconv1(y)))
315
+ if self.gamma is not None:
316
+ y = self.gamma.view(1, -1, 1, 1) * y
317
+ return self.drop_path(y)
318
+
319
+ def forward(self, inputs):
320
+
321
+ def _f(x):
322
+ return x + self.compute_residual(x)
323
+
324
+ if self.with_cp and inputs.requires_grad:
325
+ out = checkpoint.checkpoint(_f, inputs)
326
+ else:
327
+ out = _f(inputs)
328
+ return out
329
+
330
+ def reparameterize(self):
331
+ if hasattr(self.dwconv, 'merge_dilated_branches'):
332
+ self.dwconv.merge_dilated_branches()
333
+ if hasattr(self.norm, 'running_var'):
334
+ std = (self.norm.running_var + self.norm.eps).sqrt()
335
+ if hasattr(self.dwconv, 'lk_origin'):
336
+ self.dwconv.lk_origin.weight.data *= (self.norm.weight / std).view(-1, 1, 1, 1)
337
+ self.dwconv.lk_origin.bias.data = self.norm.bias + (
338
+ self.dwconv.lk_origin.bias - self.norm.running_mean) * self.norm.weight / std
339
+ else:
340
+ conv = nn.Conv2d(self.dwconv.in_channels, self.dwconv.out_channels, self.dwconv.kernel_size,
341
+ padding=self.dwconv.padding, groups=self.dwconv.groups, bias=True)
342
+ conv.weight.data = self.dwconv.weight * (self.norm.weight / std).view(-1, 1, 1, 1)
343
+ conv.bias.data = self.norm.bias - self.norm.running_mean * self.norm.weight / std
344
+ self.dwconv = conv
345
+ self.norm = nn.Identity()
346
+ if self.gamma is not None:
347
+ final_scale = self.gamma.data
348
+ self.gamma = None
349
+ else:
350
+ final_scale = 1
351
+ if self.act[1].use_bias and len(self.pwconv2) == 3:
352
+ grn_bias = self.act[1].beta.data
353
+ self.act[1].__delattr__('beta')
354
+ self.act[1].use_bias = False
355
+ linear = self.pwconv2[0]
356
+ grn_bias_projected_bias = (linear.weight.data @ grn_bias.view(-1, 1)).squeeze()
357
+ bn = self.pwconv2[2]
358
+ std = (bn.running_var + bn.eps).sqrt()
359
+ new_linear = nn.Linear(linear.in_features, linear.out_features, bias=True)
360
+ new_linear.weight.data = linear.weight * (bn.weight / std * final_scale).view(-1, 1)
361
+ linear_bias = 0 if linear.bias is None else linear.bias.data
362
+ linear_bias += grn_bias_projected_bias
363
+ new_linear.bias.data = (bn.bias + (linear_bias - bn.running_mean) * bn.weight / std) * final_scale
364
+ self.pwconv2 = nn.Sequential(new_linear, self.pwconv2[1])
365
+
366
+
367
+
368
+ default_UniRepLKNet_A_F_P_kernel_sizes = ((3, 3),
369
+ (13, 13),
370
+ (13, 13, 13, 13, 13, 13),
371
+ (13, 13))
372
+ default_UniRepLKNet_N_kernel_sizes = ((3, 3),
373
+ (13, 13),
374
+ (13, 13, 13, 13, 13, 13, 13, 13),
375
+ (13, 13))
376
+ default_UniRepLKNet_T_kernel_sizes = ((3, 3, 3),
377
+ (13, 13, 13),
378
+ (13, 3, 13, 3, 13, 3, 13, 3, 13, 3, 13, 3, 13, 3, 13, 3, 13, 3),
379
+ (13, 13, 13))
380
+ default_UniRepLKNet_S_B_L_XL_kernel_sizes = ((3, 3, 3),
381
+ (13, 13, 13),
382
+ (13, 3, 3, 13, 3, 3, 13, 3, 3, 13, 3, 3, 13, 3, 3, 13, 3, 3, 13, 3, 3, 13, 3, 3, 13, 3, 3),
383
+ (13, 13, 13))
384
+ UniRepLKNet_A_F_P_depths = (2, 2, 6, 2)
385
+ UniRepLKNet_N_depths = (2, 2, 8, 2)
386
+ UniRepLKNet_T_depths = (3, 3, 18, 3)
387
+ UniRepLKNet_S_B_L_XL_depths = (3, 3, 27, 3)
388
+
389
+ default_depths_to_kernel_sizes = {
390
+ UniRepLKNet_A_F_P_depths: default_UniRepLKNet_A_F_P_kernel_sizes,
391
+ UniRepLKNet_N_depths: default_UniRepLKNet_N_kernel_sizes,
392
+ UniRepLKNet_T_depths: default_UniRepLKNet_T_kernel_sizes,
393
+ UniRepLKNet_S_B_L_XL_depths: default_UniRepLKNet_S_B_L_XL_kernel_sizes
394
+ }
395
+
396
+ class UniRepLKNet(nn.Module):
397
+ r""" UniRepLKNet
398
+ A PyTorch impl of UniRepLKNet
399
+
400
+ Args:
401
+ in_chans (int): Number of input image channels. Default: 3
402
+ num_classes (int): Number of classes for classification head. Default: 1000
403
+ depths (tuple(int)): Number of blocks at each stage. Default: (3, 3, 27, 3)
404
+ dims (int): Feature dimension at each stage. Default: (96, 192, 384, 768)
405
+ drop_path_rate (float): Stochastic depth rate. Default: 0.
406
+ layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6.
407
+ head_init_scale (float): Init scaling value for classifier weights and biases. Default: 1.
408
+ kernel_sizes (tuple(tuple(int))): Kernel size for each block. None means using the default settings. Default: None.
409
+ deploy (bool): deploy = True means using the inference structure. Default: False
410
+ with_cp (bool): with_cp = True means using torch.utils.checkpoint to save GPU memory. Default: False
411
+ init_cfg (dict): weights to load. The easiest way to use UniRepLKNet with for OpenMMLab family. Default: None
412
+ attempt_use_lk_impl (bool): try to load the efficient iGEMM large-kernel impl. Setting it to False disabling the iGEMM impl. Default: True
413
+ use_sync_bn (bool): use_sync_bn = True means using sync BN. Use it if your batch size is small. Default: False
414
+ """
415
+ def __init__(self,
416
+ in_chans=3,
417
+ num_classes=1000,
418
+ depths=(3, 3, 27, 3),
419
+ dims=(96, 192, 384, 768),
420
+ drop_path_rate=0.,
421
+ layer_scale_init_value=1e-6,
422
+ head_init_scale=1.,
423
+ kernel_sizes=None,
424
+ deploy=False,
425
+ with_cp=True,
426
+ init_cfg=None,
427
+ attempt_use_lk_impl=True,
428
+ use_sync_bn=False,
429
+ **kwargs
430
+ ):
431
+ super().__init__()
432
+
433
+ depths = tuple(depths)
434
+ if kernel_sizes is None:
435
+ if depths in default_depths_to_kernel_sizes:
436
+ print('=========== use default kernel size ')
437
+ kernel_sizes = default_depths_to_kernel_sizes[depths]
438
+ else:
439
+ raise ValueError('no default kernel size settings for the given depths, '
440
+ 'please specify kernel sizes for each block, e.g., '
441
+ '((3, 3), (13, 13), (13, 13, 13, 13, 13, 13), (13, 13))')
442
+ print(kernel_sizes)
443
+ for i in range(4):
444
+ assert len(kernel_sizes[i]) == depths[i], 'kernel sizes do not match the depths'
445
+
446
+ self.with_cp = with_cp
447
+
448
+ dp_rates = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]
449
+ print('=========== drop path rates: ', dp_rates)
450
+
451
+ self.downsample_layers = nn.ModuleList()
452
+ self.downsample_layers.append(nn.Sequential(
453
+ nn.Conv2d(in_chans, dims[0] // 2, kernel_size=3, stride=2, padding=1),
454
+ LayerNorm(dims[0] // 2, eps=1e-6, data_format="channels_first"),
455
+ nn.GELU(),
456
+ nn.Conv2d(dims[0] // 2, dims[0], kernel_size=3, stride=2, padding=1),
457
+ LayerNorm(dims[0], eps=1e-6, data_format="channels_first")))
458
+
459
+ for i in range(3):
460
+ self.downsample_layers.append(nn.Sequential(
461
+ nn.Conv2d(dims[i], dims[i + 1], kernel_size=3, stride=2, padding=1),
462
+ LayerNorm(dims[i + 1], eps=1e-6, data_format="channels_first")))
463
+
464
+ self.stages = nn.ModuleList()
465
+
466
+ cur = 0
467
+ for i in range(4):
468
+ main_stage = nn.Sequential(
469
+ *[UniRepLKNetBlock(dim=dims[i], kernel_size=kernel_sizes[i][j], drop_path=dp_rates[cur + j],
470
+ layer_scale_init_value=layer_scale_init_value, deploy=deploy,
471
+ attempt_use_lk_impl=attempt_use_lk_impl,
472
+ with_cp=with_cp, use_sync_bn=use_sync_bn) for j in
473
+ range(depths[i])])
474
+ self.stages.append(main_stage)
475
+ cur += depths[i]
476
+
477
+ self.last_channels = dims[-1]
478
+
479
+ self.for_pretrain = init_cfg is None
480
+ self.for_downstream = not self.for_pretrain # there may be some other scenarios
481
+ if self.for_downstream:
482
+ assert num_classes is None
483
+
484
+ if self.for_pretrain:
485
+ self.init_cfg = None
486
+ self.norm = nn.LayerNorm(self.last_channels, eps=1e-6) # final norm layer
487
+ # self.head = nn.Linear(self.last_channels, num_classes)
488
+ self.head = nn.Linear(self.last_channels, self.last_channels)
489
+ self.apply(self._init_weights)
490
+ self.head.weight.data.mul_(head_init_scale)
491
+ self.head.bias.data.mul_(head_init_scale)
492
+ self.output_mode = 'logits'
493
+ else:
494
+ self.init_cfg = init_cfg # OpenMMLab style init
495
+ self.init_weights()
496
+ self.output_mode = 'features'
497
+ norm_layer = partial(LayerNorm, eps=1e-6, data_format="channels_first")
498
+ for i_layer in range(4):
499
+ layer = norm_layer(dims[i_layer])
500
+ layer_name = f'norm{i_layer}'
501
+ self.add_module(layer_name, layer)
502
+
503
+
504
+ # load pretrained backbone weights in the OpenMMLab style
505
+ def init_weights(self):
506
+
507
+ def load_state_dict(module, state_dict, strict=False, logger=None):
508
+ unexpected_keys = []
509
+ own_state = module.state_dict()
510
+ for name, param in state_dict.items():
511
+ if name not in own_state:
512
+ unexpected_keys.append(name)
513
+ continue
514
+ if isinstance(param, torch.nn.Parameter):
515
+ # backwards compatibility for serialized parameters
516
+ param = param.data
517
+ try:
518
+ own_state[name].copy_(param)
519
+ except Exception:
520
+ raise RuntimeError(
521
+ 'While copying the parameter named {}, '
522
+ 'whose dimensions in the model are {} and '
523
+ 'whose dimensions in the checkpoint are {}.'.format(
524
+ name, own_state[name].size(), param.size()))
525
+ missing_keys = set(own_state.keys()) - set(state_dict.keys())
526
+
527
+ err_msg = []
528
+ if unexpected_keys:
529
+ err_msg.append('unexpected key in source state_dict: {}\n'.format(', '.join(unexpected_keys)))
530
+ if missing_keys:
531
+ err_msg.append('missing keys in source state_dict: {}\n'.format(', '.join(missing_keys)))
532
+ err_msg = '\n'.join(err_msg)
533
+ if err_msg:
534
+ if strict:
535
+ raise RuntimeError(err_msg)
536
+ elif logger is not None:
537
+ logger.warn(err_msg)
538
+ else:
539
+ print(err_msg)
540
+
541
+ logger = get_root_logger()
542
+ assert self.init_cfg is not None
543
+ ckpt_path = self.init_cfg['checkpoint']
544
+ if ckpt_path is None:
545
+ print('================ Note: init_cfg is provided but I got no init ckpt path, so skip initialization')
546
+ else:
547
+ ckpt = _load_checkpoint(ckpt_path, logger=logger, map_location='cpu')
548
+ if 'state_dict' in ckpt:
549
+ _state_dict = ckpt['state_dict']
550
+ elif 'model' in ckpt:
551
+ _state_dict = ckpt['model']
552
+ else:
553
+ _state_dict = ckpt
554
+
555
+ load_state_dict(self, _state_dict, strict=False, logger=logger)
556
+
557
+
558
+ def _init_weights(self, m):
559
+ if isinstance(m, (nn.Conv2d, nn.Linear)):
560
+ trunc_normal_(m.weight, std=.02)
561
+ if hasattr(m, 'bias') and m.bias is not None:
562
+ nn.init.constant_(m.bias, 0)
563
+
564
+ def forward(self, x):
565
+ if self.output_mode == 'logits':
566
+ for stage_idx in range(4):
567
+ x = self.downsample_layers[stage_idx](x)
568
+ x = self.stages[stage_idx](x)
569
+ x = self.norm(x.mean([-2, -1]))
570
+ x = self.head(x)
571
+ return x
572
+ elif self.output_mode == 'features':
573
+ outs = []
574
+ for stage_idx in range(4):
575
+ x = self.downsample_layers[stage_idx](x)
576
+ x = self.stages[stage_idx](x)
577
+ outs.append(self.__getattr__(f'norm{stage_idx}')(x))
578
+ return outs
579
+ else:
580
+ raise ValueError('Defined new output mode?')
581
+
582
+ def reparameterize_unireplknet(self):
583
+ for m in self.modules():
584
+ if hasattr(m, 'reparameterize'):
585
+ m.reparameterize()
586
+
587
+ @torch.jit.ignore
588
+ def get_classifier(self):
589
+ return self.head.fc
590
+
591
+ def reset_classifier(self, num_classes=0, global_pool=None):
592
+ if global_pool is not None:
593
+ self.head.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
594
+ self.head.flatten = nn.Flatten(1) if global_pool else nn.Identity()
595
+ # self.head.fc = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
596
+ self.head.fc = nn.Linear(self.num_features, self.num_features) if num_classes > 0 else nn.Identity()
597
+
598
+
599
+
600
+
601
+
602
+ class LayerNorm(nn.Module):
603
+ r""" LayerNorm implementation used in ConvNeXt
604
+ LayerNorm that supports two data formats: channels_last (default) or channels_first.
605
+ The ordering of the dimensions in the inputs. channels_last corresponds to inputs with
606
+ shape (batch_size, height, width, channels) while channels_first corresponds to inputs
607
+ with shape (batch_size, channels, height, width).
608
+ """
609
+
610
+ def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last", reshape_last_to_first=False):
611
+ super().__init__()
612
+ self.weight = nn.Parameter(torch.ones(normalized_shape))
613
+ self.bias = nn.Parameter(torch.zeros(normalized_shape))
614
+ self.eps = eps
615
+ self.data_format = data_format
616
+ if self.data_format not in ["channels_last", "channels_first"]:
617
+ raise NotImplementedError
618
+ self.normalized_shape = (normalized_shape,)
619
+ self.reshape_last_to_first = reshape_last_to_first
620
+
621
+ def forward(self, x):
622
+ if self.data_format == "channels_last":
623
+ return F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
624
+ elif self.data_format == "channels_first":
625
+ u = x.mean(1, keepdim=True)
626
+ s = (x - u).pow(2).mean(1, keepdim=True)
627
+ x = (x - u) / torch.sqrt(s + self.eps)
628
+ x = self.weight[:, None, None] * x + self.bias[:, None, None]
629
+ return x
630
+
631
+
632
+ # For easy use as backbone in MMDetection framework. Ignore these lines if you do not use MMDetection
633
+ if has_mmdet:
634
+ @det_BACKBONES.register_module()
635
+ class UniRepLKNetBackbone(UniRepLKNet):
636
+ def __init__(self,
637
+ depths=(3, 3, 27, 3),
638
+ dims=(96, 192, 384, 768),
639
+ drop_path_rate=0.,
640
+ layer_scale_init_value=1e-6,
641
+ kernel_sizes=None,
642
+ deploy=False,
643
+ with_cp=False,
644
+ init_cfg=None,
645
+ attempt_use_lk_impl=False):
646
+ assert init_cfg is not None
647
+ super().__init__(in_chans=3, num_classes=None, depths=depths, dims=dims,
648
+ drop_path_rate=drop_path_rate, layer_scale_init_value=layer_scale_init_value,
649
+ kernel_sizes=kernel_sizes, deploy=deploy, with_cp=with_cp,
650
+ init_cfg=init_cfg, attempt_use_lk_impl=attempt_use_lk_impl, use_sync_bn=True)
651
+
652
+ # For easy use as backbone in MMSegmentation framework. Ignore these lines if you do not use MMSegmentation
653
+ if has_mmseg:
654
+ @seg_BACKBONES.register_module()
655
+ class UniRepLKNetBackbone(UniRepLKNet):
656
+ def __init__(self,
657
+ depths=(3, 3, 27, 3),
658
+ dims=(96, 192, 384, 768),
659
+ drop_path_rate=0.,
660
+ layer_scale_init_value=1e-6,
661
+ kernel_sizes=None,
662
+ deploy=False,
663
+ with_cp=False,
664
+ init_cfg=None,
665
+ attempt_use_lk_impl=False):
666
+ assert init_cfg is not None
667
+ super().__init__(in_chans=3, num_classes=None, depths=depths, dims=dims,
668
+ drop_path_rate=drop_path_rate, layer_scale_init_value=layer_scale_init_value,
669
+ kernel_sizes=kernel_sizes, deploy=deploy, with_cp=with_cp,
670
+ init_cfg=init_cfg, attempt_use_lk_impl=attempt_use_lk_impl, use_sync_bn=True)
671
+
672
+
673
+ model_urls = {
674
+ #TODO: it seems that google drive does not support direct downloading with url? so where to upload the checkpoints other than huggingface? any suggestions?
675
+ }
676
+
677
+ huggingface_file_names = {
678
+ "unireplknet_a_1k": "unireplknet_a_in1k_224_acc77.03.pth",
679
+ "unireplknet_f_1k": "unireplknet_f_in1k_224_acc78.58.pth",
680
+ "unireplknet_p_1k": "unireplknet_p_in1k_224_acc80.23.pth",
681
+ "unireplknet_n_1k": "unireplknet_n_in1k_224_acc81.64.pth",
682
+ "unireplknet_t_1k": "unireplknet_t_in1k_224_acc83.21.pth",
683
+ "unireplknet_s_1k": "unireplknet_s_in1k_224_acc83.91.pth",
684
+ "unireplknet_s_22k": "unireplknet_s_in22k_pretrain.pth",
685
+ "unireplknet_s_22k_to_1k": "unireplknet_s_in22k_to_in1k_384_acc86.44.pth",
686
+ "unireplknet_b_22k": "unireplknet_b_in22k_pretrain.pth",
687
+ "unireplknet_b_22k_to_1k": "unireplknet_b_in22k_to_in1k_384_acc87.40.pth",
688
+ "unireplknet_l_22k": "unireplknet_l_in22k_pretrain.pth",
689
+ "unireplknet_l_22k_to_1k": "unireplknet_l_in22k_to_in1k_384_acc87.88.pth",
690
+ "unireplknet_xl_22k": "unireplknet_xl_in22k_pretrain.pth",
691
+ "unireplknet_xl_22k_to_1k": "unireplknet_xl_in22k_to_in1k_384_acc87.96.pth"
692
+ }
693
+
694
+ def load_with_key(model, key):
695
+ # if huggingface hub is found, download from our huggingface repo
696
+ if hf_hub_download is not None:
697
+ repo_id = 'DingXiaoH/UniRepLKNet'
698
+ cache_file = hf_hub_download(repo_id=repo_id, filename=huggingface_file_names[key])
699
+ checkpoint = torch.load(cache_file, map_location='cpu')
700
+ else:
701
+ checkpoint = torch.hub.load_state_dict_from_url(url=model_urls[key], map_location="cpu", check_hash=True)
702
+ if 'model' in checkpoint:
703
+ checkpoint = checkpoint['model']
704
+ model.load_state_dict(checkpoint)
705
+
706
+ def initialize_with_pretrained(model, model_name, in_1k_pretrained, in_22k_pretrained, in_22k_to_1k):
707
+ if in_1k_pretrained:
708
+ key = model_name + '_1k'
709
+ elif in_22k_pretrained:
710
+ key = model_name + '_22k'
711
+ elif in_22k_to_1k:
712
+ key = model_name + '_22k_to_1k'
713
+ else:
714
+ key = None
715
+ if key:
716
+ load_with_key(model, key)
717
+
718
+ @register_model
719
+ def unireplknet_a(in_1k_pretrained=False, **kwargs):
720
+ model = UniRepLKNet(depths=UniRepLKNet_A_F_P_depths, dims=(40, 80, 160, 320), **kwargs)
721
+ initialize_with_pretrained(model, 'unireplknet_a', in_1k_pretrained, False, False)
722
+ return model
723
+
724
+ @register_model
725
+ def unireplknet_f(in_1k_pretrained=False, **kwargs):
726
+ model = UniRepLKNet(depths=UniRepLKNet_A_F_P_depths, dims=(48, 96, 192, 384), **kwargs)
727
+ initialize_with_pretrained(model, 'unireplknet_f', in_1k_pretrained, False, False)
728
+ return model
729
+
730
+ @register_model
731
+ def unireplknet_p(in_1k_pretrained=False, **kwargs):
732
+ model = UniRepLKNet(depths=UniRepLKNet_A_F_P_depths, dims=(64, 128, 256, 512), **kwargs)
733
+ initialize_with_pretrained(model, 'unireplknet_p', in_1k_pretrained, False, False)
734
+ return model
735
+
736
+ @register_model
737
+ def unireplknet_n(in_1k_pretrained=False, **kwargs):
738
+ model = UniRepLKNet(depths=UniRepLKNet_N_depths, dims=(80, 160, 320, 640), **kwargs)
739
+ initialize_with_pretrained(model, 'unireplknet_n', in_1k_pretrained, False, False)
740
+ return model
741
+
742
+ @register_model
743
+ def unireplknet_t(in_1k_pretrained=False, **kwargs):
744
+ model = UniRepLKNet(depths=UniRepLKNet_T_depths, dims=(80, 160, 320, 640), **kwargs)
745
+ initialize_with_pretrained(model, 'unireplknet_t', in_1k_pretrained, False, False)
746
+ return model
747
+
748
+ @register_model
749
+ def unireplknet_s(in_1k_pretrained=False, in_22k_pretrained=False, in_22k_to_1k=False, **kwargs):
750
+ model = UniRepLKNet(depths=UniRepLKNet_S_B_L_XL_depths, dims=(96, 192, 384, 768), **kwargs)
751
+ initialize_with_pretrained(model, 'unireplknet_s', in_1k_pretrained, in_22k_pretrained, in_22k_to_1k)
752
+ return model
753
+
754
+ @register_model
755
+ def unireplknet_b(in_22k_pretrained=False, in_22k_to_1k=False, **kwargs):
756
+ model = UniRepLKNet(depths=UniRepLKNet_S_B_L_XL_depths, dims=(128, 256, 512, 1024), **kwargs)
757
+ initialize_with_pretrained(model, 'unireplknet_b', False, in_22k_pretrained, in_22k_to_1k)
758
+ return model
759
+
760
+ @register_model
761
+ def unireplknet_l(in_22k_pretrained=False, in_22k_to_1k=False, **kwargs):
762
+ model = UniRepLKNet(depths=UniRepLKNet_S_B_L_XL_depths, dims=(192, 384, 768, 1536), **kwargs)
763
+ initialize_with_pretrained(model, 'unireplknet_l', False, in_22k_pretrained, in_22k_to_1k)
764
+ return model
765
+
766
+ @register_model
767
+ def unireplknet_xl(in_22k_pretrained=False, in_22k_to_1k=False, **kwargs):
768
+ model = UniRepLKNet(depths=UniRepLKNet_S_B_L_XL_depths, dims=(256, 512, 1024, 2048), **kwargs)
769
+ initialize_with_pretrained(model, 'unireplknet_xl', False, in_22k_pretrained, in_22k_to_1k)
770
+ return model
771
+
772
+ @register_model
773
+ def unireplknet_h(in_22k_pretrained=False, in_22k_to_1k=False, **kwargs):
774
+ model = UniRepLKNet(depths=UniRepLKNet_S_B_L_XL_depths, dims=(480, 960, 1920, 3840), **kwargs)
775
+ initialize_with_pretrained(model, 'unireplknet_h', False, in_22k_pretrained, in_22k_to_1k)
776
+ return model
777
+
778
+
779
+ if __name__ == '__main__':
780
+ model_large = unireplknet_l()
781
+ print(model_large)
782
+ ckpt = torch.load("UniRepLKNet-L-b75k_s10B_CLIP-in1k_75.72.pt")
783
+ model_large.load_state_dict(ckpt,strict=False) # Since we do not need heads in CLIP pretraining.
784
+ print("Loaded CLIP Pretrained Models")