File size: 9,417 Bytes
d4b77ac |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 |
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
class Attention(nn.Module):
"""
Compute 'Scaled Dot Product Attention
"""
def __init__(self, p=0.1):
super(Attention, self).__init__()
self.dropout = nn.Dropout(p=p)
def forward(self, query, key, value):
scores = torch.matmul(query, key.transpose(-2, -1)
) / math.sqrt(query.size(-1))
p_attn = F.softmax(scores, dim=-1)
p_attn = self.dropout(p_attn)
p_val = torch.matmul(p_attn, value)
return p_val, p_attn
class SWMHSA_depthGlobalWindowConcatLN_qkFlow_reweightFlow(nn.Module):
def __init__(self, token_size, window_size, kernel_size, d_model, flow_dModel, head, p=0.1):
super(SWMHSA_depthGlobalWindowConcatLN_qkFlow_reweightFlow, self).__init__()
self.h, self.w = token_size
self.head = head
self.window_size = window_size
self.d_model = d_model
self.flow_dModel = flow_dModel
in_channels = d_model + flow_dModel
self.query_embedding = nn.Linear(in_channels, d_model)
self.key_embedding = nn.Linear(in_channels, d_model)
self.value_embedding = nn.Linear(d_model, d_model)
self.output_linear = nn.Linear(d_model, d_model)
self.attention = Attention(p)
self.pad_l = self.pad_t = 0
self.pad_r = (self.window_size - self.w % self.window_size) % self.window_size
self.pad_b = (self.window_size - self.h % self.window_size) % self.window_size
self.new_h, self.new_w = self.h + self.pad_b, self.w + self.pad_r
self.group_h, self.group_w = self.new_h // self.window_size, self.new_w // self.window_size
self.global_extract_v = nn.Conv2d(d_model, d_model, kernel_size=kernel_size, stride=kernel_size, padding=0,
groups=d_model)
self.global_extract_k = nn.Conv2d(in_channels, in_channels, kernel_size=kernel_size, stride=kernel_size,
padding=0,
groups=in_channels)
self.q_norm = nn.LayerNorm(d_model + flow_dModel)
self.k_norm = nn.LayerNorm(d_model + flow_dModel)
self.v_norm = nn.LayerNorm(d_model)
self.reweightFlow = nn.Sequential(
nn.Linear(in_channels, flow_dModel),
nn.Sigmoid()
)
def inference(self, x, f, h, w):
pad_r = (self.window_size - w % self.window_size) % self.window_size
pad_b = (self.window_size - h % self.window_size) % self.window_size
new_h, new_w = h + pad_b, w + pad_r
group_h, group_w = new_h // self.window_size, new_w // self.window_size
bt, n, c = x.shape
cf = f.shape[2]
x = x.view(bt, h, w, c)
f = f.view(bt, h, w, cf)
if pad_r > 0 or pad_b > 0:
x = F.pad(x, (0, 0, self.pad_l, pad_r, self.pad_t, pad_b))
f = F.pad(f, (0, 0, self.pad_l, pad_r, self.pad_t, pad_b))
y = x.permute(0, 3, 1, 2)
xf = torch.cat((x, f), dim=-1)
flow_weights = self.reweightFlow(xf)
f = f * flow_weights
qk = torch.cat((x, f), dim=-1) # [b, h, w, c]
qk_c = qk.shape[-1]
# generate q
q = qk.reshape(bt, group_h, self.window_size, group_w, self.window_size, qk_c).transpose(2, 3)
q = q.reshape(bt, group_h * group_w, self.window_size * self.window_size, qk_c)
# generate k
ky = qk.permute(0, 3, 1, 2) # [b, c, h, w]
k_global = self.global_extract_k(ky)
k_global = k_global.permute(0, 2, 3, 1).reshape(bt, -1, qk_c).unsqueeze(1).repeat(1, group_h * group_w, 1, 1)
k = torch.cat((q, k_global), dim=2)
# norm q and k
q = self.q_norm(q)
k = self.k_norm(k)
# generate v
global_tokens = self.global_extract_v(y) # [bt, c, h', w']
global_tokens = global_tokens.permute(0, 2, 3, 1).reshape(bt, -1, c).unsqueeze(1).repeat(1,
group_h * group_w,
1,
1) # [bt, gh * gw, h'*w', c]
x = x.reshape(bt, group_h, self.window_size, group_w, self.window_size, c).transpose(2,
3) # [bt, gh, gw, ws, ws, c]
x = x.reshape(bt, group_h * group_w, self.window_size * self.window_size, c) # [bt, gh * gw, ws^2, c]
v = torch.cat((x, global_tokens), dim=2)
v = self.v_norm(v)
query = self.query_embedding(q) # [bt, self.group_h, self.group_w, self.window_size, self.window_size, c]
key = self.key_embedding(k)
value = self.value_embedding(v)
query = query.reshape(bt, group_h * group_w, self.window_size * self.window_size, self.head,
c // self.head).permute(0, 1, 3, 2, 4)
key = key.reshape(bt, group_h * group_w, -1, self.head,
c // self.head).permute(0, 1, 3, 2, 4)
value = value.reshape(bt, group_h * group_w, -1, self.head,
c // self.head).permute(0, 1, 3, 2, 4)
attn, _ = self.attention(query, key, value)
x = attn.transpose(2, 3).reshape(bt, group_h, group_w, self.window_size, self.window_size, c)
x = x.transpose(2, 3).reshape(bt, group_h * self.window_size, group_w * self.window_size, c)
if pad_r > 0 or pad_b > 0:
x = x[:, :h, :w, :].contiguous()
x = x.reshape(bt, n, c)
output = self.output_linear(x)
return output
def forward(self, x, f, t, h=0, w=0):
if h != 0 or w != 0:
return self.inference(x, f, h, w)
bt, n, c = x.shape
cf = f.shape[2]
x = x.view(bt, self.h, self.w, c)
f = f.view(bt, self.h, self.w, cf)
if self.pad_r > 0 or self.pad_b > 0:
x = F.pad(x, (0, 0, self.pad_l, self.pad_r, self.pad_t, self.pad_b))
f = F.pad(f, (0, 0, self.pad_l, self.pad_r, self.pad_t, self.pad_b)) # [bt, cf, h, w]
y = x.permute(0, 3, 1, 2)
xf = torch.cat((x, f), dim=-1)
weights = self.reweightFlow(xf)
f = f * weights
qk = torch.cat((x, f), dim=-1) # [b, h, w, c]
qk_c = qk.shape[-1]
# generate q
q = qk.reshape(bt, self.group_h, self.window_size, self.group_w, self.window_size, qk_c).transpose(2, 3)
q = q.reshape(bt, self.group_h * self.group_w, self.window_size * self.window_size, qk_c)
# generate k
ky = qk.permute(0, 3, 1, 2) # [b, c, h, w]
k_global = self.global_extract_k(ky) # [b, qk_c, h, w]
k_global = k_global.permute(0, 2, 3, 1).reshape(bt, -1, qk_c).unsqueeze(1).repeat(1,
self.group_h * self.group_w,
1, 1)
k = torch.cat((q, k_global), dim=2)
# norm q and k
q = self.q_norm(q)
k = self.k_norm(k)
# generate v
global_tokens = self.global_extract_v(y) # [bt, c, h', w']
global_tokens = global_tokens.permute(0, 2, 3, 1).reshape(bt, -1, c).unsqueeze(1).repeat(1,
self.group_h * self.group_w,
1,
1) # [bt, gh * gw, h'*w', c]
x = x.reshape(bt, self.group_h, self.window_size, self.group_w, self.window_size, c).transpose(2,
3) # [bt, gh, gw, ws, ws, c]
x = x.reshape(bt, self.group_h * self.group_w, self.window_size * self.window_size, c) # [bt, gh * gw, ws^2, c]
v = torch.cat((x, global_tokens), dim=2)
v = self.v_norm(v)
query = self.query_embedding(q) # [bt, self.group_h, self.group_w, self.window_size, self.window_size, c]
key = self.key_embedding(k)
value = self.value_embedding(v)
query = query.reshape(bt, self.group_h * self.group_w, self.window_size * self.window_size, self.head,
c // self.head).permute(0, 1, 3, 2, 4)
key = key.reshape(bt, self.group_h * self.group_w, -1, self.head,
c // self.head).permute(0, 1, 3, 2, 4)
value = value.reshape(bt, self.group_h * self.group_w, -1, self.head,
c // self.head).permute(0, 1, 3, 2, 4)
attn, _ = self.attention(query, key, value)
x = attn.transpose(2, 3).reshape(bt, self.group_h, self.group_w, self.window_size, self.window_size, c)
x = x.transpose(2, 3).reshape(bt, self.group_h * self.window_size, self.group_w * self.window_size, c)
if self.pad_r > 0 or self.pad_b > 0:
x = x[:, :self.h, :self.w, :].contiguous()
x = x.reshape(bt, n, c)
output = self.output_linear(x)
return output
|