hungdang1610 commited on
Commit
cacd28a
·
verified ·
1 Parent(s): 3df4fe3
Files changed (1) hide show
  1. models/mivolo_model.py +404 -0
models/mivolo_model.py ADDED
@@ -0,0 +1,404 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Code adapted from timm https://github.com/huggingface/pytorch-image-models
3
+
4
+ Modifications and additions for mivolo by / Copyright 2023, Irina Tolstykh, Maxim Kuprashevich
5
+ """
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+ from cross_bottleneck_attn import CrossBottleneckAttn
10
+ from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
11
+ from timm.layers import trunc_normal_
12
+ from timm.models._builder import build_model_with_cfg
13
+ from timm.models._registry import register_model
14
+ from timm.models.volo import VOLO
15
+
16
+ __all__ = ["MiVOLOModel"] # model_registry will add each entrypoint fn to this
17
+
18
+
19
+ def _cfg(url="", **kwargs):
20
+ return {
21
+ "url": url,
22
+ "num_classes": 1000,
23
+ "input_size": (3, 224, 224),
24
+ "pool_size": None,
25
+ "crop_pct": 0.96,
26
+ "interpolation": "bicubic",
27
+ "fixed_input_size": True,
28
+ "mean": IMAGENET_DEFAULT_MEAN,
29
+ "std": IMAGENET_DEFAULT_STD,
30
+ "first_conv": None,
31
+ "classifier": ("head", "aux_head"),
32
+ **kwargs,
33
+ }
34
+
35
+
36
+ default_cfgs = {
37
+ "mivolo_d1_224": _cfg(
38
+ url="https://github.com/sail-sg/volo/releases/download/volo_1/d1_224_84.2.pth.tar", crop_pct=0.96
39
+ ),
40
+ "mivolo_d1_384": _cfg(
41
+ url="https://github.com/sail-sg/volo/releases/download/volo_1/d1_384_85.2.pth.tar",
42
+ crop_pct=1.0,
43
+ input_size=(3, 384, 384),
44
+ ),
45
+ "mivolo_d2_224": _cfg(
46
+ url="https://github.com/sail-sg/volo/releases/download/volo_1/d2_224_85.2.pth.tar", crop_pct=0.96
47
+ ),
48
+ "mivolo_d2_384": _cfg(
49
+ url="https://github.com/sail-sg/volo/releases/download/volo_1/d2_384_86.0.pth.tar",
50
+ crop_pct=1.0,
51
+ input_size=(3, 384, 384),
52
+ ),
53
+ "mivolo_d3_224": _cfg(
54
+ url="https://github.com/sail-sg/volo/releases/download/volo_1/d3_224_85.4.pth.tar", crop_pct=0.96
55
+ ),
56
+ "mivolo_d3_448": _cfg(
57
+ url="https://github.com/sail-sg/volo/releases/download/volo_1/d3_448_86.3.pth.tar",
58
+ crop_pct=1.0,
59
+ input_size=(3, 448, 448),
60
+ ),
61
+ "mivolo_d4_224": _cfg(
62
+ url="https://github.com/sail-sg/volo/releases/download/volo_1/d4_224_85.7.pth.tar", crop_pct=0.96
63
+ ),
64
+ "mivolo_d4_448": _cfg(
65
+ url="https://github.com/sail-sg/volo/releases/download/volo_1/d4_448_86.79.pth.tar",
66
+ crop_pct=1.15,
67
+ input_size=(3, 448, 448),
68
+ ),
69
+ "mivolo_d5_224": _cfg(
70
+ url="https://github.com/sail-sg/volo/releases/download/volo_1/d5_224_86.10.pth.tar", crop_pct=0.96
71
+ ),
72
+ "mivolo_d5_448": _cfg(
73
+ url="https://github.com/sail-sg/volo/releases/download/volo_1/d5_448_87.0.pth.tar",
74
+ crop_pct=1.15,
75
+ input_size=(3, 448, 448),
76
+ ),
77
+ "mivolo_d5_512": _cfg(
78
+ url="https://github.com/sail-sg/volo/releases/download/volo_1/d5_512_87.07.pth.tar",
79
+ crop_pct=1.15,
80
+ input_size=(3, 512, 512),
81
+ ),
82
+ }
83
+
84
+
85
+ def get_output_size(input_shape, conv_layer):
86
+ padding = conv_layer.padding
87
+ dilation = conv_layer.dilation
88
+ kernel_size = conv_layer.kernel_size
89
+ stride = conv_layer.stride
90
+
91
+ output_size = [
92
+ ((input_shape[i] + 2 * padding[i] - dilation[i] * (kernel_size[i] - 1) - 1) // stride[i]) + 1 for i in range(2)
93
+ ]
94
+ return output_size
95
+
96
+
97
+ def get_output_size_module(input_size, stem):
98
+ output_size = input_size
99
+
100
+ for module in stem:
101
+ if isinstance(module, nn.Conv2d):
102
+ output_size = [
103
+ (
104
+ (output_size[i] + 2 * module.padding[i] - module.dilation[i] * (module.kernel_size[i] - 1) - 1)
105
+ // module.stride[i]
106
+ )
107
+ + 1
108
+ for i in range(2)
109
+ ]
110
+
111
+ return output_size
112
+
113
+
114
+ class PatchEmbed(nn.Module):
115
+ """Image to Patch Embedding."""
116
+
117
+ def __init__(
118
+ self, img_size=224, stem_conv=False, stem_stride=1, patch_size=8, in_chans=3, hidden_dim=64, embed_dim=384
119
+ ):
120
+ super().__init__()
121
+ assert patch_size in [4, 8, 16]
122
+ assert in_chans in [3, 6]
123
+ self.with_persons_model = in_chans == 6
124
+ self.use_cross_attn = True
125
+
126
+ if stem_conv:
127
+ if not self.with_persons_model:
128
+ self.conv = self.create_stem(stem_stride, in_chans, hidden_dim)
129
+ else:
130
+ self.conv = True # just to match interface
131
+ # split
132
+ self.conv1 = self.create_stem(stem_stride, 3, hidden_dim)
133
+ self.conv2 = self.create_stem(stem_stride, 3, hidden_dim)
134
+ else:
135
+ self.conv = None
136
+
137
+ if self.with_persons_model:
138
+
139
+ self.proj1 = nn.Conv2d(
140
+ hidden_dim, embed_dim, kernel_size=patch_size // stem_stride, stride=patch_size // stem_stride
141
+ )
142
+ self.proj2 = nn.Conv2d(
143
+ hidden_dim, embed_dim, kernel_size=patch_size // stem_stride, stride=patch_size // stem_stride
144
+ )
145
+
146
+ stem_out_shape = get_output_size_module((img_size, img_size), self.conv1)
147
+ self.proj_output_size = get_output_size(stem_out_shape, self.proj1)
148
+
149
+ self.map = CrossBottleneckAttn(embed_dim, dim_out=embed_dim, num_heads=1, feat_size=self.proj_output_size)
150
+
151
+ else:
152
+ self.proj = nn.Conv2d(
153
+ hidden_dim, embed_dim, kernel_size=patch_size // stem_stride, stride=patch_size // stem_stride
154
+ )
155
+
156
+ self.patch_dim = img_size // patch_size
157
+ self.num_patches = self.patch_dim**2
158
+
159
+ def create_stem(self, stem_stride, in_chans, hidden_dim):
160
+ return nn.Sequential(
161
+ nn.Conv2d(in_chans, hidden_dim, kernel_size=7, stride=stem_stride, padding=3, bias=False), # 112x112
162
+ nn.BatchNorm2d(hidden_dim),
163
+ nn.ReLU(inplace=True),
164
+ nn.Conv2d(hidden_dim, hidden_dim, kernel_size=3, stride=1, padding=1, bias=False), # 112x112
165
+ nn.BatchNorm2d(hidden_dim),
166
+ nn.ReLU(inplace=True),
167
+ nn.Conv2d(hidden_dim, hidden_dim, kernel_size=3, stride=1, padding=1, bias=False), # 112x112
168
+ nn.BatchNorm2d(hidden_dim),
169
+ nn.ReLU(inplace=True),
170
+ )
171
+
172
+ def forward(self, x):
173
+ if self.conv is not None:
174
+ if self.with_persons_model:
175
+ x1 = x[:, :3]
176
+ x2 = x[:, 3:]
177
+
178
+ x1 = self.conv1(x1)
179
+ x1 = self.proj1(x1)
180
+
181
+ x2 = self.conv2(x2)
182
+ x2 = self.proj2(x2)
183
+
184
+ x = torch.cat([x1, x2], dim=1)
185
+ x = self.map(x)
186
+ else:
187
+ x = self.conv(x)
188
+ x = self.proj(x) # B, C, H, W
189
+
190
+ return x
191
+
192
+
193
+ class MiVOLOModel(VOLO):
194
+ """
195
+ Vision Outlooker, the main class of our model
196
+ """
197
+
198
+ def __init__(
199
+ self,
200
+ layers,
201
+ img_size=224,
202
+ in_chans=3,
203
+ num_classes=1000,
204
+ global_pool="token",
205
+ patch_size=8,
206
+ stem_hidden_dim=64,
207
+ embed_dims=None,
208
+ num_heads=None,
209
+ downsamples=(True, False, False, False),
210
+ outlook_attention=(True, False, False, False),
211
+ mlp_ratio=3.0,
212
+ qkv_bias=False,
213
+ drop_rate=0.0,
214
+ attn_drop_rate=0.0,
215
+ drop_path_rate=0.0,
216
+ norm_layer=nn.LayerNorm,
217
+ post_layers=("ca", "ca"),
218
+ use_aux_head=True,
219
+ use_mix_token=False,
220
+ pooling_scale=2,
221
+ ):
222
+ super().__init__(
223
+ layers,
224
+ img_size,
225
+ in_chans,
226
+ num_classes,
227
+ global_pool,
228
+ patch_size,
229
+ stem_hidden_dim,
230
+ embed_dims,
231
+ num_heads,
232
+ downsamples,
233
+ outlook_attention,
234
+ mlp_ratio,
235
+ qkv_bias,
236
+ drop_rate,
237
+ attn_drop_rate,
238
+ drop_path_rate,
239
+ norm_layer,
240
+ post_layers,
241
+ use_aux_head,
242
+ use_mix_token,
243
+ pooling_scale,
244
+ )
245
+
246
+ im_size = img_size[0] if isinstance(img_size, tuple) else img_size
247
+ self.patch_embed = PatchEmbed(
248
+ img_size=im_size,
249
+ stem_conv=True,
250
+ stem_stride=2,
251
+ patch_size=patch_size,
252
+ in_chans=in_chans,
253
+ hidden_dim=stem_hidden_dim,
254
+ embed_dim=embed_dims[0],
255
+ )
256
+
257
+ trunc_normal_(self.pos_embed, std=0.02)
258
+ self.apply(self._init_weights)
259
+
260
+ def forward_features(self, x):
261
+ x = self.patch_embed(x).permute(0, 2, 3, 1) # B,C,H,W-> B,H,W,C
262
+
263
+ # step2: tokens learning in the two stages
264
+ x = self.forward_tokens(x)
265
+
266
+ # step3: post network, apply class attention or not
267
+ if self.post_network is not None:
268
+ x = self.forward_cls(x)
269
+ x = self.norm(x)
270
+ return x
271
+
272
+ def forward_head(self, x, pre_logits: bool = False, targets=None, epoch=None):
273
+ if self.global_pool == "avg":
274
+ out = x.mean(dim=1)
275
+ elif self.global_pool == "token":
276
+ out = x[:, 0]
277
+ else:
278
+ out = x
279
+ if pre_logits:
280
+ return out
281
+
282
+ features = out
283
+ fds_enabled = hasattr(self, "_fds_forward")
284
+ if fds_enabled:
285
+ features = self._fds_forward(features, targets, epoch)
286
+
287
+ out = self.head(features)
288
+ if self.aux_head is not None:
289
+ # generate classes in all feature tokens, see token labeling
290
+ aux = self.aux_head(x[:, 1:])
291
+ out = out + 0.5 * aux.max(1)[0]
292
+
293
+ return (out, features) if (fds_enabled and self.training) else out
294
+
295
+ def forward(self, x, targets=None, epoch=None):
296
+ """simplified forward (without mix token training)"""
297
+ x = self.forward_features(x)
298
+ x = self.forward_head(x, targets=targets, epoch=epoch)
299
+ return x
300
+
301
+
302
+ def _create_mivolo(variant, pretrained=False, **kwargs):
303
+ if kwargs.get("features_only", None):
304
+ raise RuntimeError("features_only not implemented for Vision Transformer models.")
305
+ return build_model_with_cfg(MiVOLOModel, variant, pretrained, **kwargs)
306
+
307
+
308
+ @register_model
309
+ def mivolo_d1_224(pretrained=False, **kwargs):
310
+ model_args = dict(layers=(4, 4, 8, 2), embed_dims=(192, 384, 384, 384), num_heads=(6, 12, 12, 12), **kwargs)
311
+ model = _create_mivolo("mivolo_d1_224", pretrained=pretrained, **model_args)
312
+ return model
313
+
314
+
315
+ @register_model
316
+ def mivolo_d1_384(pretrained=False, **kwargs):
317
+ model_args = dict(layers=(4, 4, 8, 2), embed_dims=(192, 384, 384, 384), num_heads=(6, 12, 12, 12), **kwargs)
318
+ model = _create_mivolo("mivolo_d1_384", pretrained=pretrained, **model_args)
319
+ return model
320
+
321
+
322
+ @register_model
323
+ def mivolo_d2_224(pretrained=False, **kwargs):
324
+ model_args = dict(layers=(6, 4, 10, 4), embed_dims=(256, 512, 512, 512), num_heads=(8, 16, 16, 16), **kwargs)
325
+ model = _create_mivolo("mivolo_d2_224", pretrained=pretrained, **model_args)
326
+ return model
327
+
328
+
329
+ @register_model
330
+ def mivolo_d2_384(pretrained=False, **kwargs):
331
+ model_args = dict(layers=(6, 4, 10, 4), embed_dims=(256, 512, 512, 512), num_heads=(8, 16, 16, 16), **kwargs)
332
+ model = _create_mivolo("mivolo_d2_384", pretrained=pretrained, **model_args)
333
+ return model
334
+
335
+
336
+ @register_model
337
+ def mivolo_d3_224(pretrained=False, **kwargs):
338
+ model_args = dict(layers=(8, 8, 16, 4), embed_dims=(256, 512, 512, 512), num_heads=(8, 16, 16, 16), **kwargs)
339
+ model = _create_mivolo("mivolo_d3_224", pretrained=pretrained, **model_args)
340
+ return model
341
+
342
+
343
+ @register_model
344
+ def mivolo_d3_448(pretrained=False, **kwargs):
345
+ model_args = dict(layers=(8, 8, 16, 4), embed_dims=(256, 512, 512, 512), num_heads=(8, 16, 16, 16), **kwargs)
346
+ model = _create_mivolo("mivolo_d3_448", pretrained=pretrained, **model_args)
347
+ return model
348
+
349
+
350
+ @register_model
351
+ def mivolo_d4_224(pretrained=False, **kwargs):
352
+ model_args = dict(layers=(8, 8, 16, 4), embed_dims=(384, 768, 768, 768), num_heads=(12, 16, 16, 16), **kwargs)
353
+ model = _create_mivolo("mivolo_d4_224", pretrained=pretrained, **model_args)
354
+ return model
355
+
356
+
357
+ @register_model
358
+ def mivolo_d4_448(pretrained=False, **kwargs):
359
+ """VOLO-D4 model, Params: 193M"""
360
+ model_args = dict(layers=(8, 8, 16, 4), embed_dims=(384, 768, 768, 768), num_heads=(12, 16, 16, 16), **kwargs)
361
+ model = _create_mivolo("mivolo_d4_448", pretrained=pretrained, **model_args)
362
+ return model
363
+
364
+
365
+ @register_model
366
+ def mivolo_d5_224(pretrained=False, **kwargs):
367
+ model_args = dict(
368
+ layers=(12, 12, 20, 4),
369
+ embed_dims=(384, 768, 768, 768),
370
+ num_heads=(12, 16, 16, 16),
371
+ mlp_ratio=4,
372
+ stem_hidden_dim=128,
373
+ **kwargs
374
+ )
375
+ model = _create_mivolo("mivolo_d5_224", pretrained=pretrained, **model_args)
376
+ return model
377
+
378
+
379
+ @register_model
380
+ def mivolo_d5_448(pretrained=False, **kwargs):
381
+ model_args = dict(
382
+ layers=(12, 12, 20, 4),
383
+ embed_dims=(384, 768, 768, 768),
384
+ num_heads=(12, 16, 16, 16),
385
+ mlp_ratio=4,
386
+ stem_hidden_dim=128,
387
+ **kwargs
388
+ )
389
+ model = _create_mivolo("mivolo_d5_448", pretrained=pretrained, **model_args)
390
+ return model
391
+
392
+
393
+ @register_model
394
+ def mivolo_d5_512(pretrained=False, **kwargs):
395
+ model_args = dict(
396
+ layers=(12, 12, 20, 4),
397
+ embed_dims=(384, 768, 768, 768),
398
+ num_heads=(12, 16, 16, 16),
399
+ mlp_ratio=4,
400
+ stem_hidden_dim=128,
401
+ **kwargs
402
+ )
403
+ model = _create_mivolo("mivolo_d5_512", pretrained=pretrained, **model_args)
404
+ return model