hungdang1610 commited on
Commit
c7024d3
·
verified ·
1 Parent(s): cacd28a

cross_bottleneck_attn

Browse files
Files changed (1) hide show
  1. models/cross_bottleneck_attn.py +116 -0
models/cross_bottleneck_attn.py ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Code based on timm https://github.com/huggingface/pytorch-image-models
3
+
4
+ Modifications and additions for mivolo by / Copyright 2023, Irina Tolstykh, Maxim Kuprashevich
5
+ """
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+ from timm.layers.bottleneck_attn import PosEmbedRel
10
+ from timm.layers.helpers import make_divisible
11
+ from timm.layers.mlp import Mlp
12
+ from timm.layers.trace_utils import _assert
13
+ from timm.layers.weight_init import trunc_normal_
14
+
15
+
16
+ class CrossBottleneckAttn(nn.Module):
17
+ def __init__(
18
+ self,
19
+ dim,
20
+ dim_out=None,
21
+ feat_size=None,
22
+ stride=1,
23
+ num_heads=4,
24
+ dim_head=None,
25
+ qk_ratio=1.0,
26
+ qkv_bias=False,
27
+ scale_pos_embed=False,
28
+ ):
29
+ super().__init__()
30
+ assert feat_size is not None, "A concrete feature size matching expected input (H, W) is required"
31
+ dim_out = dim_out or dim
32
+ assert dim_out % num_heads == 0
33
+
34
+ self.num_heads = num_heads
35
+ self.dim_head_qk = dim_head or make_divisible(dim_out * qk_ratio, divisor=8) // num_heads
36
+ self.dim_head_v = dim_out // self.num_heads
37
+ self.dim_out_qk = num_heads * self.dim_head_qk
38
+ self.dim_out_v = num_heads * self.dim_head_v
39
+ self.scale = self.dim_head_qk**-0.5
40
+ self.scale_pos_embed = scale_pos_embed
41
+
42
+ self.qkv_f = nn.Conv2d(dim, self.dim_out_qk * 2 + self.dim_out_v, 1, bias=qkv_bias)
43
+ self.qkv_p = nn.Conv2d(dim, self.dim_out_qk * 2 + self.dim_out_v, 1, bias=qkv_bias)
44
+
45
+ # NOTE I'm only supporting relative pos embedding for now
46
+ self.pos_embed = PosEmbedRel(feat_size, dim_head=self.dim_head_qk, scale=self.scale)
47
+
48
+ self.norm = nn.LayerNorm([self.dim_out_v * 2, *feat_size])
49
+ mlp_ratio = 4
50
+ self.mlp = Mlp(
51
+ in_features=self.dim_out_v * 2,
52
+ hidden_features=int(dim * mlp_ratio),
53
+ act_layer=nn.GELU,
54
+ out_features=dim_out,
55
+ drop=0,
56
+ use_conv=True,
57
+ )
58
+
59
+ self.pool = nn.AvgPool2d(2, 2) if stride == 2 else nn.Identity()
60
+ self.reset_parameters()
61
+
62
+ def reset_parameters(self):
63
+ trunc_normal_(self.qkv_f.weight, std=self.qkv_f.weight.shape[1] ** -0.5) # fan-in
64
+ trunc_normal_(self.qkv_p.weight, std=self.qkv_p.weight.shape[1] ** -0.5) # fan-in
65
+ trunc_normal_(self.pos_embed.height_rel, std=self.scale)
66
+ trunc_normal_(self.pos_embed.width_rel, std=self.scale)
67
+
68
+ def get_qkv(self, x, qvk_conv):
69
+ B, C, H, W = x.shape
70
+
71
+ x = qvk_conv(x) # B, (2 * dim_head_qk + dim_head_v) * num_heads, H, W
72
+
73
+ q, k, v = torch.split(x, [self.dim_out_qk, self.dim_out_qk, self.dim_out_v], dim=1)
74
+
75
+ q = q.reshape(B * self.num_heads, self.dim_head_qk, -1).transpose(-1, -2)
76
+ k = k.reshape(B * self.num_heads, self.dim_head_qk, -1) # no transpose, for q @ k
77
+ v = v.reshape(B * self.num_heads, self.dim_head_v, -1).transpose(-1, -2)
78
+
79
+ return q, k, v
80
+
81
+ def apply_attn(self, q, k, v, B, H, W, dropout=None):
82
+ if self.scale_pos_embed:
83
+ attn = (q @ k + self.pos_embed(q)) * self.scale # B * num_heads, H * W, H * W
84
+ else:
85
+ attn = (q @ k) * self.scale + self.pos_embed(q)
86
+ attn = attn.softmax(dim=-1)
87
+ if dropout:
88
+ attn = dropout(attn)
89
+
90
+ out = (attn @ v).transpose(-1, -2).reshape(B, self.dim_out_v, H, W) # B, dim_out, H, W
91
+ return out
92
+
93
+ def forward(self, x):
94
+ B, C, H, W = x.shape
95
+
96
+ dim = int(C / 2)
97
+ x1 = x[:, :dim, :, :]
98
+ x2 = x[:, dim:, :, :]
99
+
100
+ _assert(H == self.pos_embed.height, "")
101
+ _assert(W == self.pos_embed.width, "")
102
+
103
+ q_f, k_f, v_f = self.get_qkv(x1, self.qkv_f)
104
+ q_p, k_p, v_p = self.get_qkv(x2, self.qkv_p)
105
+
106
+ # person to face
107
+ out_f = self.apply_attn(q_f, k_p, v_p, B, H, W)
108
+ # face to person
109
+ out_p = self.apply_attn(q_p, k_f, v_f, B, H, W)
110
+
111
+ x_pf = torch.cat((out_f, out_p), dim=1) # B, dim_out * 2, H, W
112
+ x_pf = self.norm(x_pf)
113
+ x_pf = self.mlp(x_pf) # B, dim_out, H, W
114
+
115
+ out = self.pool(x_pf)
116
+ return out