Ravi21 commited on
Commit
e832084
·
verified ·
1 Parent(s): 6e858ba

Upload 7 files

Browse files
models/__pycache__/afwm.cpython-310.pyc ADDED
Binary file (6.52 kB). View file
 
models/__pycache__/networks.cpython-310.pyc ADDED
Binary file (4.98 kB). View file
 
models/afwm.py ADDED
@@ -0,0 +1,198 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from .correlation import correlation
5
+
6
+ def apply_offset(offset):
7
+
8
+ sizes = list(offset.size()[2:])
9
+ grid_list = torch.meshgrid([torch.arange(size, device=offset.device) for size in sizes])
10
+ grid_list = reversed(grid_list)
11
+
12
+ grid_list = [grid.float().unsqueeze(0) + offset[:, dim, ...]
13
+ for dim, grid in enumerate(grid_list)]
14
+
15
+ grid_list = [grid / ((size - 1.0) / 2.0) - 1.0
16
+ for grid, size in zip(grid_list, reversed(sizes))]
17
+
18
+ return torch.stack(grid_list, dim=-1)
19
+
20
+
21
+ class ResBlock(nn.Module):
22
+ def __init__(self, in_channels):
23
+ super(ResBlock, self).__init__()
24
+ self.block = nn.Sequential(
25
+ nn.BatchNorm2d(in_channels),
26
+ nn.ReLU(inplace=True),
27
+ nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1, bias=False),
28
+ nn.BatchNorm2d(in_channels),
29
+ nn.ReLU(inplace=True),
30
+ nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1, bias=False)
31
+ )
32
+
33
+ def forward(self, x):
34
+ return self.block(x) + x
35
+
36
+
37
+ class DownSample(nn.Module):
38
+ def __init__(self, in_channels, out_channels):
39
+ super(DownSample, self).__init__()
40
+ self.block= nn.Sequential(
41
+ nn.BatchNorm2d(in_channels),
42
+ nn.ReLU(inplace=True),
43
+ nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=2, padding=1, bias=False)
44
+ )
45
+
46
+ def forward(self, x):
47
+ return self.block(x)
48
+
49
+
50
+
51
+ class FeatureEncoder(nn.Module):
52
+ def __init__(self, in_channels, chns=[64,128,256,256,256]):
53
+ super(FeatureEncoder, self).__init__()
54
+ self.encoders = []
55
+ for i, out_chns in enumerate(chns):
56
+ if i == 0:
57
+ encoder = nn.Sequential(DownSample(in_channels, out_chns),
58
+ ResBlock(out_chns),
59
+ ResBlock(out_chns))
60
+ else:
61
+ encoder = nn.Sequential(DownSample(chns[i-1], out_chns),
62
+ ResBlock(out_chns),
63
+ ResBlock(out_chns))
64
+
65
+ self.encoders.append(encoder)
66
+
67
+ self.encoders = nn.ModuleList(self.encoders)
68
+
69
+
70
+ def forward(self, x):
71
+ encoder_features = []
72
+ for encoder in self.encoders:
73
+ x = encoder(x)
74
+ encoder_features.append(x)
75
+ return encoder_features
76
+
77
+ class RefinePyramid(nn.Module):
78
+ def __init__(self, chns=[64,128,256,256,256], fpn_dim=256):
79
+ super(RefinePyramid, self).__init__()
80
+ self.chns = chns
81
+
82
+ self.adaptive = []
83
+ for in_chns in list(reversed(chns)):
84
+ adaptive_layer = nn.Conv2d(in_chns, fpn_dim, kernel_size=1)
85
+ self.adaptive.append(adaptive_layer)
86
+ self.adaptive = nn.ModuleList(self.adaptive)
87
+
88
+ self.smooth = []
89
+ for i in range(len(chns)):
90
+ smooth_layer = nn.Conv2d(fpn_dim, fpn_dim, kernel_size=3, padding=1)
91
+ self.smooth.append(smooth_layer)
92
+ self.smooth = nn.ModuleList(self.smooth)
93
+
94
+ def forward(self, x):
95
+ conv_ftr_list = x
96
+
97
+ feature_list = []
98
+ last_feature = None
99
+ for i, conv_ftr in enumerate(list(reversed(conv_ftr_list))):
100
+ feature = self.adaptive[i](conv_ftr)
101
+
102
+ if last_feature is not None:
103
+ feature = feature + F.interpolate(last_feature, scale_factor=2, mode='nearest')
104
+
105
+ feature = self.smooth[i](feature)
106
+ last_feature = feature
107
+ feature_list.append(feature)
108
+
109
+ return tuple(reversed(feature_list))
110
+
111
+
112
+ class AFlowNet(nn.Module):
113
+ def __init__(self, num_pyramid, fpn_dim=256):
114
+ super(AFlowNet, self).__init__()
115
+ self.netMain = []
116
+ self.netRefine = []
117
+ for i in range(num_pyramid):
118
+ netMain_layer = torch.nn.Sequential(
119
+ torch.nn.Conv2d(in_channels=49, out_channels=128, kernel_size=3, stride=1, padding=1),
120
+ torch.nn.LeakyReLU(inplace=False, negative_slope=0.1),
121
+ torch.nn.Conv2d(in_channels=128, out_channels=64, kernel_size=3, stride=1, padding=1),
122
+ torch.nn.LeakyReLU(inplace=False, negative_slope=0.1),
123
+ torch.nn.Conv2d(in_channels=64, out_channels=32, kernel_size=3, stride=1, padding=1),
124
+ torch.nn.LeakyReLU(inplace=False, negative_slope=0.1),
125
+ torch.nn.Conv2d(in_channels=32, out_channels=2, kernel_size=3, stride=1, padding=1)
126
+ )
127
+
128
+ netRefine_layer = torch.nn.Sequential(
129
+ torch.nn.Conv2d(2 * fpn_dim, out_channels=128, kernel_size=3, stride=1, padding=1),
130
+ torch.nn.LeakyReLU(inplace=False, negative_slope=0.1),
131
+ torch.nn.Conv2d(in_channels=128, out_channels=64, kernel_size=3, stride=1, padding=1),
132
+ torch.nn.LeakyReLU(inplace=False, negative_slope=0.1),
133
+ torch.nn.Conv2d(in_channels=64, out_channels=32, kernel_size=3, stride=1, padding=1),
134
+ torch.nn.LeakyReLU(inplace=False, negative_slope=0.1),
135
+ torch.nn.Conv2d(in_channels=32, out_channels=2, kernel_size=3, stride=1, padding=1)
136
+ )
137
+ self.netMain.append(netMain_layer)
138
+ self.netRefine.append(netRefine_layer)
139
+
140
+ self.netMain = nn.ModuleList(self.netMain)
141
+ self.netRefine = nn.ModuleList(self.netRefine)
142
+
143
+
144
+ def forward(self, x, x_warps, x_conds, warp_feature=True):
145
+ last_flow = None
146
+
147
+ for i in range(len(x_warps)):
148
+ x_warp = x_warps[len(x_warps) - 1 - i]
149
+ x_cond = x_conds[len(x_warps) - 1 - i]
150
+
151
+ if last_flow is not None and warp_feature:
152
+ x_warp_after = F.grid_sample(x_warp, last_flow.detach().permute(0, 2, 3, 1),
153
+ mode='bilinear', padding_mode='border')
154
+ else:
155
+ x_warp_after = x_warp
156
+
157
+ tenCorrelation = F.leaky_relu(input=correlation.FunctionCorrelation(tenFirst=x_warp_after, tenSecond=x_cond, intStride=1), negative_slope=0.1, inplace=False)
158
+ flow = self.netMain[i](tenCorrelation)
159
+ flow = apply_offset(flow)
160
+
161
+ if last_flow is not None:
162
+ flow = F.grid_sample(last_flow, flow, mode='bilinear', padding_mode='border')
163
+ else:
164
+ flow = flow.permute(0, 3, 1, 2)
165
+
166
+ last_flow = flow
167
+ x_warp = F.grid_sample(x_warp, flow.permute(0, 2, 3, 1),mode='bilinear', padding_mode='border')
168
+ concat = torch.cat([x_warp,x_cond],1)
169
+ flow = self.netRefine[i](concat)
170
+ flow = apply_offset(flow)
171
+ flow = F.grid_sample(last_flow, flow, mode='bilinear', padding_mode='border')
172
+
173
+ last_flow = F.interpolate(flow, scale_factor=2, mode='bilinear')
174
+
175
+ x_warp = F.grid_sample(x, last_flow.permute(0, 2, 3, 1),
176
+ mode='bilinear', padding_mode='border')
177
+ return x_warp, last_flow,
178
+
179
+
180
+ class AFWM(nn.Module):
181
+
182
+ def __init__(self, opt, input_nc):
183
+ super(AFWM, self).__init__()
184
+ num_filters = [64,128,256,256,256]
185
+ self.image_features = FeatureEncoder(3, num_filters)
186
+ self.cond_features = FeatureEncoder(input_nc, num_filters)
187
+ self.image_FPN = RefinePyramid(num_filters)
188
+ self.cond_FPN = RefinePyramid(num_filters)
189
+ self.aflow_net = AFlowNet(len(num_filters))
190
+
191
+ def forward(self, cond_input, image_input):
192
+ cond_pyramids = self.cond_FPN(self.cond_features(cond_input)) # maybe use nn.Sequential
193
+ image_pyramids = self.image_FPN(self.image_features(image_input))
194
+
195
+ x_warp, last_flow = self.aflow_net(image_input, image_pyramids, cond_pyramids)
196
+
197
+ return x_warp, last_flow
198
+
models/correlation/README.md ADDED
@@ -0,0 +1 @@
 
 
1
+ This is an adaptation of the <a href="https://github.com/lmb-freiburg/flownet2">FlowNet2 implementation</a> in order to compute cost volumes. Should you be making use of this work, please make sure to adhere to the <a href="https://github.com/lmb-freiburg/flownet2#license-and-citation">licensing terms</a> of the original authors. Should you be making use or modify this particular implementation, please acknowledge it appropriately.
models/correlation/__pycache__/correlation.cpython-310.pyc ADDED
Binary file (13.7 kB). View file
 
models/correlation/correlation.py ADDED
@@ -0,0 +1,405 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+
3
+ import torch
4
+
5
+ import cupy
6
+ import math
7
+ import re
8
+
9
+ kernel_Correlation_rearrange = '''
10
+ extern "C" __global__ void kernel_Correlation_rearrange(
11
+ const int n,
12
+ const float* input,
13
+ float* output
14
+ ) {
15
+ int intIndex = (blockIdx.x * blockDim.x) + threadIdx.x;
16
+
17
+ if (intIndex >= n) {
18
+ return;
19
+ }
20
+
21
+ int intSample = blockIdx.z;
22
+ int intChannel = blockIdx.y;
23
+
24
+ float fltValue = input[(((intSample * SIZE_1(input)) + intChannel) * SIZE_2(input) * SIZE_3(input)) + intIndex];
25
+
26
+ __syncthreads();
27
+
28
+ int intPaddedY = (intIndex / SIZE_3(input)) + 3*{{intStride}};
29
+ int intPaddedX = (intIndex % SIZE_3(input)) + 3*{{intStride}};
30
+ int intRearrange = ((SIZE_3(input) + 6*{{intStride}}) * intPaddedY) + intPaddedX;
31
+
32
+ output[(((intSample * SIZE_1(output) * SIZE_2(output)) + intRearrange) * SIZE_1(input)) + intChannel] = fltValue;
33
+ }
34
+ '''
35
+
36
+ kernel_Correlation_updateOutput = '''
37
+ extern "C" __global__ void kernel_Correlation_updateOutput(
38
+ const int n,
39
+ const float* rbot0,
40
+ const float* rbot1,
41
+ float* top
42
+ ) {
43
+ extern __shared__ char patch_data_char[];
44
+
45
+ float *patch_data = (float *)patch_data_char;
46
+
47
+ // First (upper left) position of kernel upper-left corner in current center position of neighborhood in image 1
48
+ int x1 = (blockIdx.x + 3) * {{intStride}};
49
+ int y1 = (blockIdx.y + 3) * {{intStride}};
50
+ int item = blockIdx.z;
51
+ int ch_off = threadIdx.x;
52
+
53
+ // Load 3D patch into shared shared memory
54
+ for (int j = 0; j < 1; j++) { // HEIGHT
55
+ for (int i = 0; i < 1; i++) { // WIDTH
56
+ int ji_off = (j + i) * SIZE_3(rbot0);
57
+ for (int ch = ch_off; ch < SIZE_3(rbot0); ch += 32) { // CHANNELS
58
+ int idx1 = ((item * SIZE_1(rbot0) + y1+j) * SIZE_2(rbot0) + x1+i) * SIZE_3(rbot0) + ch;
59
+ int idxPatchData = ji_off + ch;
60
+ patch_data[idxPatchData] = rbot0[idx1];
61
+ }
62
+ }
63
+ }
64
+
65
+ __syncthreads();
66
+
67
+ __shared__ float sum[32];
68
+
69
+ // Compute correlation
70
+ for (int top_channel = 0; top_channel < SIZE_1(top); top_channel++) {
71
+ sum[ch_off] = 0;
72
+
73
+ int s2o = (top_channel % 7 - 3) * {{intStride}};
74
+ int s2p = (top_channel / 7 - 3) * {{intStride}};
75
+
76
+ for (int j = 0; j < 1; j++) { // HEIGHT
77
+ for (int i = 0; i < 1; i++) { // WIDTH
78
+ int ji_off = (j + i) * SIZE_3(rbot0);
79
+ for (int ch = ch_off; ch < SIZE_3(rbot0); ch += 32) { // CHANNELS
80
+ int x2 = x1 + s2o;
81
+ int y2 = y1 + s2p;
82
+
83
+ int idxPatchData = ji_off + ch;
84
+ int idx2 = ((item * SIZE_1(rbot0) + y2+j) * SIZE_2(rbot0) + x2+i) * SIZE_3(rbot0) + ch;
85
+
86
+ sum[ch_off] += patch_data[idxPatchData] * rbot1[idx2];
87
+ }
88
+ }
89
+ }
90
+
91
+ __syncthreads();
92
+
93
+ if (ch_off == 0) {
94
+ float total_sum = 0;
95
+ for (int idx = 0; idx < 32; idx++) {
96
+ total_sum += sum[idx];
97
+ }
98
+ const int sumelems = SIZE_3(rbot0);
99
+ const int index = ((top_channel*SIZE_2(top) + blockIdx.y)*SIZE_3(top))+blockIdx.x;
100
+ top[index + item*SIZE_1(top)*SIZE_2(top)*SIZE_3(top)] = total_sum / (float)sumelems;
101
+ }
102
+ }
103
+ }
104
+ '''
105
+
106
+ kernel_Correlation_updateGradFirst = '''
107
+ #define ROUND_OFF 50000
108
+
109
+ extern "C" __global__ void kernel_Correlation_updateGradFirst(
110
+ const int n,
111
+ const int intSample,
112
+ const float* rbot0,
113
+ const float* rbot1,
114
+ const float* gradOutput,
115
+ float* gradFirst,
116
+ float* gradSecond
117
+ ) { for (int intIndex = (blockIdx.x * blockDim.x) + threadIdx.x; intIndex < n; intIndex += blockDim.x * gridDim.x) {
118
+ int n = intIndex % SIZE_1(gradFirst); // channels
119
+ int l = (intIndex / SIZE_1(gradFirst)) % SIZE_3(gradFirst) + 3*{{intStride}}; // w-pos
120
+ int m = (intIndex / SIZE_1(gradFirst) / SIZE_3(gradFirst)) % SIZE_2(gradFirst) + 3*{{intStride}}; // h-pos
121
+
122
+ // round_off is a trick to enable integer division with ceil, even for negative numbers
123
+ // We use a large offset, for the inner part not to become negative.
124
+ const int round_off = ROUND_OFF;
125
+ const int round_off_s1 = {{intStride}} * round_off;
126
+
127
+ // We add round_off before_s1 the int division and subtract round_off after it, to ensure the formula matches ceil behavior:
128
+ int xmin = (l - 3*{{intStride}} + round_off_s1 - 1) / {{intStride}} + 1 - round_off; // ceil (l - 3*{{intStride}}) / {{intStride}}
129
+ int ymin = (m - 3*{{intStride}} + round_off_s1 - 1) / {{intStride}} + 1 - round_off; // ceil (l - 3*{{intStride}}) / {{intStride}}
130
+
131
+ // Same here:
132
+ int xmax = (l - 3*{{intStride}} + round_off_s1) / {{intStride}} - round_off; // floor (l - 3*{{intStride}}) / {{intStride}}
133
+ int ymax = (m - 3*{{intStride}} + round_off_s1) / {{intStride}} - round_off; // floor (m - 3*{{intStride}}) / {{intStride}}
134
+
135
+ float sum = 0;
136
+ if (xmax>=0 && ymax>=0 && (xmin<=SIZE_3(gradOutput)-1) && (ymin<=SIZE_2(gradOutput)-1)) {
137
+ xmin = max(0,xmin);
138
+ xmax = min(SIZE_3(gradOutput)-1,xmax);
139
+
140
+ ymin = max(0,ymin);
141
+ ymax = min(SIZE_2(gradOutput)-1,ymax);
142
+
143
+ for (int p = -3; p <= 3; p++) {
144
+ for (int o = -3; o <= 3; o++) {
145
+ // Get rbot1 data:
146
+ int s2o = {{intStride}} * o;
147
+ int s2p = {{intStride}} * p;
148
+ int idxbot1 = ((intSample * SIZE_1(rbot0) + (m+s2p)) * SIZE_2(rbot0) + (l+s2o)) * SIZE_3(rbot0) + n;
149
+ float bot1tmp = rbot1[idxbot1]; // rbot1[l+s2o,m+s2p,n]
150
+
151
+ // Index offset for gradOutput in following loops:
152
+ int op = (p+3) * 7 + (o+3); // index[o,p]
153
+ int idxopoffset = (intSample * SIZE_1(gradOutput) + op);
154
+
155
+ for (int y = ymin; y <= ymax; y++) {
156
+ for (int x = xmin; x <= xmax; x++) {
157
+ int idxgradOutput = (idxopoffset * SIZE_2(gradOutput) + y) * SIZE_3(gradOutput) + x; // gradOutput[x,y,o,p]
158
+ sum += gradOutput[idxgradOutput] * bot1tmp;
159
+ }
160
+ }
161
+ }
162
+ }
163
+ }
164
+ const int sumelems = SIZE_1(gradFirst);
165
+ const int bot0index = ((n * SIZE_2(gradFirst)) + (m-3*{{intStride}})) * SIZE_3(gradFirst) + (l-3*{{intStride}});
166
+ gradFirst[bot0index + intSample*SIZE_1(gradFirst)*SIZE_2(gradFirst)*SIZE_3(gradFirst)] = sum / (float)sumelems;
167
+ } }
168
+ '''
169
+
170
+ kernel_Correlation_updateGradSecond = '''
171
+ #define ROUND_OFF 50000
172
+
173
+ extern "C" __global__ void kernel_Correlation_updateGradSecond(
174
+ const int n,
175
+ const int intSample,
176
+ const float* rbot0,
177
+ const float* rbot1,
178
+ const float* gradOutput,
179
+ float* gradFirst,
180
+ float* gradSecond
181
+ ) { for (int intIndex = (blockIdx.x * blockDim.x) + threadIdx.x; intIndex < n; intIndex += blockDim.x * gridDim.x) {
182
+ int n = intIndex % SIZE_1(gradSecond); // channels
183
+ int l = (intIndex / SIZE_1(gradSecond)) % SIZE_3(gradSecond) + 3*{{intStride}}; // w-pos
184
+ int m = (intIndex / SIZE_1(gradSecond) / SIZE_3(gradSecond)) % SIZE_2(gradSecond) + 3*{{intStride}}; // h-pos
185
+
186
+ // round_off is a trick to enable integer division with ceil, even for negative numbers
187
+ // We use a large offset, for the inner part not to become negative.
188
+ const int round_off = ROUND_OFF;
189
+ const int round_off_s1 = {{intStride}} * round_off;
190
+
191
+ float sum = 0;
192
+ for (int p = -3; p <= 3; p++) {
193
+ for (int o = -3; o <= 3; o++) {
194
+ int s2o = {{intStride}} * o;
195
+ int s2p = {{intStride}} * p;
196
+
197
+ //Get X,Y ranges and clamp
198
+ // We add round_off before_s1 the int division and subtract round_off after it, to ensure the formula matches ceil behavior:
199
+ int xmin = (l - 3*{{intStride}} - s2o + round_off_s1 - 1) / {{intStride}} + 1 - round_off; // ceil (l - 3*{{intStride}} - s2o) / {{intStride}}
200
+ int ymin = (m - 3*{{intStride}} - s2p + round_off_s1 - 1) / {{intStride}} + 1 - round_off; // ceil (l - 3*{{intStride}} - s2o) / {{intStride}}
201
+
202
+ // Same here:
203
+ int xmax = (l - 3*{{intStride}} - s2o + round_off_s1) / {{intStride}} - round_off; // floor (l - 3*{{intStride}} - s2o) / {{intStride}}
204
+ int ymax = (m - 3*{{intStride}} - s2p + round_off_s1) / {{intStride}} - round_off; // floor (m - 3*{{intStride}} - s2p) / {{intStride}}
205
+
206
+ if (xmax>=0 && ymax>=0 && (xmin<=SIZE_3(gradOutput)-1) && (ymin<=SIZE_2(gradOutput)-1)) {
207
+ xmin = max(0,xmin);
208
+ xmax = min(SIZE_3(gradOutput)-1,xmax);
209
+
210
+ ymin = max(0,ymin);
211
+ ymax = min(SIZE_2(gradOutput)-1,ymax);
212
+
213
+ // Get rbot0 data:
214
+ int idxbot0 = ((intSample * SIZE_1(rbot0) + (m-s2p)) * SIZE_2(rbot0) + (l-s2o)) * SIZE_3(rbot0) + n;
215
+ float bot0tmp = rbot0[idxbot0]; // rbot1[l+s2o,m+s2p,n]
216
+
217
+ // Index offset for gradOutput in following loops:
218
+ int op = (p+3) * 7 + (o+3); // index[o,p]
219
+ int idxopoffset = (intSample * SIZE_1(gradOutput) + op);
220
+
221
+ for (int y = ymin; y <= ymax; y++) {
222
+ for (int x = xmin; x <= xmax; x++) {
223
+ int idxgradOutput = (idxopoffset * SIZE_2(gradOutput) + y) * SIZE_3(gradOutput) + x; // gradOutput[x,y,o,p]
224
+ sum += gradOutput[idxgradOutput] * bot0tmp;
225
+ }
226
+ }
227
+ }
228
+ }
229
+ }
230
+ const int sumelems = SIZE_1(gradSecond);
231
+ const int bot1index = ((n * SIZE_2(gradSecond)) + (m-3*{{intStride}})) * SIZE_3(gradSecond) + (l-3*{{intStride}});
232
+ gradSecond[bot1index + intSample*SIZE_1(gradSecond)*SIZE_2(gradSecond)*SIZE_3(gradSecond)] = sum / (float)sumelems;
233
+ } }
234
+ '''
235
+
236
+ def cupy_kernel(strFunction, objVariables):
237
+ strKernel = globals()[strFunction].replace('{{intStride}}', str(objVariables['intStride']))
238
+
239
+ while True:
240
+ objMatch = re.search('(SIZE_)([0-4])(\()([^\)]*)(\))', strKernel)
241
+
242
+ if objMatch is None:
243
+ break
244
+ # end
245
+
246
+ intArg = int(objMatch.group(2))
247
+
248
+ strTensor = objMatch.group(4)
249
+ intSizes = objVariables[strTensor].size()
250
+
251
+ strKernel = strKernel.replace(objMatch.group(), str(intSizes[intArg]))
252
+ # end
253
+
254
+ while True:
255
+ objMatch = re.search('(VALUE_)([0-4])(\()([^\)]+)(\))', strKernel)
256
+
257
+ if objMatch is None:
258
+ break
259
+ # end
260
+
261
+ intArgs = int(objMatch.group(2))
262
+ strArgs = objMatch.group(4).split(',')
263
+
264
+ strTensor = strArgs[0]
265
+ intStrides = objVariables[strTensor].stride()
266
+ strIndex = [ '((' + strArgs[intArg + 1].replace('{', '(').replace('}', ')').strip() + ')*' + str(intStrides[intArg]) + ')' for intArg in range(intArgs) ]
267
+
268
+ strKernel = strKernel.replace(objMatch.group(0), strTensor + '[' + str.join('+', strIndex) + ']')
269
+ # end
270
+
271
+ return strKernel
272
+ # end
273
+
274
+ @cupy.util.memoize(for_each_device=True)
275
+ def cupy_launch(strFunction, strKernel):
276
+ return cupy.cuda.compile_with_cache(strKernel).get_function(strFunction)
277
+ # end
278
+
279
+ class _FunctionCorrelation(torch.autograd.Function):
280
+ @staticmethod
281
+ def forward(self, first, second, intStride):
282
+ rbot0 = first.new_zeros([ first.shape[0], first.shape[2] + (6 * intStride), first.shape[3] + (6 * intStride), first.shape[1] ])
283
+ rbot1 = first.new_zeros([ first.shape[0], first.shape[2] + (6 * intStride), first.shape[3] + (6 * intStride), first.shape[1] ])
284
+
285
+ self.save_for_backward(first, second, rbot0, rbot1)
286
+
287
+ self.intStride = intStride
288
+
289
+ assert(first.is_contiguous() == True)
290
+ assert(second.is_contiguous() == True)
291
+
292
+ output = first.new_zeros([ first.shape[0], 49, int(math.ceil(first.shape[2] / intStride)), int(math.ceil(first.shape[3] / intStride)) ])
293
+
294
+ if first.is_cuda == True:
295
+ n = first.shape[2] * first.shape[3]
296
+ cupy_launch('kernel_Correlation_rearrange', cupy_kernel('kernel_Correlation_rearrange', {
297
+ 'intStride': self.intStride,
298
+ 'input': first,
299
+ 'output': rbot0
300
+ }))(
301
+ grid=tuple([ int((n + 16 - 1) / 16), first.shape[1], first.shape[0] ]),
302
+ block=tuple([ 16, 1, 1 ]),
303
+ args=[ n, first.data_ptr(), rbot0.data_ptr() ]
304
+ )
305
+
306
+ n = second.shape[2] * second.shape[3]
307
+ cupy_launch('kernel_Correlation_rearrange', cupy_kernel('kernel_Correlation_rearrange', {
308
+ 'intStride': self.intStride,
309
+ 'input': second,
310
+ 'output': rbot1
311
+ }))(
312
+ grid=tuple([ int((n + 16 - 1) / 16), second.shape[1], second.shape[0] ]),
313
+ block=tuple([ 16, 1, 1 ]),
314
+ args=[ n, second.data_ptr(), rbot1.data_ptr() ]
315
+ )
316
+
317
+ n = output.shape[1] * output.shape[2] * output.shape[3]
318
+ cupy_launch('kernel_Correlation_updateOutput', cupy_kernel('kernel_Correlation_updateOutput', {
319
+ 'intStride': self.intStride,
320
+ 'rbot0': rbot0,
321
+ 'rbot1': rbot1,
322
+ 'top': output
323
+ }))(
324
+ grid=tuple([ output.shape[3], output.shape[2], output.shape[0] ]),
325
+ block=tuple([ 32, 1, 1 ]),
326
+ shared_mem=first.shape[1] * 4,
327
+ args=[ n, rbot0.data_ptr(), rbot1.data_ptr(), output.data_ptr() ]
328
+ )
329
+
330
+ elif first.is_cuda == False:
331
+ raise NotImplementedError()
332
+
333
+ # end
334
+
335
+ return output
336
+ # end
337
+
338
+ @staticmethod
339
+ def backward(self, gradOutput):
340
+ first, second, rbot0, rbot1 = self.saved_tensors
341
+
342
+ assert(gradOutput.is_contiguous() == True)
343
+
344
+ gradFirst = first.new_zeros([ first.shape[0], first.shape[1], first.shape[2], first.shape[3] ]) if self.needs_input_grad[0] == True else None
345
+ gradSecond = first.new_zeros([ first.shape[0], first.shape[1], first.shape[2], first.shape[3] ]) if self.needs_input_grad[1] == True else None
346
+
347
+ if first.is_cuda == True:#
348
+ if gradFirst is not None:
349
+ for intSample in range(first.shape[0]):
350
+ n = first.shape[1] * first.shape[2] * first.shape[3]
351
+ cupy_launch('kernel_Correlation_updateGradFirst', cupy_kernel('kernel_Correlation_updateGradFirst', {
352
+ 'intStride': self.intStride,
353
+ 'rbot0': rbot0,
354
+ 'rbot1': rbot1,
355
+ 'gradOutput': gradOutput,
356
+ 'gradFirst': gradFirst,
357
+ 'gradSecond': None
358
+ }))(
359
+ grid=tuple([ int((n + 512 - 1) / 512), 1, 1 ]),
360
+ block=tuple([ 512, 1, 1 ]),
361
+ args=[ n, intSample, rbot0.data_ptr(), rbot1.data_ptr(), gradOutput.data_ptr(), gradFirst.data_ptr(), None ]
362
+ )
363
+ # end
364
+ # end
365
+
366
+ if gradSecond is not None:
367
+ for intSample in range(first.shape[0]):
368
+ n = first.shape[1] * first.shape[2] * first.shape[3]
369
+ cupy_launch('kernel_Correlation_updateGradSecond', cupy_kernel('kernel_Correlation_updateGradSecond', {
370
+ 'intStride': self.intStride,
371
+ 'rbot0': rbot0,
372
+ 'rbot1': rbot1,
373
+ 'gradOutput': gradOutput,
374
+ 'gradFirst': None,
375
+ 'gradSecond': gradSecond
376
+ }))(
377
+ grid=tuple([ int((n + 512 - 1) / 512), 1, 1 ]),
378
+ block=tuple([ 512, 1, 1 ]),
379
+ args=[ n, intSample, rbot0.data_ptr(), rbot1.data_ptr(), gradOutput.data_ptr(), None, gradSecond.data_ptr() ]
380
+ )
381
+ # end
382
+ # end
383
+
384
+ elif first.is_cuda == False:
385
+ raise NotImplementedError()
386
+
387
+ # end
388
+
389
+ return gradFirst, gradSecond, None
390
+ # end
391
+ # end
392
+
393
+ def FunctionCorrelation(tenFirst, tenSecond, intStride):
394
+ return _FunctionCorrelation.apply(tenFirst, tenSecond, intStride)
395
+ # end
396
+
397
+ class ModuleCorrelation(torch.nn.Module):
398
+ def __init__(self):
399
+ super(ModuleCorrelation, self).__init__()
400
+ # end
401
+
402
+ def forward(self, tenFirst, tenSecond, intStride):
403
+ return _FunctionCorrelation.apply(tenFirst, tenSecond, intStride)
404
+ # end
405
+ # end
models/networks.py ADDED
@@ -0,0 +1,186 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.parallel
4
+ import os
5
+
6
+ class UnetSkipConnectionBlock(nn.Module):
7
+ def __init__(self, outer_nc, inner_nc, input_nc=None,
8
+ submodule=None, outermost=False, innermost=False, norm_layer=nn.BatchNorm2d, use_dropout=False):
9
+ super(UnetSkipConnectionBlock, self).__init__()
10
+ self.outermost = outermost
11
+ use_bias = norm_layer == nn.InstanceNorm2d
12
+
13
+ if input_nc is None:
14
+ input_nc = outer_nc
15
+ downconv = nn.Conv2d(input_nc, inner_nc, kernel_size=4,
16
+ stride=2, padding=1, bias=use_bias)
17
+ downrelu = nn.LeakyReLU(0.2, True)
18
+ uprelu = nn.ReLU(True)
19
+ if norm_layer != None:
20
+ downnorm = norm_layer(inner_nc)
21
+ upnorm = norm_layer(outer_nc)
22
+
23
+ if outermost:
24
+ upsample = nn.Upsample(scale_factor=2, mode='bilinear')
25
+ upconv = nn.Conv2d(inner_nc * 2, outer_nc, kernel_size=3, stride=1, padding=1, bias=use_bias)
26
+ down = [downconv]
27
+ up = [uprelu, upsample, upconv]
28
+ model = down + [submodule] + up
29
+ elif innermost:
30
+ upsample = nn.Upsample(scale_factor=2, mode='bilinear')
31
+ upconv = nn.Conv2d(inner_nc, outer_nc, kernel_size=3, stride=1, padding=1, bias=use_bias)
32
+ down = [downrelu, downconv]
33
+ if norm_layer == None:
34
+ up = [uprelu, upsample, upconv]
35
+ else:
36
+ up = [uprelu, upsample, upconv, upnorm]
37
+ model = down + up
38
+ else:
39
+ upsample = nn.Upsample(scale_factor=2, mode='bilinear')
40
+ upconv = nn.Conv2d(inner_nc*2, outer_nc, kernel_size=3, stride=1, padding=1, bias=use_bias)
41
+ if norm_layer == None:
42
+ down = [downrelu, downconv]
43
+ up = [uprelu, upsample, upconv]
44
+ else:
45
+ down = [downrelu, downconv, downnorm]
46
+ up = [uprelu, upsample, upconv, upnorm]
47
+
48
+ if use_dropout:
49
+ model = down + [submodule] + up + [nn.Dropout(0.5)]
50
+ else:
51
+ model = down + [submodule] + up
52
+
53
+ self.model = nn.Sequential(*model)
54
+
55
+ def forward(self, x):
56
+ if self.outermost:
57
+ return self.model(x)
58
+ else:
59
+ return torch.cat([x, self.model(x)], 1)
60
+
61
+ class ResidualBlock(nn.Module):
62
+ def __init__(self, in_features=64, norm_layer=nn.BatchNorm2d):
63
+ super(ResidualBlock, self).__init__()
64
+ self.relu = nn.ReLU(True)
65
+ if norm_layer == None:
66
+ self.block = nn.Sequential(
67
+ nn.Conv2d(in_features, in_features, 3, 1, 1, bias=False),
68
+ nn.ReLU(inplace=True),
69
+ nn.Conv2d(in_features, in_features, 3, 1, 1, bias=False),
70
+ )
71
+ else:
72
+ self.block = nn.Sequential(
73
+ nn.Conv2d(in_features, in_features, 3, 1, 1, bias=False),
74
+ norm_layer(in_features),
75
+ nn.ReLU(inplace=True),
76
+ nn.Conv2d(in_features, in_features, 3, 1, 1, bias=False),
77
+ norm_layer(in_features)
78
+ )
79
+
80
+ def forward(self, x):
81
+ residual = x
82
+ out = self.block(x)
83
+ out += residual
84
+ out = self.relu(out)
85
+ return out
86
+
87
+ class ResUnetGenerator(nn.Module):
88
+ def __init__(self, input_nc, output_nc, num_downs, ngf=64,
89
+ norm_layer=nn.BatchNorm2d, use_dropout=False):
90
+ super(ResUnetGenerator, self).__init__()
91
+ unet_block = ResUnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=None, norm_layer=norm_layer, innermost=True)
92
+
93
+ for i in range(num_downs - 5):
94
+ unet_block = ResUnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer, use_dropout=use_dropout)
95
+ unet_block = ResUnetSkipConnectionBlock(ngf * 4, ngf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer)
96
+ unet_block = ResUnetSkipConnectionBlock(ngf * 2, ngf * 4, input_nc=None, submodule=unet_block, norm_layer=norm_layer)
97
+ unet_block = ResUnetSkipConnectionBlock(ngf, ngf * 2, input_nc=None, submodule=unet_block, norm_layer=norm_layer)
98
+ unet_block = ResUnetSkipConnectionBlock(output_nc, ngf, input_nc=input_nc, submodule=unet_block, outermost=True, norm_layer=norm_layer)
99
+
100
+ self.model = unet_block
101
+
102
+ def forward(self, input):
103
+ return self.model(input)
104
+
105
+
106
+ class ResUnetSkipConnectionBlock(nn.Module):
107
+ def __init__(self, outer_nc, inner_nc, input_nc=None,
108
+ submodule=None, outermost=False, innermost=False, norm_layer=nn.BatchNorm2d, use_dropout=False):
109
+ super(ResUnetSkipConnectionBlock, self).__init__()
110
+ self.outermost = outermost
111
+ use_bias = norm_layer == nn.InstanceNorm2d
112
+
113
+ if input_nc is None:
114
+ input_nc = outer_nc
115
+ downconv = nn.Conv2d(input_nc, inner_nc, kernel_size=3,
116
+ stride=2, padding=1, bias=use_bias)
117
+
118
+ res_downconv = [ResidualBlock(inner_nc, norm_layer), ResidualBlock(inner_nc, norm_layer)]
119
+ res_upconv = [ResidualBlock(outer_nc, norm_layer), ResidualBlock(outer_nc, norm_layer)]
120
+
121
+ downrelu = nn.ReLU(True)
122
+ uprelu = nn.ReLU(True)
123
+ if norm_layer != None:
124
+ downnorm = norm_layer(inner_nc)
125
+ upnorm = norm_layer(outer_nc)
126
+
127
+ if outermost:
128
+ upsample = nn.Upsample(scale_factor=2, mode='nearest')
129
+ upconv = nn.Conv2d(inner_nc * 2, outer_nc, kernel_size=3, stride=1, padding=1, bias=use_bias)
130
+ down = [downconv, downrelu] + res_downconv
131
+ up = [upsample, upconv]
132
+ model = down + [submodule] + up
133
+ elif innermost:
134
+ upsample = nn.Upsample(scale_factor=2, mode='nearest')
135
+ upconv = nn.Conv2d(inner_nc, outer_nc, kernel_size=3, stride=1, padding=1, bias=use_bias)
136
+ down = [downconv, downrelu] + res_downconv
137
+ if norm_layer == None:
138
+ up = [upsample, upconv, uprelu] + res_upconv
139
+ else:
140
+ up = [upsample, upconv, upnorm, uprelu] + res_upconv
141
+ model = down + up
142
+ else:
143
+ upsample = nn.Upsample(scale_factor=2, mode='nearest')
144
+ upconv = nn.Conv2d(inner_nc*2, outer_nc, kernel_size=3, stride=1, padding=1, bias=use_bias)
145
+ if norm_layer == None:
146
+ down = [downconv, downrelu] + res_downconv
147
+ up = [upsample, upconv, uprelu] + res_upconv
148
+ else:
149
+ down = [downconv, downnorm, downrelu] + res_downconv
150
+ up = [upsample, upconv, upnorm, uprelu] + res_upconv
151
+
152
+ if use_dropout:
153
+ model = down + [submodule] + up + [nn.Dropout(0.5)]
154
+ else:
155
+ model = down + [submodule] + up
156
+
157
+ self.model = nn.Sequential(*model)
158
+
159
+ def forward(self, x):
160
+ if self.outermost:
161
+ return self.model(x)
162
+ else:
163
+ return torch.cat([x, self.model(x)], 1)
164
+
165
+
166
+ def save_checkpoint(model, save_path):
167
+ if not os.path.exists(os.path.dirname(save_path)):
168
+ os.makedirs(os.path.dirname(save_path))
169
+ torch.save(model.state_dict(), save_path)
170
+
171
+
172
+ def load_checkpoint(model, checkpoint_path):
173
+
174
+ if not os.path.exists(checkpoint_path):
175
+ print('No checkpoint!')
176
+ return
177
+
178
+ checkpoint = torch.load(checkpoint_path)
179
+ checkpoint_new = model.state_dict()
180
+ for param in checkpoint_new:
181
+ checkpoint_new[param] = checkpoint[param]
182
+
183
+ model.load_state_dict(checkpoint_new)
184
+
185
+
186
+