hungdang1610
commited on
cross_bottleneck_attn
Browse files- models/cross_bottleneck_attn.py +116 -0
models/cross_bottleneck_attn.py
ADDED
@@ -0,0 +1,116 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Code based on timm https://github.com/huggingface/pytorch-image-models
|
3 |
+
|
4 |
+
Modifications and additions for mivolo by / Copyright 2023, Irina Tolstykh, Maxim Kuprashevich
|
5 |
+
"""
|
6 |
+
|
7 |
+
import torch
|
8 |
+
import torch.nn as nn
|
9 |
+
from timm.layers.bottleneck_attn import PosEmbedRel
|
10 |
+
from timm.layers.helpers import make_divisible
|
11 |
+
from timm.layers.mlp import Mlp
|
12 |
+
from timm.layers.trace_utils import _assert
|
13 |
+
from timm.layers.weight_init import trunc_normal_
|
14 |
+
|
15 |
+
|
16 |
+
class CrossBottleneckAttn(nn.Module):
|
17 |
+
def __init__(
|
18 |
+
self,
|
19 |
+
dim,
|
20 |
+
dim_out=None,
|
21 |
+
feat_size=None,
|
22 |
+
stride=1,
|
23 |
+
num_heads=4,
|
24 |
+
dim_head=None,
|
25 |
+
qk_ratio=1.0,
|
26 |
+
qkv_bias=False,
|
27 |
+
scale_pos_embed=False,
|
28 |
+
):
|
29 |
+
super().__init__()
|
30 |
+
assert feat_size is not None, "A concrete feature size matching expected input (H, W) is required"
|
31 |
+
dim_out = dim_out or dim
|
32 |
+
assert dim_out % num_heads == 0
|
33 |
+
|
34 |
+
self.num_heads = num_heads
|
35 |
+
self.dim_head_qk = dim_head or make_divisible(dim_out * qk_ratio, divisor=8) // num_heads
|
36 |
+
self.dim_head_v = dim_out // self.num_heads
|
37 |
+
self.dim_out_qk = num_heads * self.dim_head_qk
|
38 |
+
self.dim_out_v = num_heads * self.dim_head_v
|
39 |
+
self.scale = self.dim_head_qk**-0.5
|
40 |
+
self.scale_pos_embed = scale_pos_embed
|
41 |
+
|
42 |
+
self.qkv_f = nn.Conv2d(dim, self.dim_out_qk * 2 + self.dim_out_v, 1, bias=qkv_bias)
|
43 |
+
self.qkv_p = nn.Conv2d(dim, self.dim_out_qk * 2 + self.dim_out_v, 1, bias=qkv_bias)
|
44 |
+
|
45 |
+
# NOTE I'm only supporting relative pos embedding for now
|
46 |
+
self.pos_embed = PosEmbedRel(feat_size, dim_head=self.dim_head_qk, scale=self.scale)
|
47 |
+
|
48 |
+
self.norm = nn.LayerNorm([self.dim_out_v * 2, *feat_size])
|
49 |
+
mlp_ratio = 4
|
50 |
+
self.mlp = Mlp(
|
51 |
+
in_features=self.dim_out_v * 2,
|
52 |
+
hidden_features=int(dim * mlp_ratio),
|
53 |
+
act_layer=nn.GELU,
|
54 |
+
out_features=dim_out,
|
55 |
+
drop=0,
|
56 |
+
use_conv=True,
|
57 |
+
)
|
58 |
+
|
59 |
+
self.pool = nn.AvgPool2d(2, 2) if stride == 2 else nn.Identity()
|
60 |
+
self.reset_parameters()
|
61 |
+
|
62 |
+
def reset_parameters(self):
|
63 |
+
trunc_normal_(self.qkv_f.weight, std=self.qkv_f.weight.shape[1] ** -0.5) # fan-in
|
64 |
+
trunc_normal_(self.qkv_p.weight, std=self.qkv_p.weight.shape[1] ** -0.5) # fan-in
|
65 |
+
trunc_normal_(self.pos_embed.height_rel, std=self.scale)
|
66 |
+
trunc_normal_(self.pos_embed.width_rel, std=self.scale)
|
67 |
+
|
68 |
+
def get_qkv(self, x, qvk_conv):
|
69 |
+
B, C, H, W = x.shape
|
70 |
+
|
71 |
+
x = qvk_conv(x) # B, (2 * dim_head_qk + dim_head_v) * num_heads, H, W
|
72 |
+
|
73 |
+
q, k, v = torch.split(x, [self.dim_out_qk, self.dim_out_qk, self.dim_out_v], dim=1)
|
74 |
+
|
75 |
+
q = q.reshape(B * self.num_heads, self.dim_head_qk, -1).transpose(-1, -2)
|
76 |
+
k = k.reshape(B * self.num_heads, self.dim_head_qk, -1) # no transpose, for q @ k
|
77 |
+
v = v.reshape(B * self.num_heads, self.dim_head_v, -1).transpose(-1, -2)
|
78 |
+
|
79 |
+
return q, k, v
|
80 |
+
|
81 |
+
def apply_attn(self, q, k, v, B, H, W, dropout=None):
|
82 |
+
if self.scale_pos_embed:
|
83 |
+
attn = (q @ k + self.pos_embed(q)) * self.scale # B * num_heads, H * W, H * W
|
84 |
+
else:
|
85 |
+
attn = (q @ k) * self.scale + self.pos_embed(q)
|
86 |
+
attn = attn.softmax(dim=-1)
|
87 |
+
if dropout:
|
88 |
+
attn = dropout(attn)
|
89 |
+
|
90 |
+
out = (attn @ v).transpose(-1, -2).reshape(B, self.dim_out_v, H, W) # B, dim_out, H, W
|
91 |
+
return out
|
92 |
+
|
93 |
+
def forward(self, x):
|
94 |
+
B, C, H, W = x.shape
|
95 |
+
|
96 |
+
dim = int(C / 2)
|
97 |
+
x1 = x[:, :dim, :, :]
|
98 |
+
x2 = x[:, dim:, :, :]
|
99 |
+
|
100 |
+
_assert(H == self.pos_embed.height, "")
|
101 |
+
_assert(W == self.pos_embed.width, "")
|
102 |
+
|
103 |
+
q_f, k_f, v_f = self.get_qkv(x1, self.qkv_f)
|
104 |
+
q_p, k_p, v_p = self.get_qkv(x2, self.qkv_p)
|
105 |
+
|
106 |
+
# person to face
|
107 |
+
out_f = self.apply_attn(q_f, k_p, v_p, B, H, W)
|
108 |
+
# face to person
|
109 |
+
out_p = self.apply_attn(q_p, k_f, v_f, B, H, W)
|
110 |
+
|
111 |
+
x_pf = torch.cat((out_f, out_p), dim=1) # B, dim_out * 2, H, W
|
112 |
+
x_pf = self.norm(x_pf)
|
113 |
+
x_pf = self.mlp(x_pf) # B, dim_out, H, W
|
114 |
+
|
115 |
+
out = self.pool(x_pf)
|
116 |
+
return out
|