Upload 7 files
Browse files
models/__pycache__/afwm.cpython-310.pyc
ADDED
Binary file (6.52 kB). View file
|
|
models/__pycache__/networks.cpython-310.pyc
ADDED
Binary file (4.98 kB). View file
|
|
models/afwm.py
ADDED
@@ -0,0 +1,198 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
from .correlation import correlation
|
5 |
+
|
6 |
+
def apply_offset(offset):
|
7 |
+
|
8 |
+
sizes = list(offset.size()[2:])
|
9 |
+
grid_list = torch.meshgrid([torch.arange(size, device=offset.device) for size in sizes])
|
10 |
+
grid_list = reversed(grid_list)
|
11 |
+
|
12 |
+
grid_list = [grid.float().unsqueeze(0) + offset[:, dim, ...]
|
13 |
+
for dim, grid in enumerate(grid_list)]
|
14 |
+
|
15 |
+
grid_list = [grid / ((size - 1.0) / 2.0) - 1.0
|
16 |
+
for grid, size in zip(grid_list, reversed(sizes))]
|
17 |
+
|
18 |
+
return torch.stack(grid_list, dim=-1)
|
19 |
+
|
20 |
+
|
21 |
+
class ResBlock(nn.Module):
|
22 |
+
def __init__(self, in_channels):
|
23 |
+
super(ResBlock, self).__init__()
|
24 |
+
self.block = nn.Sequential(
|
25 |
+
nn.BatchNorm2d(in_channels),
|
26 |
+
nn.ReLU(inplace=True),
|
27 |
+
nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1, bias=False),
|
28 |
+
nn.BatchNorm2d(in_channels),
|
29 |
+
nn.ReLU(inplace=True),
|
30 |
+
nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1, bias=False)
|
31 |
+
)
|
32 |
+
|
33 |
+
def forward(self, x):
|
34 |
+
return self.block(x) + x
|
35 |
+
|
36 |
+
|
37 |
+
class DownSample(nn.Module):
|
38 |
+
def __init__(self, in_channels, out_channels):
|
39 |
+
super(DownSample, self).__init__()
|
40 |
+
self.block= nn.Sequential(
|
41 |
+
nn.BatchNorm2d(in_channels),
|
42 |
+
nn.ReLU(inplace=True),
|
43 |
+
nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=2, padding=1, bias=False)
|
44 |
+
)
|
45 |
+
|
46 |
+
def forward(self, x):
|
47 |
+
return self.block(x)
|
48 |
+
|
49 |
+
|
50 |
+
|
51 |
+
class FeatureEncoder(nn.Module):
|
52 |
+
def __init__(self, in_channels, chns=[64,128,256,256,256]):
|
53 |
+
super(FeatureEncoder, self).__init__()
|
54 |
+
self.encoders = []
|
55 |
+
for i, out_chns in enumerate(chns):
|
56 |
+
if i == 0:
|
57 |
+
encoder = nn.Sequential(DownSample(in_channels, out_chns),
|
58 |
+
ResBlock(out_chns),
|
59 |
+
ResBlock(out_chns))
|
60 |
+
else:
|
61 |
+
encoder = nn.Sequential(DownSample(chns[i-1], out_chns),
|
62 |
+
ResBlock(out_chns),
|
63 |
+
ResBlock(out_chns))
|
64 |
+
|
65 |
+
self.encoders.append(encoder)
|
66 |
+
|
67 |
+
self.encoders = nn.ModuleList(self.encoders)
|
68 |
+
|
69 |
+
|
70 |
+
def forward(self, x):
|
71 |
+
encoder_features = []
|
72 |
+
for encoder in self.encoders:
|
73 |
+
x = encoder(x)
|
74 |
+
encoder_features.append(x)
|
75 |
+
return encoder_features
|
76 |
+
|
77 |
+
class RefinePyramid(nn.Module):
|
78 |
+
def __init__(self, chns=[64,128,256,256,256], fpn_dim=256):
|
79 |
+
super(RefinePyramid, self).__init__()
|
80 |
+
self.chns = chns
|
81 |
+
|
82 |
+
self.adaptive = []
|
83 |
+
for in_chns in list(reversed(chns)):
|
84 |
+
adaptive_layer = nn.Conv2d(in_chns, fpn_dim, kernel_size=1)
|
85 |
+
self.adaptive.append(adaptive_layer)
|
86 |
+
self.adaptive = nn.ModuleList(self.adaptive)
|
87 |
+
|
88 |
+
self.smooth = []
|
89 |
+
for i in range(len(chns)):
|
90 |
+
smooth_layer = nn.Conv2d(fpn_dim, fpn_dim, kernel_size=3, padding=1)
|
91 |
+
self.smooth.append(smooth_layer)
|
92 |
+
self.smooth = nn.ModuleList(self.smooth)
|
93 |
+
|
94 |
+
def forward(self, x):
|
95 |
+
conv_ftr_list = x
|
96 |
+
|
97 |
+
feature_list = []
|
98 |
+
last_feature = None
|
99 |
+
for i, conv_ftr in enumerate(list(reversed(conv_ftr_list))):
|
100 |
+
feature = self.adaptive[i](conv_ftr)
|
101 |
+
|
102 |
+
if last_feature is not None:
|
103 |
+
feature = feature + F.interpolate(last_feature, scale_factor=2, mode='nearest')
|
104 |
+
|
105 |
+
feature = self.smooth[i](feature)
|
106 |
+
last_feature = feature
|
107 |
+
feature_list.append(feature)
|
108 |
+
|
109 |
+
return tuple(reversed(feature_list))
|
110 |
+
|
111 |
+
|
112 |
+
class AFlowNet(nn.Module):
|
113 |
+
def __init__(self, num_pyramid, fpn_dim=256):
|
114 |
+
super(AFlowNet, self).__init__()
|
115 |
+
self.netMain = []
|
116 |
+
self.netRefine = []
|
117 |
+
for i in range(num_pyramid):
|
118 |
+
netMain_layer = torch.nn.Sequential(
|
119 |
+
torch.nn.Conv2d(in_channels=49, out_channels=128, kernel_size=3, stride=1, padding=1),
|
120 |
+
torch.nn.LeakyReLU(inplace=False, negative_slope=0.1),
|
121 |
+
torch.nn.Conv2d(in_channels=128, out_channels=64, kernel_size=3, stride=1, padding=1),
|
122 |
+
torch.nn.LeakyReLU(inplace=False, negative_slope=0.1),
|
123 |
+
torch.nn.Conv2d(in_channels=64, out_channels=32, kernel_size=3, stride=1, padding=1),
|
124 |
+
torch.nn.LeakyReLU(inplace=False, negative_slope=0.1),
|
125 |
+
torch.nn.Conv2d(in_channels=32, out_channels=2, kernel_size=3, stride=1, padding=1)
|
126 |
+
)
|
127 |
+
|
128 |
+
netRefine_layer = torch.nn.Sequential(
|
129 |
+
torch.nn.Conv2d(2 * fpn_dim, out_channels=128, kernel_size=3, stride=1, padding=1),
|
130 |
+
torch.nn.LeakyReLU(inplace=False, negative_slope=0.1),
|
131 |
+
torch.nn.Conv2d(in_channels=128, out_channels=64, kernel_size=3, stride=1, padding=1),
|
132 |
+
torch.nn.LeakyReLU(inplace=False, negative_slope=0.1),
|
133 |
+
torch.nn.Conv2d(in_channels=64, out_channels=32, kernel_size=3, stride=1, padding=1),
|
134 |
+
torch.nn.LeakyReLU(inplace=False, negative_slope=0.1),
|
135 |
+
torch.nn.Conv2d(in_channels=32, out_channels=2, kernel_size=3, stride=1, padding=1)
|
136 |
+
)
|
137 |
+
self.netMain.append(netMain_layer)
|
138 |
+
self.netRefine.append(netRefine_layer)
|
139 |
+
|
140 |
+
self.netMain = nn.ModuleList(self.netMain)
|
141 |
+
self.netRefine = nn.ModuleList(self.netRefine)
|
142 |
+
|
143 |
+
|
144 |
+
def forward(self, x, x_warps, x_conds, warp_feature=True):
|
145 |
+
last_flow = None
|
146 |
+
|
147 |
+
for i in range(len(x_warps)):
|
148 |
+
x_warp = x_warps[len(x_warps) - 1 - i]
|
149 |
+
x_cond = x_conds[len(x_warps) - 1 - i]
|
150 |
+
|
151 |
+
if last_flow is not None and warp_feature:
|
152 |
+
x_warp_after = F.grid_sample(x_warp, last_flow.detach().permute(0, 2, 3, 1),
|
153 |
+
mode='bilinear', padding_mode='border')
|
154 |
+
else:
|
155 |
+
x_warp_after = x_warp
|
156 |
+
|
157 |
+
tenCorrelation = F.leaky_relu(input=correlation.FunctionCorrelation(tenFirst=x_warp_after, tenSecond=x_cond, intStride=1), negative_slope=0.1, inplace=False)
|
158 |
+
flow = self.netMain[i](tenCorrelation)
|
159 |
+
flow = apply_offset(flow)
|
160 |
+
|
161 |
+
if last_flow is not None:
|
162 |
+
flow = F.grid_sample(last_flow, flow, mode='bilinear', padding_mode='border')
|
163 |
+
else:
|
164 |
+
flow = flow.permute(0, 3, 1, 2)
|
165 |
+
|
166 |
+
last_flow = flow
|
167 |
+
x_warp = F.grid_sample(x_warp, flow.permute(0, 2, 3, 1),mode='bilinear', padding_mode='border')
|
168 |
+
concat = torch.cat([x_warp,x_cond],1)
|
169 |
+
flow = self.netRefine[i](concat)
|
170 |
+
flow = apply_offset(flow)
|
171 |
+
flow = F.grid_sample(last_flow, flow, mode='bilinear', padding_mode='border')
|
172 |
+
|
173 |
+
last_flow = F.interpolate(flow, scale_factor=2, mode='bilinear')
|
174 |
+
|
175 |
+
x_warp = F.grid_sample(x, last_flow.permute(0, 2, 3, 1),
|
176 |
+
mode='bilinear', padding_mode='border')
|
177 |
+
return x_warp, last_flow,
|
178 |
+
|
179 |
+
|
180 |
+
class AFWM(nn.Module):
|
181 |
+
|
182 |
+
def __init__(self, opt, input_nc):
|
183 |
+
super(AFWM, self).__init__()
|
184 |
+
num_filters = [64,128,256,256,256]
|
185 |
+
self.image_features = FeatureEncoder(3, num_filters)
|
186 |
+
self.cond_features = FeatureEncoder(input_nc, num_filters)
|
187 |
+
self.image_FPN = RefinePyramid(num_filters)
|
188 |
+
self.cond_FPN = RefinePyramid(num_filters)
|
189 |
+
self.aflow_net = AFlowNet(len(num_filters))
|
190 |
+
|
191 |
+
def forward(self, cond_input, image_input):
|
192 |
+
cond_pyramids = self.cond_FPN(self.cond_features(cond_input)) # maybe use nn.Sequential
|
193 |
+
image_pyramids = self.image_FPN(self.image_features(image_input))
|
194 |
+
|
195 |
+
x_warp, last_flow = self.aflow_net(image_input, image_pyramids, cond_pyramids)
|
196 |
+
|
197 |
+
return x_warp, last_flow
|
198 |
+
|
models/correlation/README.md
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
This is an adaptation of the <a href="https://github.com/lmb-freiburg/flownet2">FlowNet2 implementation</a> in order to compute cost volumes. Should you be making use of this work, please make sure to adhere to the <a href="https://github.com/lmb-freiburg/flownet2#license-and-citation">licensing terms</a> of the original authors. Should you be making use or modify this particular implementation, please acknowledge it appropriately.
|
models/correlation/__pycache__/correlation.cpython-310.pyc
ADDED
Binary file (13.7 kB). View file
|
|
models/correlation/correlation.py
ADDED
@@ -0,0 +1,405 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
|
3 |
+
import torch
|
4 |
+
|
5 |
+
import cupy
|
6 |
+
import math
|
7 |
+
import re
|
8 |
+
|
9 |
+
kernel_Correlation_rearrange = '''
|
10 |
+
extern "C" __global__ void kernel_Correlation_rearrange(
|
11 |
+
const int n,
|
12 |
+
const float* input,
|
13 |
+
float* output
|
14 |
+
) {
|
15 |
+
int intIndex = (blockIdx.x * blockDim.x) + threadIdx.x;
|
16 |
+
|
17 |
+
if (intIndex >= n) {
|
18 |
+
return;
|
19 |
+
}
|
20 |
+
|
21 |
+
int intSample = blockIdx.z;
|
22 |
+
int intChannel = blockIdx.y;
|
23 |
+
|
24 |
+
float fltValue = input[(((intSample * SIZE_1(input)) + intChannel) * SIZE_2(input) * SIZE_3(input)) + intIndex];
|
25 |
+
|
26 |
+
__syncthreads();
|
27 |
+
|
28 |
+
int intPaddedY = (intIndex / SIZE_3(input)) + 3*{{intStride}};
|
29 |
+
int intPaddedX = (intIndex % SIZE_3(input)) + 3*{{intStride}};
|
30 |
+
int intRearrange = ((SIZE_3(input) + 6*{{intStride}}) * intPaddedY) + intPaddedX;
|
31 |
+
|
32 |
+
output[(((intSample * SIZE_1(output) * SIZE_2(output)) + intRearrange) * SIZE_1(input)) + intChannel] = fltValue;
|
33 |
+
}
|
34 |
+
'''
|
35 |
+
|
36 |
+
kernel_Correlation_updateOutput = '''
|
37 |
+
extern "C" __global__ void kernel_Correlation_updateOutput(
|
38 |
+
const int n,
|
39 |
+
const float* rbot0,
|
40 |
+
const float* rbot1,
|
41 |
+
float* top
|
42 |
+
) {
|
43 |
+
extern __shared__ char patch_data_char[];
|
44 |
+
|
45 |
+
float *patch_data = (float *)patch_data_char;
|
46 |
+
|
47 |
+
// First (upper left) position of kernel upper-left corner in current center position of neighborhood in image 1
|
48 |
+
int x1 = (blockIdx.x + 3) * {{intStride}};
|
49 |
+
int y1 = (blockIdx.y + 3) * {{intStride}};
|
50 |
+
int item = blockIdx.z;
|
51 |
+
int ch_off = threadIdx.x;
|
52 |
+
|
53 |
+
// Load 3D patch into shared shared memory
|
54 |
+
for (int j = 0; j < 1; j++) { // HEIGHT
|
55 |
+
for (int i = 0; i < 1; i++) { // WIDTH
|
56 |
+
int ji_off = (j + i) * SIZE_3(rbot0);
|
57 |
+
for (int ch = ch_off; ch < SIZE_3(rbot0); ch += 32) { // CHANNELS
|
58 |
+
int idx1 = ((item * SIZE_1(rbot0) + y1+j) * SIZE_2(rbot0) + x1+i) * SIZE_3(rbot0) + ch;
|
59 |
+
int idxPatchData = ji_off + ch;
|
60 |
+
patch_data[idxPatchData] = rbot0[idx1];
|
61 |
+
}
|
62 |
+
}
|
63 |
+
}
|
64 |
+
|
65 |
+
__syncthreads();
|
66 |
+
|
67 |
+
__shared__ float sum[32];
|
68 |
+
|
69 |
+
// Compute correlation
|
70 |
+
for (int top_channel = 0; top_channel < SIZE_1(top); top_channel++) {
|
71 |
+
sum[ch_off] = 0;
|
72 |
+
|
73 |
+
int s2o = (top_channel % 7 - 3) * {{intStride}};
|
74 |
+
int s2p = (top_channel / 7 - 3) * {{intStride}};
|
75 |
+
|
76 |
+
for (int j = 0; j < 1; j++) { // HEIGHT
|
77 |
+
for (int i = 0; i < 1; i++) { // WIDTH
|
78 |
+
int ji_off = (j + i) * SIZE_3(rbot0);
|
79 |
+
for (int ch = ch_off; ch < SIZE_3(rbot0); ch += 32) { // CHANNELS
|
80 |
+
int x2 = x1 + s2o;
|
81 |
+
int y2 = y1 + s2p;
|
82 |
+
|
83 |
+
int idxPatchData = ji_off + ch;
|
84 |
+
int idx2 = ((item * SIZE_1(rbot0) + y2+j) * SIZE_2(rbot0) + x2+i) * SIZE_3(rbot0) + ch;
|
85 |
+
|
86 |
+
sum[ch_off] += patch_data[idxPatchData] * rbot1[idx2];
|
87 |
+
}
|
88 |
+
}
|
89 |
+
}
|
90 |
+
|
91 |
+
__syncthreads();
|
92 |
+
|
93 |
+
if (ch_off == 0) {
|
94 |
+
float total_sum = 0;
|
95 |
+
for (int idx = 0; idx < 32; idx++) {
|
96 |
+
total_sum += sum[idx];
|
97 |
+
}
|
98 |
+
const int sumelems = SIZE_3(rbot0);
|
99 |
+
const int index = ((top_channel*SIZE_2(top) + blockIdx.y)*SIZE_3(top))+blockIdx.x;
|
100 |
+
top[index + item*SIZE_1(top)*SIZE_2(top)*SIZE_3(top)] = total_sum / (float)sumelems;
|
101 |
+
}
|
102 |
+
}
|
103 |
+
}
|
104 |
+
'''
|
105 |
+
|
106 |
+
kernel_Correlation_updateGradFirst = '''
|
107 |
+
#define ROUND_OFF 50000
|
108 |
+
|
109 |
+
extern "C" __global__ void kernel_Correlation_updateGradFirst(
|
110 |
+
const int n,
|
111 |
+
const int intSample,
|
112 |
+
const float* rbot0,
|
113 |
+
const float* rbot1,
|
114 |
+
const float* gradOutput,
|
115 |
+
float* gradFirst,
|
116 |
+
float* gradSecond
|
117 |
+
) { for (int intIndex = (blockIdx.x * blockDim.x) + threadIdx.x; intIndex < n; intIndex += blockDim.x * gridDim.x) {
|
118 |
+
int n = intIndex % SIZE_1(gradFirst); // channels
|
119 |
+
int l = (intIndex / SIZE_1(gradFirst)) % SIZE_3(gradFirst) + 3*{{intStride}}; // w-pos
|
120 |
+
int m = (intIndex / SIZE_1(gradFirst) / SIZE_3(gradFirst)) % SIZE_2(gradFirst) + 3*{{intStride}}; // h-pos
|
121 |
+
|
122 |
+
// round_off is a trick to enable integer division with ceil, even for negative numbers
|
123 |
+
// We use a large offset, for the inner part not to become negative.
|
124 |
+
const int round_off = ROUND_OFF;
|
125 |
+
const int round_off_s1 = {{intStride}} * round_off;
|
126 |
+
|
127 |
+
// We add round_off before_s1 the int division and subtract round_off after it, to ensure the formula matches ceil behavior:
|
128 |
+
int xmin = (l - 3*{{intStride}} + round_off_s1 - 1) / {{intStride}} + 1 - round_off; // ceil (l - 3*{{intStride}}) / {{intStride}}
|
129 |
+
int ymin = (m - 3*{{intStride}} + round_off_s1 - 1) / {{intStride}} + 1 - round_off; // ceil (l - 3*{{intStride}}) / {{intStride}}
|
130 |
+
|
131 |
+
// Same here:
|
132 |
+
int xmax = (l - 3*{{intStride}} + round_off_s1) / {{intStride}} - round_off; // floor (l - 3*{{intStride}}) / {{intStride}}
|
133 |
+
int ymax = (m - 3*{{intStride}} + round_off_s1) / {{intStride}} - round_off; // floor (m - 3*{{intStride}}) / {{intStride}}
|
134 |
+
|
135 |
+
float sum = 0;
|
136 |
+
if (xmax>=0 && ymax>=0 && (xmin<=SIZE_3(gradOutput)-1) && (ymin<=SIZE_2(gradOutput)-1)) {
|
137 |
+
xmin = max(0,xmin);
|
138 |
+
xmax = min(SIZE_3(gradOutput)-1,xmax);
|
139 |
+
|
140 |
+
ymin = max(0,ymin);
|
141 |
+
ymax = min(SIZE_2(gradOutput)-1,ymax);
|
142 |
+
|
143 |
+
for (int p = -3; p <= 3; p++) {
|
144 |
+
for (int o = -3; o <= 3; o++) {
|
145 |
+
// Get rbot1 data:
|
146 |
+
int s2o = {{intStride}} * o;
|
147 |
+
int s2p = {{intStride}} * p;
|
148 |
+
int idxbot1 = ((intSample * SIZE_1(rbot0) + (m+s2p)) * SIZE_2(rbot0) + (l+s2o)) * SIZE_3(rbot0) + n;
|
149 |
+
float bot1tmp = rbot1[idxbot1]; // rbot1[l+s2o,m+s2p,n]
|
150 |
+
|
151 |
+
// Index offset for gradOutput in following loops:
|
152 |
+
int op = (p+3) * 7 + (o+3); // index[o,p]
|
153 |
+
int idxopoffset = (intSample * SIZE_1(gradOutput) + op);
|
154 |
+
|
155 |
+
for (int y = ymin; y <= ymax; y++) {
|
156 |
+
for (int x = xmin; x <= xmax; x++) {
|
157 |
+
int idxgradOutput = (idxopoffset * SIZE_2(gradOutput) + y) * SIZE_3(gradOutput) + x; // gradOutput[x,y,o,p]
|
158 |
+
sum += gradOutput[idxgradOutput] * bot1tmp;
|
159 |
+
}
|
160 |
+
}
|
161 |
+
}
|
162 |
+
}
|
163 |
+
}
|
164 |
+
const int sumelems = SIZE_1(gradFirst);
|
165 |
+
const int bot0index = ((n * SIZE_2(gradFirst)) + (m-3*{{intStride}})) * SIZE_3(gradFirst) + (l-3*{{intStride}});
|
166 |
+
gradFirst[bot0index + intSample*SIZE_1(gradFirst)*SIZE_2(gradFirst)*SIZE_3(gradFirst)] = sum / (float)sumelems;
|
167 |
+
} }
|
168 |
+
'''
|
169 |
+
|
170 |
+
kernel_Correlation_updateGradSecond = '''
|
171 |
+
#define ROUND_OFF 50000
|
172 |
+
|
173 |
+
extern "C" __global__ void kernel_Correlation_updateGradSecond(
|
174 |
+
const int n,
|
175 |
+
const int intSample,
|
176 |
+
const float* rbot0,
|
177 |
+
const float* rbot1,
|
178 |
+
const float* gradOutput,
|
179 |
+
float* gradFirst,
|
180 |
+
float* gradSecond
|
181 |
+
) { for (int intIndex = (blockIdx.x * blockDim.x) + threadIdx.x; intIndex < n; intIndex += blockDim.x * gridDim.x) {
|
182 |
+
int n = intIndex % SIZE_1(gradSecond); // channels
|
183 |
+
int l = (intIndex / SIZE_1(gradSecond)) % SIZE_3(gradSecond) + 3*{{intStride}}; // w-pos
|
184 |
+
int m = (intIndex / SIZE_1(gradSecond) / SIZE_3(gradSecond)) % SIZE_2(gradSecond) + 3*{{intStride}}; // h-pos
|
185 |
+
|
186 |
+
// round_off is a trick to enable integer division with ceil, even for negative numbers
|
187 |
+
// We use a large offset, for the inner part not to become negative.
|
188 |
+
const int round_off = ROUND_OFF;
|
189 |
+
const int round_off_s1 = {{intStride}} * round_off;
|
190 |
+
|
191 |
+
float sum = 0;
|
192 |
+
for (int p = -3; p <= 3; p++) {
|
193 |
+
for (int o = -3; o <= 3; o++) {
|
194 |
+
int s2o = {{intStride}} * o;
|
195 |
+
int s2p = {{intStride}} * p;
|
196 |
+
|
197 |
+
//Get X,Y ranges and clamp
|
198 |
+
// We add round_off before_s1 the int division and subtract round_off after it, to ensure the formula matches ceil behavior:
|
199 |
+
int xmin = (l - 3*{{intStride}} - s2o + round_off_s1 - 1) / {{intStride}} + 1 - round_off; // ceil (l - 3*{{intStride}} - s2o) / {{intStride}}
|
200 |
+
int ymin = (m - 3*{{intStride}} - s2p + round_off_s1 - 1) / {{intStride}} + 1 - round_off; // ceil (l - 3*{{intStride}} - s2o) / {{intStride}}
|
201 |
+
|
202 |
+
// Same here:
|
203 |
+
int xmax = (l - 3*{{intStride}} - s2o + round_off_s1) / {{intStride}} - round_off; // floor (l - 3*{{intStride}} - s2o) / {{intStride}}
|
204 |
+
int ymax = (m - 3*{{intStride}} - s2p + round_off_s1) / {{intStride}} - round_off; // floor (m - 3*{{intStride}} - s2p) / {{intStride}}
|
205 |
+
|
206 |
+
if (xmax>=0 && ymax>=0 && (xmin<=SIZE_3(gradOutput)-1) && (ymin<=SIZE_2(gradOutput)-1)) {
|
207 |
+
xmin = max(0,xmin);
|
208 |
+
xmax = min(SIZE_3(gradOutput)-1,xmax);
|
209 |
+
|
210 |
+
ymin = max(0,ymin);
|
211 |
+
ymax = min(SIZE_2(gradOutput)-1,ymax);
|
212 |
+
|
213 |
+
// Get rbot0 data:
|
214 |
+
int idxbot0 = ((intSample * SIZE_1(rbot0) + (m-s2p)) * SIZE_2(rbot0) + (l-s2o)) * SIZE_3(rbot0) + n;
|
215 |
+
float bot0tmp = rbot0[idxbot0]; // rbot1[l+s2o,m+s2p,n]
|
216 |
+
|
217 |
+
// Index offset for gradOutput in following loops:
|
218 |
+
int op = (p+3) * 7 + (o+3); // index[o,p]
|
219 |
+
int idxopoffset = (intSample * SIZE_1(gradOutput) + op);
|
220 |
+
|
221 |
+
for (int y = ymin; y <= ymax; y++) {
|
222 |
+
for (int x = xmin; x <= xmax; x++) {
|
223 |
+
int idxgradOutput = (idxopoffset * SIZE_2(gradOutput) + y) * SIZE_3(gradOutput) + x; // gradOutput[x,y,o,p]
|
224 |
+
sum += gradOutput[idxgradOutput] * bot0tmp;
|
225 |
+
}
|
226 |
+
}
|
227 |
+
}
|
228 |
+
}
|
229 |
+
}
|
230 |
+
const int sumelems = SIZE_1(gradSecond);
|
231 |
+
const int bot1index = ((n * SIZE_2(gradSecond)) + (m-3*{{intStride}})) * SIZE_3(gradSecond) + (l-3*{{intStride}});
|
232 |
+
gradSecond[bot1index + intSample*SIZE_1(gradSecond)*SIZE_2(gradSecond)*SIZE_3(gradSecond)] = sum / (float)sumelems;
|
233 |
+
} }
|
234 |
+
'''
|
235 |
+
|
236 |
+
def cupy_kernel(strFunction, objVariables):
|
237 |
+
strKernel = globals()[strFunction].replace('{{intStride}}', str(objVariables['intStride']))
|
238 |
+
|
239 |
+
while True:
|
240 |
+
objMatch = re.search('(SIZE_)([0-4])(\()([^\)]*)(\))', strKernel)
|
241 |
+
|
242 |
+
if objMatch is None:
|
243 |
+
break
|
244 |
+
# end
|
245 |
+
|
246 |
+
intArg = int(objMatch.group(2))
|
247 |
+
|
248 |
+
strTensor = objMatch.group(4)
|
249 |
+
intSizes = objVariables[strTensor].size()
|
250 |
+
|
251 |
+
strKernel = strKernel.replace(objMatch.group(), str(intSizes[intArg]))
|
252 |
+
# end
|
253 |
+
|
254 |
+
while True:
|
255 |
+
objMatch = re.search('(VALUE_)([0-4])(\()([^\)]+)(\))', strKernel)
|
256 |
+
|
257 |
+
if objMatch is None:
|
258 |
+
break
|
259 |
+
# end
|
260 |
+
|
261 |
+
intArgs = int(objMatch.group(2))
|
262 |
+
strArgs = objMatch.group(4).split(',')
|
263 |
+
|
264 |
+
strTensor = strArgs[0]
|
265 |
+
intStrides = objVariables[strTensor].stride()
|
266 |
+
strIndex = [ '((' + strArgs[intArg + 1].replace('{', '(').replace('}', ')').strip() + ')*' + str(intStrides[intArg]) + ')' for intArg in range(intArgs) ]
|
267 |
+
|
268 |
+
strKernel = strKernel.replace(objMatch.group(0), strTensor + '[' + str.join('+', strIndex) + ']')
|
269 |
+
# end
|
270 |
+
|
271 |
+
return strKernel
|
272 |
+
# end
|
273 |
+
|
274 |
+
@cupy.util.memoize(for_each_device=True)
|
275 |
+
def cupy_launch(strFunction, strKernel):
|
276 |
+
return cupy.cuda.compile_with_cache(strKernel).get_function(strFunction)
|
277 |
+
# end
|
278 |
+
|
279 |
+
class _FunctionCorrelation(torch.autograd.Function):
|
280 |
+
@staticmethod
|
281 |
+
def forward(self, first, second, intStride):
|
282 |
+
rbot0 = first.new_zeros([ first.shape[0], first.shape[2] + (6 * intStride), first.shape[3] + (6 * intStride), first.shape[1] ])
|
283 |
+
rbot1 = first.new_zeros([ first.shape[0], first.shape[2] + (6 * intStride), first.shape[3] + (6 * intStride), first.shape[1] ])
|
284 |
+
|
285 |
+
self.save_for_backward(first, second, rbot0, rbot1)
|
286 |
+
|
287 |
+
self.intStride = intStride
|
288 |
+
|
289 |
+
assert(first.is_contiguous() == True)
|
290 |
+
assert(second.is_contiguous() == True)
|
291 |
+
|
292 |
+
output = first.new_zeros([ first.shape[0], 49, int(math.ceil(first.shape[2] / intStride)), int(math.ceil(first.shape[3] / intStride)) ])
|
293 |
+
|
294 |
+
if first.is_cuda == True:
|
295 |
+
n = first.shape[2] * first.shape[3]
|
296 |
+
cupy_launch('kernel_Correlation_rearrange', cupy_kernel('kernel_Correlation_rearrange', {
|
297 |
+
'intStride': self.intStride,
|
298 |
+
'input': first,
|
299 |
+
'output': rbot0
|
300 |
+
}))(
|
301 |
+
grid=tuple([ int((n + 16 - 1) / 16), first.shape[1], first.shape[0] ]),
|
302 |
+
block=tuple([ 16, 1, 1 ]),
|
303 |
+
args=[ n, first.data_ptr(), rbot0.data_ptr() ]
|
304 |
+
)
|
305 |
+
|
306 |
+
n = second.shape[2] * second.shape[3]
|
307 |
+
cupy_launch('kernel_Correlation_rearrange', cupy_kernel('kernel_Correlation_rearrange', {
|
308 |
+
'intStride': self.intStride,
|
309 |
+
'input': second,
|
310 |
+
'output': rbot1
|
311 |
+
}))(
|
312 |
+
grid=tuple([ int((n + 16 - 1) / 16), second.shape[1], second.shape[0] ]),
|
313 |
+
block=tuple([ 16, 1, 1 ]),
|
314 |
+
args=[ n, second.data_ptr(), rbot1.data_ptr() ]
|
315 |
+
)
|
316 |
+
|
317 |
+
n = output.shape[1] * output.shape[2] * output.shape[3]
|
318 |
+
cupy_launch('kernel_Correlation_updateOutput', cupy_kernel('kernel_Correlation_updateOutput', {
|
319 |
+
'intStride': self.intStride,
|
320 |
+
'rbot0': rbot0,
|
321 |
+
'rbot1': rbot1,
|
322 |
+
'top': output
|
323 |
+
}))(
|
324 |
+
grid=tuple([ output.shape[3], output.shape[2], output.shape[0] ]),
|
325 |
+
block=tuple([ 32, 1, 1 ]),
|
326 |
+
shared_mem=first.shape[1] * 4,
|
327 |
+
args=[ n, rbot0.data_ptr(), rbot1.data_ptr(), output.data_ptr() ]
|
328 |
+
)
|
329 |
+
|
330 |
+
elif first.is_cuda == False:
|
331 |
+
raise NotImplementedError()
|
332 |
+
|
333 |
+
# end
|
334 |
+
|
335 |
+
return output
|
336 |
+
# end
|
337 |
+
|
338 |
+
@staticmethod
|
339 |
+
def backward(self, gradOutput):
|
340 |
+
first, second, rbot0, rbot1 = self.saved_tensors
|
341 |
+
|
342 |
+
assert(gradOutput.is_contiguous() == True)
|
343 |
+
|
344 |
+
gradFirst = first.new_zeros([ first.shape[0], first.shape[1], first.shape[2], first.shape[3] ]) if self.needs_input_grad[0] == True else None
|
345 |
+
gradSecond = first.new_zeros([ first.shape[0], first.shape[1], first.shape[2], first.shape[3] ]) if self.needs_input_grad[1] == True else None
|
346 |
+
|
347 |
+
if first.is_cuda == True:#
|
348 |
+
if gradFirst is not None:
|
349 |
+
for intSample in range(first.shape[0]):
|
350 |
+
n = first.shape[1] * first.shape[2] * first.shape[3]
|
351 |
+
cupy_launch('kernel_Correlation_updateGradFirst', cupy_kernel('kernel_Correlation_updateGradFirst', {
|
352 |
+
'intStride': self.intStride,
|
353 |
+
'rbot0': rbot0,
|
354 |
+
'rbot1': rbot1,
|
355 |
+
'gradOutput': gradOutput,
|
356 |
+
'gradFirst': gradFirst,
|
357 |
+
'gradSecond': None
|
358 |
+
}))(
|
359 |
+
grid=tuple([ int((n + 512 - 1) / 512), 1, 1 ]),
|
360 |
+
block=tuple([ 512, 1, 1 ]),
|
361 |
+
args=[ n, intSample, rbot0.data_ptr(), rbot1.data_ptr(), gradOutput.data_ptr(), gradFirst.data_ptr(), None ]
|
362 |
+
)
|
363 |
+
# end
|
364 |
+
# end
|
365 |
+
|
366 |
+
if gradSecond is not None:
|
367 |
+
for intSample in range(first.shape[0]):
|
368 |
+
n = first.shape[1] * first.shape[2] * first.shape[3]
|
369 |
+
cupy_launch('kernel_Correlation_updateGradSecond', cupy_kernel('kernel_Correlation_updateGradSecond', {
|
370 |
+
'intStride': self.intStride,
|
371 |
+
'rbot0': rbot0,
|
372 |
+
'rbot1': rbot1,
|
373 |
+
'gradOutput': gradOutput,
|
374 |
+
'gradFirst': None,
|
375 |
+
'gradSecond': gradSecond
|
376 |
+
}))(
|
377 |
+
grid=tuple([ int((n + 512 - 1) / 512), 1, 1 ]),
|
378 |
+
block=tuple([ 512, 1, 1 ]),
|
379 |
+
args=[ n, intSample, rbot0.data_ptr(), rbot1.data_ptr(), gradOutput.data_ptr(), None, gradSecond.data_ptr() ]
|
380 |
+
)
|
381 |
+
# end
|
382 |
+
# end
|
383 |
+
|
384 |
+
elif first.is_cuda == False:
|
385 |
+
raise NotImplementedError()
|
386 |
+
|
387 |
+
# end
|
388 |
+
|
389 |
+
return gradFirst, gradSecond, None
|
390 |
+
# end
|
391 |
+
# end
|
392 |
+
|
393 |
+
def FunctionCorrelation(tenFirst, tenSecond, intStride):
|
394 |
+
return _FunctionCorrelation.apply(tenFirst, tenSecond, intStride)
|
395 |
+
# end
|
396 |
+
|
397 |
+
class ModuleCorrelation(torch.nn.Module):
|
398 |
+
def __init__(self):
|
399 |
+
super(ModuleCorrelation, self).__init__()
|
400 |
+
# end
|
401 |
+
|
402 |
+
def forward(self, tenFirst, tenSecond, intStride):
|
403 |
+
return _FunctionCorrelation.apply(tenFirst, tenSecond, intStride)
|
404 |
+
# end
|
405 |
+
# end
|
models/networks.py
ADDED
@@ -0,0 +1,186 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.parallel
|
4 |
+
import os
|
5 |
+
|
6 |
+
class UnetSkipConnectionBlock(nn.Module):
|
7 |
+
def __init__(self, outer_nc, inner_nc, input_nc=None,
|
8 |
+
submodule=None, outermost=False, innermost=False, norm_layer=nn.BatchNorm2d, use_dropout=False):
|
9 |
+
super(UnetSkipConnectionBlock, self).__init__()
|
10 |
+
self.outermost = outermost
|
11 |
+
use_bias = norm_layer == nn.InstanceNorm2d
|
12 |
+
|
13 |
+
if input_nc is None:
|
14 |
+
input_nc = outer_nc
|
15 |
+
downconv = nn.Conv2d(input_nc, inner_nc, kernel_size=4,
|
16 |
+
stride=2, padding=1, bias=use_bias)
|
17 |
+
downrelu = nn.LeakyReLU(0.2, True)
|
18 |
+
uprelu = nn.ReLU(True)
|
19 |
+
if norm_layer != None:
|
20 |
+
downnorm = norm_layer(inner_nc)
|
21 |
+
upnorm = norm_layer(outer_nc)
|
22 |
+
|
23 |
+
if outermost:
|
24 |
+
upsample = nn.Upsample(scale_factor=2, mode='bilinear')
|
25 |
+
upconv = nn.Conv2d(inner_nc * 2, outer_nc, kernel_size=3, stride=1, padding=1, bias=use_bias)
|
26 |
+
down = [downconv]
|
27 |
+
up = [uprelu, upsample, upconv]
|
28 |
+
model = down + [submodule] + up
|
29 |
+
elif innermost:
|
30 |
+
upsample = nn.Upsample(scale_factor=2, mode='bilinear')
|
31 |
+
upconv = nn.Conv2d(inner_nc, outer_nc, kernel_size=3, stride=1, padding=1, bias=use_bias)
|
32 |
+
down = [downrelu, downconv]
|
33 |
+
if norm_layer == None:
|
34 |
+
up = [uprelu, upsample, upconv]
|
35 |
+
else:
|
36 |
+
up = [uprelu, upsample, upconv, upnorm]
|
37 |
+
model = down + up
|
38 |
+
else:
|
39 |
+
upsample = nn.Upsample(scale_factor=2, mode='bilinear')
|
40 |
+
upconv = nn.Conv2d(inner_nc*2, outer_nc, kernel_size=3, stride=1, padding=1, bias=use_bias)
|
41 |
+
if norm_layer == None:
|
42 |
+
down = [downrelu, downconv]
|
43 |
+
up = [uprelu, upsample, upconv]
|
44 |
+
else:
|
45 |
+
down = [downrelu, downconv, downnorm]
|
46 |
+
up = [uprelu, upsample, upconv, upnorm]
|
47 |
+
|
48 |
+
if use_dropout:
|
49 |
+
model = down + [submodule] + up + [nn.Dropout(0.5)]
|
50 |
+
else:
|
51 |
+
model = down + [submodule] + up
|
52 |
+
|
53 |
+
self.model = nn.Sequential(*model)
|
54 |
+
|
55 |
+
def forward(self, x):
|
56 |
+
if self.outermost:
|
57 |
+
return self.model(x)
|
58 |
+
else:
|
59 |
+
return torch.cat([x, self.model(x)], 1)
|
60 |
+
|
61 |
+
class ResidualBlock(nn.Module):
|
62 |
+
def __init__(self, in_features=64, norm_layer=nn.BatchNorm2d):
|
63 |
+
super(ResidualBlock, self).__init__()
|
64 |
+
self.relu = nn.ReLU(True)
|
65 |
+
if norm_layer == None:
|
66 |
+
self.block = nn.Sequential(
|
67 |
+
nn.Conv2d(in_features, in_features, 3, 1, 1, bias=False),
|
68 |
+
nn.ReLU(inplace=True),
|
69 |
+
nn.Conv2d(in_features, in_features, 3, 1, 1, bias=False),
|
70 |
+
)
|
71 |
+
else:
|
72 |
+
self.block = nn.Sequential(
|
73 |
+
nn.Conv2d(in_features, in_features, 3, 1, 1, bias=False),
|
74 |
+
norm_layer(in_features),
|
75 |
+
nn.ReLU(inplace=True),
|
76 |
+
nn.Conv2d(in_features, in_features, 3, 1, 1, bias=False),
|
77 |
+
norm_layer(in_features)
|
78 |
+
)
|
79 |
+
|
80 |
+
def forward(self, x):
|
81 |
+
residual = x
|
82 |
+
out = self.block(x)
|
83 |
+
out += residual
|
84 |
+
out = self.relu(out)
|
85 |
+
return out
|
86 |
+
|
87 |
+
class ResUnetGenerator(nn.Module):
|
88 |
+
def __init__(self, input_nc, output_nc, num_downs, ngf=64,
|
89 |
+
norm_layer=nn.BatchNorm2d, use_dropout=False):
|
90 |
+
super(ResUnetGenerator, self).__init__()
|
91 |
+
unet_block = ResUnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=None, norm_layer=norm_layer, innermost=True)
|
92 |
+
|
93 |
+
for i in range(num_downs - 5):
|
94 |
+
unet_block = ResUnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer, use_dropout=use_dropout)
|
95 |
+
unet_block = ResUnetSkipConnectionBlock(ngf * 4, ngf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer)
|
96 |
+
unet_block = ResUnetSkipConnectionBlock(ngf * 2, ngf * 4, input_nc=None, submodule=unet_block, norm_layer=norm_layer)
|
97 |
+
unet_block = ResUnetSkipConnectionBlock(ngf, ngf * 2, input_nc=None, submodule=unet_block, norm_layer=norm_layer)
|
98 |
+
unet_block = ResUnetSkipConnectionBlock(output_nc, ngf, input_nc=input_nc, submodule=unet_block, outermost=True, norm_layer=norm_layer)
|
99 |
+
|
100 |
+
self.model = unet_block
|
101 |
+
|
102 |
+
def forward(self, input):
|
103 |
+
return self.model(input)
|
104 |
+
|
105 |
+
|
106 |
+
class ResUnetSkipConnectionBlock(nn.Module):
|
107 |
+
def __init__(self, outer_nc, inner_nc, input_nc=None,
|
108 |
+
submodule=None, outermost=False, innermost=False, norm_layer=nn.BatchNorm2d, use_dropout=False):
|
109 |
+
super(ResUnetSkipConnectionBlock, self).__init__()
|
110 |
+
self.outermost = outermost
|
111 |
+
use_bias = norm_layer == nn.InstanceNorm2d
|
112 |
+
|
113 |
+
if input_nc is None:
|
114 |
+
input_nc = outer_nc
|
115 |
+
downconv = nn.Conv2d(input_nc, inner_nc, kernel_size=3,
|
116 |
+
stride=2, padding=1, bias=use_bias)
|
117 |
+
|
118 |
+
res_downconv = [ResidualBlock(inner_nc, norm_layer), ResidualBlock(inner_nc, norm_layer)]
|
119 |
+
res_upconv = [ResidualBlock(outer_nc, norm_layer), ResidualBlock(outer_nc, norm_layer)]
|
120 |
+
|
121 |
+
downrelu = nn.ReLU(True)
|
122 |
+
uprelu = nn.ReLU(True)
|
123 |
+
if norm_layer != None:
|
124 |
+
downnorm = norm_layer(inner_nc)
|
125 |
+
upnorm = norm_layer(outer_nc)
|
126 |
+
|
127 |
+
if outermost:
|
128 |
+
upsample = nn.Upsample(scale_factor=2, mode='nearest')
|
129 |
+
upconv = nn.Conv2d(inner_nc * 2, outer_nc, kernel_size=3, stride=1, padding=1, bias=use_bias)
|
130 |
+
down = [downconv, downrelu] + res_downconv
|
131 |
+
up = [upsample, upconv]
|
132 |
+
model = down + [submodule] + up
|
133 |
+
elif innermost:
|
134 |
+
upsample = nn.Upsample(scale_factor=2, mode='nearest')
|
135 |
+
upconv = nn.Conv2d(inner_nc, outer_nc, kernel_size=3, stride=1, padding=1, bias=use_bias)
|
136 |
+
down = [downconv, downrelu] + res_downconv
|
137 |
+
if norm_layer == None:
|
138 |
+
up = [upsample, upconv, uprelu] + res_upconv
|
139 |
+
else:
|
140 |
+
up = [upsample, upconv, upnorm, uprelu] + res_upconv
|
141 |
+
model = down + up
|
142 |
+
else:
|
143 |
+
upsample = nn.Upsample(scale_factor=2, mode='nearest')
|
144 |
+
upconv = nn.Conv2d(inner_nc*2, outer_nc, kernel_size=3, stride=1, padding=1, bias=use_bias)
|
145 |
+
if norm_layer == None:
|
146 |
+
down = [downconv, downrelu] + res_downconv
|
147 |
+
up = [upsample, upconv, uprelu] + res_upconv
|
148 |
+
else:
|
149 |
+
down = [downconv, downnorm, downrelu] + res_downconv
|
150 |
+
up = [upsample, upconv, upnorm, uprelu] + res_upconv
|
151 |
+
|
152 |
+
if use_dropout:
|
153 |
+
model = down + [submodule] + up + [nn.Dropout(0.5)]
|
154 |
+
else:
|
155 |
+
model = down + [submodule] + up
|
156 |
+
|
157 |
+
self.model = nn.Sequential(*model)
|
158 |
+
|
159 |
+
def forward(self, x):
|
160 |
+
if self.outermost:
|
161 |
+
return self.model(x)
|
162 |
+
else:
|
163 |
+
return torch.cat([x, self.model(x)], 1)
|
164 |
+
|
165 |
+
|
166 |
+
def save_checkpoint(model, save_path):
|
167 |
+
if not os.path.exists(os.path.dirname(save_path)):
|
168 |
+
os.makedirs(os.path.dirname(save_path))
|
169 |
+
torch.save(model.state_dict(), save_path)
|
170 |
+
|
171 |
+
|
172 |
+
def load_checkpoint(model, checkpoint_path):
|
173 |
+
|
174 |
+
if not os.path.exists(checkpoint_path):
|
175 |
+
print('No checkpoint!')
|
176 |
+
return
|
177 |
+
|
178 |
+
checkpoint = torch.load(checkpoint_path)
|
179 |
+
checkpoint_new = model.state_dict()
|
180 |
+
for param in checkpoint_new:
|
181 |
+
checkpoint_new[param] = checkpoint[param]
|
182 |
+
|
183 |
+
model.load_state_dict(checkpoint_new)
|
184 |
+
|
185 |
+
|
186 |
+
|