AK391
commited on
Commit
·
ff8c072
1
Parent(s):
b483613
models
Browse files- models/.DS_Store +0 -0
- models/ade20k/.DS_Store +0 -0
- models/ade20k/__init__.py +1 -0
- models/ade20k/base.py +627 -0
- models/ade20k/color150.mat +0 -0
- models/ade20k/mobilenet.py +154 -0
- models/ade20k/object150_info.csv +151 -0
- models/ade20k/resnet.py +181 -0
- models/ade20k/segm_lib/.DS_Store +0 -0
- models/ade20k/segm_lib/nn/.DS_Store +0 -0
- models/ade20k/segm_lib/nn/__init__.py +2 -0
- models/ade20k/segm_lib/nn/modules/__init__.py +12 -0
- models/ade20k/segm_lib/nn/modules/batchnorm.py +329 -0
- models/ade20k/segm_lib/nn/modules/comm.py +131 -0
- models/ade20k/segm_lib/nn/modules/replicate.py +94 -0
- models/ade20k/segm_lib/nn/modules/tests/test_numeric_batchnorm.py +56 -0
- models/ade20k/segm_lib/nn/modules/tests/test_sync_batchnorm.py +111 -0
- models/ade20k/segm_lib/nn/modules/unittest.py +29 -0
- models/ade20k/segm_lib/nn/parallel/__init__.py +1 -0
- models/ade20k/segm_lib/nn/parallel/data_parallel.py +112 -0
- models/ade20k/segm_lib/utils/__init__.py +1 -0
- models/ade20k/segm_lib/utils/data/__init__.py +3 -0
- models/ade20k/segm_lib/utils/data/dataloader.py +425 -0
- models/ade20k/segm_lib/utils/data/dataset.py +118 -0
- models/ade20k/segm_lib/utils/data/distributed.py +58 -0
- models/ade20k/segm_lib/utils/data/sampler.py +131 -0
- models/ade20k/segm_lib/utils/th.py +41 -0
- models/ade20k/utils.py +40 -0
- models/lpips_models/alex.pth +3 -0
- models/lpips_models/squeeze.pth +3 -0
- models/lpips_models/vgg.pth +3 -0
models/.DS_Store
ADDED
Binary file (8.2 kB). View file
|
|
models/ade20k/.DS_Store
ADDED
Binary file (6.15 kB). View file
|
|
models/ade20k/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from .base import *
|
models/ade20k/base.py
ADDED
@@ -0,0 +1,627 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Modified from https://github.com/CSAILVision/semantic-segmentation-pytorch"""
|
2 |
+
|
3 |
+
import os
|
4 |
+
|
5 |
+
import pandas as pd
|
6 |
+
import torch
|
7 |
+
import torch.nn as nn
|
8 |
+
import torch.nn.functional as F
|
9 |
+
from scipy.io import loadmat
|
10 |
+
from torch.nn.modules import BatchNorm2d
|
11 |
+
|
12 |
+
from . import resnet
|
13 |
+
from . import mobilenet
|
14 |
+
|
15 |
+
|
16 |
+
NUM_CLASS = 150
|
17 |
+
base_path = os.path.dirname(os.path.abspath(__file__)) # current file path
|
18 |
+
colors_path = os.path.join(base_path, 'color150.mat')
|
19 |
+
classes_path = os.path.join(base_path, 'object150_info.csv')
|
20 |
+
|
21 |
+
segm_options = dict(colors=loadmat(colors_path)['colors'],
|
22 |
+
classes=pd.read_csv(classes_path),)
|
23 |
+
|
24 |
+
|
25 |
+
class NormalizeTensor:
|
26 |
+
def __init__(self, mean, std, inplace=False):
|
27 |
+
"""Normalize a tensor image with mean and standard deviation.
|
28 |
+
.. note::
|
29 |
+
This transform acts out of place by default, i.e., it does not mutates the input tensor.
|
30 |
+
See :class:`~torchvision.transforms.Normalize` for more details.
|
31 |
+
Args:
|
32 |
+
tensor (Tensor): Tensor image of size (C, H, W) to be normalized.
|
33 |
+
mean (sequence): Sequence of means for each channel.
|
34 |
+
std (sequence): Sequence of standard deviations for each channel.
|
35 |
+
inplace(bool,optional): Bool to make this operation inplace.
|
36 |
+
Returns:
|
37 |
+
Tensor: Normalized Tensor image.
|
38 |
+
"""
|
39 |
+
|
40 |
+
self.mean = mean
|
41 |
+
self.std = std
|
42 |
+
self.inplace = inplace
|
43 |
+
|
44 |
+
def __call__(self, tensor):
|
45 |
+
if not self.inplace:
|
46 |
+
tensor = tensor.clone()
|
47 |
+
|
48 |
+
dtype = tensor.dtype
|
49 |
+
mean = torch.as_tensor(self.mean, dtype=dtype, device=tensor.device)
|
50 |
+
std = torch.as_tensor(self.std, dtype=dtype, device=tensor.device)
|
51 |
+
tensor.sub_(mean[None, :, None, None]).div_(std[None, :, None, None])
|
52 |
+
return tensor
|
53 |
+
|
54 |
+
|
55 |
+
# Model Builder
|
56 |
+
class ModelBuilder:
|
57 |
+
# custom weights initialization
|
58 |
+
@staticmethod
|
59 |
+
def weights_init(m):
|
60 |
+
classname = m.__class__.__name__
|
61 |
+
if classname.find('Conv') != -1:
|
62 |
+
nn.init.kaiming_normal_(m.weight.data)
|
63 |
+
elif classname.find('BatchNorm') != -1:
|
64 |
+
m.weight.data.fill_(1.)
|
65 |
+
m.bias.data.fill_(1e-4)
|
66 |
+
|
67 |
+
@staticmethod
|
68 |
+
def build_encoder(arch='resnet50dilated', fc_dim=512, weights=''):
|
69 |
+
pretrained = True if len(weights) == 0 else False
|
70 |
+
arch = arch.lower()
|
71 |
+
if arch == 'mobilenetv2dilated':
|
72 |
+
orig_mobilenet = mobilenet.__dict__['mobilenetv2'](pretrained=pretrained)
|
73 |
+
net_encoder = MobileNetV2Dilated(orig_mobilenet, dilate_scale=8)
|
74 |
+
elif arch == 'resnet18':
|
75 |
+
orig_resnet = resnet.__dict__['resnet18'](pretrained=pretrained)
|
76 |
+
net_encoder = Resnet(orig_resnet)
|
77 |
+
elif arch == 'resnet18dilated':
|
78 |
+
orig_resnet = resnet.__dict__['resnet18'](pretrained=pretrained)
|
79 |
+
net_encoder = ResnetDilated(orig_resnet, dilate_scale=8)
|
80 |
+
elif arch == 'resnet50dilated':
|
81 |
+
orig_resnet = resnet.__dict__['resnet50'](pretrained=pretrained)
|
82 |
+
net_encoder = ResnetDilated(orig_resnet, dilate_scale=8)
|
83 |
+
elif arch == 'resnet50':
|
84 |
+
orig_resnet = resnet.__dict__['resnet50'](pretrained=pretrained)
|
85 |
+
net_encoder = Resnet(orig_resnet)
|
86 |
+
else:
|
87 |
+
raise Exception('Architecture undefined!')
|
88 |
+
|
89 |
+
# encoders are usually pretrained
|
90 |
+
# net_encoder.apply(ModelBuilder.weights_init)
|
91 |
+
if len(weights) > 0:
|
92 |
+
print('Loading weights for net_encoder')
|
93 |
+
net_encoder.load_state_dict(
|
94 |
+
torch.load(weights, map_location=lambda storage, loc: storage), strict=False)
|
95 |
+
return net_encoder
|
96 |
+
|
97 |
+
@staticmethod
|
98 |
+
def build_decoder(arch='ppm_deepsup',
|
99 |
+
fc_dim=512, num_class=NUM_CLASS,
|
100 |
+
weights='', use_softmax=False, drop_last_conv=False):
|
101 |
+
arch = arch.lower()
|
102 |
+
if arch == 'ppm_deepsup':
|
103 |
+
net_decoder = PPMDeepsup(
|
104 |
+
num_class=num_class,
|
105 |
+
fc_dim=fc_dim,
|
106 |
+
use_softmax=use_softmax,
|
107 |
+
drop_last_conv=drop_last_conv)
|
108 |
+
elif arch == 'c1_deepsup':
|
109 |
+
net_decoder = C1DeepSup(
|
110 |
+
num_class=num_class,
|
111 |
+
fc_dim=fc_dim,
|
112 |
+
use_softmax=use_softmax,
|
113 |
+
drop_last_conv=drop_last_conv)
|
114 |
+
else:
|
115 |
+
raise Exception('Architecture undefined!')
|
116 |
+
|
117 |
+
net_decoder.apply(ModelBuilder.weights_init)
|
118 |
+
if len(weights) > 0:
|
119 |
+
print('Loading weights for net_decoder')
|
120 |
+
net_decoder.load_state_dict(
|
121 |
+
torch.load(weights, map_location=lambda storage, loc: storage), strict=False)
|
122 |
+
return net_decoder
|
123 |
+
|
124 |
+
@staticmethod
|
125 |
+
def get_decoder(weights_path, arch_encoder, arch_decoder, fc_dim, drop_last_conv, *arts, **kwargs):
|
126 |
+
path = os.path.join(weights_path, 'ade20k', f'ade20k-{arch_encoder}-{arch_decoder}/decoder_epoch_20.pth')
|
127 |
+
return ModelBuilder.build_decoder(arch=arch_decoder, fc_dim=fc_dim, weights=path, use_softmax=True, drop_last_conv=drop_last_conv)
|
128 |
+
|
129 |
+
@staticmethod
|
130 |
+
def get_encoder(weights_path, arch_encoder, arch_decoder, fc_dim, segmentation,
|
131 |
+
*arts, **kwargs):
|
132 |
+
if segmentation:
|
133 |
+
path = os.path.join(weights_path, 'ade20k', f'ade20k-{arch_encoder}-{arch_decoder}/encoder_epoch_20.pth')
|
134 |
+
else:
|
135 |
+
path = ''
|
136 |
+
return ModelBuilder.build_encoder(arch=arch_encoder, fc_dim=fc_dim, weights=path)
|
137 |
+
|
138 |
+
|
139 |
+
def conv3x3_bn_relu(in_planes, out_planes, stride=1):
|
140 |
+
return nn.Sequential(
|
141 |
+
nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False),
|
142 |
+
BatchNorm2d(out_planes),
|
143 |
+
nn.ReLU(inplace=True),
|
144 |
+
)
|
145 |
+
|
146 |
+
|
147 |
+
class SegmentationModule(nn.Module):
|
148 |
+
def __init__(self,
|
149 |
+
weights_path,
|
150 |
+
num_classes=150,
|
151 |
+
arch_encoder="resnet50dilated",
|
152 |
+
drop_last_conv=False,
|
153 |
+
net_enc=None, # None for Default encoder
|
154 |
+
net_dec=None, # None for Default decoder
|
155 |
+
encode=None, # {None, 'binary', 'color', 'sky'}
|
156 |
+
use_default_normalization=False,
|
157 |
+
return_feature_maps=False,
|
158 |
+
return_feature_maps_level=3, # {0, 1, 2, 3}
|
159 |
+
return_feature_maps_only=True,
|
160 |
+
**kwargs,
|
161 |
+
):
|
162 |
+
super().__init__()
|
163 |
+
self.weights_path = weights_path
|
164 |
+
self.drop_last_conv = drop_last_conv
|
165 |
+
self.arch_encoder = arch_encoder
|
166 |
+
if self.arch_encoder == "resnet50dilated":
|
167 |
+
self.arch_decoder = "ppm_deepsup"
|
168 |
+
self.fc_dim = 2048
|
169 |
+
elif self.arch_encoder == "mobilenetv2dilated":
|
170 |
+
self.arch_decoder = "c1_deepsup"
|
171 |
+
self.fc_dim = 320
|
172 |
+
else:
|
173 |
+
raise NotImplementedError(f"No such arch_encoder={self.arch_encoder}")
|
174 |
+
model_builder_kwargs = dict(arch_encoder=self.arch_encoder,
|
175 |
+
arch_decoder=self.arch_decoder,
|
176 |
+
fc_dim=self.fc_dim,
|
177 |
+
drop_last_conv=drop_last_conv,
|
178 |
+
weights_path=self.weights_path)
|
179 |
+
|
180 |
+
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
181 |
+
self.encoder = ModelBuilder.get_encoder(**model_builder_kwargs) if net_enc is None else net_enc
|
182 |
+
self.decoder = ModelBuilder.get_decoder(**model_builder_kwargs) if net_dec is None else net_dec
|
183 |
+
self.use_default_normalization = use_default_normalization
|
184 |
+
self.default_normalization = NormalizeTensor(mean=[0.485, 0.456, 0.406],
|
185 |
+
std=[0.229, 0.224, 0.225])
|
186 |
+
|
187 |
+
self.encode = encode
|
188 |
+
|
189 |
+
self.return_feature_maps = return_feature_maps
|
190 |
+
|
191 |
+
assert 0 <= return_feature_maps_level <= 3
|
192 |
+
self.return_feature_maps_level = return_feature_maps_level
|
193 |
+
|
194 |
+
def normalize_input(self, tensor):
|
195 |
+
if tensor.min() < 0 or tensor.max() > 1:
|
196 |
+
raise ValueError("Tensor should be 0..1 before using normalize_input")
|
197 |
+
return self.default_normalization(tensor)
|
198 |
+
|
199 |
+
@property
|
200 |
+
def feature_maps_channels(self):
|
201 |
+
return 256 * 2**(self.return_feature_maps_level) # 256, 512, 1024, 2048
|
202 |
+
|
203 |
+
def forward(self, img_data, segSize=None):
|
204 |
+
if segSize is None:
|
205 |
+
raise NotImplementedError("Please pass segSize param. By default: (300, 300)")
|
206 |
+
|
207 |
+
fmaps = self.encoder(img_data, return_feature_maps=True)
|
208 |
+
pred = self.decoder(fmaps, segSize=segSize)
|
209 |
+
|
210 |
+
if self.return_feature_maps:
|
211 |
+
return pred, fmaps
|
212 |
+
# print("BINARY", img_data.shape, pred.shape)
|
213 |
+
return pred
|
214 |
+
|
215 |
+
def multi_mask_from_multiclass(self, pred, classes):
|
216 |
+
def isin(ar1, ar2):
|
217 |
+
return (ar1[..., None] == ar2).any(-1).float()
|
218 |
+
return isin(pred, torch.LongTensor(classes).to(self.device))
|
219 |
+
|
220 |
+
@staticmethod
|
221 |
+
def multi_mask_from_multiclass_probs(scores, classes):
|
222 |
+
res = None
|
223 |
+
for c in classes:
|
224 |
+
if res is None:
|
225 |
+
res = scores[:, c]
|
226 |
+
else:
|
227 |
+
res += scores[:, c]
|
228 |
+
return res
|
229 |
+
|
230 |
+
def predict(self, tensor, imgSizes=(-1,), # (300, 375, 450, 525, 600)
|
231 |
+
segSize=None):
|
232 |
+
"""Entry-point for segmentation. Use this methods instead of forward
|
233 |
+
Arguments:
|
234 |
+
tensor {torch.Tensor} -- BCHW
|
235 |
+
Keyword Arguments:
|
236 |
+
imgSizes {tuple or list} -- imgSizes for segmentation input.
|
237 |
+
default: (300, 450)
|
238 |
+
original implementation: (300, 375, 450, 525, 600)
|
239 |
+
|
240 |
+
"""
|
241 |
+
if segSize is None:
|
242 |
+
segSize = tensor.shape[-2:]
|
243 |
+
segSize = (tensor.shape[2], tensor.shape[3])
|
244 |
+
with torch.no_grad():
|
245 |
+
if self.use_default_normalization:
|
246 |
+
tensor = self.normalize_input(tensor)
|
247 |
+
scores = torch.zeros(1, NUM_CLASS, segSize[0], segSize[1]).to(self.device)
|
248 |
+
features = torch.zeros(1, self.feature_maps_channels, segSize[0], segSize[1]).to(self.device)
|
249 |
+
|
250 |
+
result = []
|
251 |
+
for img_size in imgSizes:
|
252 |
+
if img_size != -1:
|
253 |
+
img_data = F.interpolate(tensor.clone(), size=img_size)
|
254 |
+
else:
|
255 |
+
img_data = tensor.clone()
|
256 |
+
|
257 |
+
if self.return_feature_maps:
|
258 |
+
pred_current, fmaps = self.forward(img_data, segSize=segSize)
|
259 |
+
else:
|
260 |
+
pred_current = self.forward(img_data, segSize=segSize)
|
261 |
+
|
262 |
+
|
263 |
+
result.append(pred_current)
|
264 |
+
scores = scores + pred_current / len(imgSizes)
|
265 |
+
|
266 |
+
# Disclaimer: We use and aggregate only last fmaps: fmaps[3]
|
267 |
+
if self.return_feature_maps:
|
268 |
+
features = features + F.interpolate(fmaps[self.return_feature_maps_level], size=segSize) / len(imgSizes)
|
269 |
+
|
270 |
+
_, pred = torch.max(scores, dim=1)
|
271 |
+
|
272 |
+
if self.return_feature_maps:
|
273 |
+
return features
|
274 |
+
|
275 |
+
return pred, result
|
276 |
+
|
277 |
+
def get_edges(self, t):
|
278 |
+
edge = torch.cuda.ByteTensor(t.size()).zero_()
|
279 |
+
edge[:, :, :, 1:] = edge[:, :, :, 1:] | (t[:, :, :, 1:] != t[:, :, :, :-1])
|
280 |
+
edge[:, :, :, :-1] = edge[:, :, :, :-1] | (t[:, :, :, 1:] != t[:, :, :, :-1])
|
281 |
+
edge[:, :, 1:, :] = edge[:, :, 1:, :] | (t[:, :, 1:, :] != t[:, :, :-1, :])
|
282 |
+
edge[:, :, :-1, :] = edge[:, :, :-1, :] | (t[:, :, 1:, :] != t[:, :, :-1, :])
|
283 |
+
|
284 |
+
if True:
|
285 |
+
return edge.half()
|
286 |
+
return edge.float()
|
287 |
+
|
288 |
+
|
289 |
+
# pyramid pooling, deep supervision
|
290 |
+
class PPMDeepsup(nn.Module):
|
291 |
+
def __init__(self, num_class=NUM_CLASS, fc_dim=4096,
|
292 |
+
use_softmax=False, pool_scales=(1, 2, 3, 6),
|
293 |
+
drop_last_conv=False):
|
294 |
+
super().__init__()
|
295 |
+
self.use_softmax = use_softmax
|
296 |
+
self.drop_last_conv = drop_last_conv
|
297 |
+
|
298 |
+
self.ppm = []
|
299 |
+
for scale in pool_scales:
|
300 |
+
self.ppm.append(nn.Sequential(
|
301 |
+
nn.AdaptiveAvgPool2d(scale),
|
302 |
+
nn.Conv2d(fc_dim, 512, kernel_size=1, bias=False),
|
303 |
+
BatchNorm2d(512),
|
304 |
+
nn.ReLU(inplace=True)
|
305 |
+
))
|
306 |
+
self.ppm = nn.ModuleList(self.ppm)
|
307 |
+
self.cbr_deepsup = conv3x3_bn_relu(fc_dim // 2, fc_dim // 4, 1)
|
308 |
+
|
309 |
+
self.conv_last = nn.Sequential(
|
310 |
+
nn.Conv2d(fc_dim + len(pool_scales) * 512, 512,
|
311 |
+
kernel_size=3, padding=1, bias=False),
|
312 |
+
BatchNorm2d(512),
|
313 |
+
nn.ReLU(inplace=True),
|
314 |
+
nn.Dropout2d(0.1),
|
315 |
+
nn.Conv2d(512, num_class, kernel_size=1)
|
316 |
+
)
|
317 |
+
self.conv_last_deepsup = nn.Conv2d(fc_dim // 4, num_class, 1, 1, 0)
|
318 |
+
self.dropout_deepsup = nn.Dropout2d(0.1)
|
319 |
+
|
320 |
+
def forward(self, conv_out, segSize=None):
|
321 |
+
conv5 = conv_out[-1]
|
322 |
+
|
323 |
+
input_size = conv5.size()
|
324 |
+
ppm_out = [conv5]
|
325 |
+
for pool_scale in self.ppm:
|
326 |
+
ppm_out.append(nn.functional.interpolate(
|
327 |
+
pool_scale(conv5),
|
328 |
+
(input_size[2], input_size[3]),
|
329 |
+
mode='bilinear', align_corners=False))
|
330 |
+
ppm_out = torch.cat(ppm_out, 1)
|
331 |
+
|
332 |
+
if self.drop_last_conv:
|
333 |
+
return ppm_out
|
334 |
+
else:
|
335 |
+
x = self.conv_last(ppm_out)
|
336 |
+
|
337 |
+
if self.use_softmax: # is True during inference
|
338 |
+
x = nn.functional.interpolate(
|
339 |
+
x, size=segSize, mode='bilinear', align_corners=False)
|
340 |
+
x = nn.functional.softmax(x, dim=1)
|
341 |
+
return x
|
342 |
+
|
343 |
+
# deep sup
|
344 |
+
conv4 = conv_out[-2]
|
345 |
+
_ = self.cbr_deepsup(conv4)
|
346 |
+
_ = self.dropout_deepsup(_)
|
347 |
+
_ = self.conv_last_deepsup(_)
|
348 |
+
|
349 |
+
x = nn.functional.log_softmax(x, dim=1)
|
350 |
+
_ = nn.functional.log_softmax(_, dim=1)
|
351 |
+
|
352 |
+
return (x, _)
|
353 |
+
|
354 |
+
|
355 |
+
class Resnet(nn.Module):
|
356 |
+
def __init__(self, orig_resnet):
|
357 |
+
super(Resnet, self).__init__()
|
358 |
+
|
359 |
+
# take pretrained resnet, except AvgPool and FC
|
360 |
+
self.conv1 = orig_resnet.conv1
|
361 |
+
self.bn1 = orig_resnet.bn1
|
362 |
+
self.relu1 = orig_resnet.relu1
|
363 |
+
self.conv2 = orig_resnet.conv2
|
364 |
+
self.bn2 = orig_resnet.bn2
|
365 |
+
self.relu2 = orig_resnet.relu2
|
366 |
+
self.conv3 = orig_resnet.conv3
|
367 |
+
self.bn3 = orig_resnet.bn3
|
368 |
+
self.relu3 = orig_resnet.relu3
|
369 |
+
self.maxpool = orig_resnet.maxpool
|
370 |
+
self.layer1 = orig_resnet.layer1
|
371 |
+
self.layer2 = orig_resnet.layer2
|
372 |
+
self.layer3 = orig_resnet.layer3
|
373 |
+
self.layer4 = orig_resnet.layer4
|
374 |
+
|
375 |
+
def forward(self, x, return_feature_maps=False):
|
376 |
+
conv_out = []
|
377 |
+
|
378 |
+
x = self.relu1(self.bn1(self.conv1(x)))
|
379 |
+
x = self.relu2(self.bn2(self.conv2(x)))
|
380 |
+
x = self.relu3(self.bn3(self.conv3(x)))
|
381 |
+
x = self.maxpool(x)
|
382 |
+
|
383 |
+
x = self.layer1(x); conv_out.append(x);
|
384 |
+
x = self.layer2(x); conv_out.append(x);
|
385 |
+
x = self.layer3(x); conv_out.append(x);
|
386 |
+
x = self.layer4(x); conv_out.append(x);
|
387 |
+
|
388 |
+
if return_feature_maps:
|
389 |
+
return conv_out
|
390 |
+
return [x]
|
391 |
+
|
392 |
+
# Resnet Dilated
|
393 |
+
class ResnetDilated(nn.Module):
|
394 |
+
def __init__(self, orig_resnet, dilate_scale=8):
|
395 |
+
super().__init__()
|
396 |
+
from functools import partial
|
397 |
+
|
398 |
+
if dilate_scale == 8:
|
399 |
+
orig_resnet.layer3.apply(
|
400 |
+
partial(self._nostride_dilate, dilate=2))
|
401 |
+
orig_resnet.layer4.apply(
|
402 |
+
partial(self._nostride_dilate, dilate=4))
|
403 |
+
elif dilate_scale == 16:
|
404 |
+
orig_resnet.layer4.apply(
|
405 |
+
partial(self._nostride_dilate, dilate=2))
|
406 |
+
|
407 |
+
# take pretrained resnet, except AvgPool and FC
|
408 |
+
self.conv1 = orig_resnet.conv1
|
409 |
+
self.bn1 = orig_resnet.bn1
|
410 |
+
self.relu1 = orig_resnet.relu1
|
411 |
+
self.conv2 = orig_resnet.conv2
|
412 |
+
self.bn2 = orig_resnet.bn2
|
413 |
+
self.relu2 = orig_resnet.relu2
|
414 |
+
self.conv3 = orig_resnet.conv3
|
415 |
+
self.bn3 = orig_resnet.bn3
|
416 |
+
self.relu3 = orig_resnet.relu3
|
417 |
+
self.maxpool = orig_resnet.maxpool
|
418 |
+
self.layer1 = orig_resnet.layer1
|
419 |
+
self.layer2 = orig_resnet.layer2
|
420 |
+
self.layer3 = orig_resnet.layer3
|
421 |
+
self.layer4 = orig_resnet.layer4
|
422 |
+
|
423 |
+
def _nostride_dilate(self, m, dilate):
|
424 |
+
classname = m.__class__.__name__
|
425 |
+
if classname.find('Conv') != -1:
|
426 |
+
# the convolution with stride
|
427 |
+
if m.stride == (2, 2):
|
428 |
+
m.stride = (1, 1)
|
429 |
+
if m.kernel_size == (3, 3):
|
430 |
+
m.dilation = (dilate // 2, dilate // 2)
|
431 |
+
m.padding = (dilate // 2, dilate // 2)
|
432 |
+
# other convoluions
|
433 |
+
else:
|
434 |
+
if m.kernel_size == (3, 3):
|
435 |
+
m.dilation = (dilate, dilate)
|
436 |
+
m.padding = (dilate, dilate)
|
437 |
+
|
438 |
+
def forward(self, x, return_feature_maps=False):
|
439 |
+
conv_out = []
|
440 |
+
|
441 |
+
x = self.relu1(self.bn1(self.conv1(x)))
|
442 |
+
x = self.relu2(self.bn2(self.conv2(x)))
|
443 |
+
x = self.relu3(self.bn3(self.conv3(x)))
|
444 |
+
x = self.maxpool(x)
|
445 |
+
|
446 |
+
x = self.layer1(x)
|
447 |
+
conv_out.append(x)
|
448 |
+
x = self.layer2(x)
|
449 |
+
conv_out.append(x)
|
450 |
+
x = self.layer3(x)
|
451 |
+
conv_out.append(x)
|
452 |
+
x = self.layer4(x)
|
453 |
+
conv_out.append(x)
|
454 |
+
|
455 |
+
if return_feature_maps:
|
456 |
+
return conv_out
|
457 |
+
return [x]
|
458 |
+
|
459 |
+
class MobileNetV2Dilated(nn.Module):
|
460 |
+
def __init__(self, orig_net, dilate_scale=8):
|
461 |
+
super(MobileNetV2Dilated, self).__init__()
|
462 |
+
from functools import partial
|
463 |
+
|
464 |
+
# take pretrained mobilenet features
|
465 |
+
self.features = orig_net.features[:-1]
|
466 |
+
|
467 |
+
self.total_idx = len(self.features)
|
468 |
+
self.down_idx = [2, 4, 7, 14]
|
469 |
+
|
470 |
+
if dilate_scale == 8:
|
471 |
+
for i in range(self.down_idx[-2], self.down_idx[-1]):
|
472 |
+
self.features[i].apply(
|
473 |
+
partial(self._nostride_dilate, dilate=2)
|
474 |
+
)
|
475 |
+
for i in range(self.down_idx[-1], self.total_idx):
|
476 |
+
self.features[i].apply(
|
477 |
+
partial(self._nostride_dilate, dilate=4)
|
478 |
+
)
|
479 |
+
elif dilate_scale == 16:
|
480 |
+
for i in range(self.down_idx[-1], self.total_idx):
|
481 |
+
self.features[i].apply(
|
482 |
+
partial(self._nostride_dilate, dilate=2)
|
483 |
+
)
|
484 |
+
|
485 |
+
def _nostride_dilate(self, m, dilate):
|
486 |
+
classname = m.__class__.__name__
|
487 |
+
if classname.find('Conv') != -1:
|
488 |
+
# the convolution with stride
|
489 |
+
if m.stride == (2, 2):
|
490 |
+
m.stride = (1, 1)
|
491 |
+
if m.kernel_size == (3, 3):
|
492 |
+
m.dilation = (dilate//2, dilate//2)
|
493 |
+
m.padding = (dilate//2, dilate//2)
|
494 |
+
# other convoluions
|
495 |
+
else:
|
496 |
+
if m.kernel_size == (3, 3):
|
497 |
+
m.dilation = (dilate, dilate)
|
498 |
+
m.padding = (dilate, dilate)
|
499 |
+
|
500 |
+
def forward(self, x, return_feature_maps=False):
|
501 |
+
if return_feature_maps:
|
502 |
+
conv_out = []
|
503 |
+
for i in range(self.total_idx):
|
504 |
+
x = self.features[i](x)
|
505 |
+
if i in self.down_idx:
|
506 |
+
conv_out.append(x)
|
507 |
+
conv_out.append(x)
|
508 |
+
return conv_out
|
509 |
+
|
510 |
+
else:
|
511 |
+
return [self.features(x)]
|
512 |
+
|
513 |
+
|
514 |
+
# last conv, deep supervision
|
515 |
+
class C1DeepSup(nn.Module):
|
516 |
+
def __init__(self, num_class=150, fc_dim=2048, use_softmax=False, drop_last_conv=False):
|
517 |
+
super(C1DeepSup, self).__init__()
|
518 |
+
self.use_softmax = use_softmax
|
519 |
+
self.drop_last_conv = drop_last_conv
|
520 |
+
|
521 |
+
self.cbr = conv3x3_bn_relu(fc_dim, fc_dim // 4, 1)
|
522 |
+
self.cbr_deepsup = conv3x3_bn_relu(fc_dim // 2, fc_dim // 4, 1)
|
523 |
+
|
524 |
+
# last conv
|
525 |
+
self.conv_last = nn.Conv2d(fc_dim // 4, num_class, 1, 1, 0)
|
526 |
+
self.conv_last_deepsup = nn.Conv2d(fc_dim // 4, num_class, 1, 1, 0)
|
527 |
+
|
528 |
+
def forward(self, conv_out, segSize=None):
|
529 |
+
conv5 = conv_out[-1]
|
530 |
+
|
531 |
+
x = self.cbr(conv5)
|
532 |
+
|
533 |
+
if self.drop_last_conv:
|
534 |
+
return x
|
535 |
+
else:
|
536 |
+
x = self.conv_last(x)
|
537 |
+
|
538 |
+
if self.use_softmax: # is True during inference
|
539 |
+
x = nn.functional.interpolate(
|
540 |
+
x, size=segSize, mode='bilinear', align_corners=False)
|
541 |
+
x = nn.functional.softmax(x, dim=1)
|
542 |
+
return x
|
543 |
+
|
544 |
+
# deep sup
|
545 |
+
conv4 = conv_out[-2]
|
546 |
+
_ = self.cbr_deepsup(conv4)
|
547 |
+
_ = self.conv_last_deepsup(_)
|
548 |
+
|
549 |
+
x = nn.functional.log_softmax(x, dim=1)
|
550 |
+
_ = nn.functional.log_softmax(_, dim=1)
|
551 |
+
|
552 |
+
return (x, _)
|
553 |
+
|
554 |
+
|
555 |
+
# last conv
|
556 |
+
class C1(nn.Module):
|
557 |
+
def __init__(self, num_class=150, fc_dim=2048, use_softmax=False):
|
558 |
+
super(C1, self).__init__()
|
559 |
+
self.use_softmax = use_softmax
|
560 |
+
|
561 |
+
self.cbr = conv3x3_bn_relu(fc_dim, fc_dim // 4, 1)
|
562 |
+
|
563 |
+
# last conv
|
564 |
+
self.conv_last = nn.Conv2d(fc_dim // 4, num_class, 1, 1, 0)
|
565 |
+
|
566 |
+
def forward(self, conv_out, segSize=None):
|
567 |
+
conv5 = conv_out[-1]
|
568 |
+
x = self.cbr(conv5)
|
569 |
+
x = self.conv_last(x)
|
570 |
+
|
571 |
+
if self.use_softmax: # is True during inference
|
572 |
+
x = nn.functional.interpolate(
|
573 |
+
x, size=segSize, mode='bilinear', align_corners=False)
|
574 |
+
x = nn.functional.softmax(x, dim=1)
|
575 |
+
else:
|
576 |
+
x = nn.functional.log_softmax(x, dim=1)
|
577 |
+
|
578 |
+
return x
|
579 |
+
|
580 |
+
|
581 |
+
# pyramid pooling
|
582 |
+
class PPM(nn.Module):
|
583 |
+
def __init__(self, num_class=150, fc_dim=4096,
|
584 |
+
use_softmax=False, pool_scales=(1, 2, 3, 6)):
|
585 |
+
super(PPM, self).__init__()
|
586 |
+
self.use_softmax = use_softmax
|
587 |
+
|
588 |
+
self.ppm = []
|
589 |
+
for scale in pool_scales:
|
590 |
+
self.ppm.append(nn.Sequential(
|
591 |
+
nn.AdaptiveAvgPool2d(scale),
|
592 |
+
nn.Conv2d(fc_dim, 512, kernel_size=1, bias=False),
|
593 |
+
BatchNorm2d(512),
|
594 |
+
nn.ReLU(inplace=True)
|
595 |
+
))
|
596 |
+
self.ppm = nn.ModuleList(self.ppm)
|
597 |
+
|
598 |
+
self.conv_last = nn.Sequential(
|
599 |
+
nn.Conv2d(fc_dim+len(pool_scales)*512, 512,
|
600 |
+
kernel_size=3, padding=1, bias=False),
|
601 |
+
BatchNorm2d(512),
|
602 |
+
nn.ReLU(inplace=True),
|
603 |
+
nn.Dropout2d(0.1),
|
604 |
+
nn.Conv2d(512, num_class, kernel_size=1)
|
605 |
+
)
|
606 |
+
|
607 |
+
def forward(self, conv_out, segSize=None):
|
608 |
+
conv5 = conv_out[-1]
|
609 |
+
|
610 |
+
input_size = conv5.size()
|
611 |
+
ppm_out = [conv5]
|
612 |
+
for pool_scale in self.ppm:
|
613 |
+
ppm_out.append(nn.functional.interpolate(
|
614 |
+
pool_scale(conv5),
|
615 |
+
(input_size[2], input_size[3]),
|
616 |
+
mode='bilinear', align_corners=False))
|
617 |
+
ppm_out = torch.cat(ppm_out, 1)
|
618 |
+
|
619 |
+
x = self.conv_last(ppm_out)
|
620 |
+
|
621 |
+
if self.use_softmax: # is True during inference
|
622 |
+
x = nn.functional.interpolate(
|
623 |
+
x, size=segSize, mode='bilinear', align_corners=False)
|
624 |
+
x = nn.functional.softmax(x, dim=1)
|
625 |
+
else:
|
626 |
+
x = nn.functional.log_softmax(x, dim=1)
|
627 |
+
return x
|
models/ade20k/color150.mat
ADDED
Binary file (502 Bytes). View file
|
|
models/ade20k/mobilenet.py
ADDED
@@ -0,0 +1,154 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
This MobileNetV2 implementation is modified from the following repository:
|
3 |
+
https://github.com/tonylins/pytorch-mobilenet-v2
|
4 |
+
"""
|
5 |
+
|
6 |
+
import torch.nn as nn
|
7 |
+
import math
|
8 |
+
from .utils import load_url
|
9 |
+
from .segm_lib.nn import SynchronizedBatchNorm2d
|
10 |
+
|
11 |
+
BatchNorm2d = SynchronizedBatchNorm2d
|
12 |
+
|
13 |
+
|
14 |
+
__all__ = ['mobilenetv2']
|
15 |
+
|
16 |
+
|
17 |
+
model_urls = {
|
18 |
+
'mobilenetv2': 'http://sceneparsing.csail.mit.edu/model/pretrained_resnet/mobilenet_v2.pth.tar',
|
19 |
+
}
|
20 |
+
|
21 |
+
|
22 |
+
def conv_bn(inp, oup, stride):
|
23 |
+
return nn.Sequential(
|
24 |
+
nn.Conv2d(inp, oup, 3, stride, 1, bias=False),
|
25 |
+
BatchNorm2d(oup),
|
26 |
+
nn.ReLU6(inplace=True)
|
27 |
+
)
|
28 |
+
|
29 |
+
|
30 |
+
def conv_1x1_bn(inp, oup):
|
31 |
+
return nn.Sequential(
|
32 |
+
nn.Conv2d(inp, oup, 1, 1, 0, bias=False),
|
33 |
+
BatchNorm2d(oup),
|
34 |
+
nn.ReLU6(inplace=True)
|
35 |
+
)
|
36 |
+
|
37 |
+
|
38 |
+
class InvertedResidual(nn.Module):
|
39 |
+
def __init__(self, inp, oup, stride, expand_ratio):
|
40 |
+
super(InvertedResidual, self).__init__()
|
41 |
+
self.stride = stride
|
42 |
+
assert stride in [1, 2]
|
43 |
+
|
44 |
+
hidden_dim = round(inp * expand_ratio)
|
45 |
+
self.use_res_connect = self.stride == 1 and inp == oup
|
46 |
+
|
47 |
+
if expand_ratio == 1:
|
48 |
+
self.conv = nn.Sequential(
|
49 |
+
# dw
|
50 |
+
nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, bias=False),
|
51 |
+
BatchNorm2d(hidden_dim),
|
52 |
+
nn.ReLU6(inplace=True),
|
53 |
+
# pw-linear
|
54 |
+
nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
|
55 |
+
BatchNorm2d(oup),
|
56 |
+
)
|
57 |
+
else:
|
58 |
+
self.conv = nn.Sequential(
|
59 |
+
# pw
|
60 |
+
nn.Conv2d(inp, hidden_dim, 1, 1, 0, bias=False),
|
61 |
+
BatchNorm2d(hidden_dim),
|
62 |
+
nn.ReLU6(inplace=True),
|
63 |
+
# dw
|
64 |
+
nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, bias=False),
|
65 |
+
BatchNorm2d(hidden_dim),
|
66 |
+
nn.ReLU6(inplace=True),
|
67 |
+
# pw-linear
|
68 |
+
nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
|
69 |
+
BatchNorm2d(oup),
|
70 |
+
)
|
71 |
+
|
72 |
+
def forward(self, x):
|
73 |
+
if self.use_res_connect:
|
74 |
+
return x + self.conv(x)
|
75 |
+
else:
|
76 |
+
return self.conv(x)
|
77 |
+
|
78 |
+
|
79 |
+
class MobileNetV2(nn.Module):
|
80 |
+
def __init__(self, n_class=1000, input_size=224, width_mult=1.):
|
81 |
+
super(MobileNetV2, self).__init__()
|
82 |
+
block = InvertedResidual
|
83 |
+
input_channel = 32
|
84 |
+
last_channel = 1280
|
85 |
+
interverted_residual_setting = [
|
86 |
+
# t, c, n, s
|
87 |
+
[1, 16, 1, 1],
|
88 |
+
[6, 24, 2, 2],
|
89 |
+
[6, 32, 3, 2],
|
90 |
+
[6, 64, 4, 2],
|
91 |
+
[6, 96, 3, 1],
|
92 |
+
[6, 160, 3, 2],
|
93 |
+
[6, 320, 1, 1],
|
94 |
+
]
|
95 |
+
|
96 |
+
# building first layer
|
97 |
+
assert input_size % 32 == 0
|
98 |
+
input_channel = int(input_channel * width_mult)
|
99 |
+
self.last_channel = int(last_channel * width_mult) if width_mult > 1.0 else last_channel
|
100 |
+
self.features = [conv_bn(3, input_channel, 2)]
|
101 |
+
# building inverted residual blocks
|
102 |
+
for t, c, n, s in interverted_residual_setting:
|
103 |
+
output_channel = int(c * width_mult)
|
104 |
+
for i in range(n):
|
105 |
+
if i == 0:
|
106 |
+
self.features.append(block(input_channel, output_channel, s, expand_ratio=t))
|
107 |
+
else:
|
108 |
+
self.features.append(block(input_channel, output_channel, 1, expand_ratio=t))
|
109 |
+
input_channel = output_channel
|
110 |
+
# building last several layers
|
111 |
+
self.features.append(conv_1x1_bn(input_channel, self.last_channel))
|
112 |
+
# make it nn.Sequential
|
113 |
+
self.features = nn.Sequential(*self.features)
|
114 |
+
|
115 |
+
# building classifier
|
116 |
+
self.classifier = nn.Sequential(
|
117 |
+
nn.Dropout(0.2),
|
118 |
+
nn.Linear(self.last_channel, n_class),
|
119 |
+
)
|
120 |
+
|
121 |
+
self._initialize_weights()
|
122 |
+
|
123 |
+
def forward(self, x):
|
124 |
+
x = self.features(x)
|
125 |
+
x = x.mean(3).mean(2)
|
126 |
+
x = self.classifier(x)
|
127 |
+
return x
|
128 |
+
|
129 |
+
def _initialize_weights(self):
|
130 |
+
for m in self.modules():
|
131 |
+
if isinstance(m, nn.Conv2d):
|
132 |
+
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
|
133 |
+
m.weight.data.normal_(0, math.sqrt(2. / n))
|
134 |
+
if m.bias is not None:
|
135 |
+
m.bias.data.zero_()
|
136 |
+
elif isinstance(m, BatchNorm2d):
|
137 |
+
m.weight.data.fill_(1)
|
138 |
+
m.bias.data.zero_()
|
139 |
+
elif isinstance(m, nn.Linear):
|
140 |
+
n = m.weight.size(1)
|
141 |
+
m.weight.data.normal_(0, 0.01)
|
142 |
+
m.bias.data.zero_()
|
143 |
+
|
144 |
+
|
145 |
+
def mobilenetv2(pretrained=False, **kwargs):
|
146 |
+
"""Constructs a MobileNet_V2 model.
|
147 |
+
|
148 |
+
Args:
|
149 |
+
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
150 |
+
"""
|
151 |
+
model = MobileNetV2(n_class=1000, **kwargs)
|
152 |
+
if pretrained:
|
153 |
+
model.load_state_dict(load_url(model_urls['mobilenetv2']), strict=False)
|
154 |
+
return model
|
models/ade20k/object150_info.csv
ADDED
@@ -0,0 +1,151 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Idx,Ratio,Train,Val,Stuff,Name
|
2 |
+
1,0.1576,11664,1172,1,wall
|
3 |
+
2,0.1072,6046,612,1,building;edifice
|
4 |
+
3,0.0878,8265,796,1,sky
|
5 |
+
4,0.0621,9336,917,1,floor;flooring
|
6 |
+
5,0.0480,6678,641,0,tree
|
7 |
+
6,0.0450,6604,643,1,ceiling
|
8 |
+
7,0.0398,4023,408,1,road;route
|
9 |
+
8,0.0231,1906,199,0,bed
|
10 |
+
9,0.0198,4688,460,0,windowpane;window
|
11 |
+
10,0.0183,2423,225,1,grass
|
12 |
+
11,0.0181,2874,294,0,cabinet
|
13 |
+
12,0.0166,3068,310,1,sidewalk;pavement
|
14 |
+
13,0.0160,5075,526,0,person;individual;someone;somebody;mortal;soul
|
15 |
+
14,0.0151,1804,190,1,earth;ground
|
16 |
+
15,0.0118,6666,796,0,door;double;door
|
17 |
+
16,0.0110,4269,411,0,table
|
18 |
+
17,0.0109,1691,160,1,mountain;mount
|
19 |
+
18,0.0104,3999,441,0,plant;flora;plant;life
|
20 |
+
19,0.0104,2149,217,0,curtain;drape;drapery;mantle;pall
|
21 |
+
20,0.0103,3261,318,0,chair
|
22 |
+
21,0.0098,3164,306,0,car;auto;automobile;machine;motorcar
|
23 |
+
22,0.0074,709,75,1,water
|
24 |
+
23,0.0067,3296,315,0,painting;picture
|
25 |
+
24,0.0065,1191,106,0,sofa;couch;lounge
|
26 |
+
25,0.0061,1516,162,0,shelf
|
27 |
+
26,0.0060,667,69,1,house
|
28 |
+
27,0.0053,651,57,1,sea
|
29 |
+
28,0.0052,1847,224,0,mirror
|
30 |
+
29,0.0046,1158,128,1,rug;carpet;carpeting
|
31 |
+
30,0.0044,480,44,1,field
|
32 |
+
31,0.0044,1172,98,0,armchair
|
33 |
+
32,0.0044,1292,184,0,seat
|
34 |
+
33,0.0033,1386,138,0,fence;fencing
|
35 |
+
34,0.0031,698,61,0,desk
|
36 |
+
35,0.0030,781,73,0,rock;stone
|
37 |
+
36,0.0027,380,43,0,wardrobe;closet;press
|
38 |
+
37,0.0026,3089,302,0,lamp
|
39 |
+
38,0.0024,404,37,0,bathtub;bathing;tub;bath;tub
|
40 |
+
39,0.0024,804,99,0,railing;rail
|
41 |
+
40,0.0023,1453,153,0,cushion
|
42 |
+
41,0.0023,411,37,0,base;pedestal;stand
|
43 |
+
42,0.0022,1440,162,0,box
|
44 |
+
43,0.0022,800,77,0,column;pillar
|
45 |
+
44,0.0020,2650,298,0,signboard;sign
|
46 |
+
45,0.0019,549,46,0,chest;of;drawers;chest;bureau;dresser
|
47 |
+
46,0.0019,367,36,0,counter
|
48 |
+
47,0.0018,311,30,1,sand
|
49 |
+
48,0.0018,1181,122,0,sink
|
50 |
+
49,0.0018,287,23,1,skyscraper
|
51 |
+
50,0.0018,468,38,0,fireplace;hearth;open;fireplace
|
52 |
+
51,0.0018,402,43,0,refrigerator;icebox
|
53 |
+
52,0.0018,130,12,1,grandstand;covered;stand
|
54 |
+
53,0.0018,561,64,1,path
|
55 |
+
54,0.0017,880,102,0,stairs;steps
|
56 |
+
55,0.0017,86,12,1,runway
|
57 |
+
56,0.0017,172,11,0,case;display;case;showcase;vitrine
|
58 |
+
57,0.0017,198,18,0,pool;table;billiard;table;snooker;table
|
59 |
+
58,0.0017,930,109,0,pillow
|
60 |
+
59,0.0015,139,18,0,screen;door;screen
|
61 |
+
60,0.0015,564,52,1,stairway;staircase
|
62 |
+
61,0.0015,320,26,1,river
|
63 |
+
62,0.0015,261,29,1,bridge;span
|
64 |
+
63,0.0014,275,22,0,bookcase
|
65 |
+
64,0.0014,335,60,0,blind;screen
|
66 |
+
65,0.0014,792,75,0,coffee;table;cocktail;table
|
67 |
+
66,0.0014,395,49,0,toilet;can;commode;crapper;pot;potty;stool;throne
|
68 |
+
67,0.0014,1309,138,0,flower
|
69 |
+
68,0.0013,1112,113,0,book
|
70 |
+
69,0.0013,266,27,1,hill
|
71 |
+
70,0.0013,659,66,0,bench
|
72 |
+
71,0.0012,331,31,0,countertop
|
73 |
+
72,0.0012,531,56,0,stove;kitchen;stove;range;kitchen;range;cooking;stove
|
74 |
+
73,0.0012,369,36,0,palm;palm;tree
|
75 |
+
74,0.0012,144,9,0,kitchen;island
|
76 |
+
75,0.0011,265,29,0,computer;computing;machine;computing;device;data;processor;electronic;computer;information;processing;system
|
77 |
+
76,0.0010,324,33,0,swivel;chair
|
78 |
+
77,0.0009,304,27,0,boat
|
79 |
+
78,0.0009,170,20,0,bar
|
80 |
+
79,0.0009,68,6,0,arcade;machine
|
81 |
+
80,0.0009,65,8,1,hovel;hut;hutch;shack;shanty
|
82 |
+
81,0.0009,248,25,0,bus;autobus;coach;charabanc;double-decker;jitney;motorbus;motorcoach;omnibus;passenger;vehicle
|
83 |
+
82,0.0008,492,49,0,towel
|
84 |
+
83,0.0008,2510,269,0,light;light;source
|
85 |
+
84,0.0008,440,39,0,truck;motortruck
|
86 |
+
85,0.0008,147,18,1,tower
|
87 |
+
86,0.0008,583,56,0,chandelier;pendant;pendent
|
88 |
+
87,0.0007,533,61,0,awning;sunshade;sunblind
|
89 |
+
88,0.0007,1989,239,0,streetlight;street;lamp
|
90 |
+
89,0.0007,71,5,0,booth;cubicle;stall;kiosk
|
91 |
+
90,0.0007,618,53,0,television;television;receiver;television;set;tv;tv;set;idiot;box;boob;tube;telly;goggle;box
|
92 |
+
91,0.0007,135,12,0,airplane;aeroplane;plane
|
93 |
+
92,0.0007,83,5,1,dirt;track
|
94 |
+
93,0.0007,178,17,0,apparel;wearing;apparel;dress;clothes
|
95 |
+
94,0.0006,1003,104,0,pole
|
96 |
+
95,0.0006,182,12,1,land;ground;soil
|
97 |
+
96,0.0006,452,50,0,bannister;banister;balustrade;balusters;handrail
|
98 |
+
97,0.0006,42,6,1,escalator;moving;staircase;moving;stairway
|
99 |
+
98,0.0006,307,31,0,ottoman;pouf;pouffe;puff;hassock
|
100 |
+
99,0.0006,965,114,0,bottle
|
101 |
+
100,0.0006,117,13,0,buffet;counter;sideboard
|
102 |
+
101,0.0006,354,35,0,poster;posting;placard;notice;bill;card
|
103 |
+
102,0.0006,108,9,1,stage
|
104 |
+
103,0.0006,557,55,0,van
|
105 |
+
104,0.0006,52,4,0,ship
|
106 |
+
105,0.0005,99,5,0,fountain
|
107 |
+
106,0.0005,57,4,1,conveyer;belt;conveyor;belt;conveyer;conveyor;transporter
|
108 |
+
107,0.0005,292,31,0,canopy
|
109 |
+
108,0.0005,77,9,0,washer;automatic;washer;washing;machine
|
110 |
+
109,0.0005,340,38,0,plaything;toy
|
111 |
+
110,0.0005,66,3,1,swimming;pool;swimming;bath;natatorium
|
112 |
+
111,0.0005,465,49,0,stool
|
113 |
+
112,0.0005,50,4,0,barrel;cask
|
114 |
+
113,0.0005,622,75,0,basket;handbasket
|
115 |
+
114,0.0005,80,9,1,waterfall;falls
|
116 |
+
115,0.0005,59,3,0,tent;collapsible;shelter
|
117 |
+
116,0.0005,531,72,0,bag
|
118 |
+
117,0.0005,282,30,0,minibike;motorbike
|
119 |
+
118,0.0005,73,7,0,cradle
|
120 |
+
119,0.0005,435,44,0,oven
|
121 |
+
120,0.0005,136,25,0,ball
|
122 |
+
121,0.0005,116,24,0,food;solid;food
|
123 |
+
122,0.0004,266,31,0,step;stair
|
124 |
+
123,0.0004,58,12,0,tank;storage;tank
|
125 |
+
124,0.0004,418,83,0,trade;name;brand;name;brand;marque
|
126 |
+
125,0.0004,319,43,0,microwave;microwave;oven
|
127 |
+
126,0.0004,1193,139,0,pot;flowerpot
|
128 |
+
127,0.0004,97,23,0,animal;animate;being;beast;brute;creature;fauna
|
129 |
+
128,0.0004,347,36,0,bicycle;bike;wheel;cycle
|
130 |
+
129,0.0004,52,5,1,lake
|
131 |
+
130,0.0004,246,22,0,dishwasher;dish;washer;dishwashing;machine
|
132 |
+
131,0.0004,108,13,0,screen;silver;screen;projection;screen
|
133 |
+
132,0.0004,201,30,0,blanket;cover
|
134 |
+
133,0.0004,285,21,0,sculpture
|
135 |
+
134,0.0004,268,27,0,hood;exhaust;hood
|
136 |
+
135,0.0003,1020,108,0,sconce
|
137 |
+
136,0.0003,1282,122,0,vase
|
138 |
+
137,0.0003,528,65,0,traffic;light;traffic;signal;stoplight
|
139 |
+
138,0.0003,453,57,0,tray
|
140 |
+
139,0.0003,671,100,0,ashcan;trash;can;garbage;can;wastebin;ash;bin;ash-bin;ashbin;dustbin;trash;barrel;trash;bin
|
141 |
+
140,0.0003,397,44,0,fan
|
142 |
+
141,0.0003,92,8,1,pier;wharf;wharfage;dock
|
143 |
+
142,0.0003,228,18,0,crt;screen
|
144 |
+
143,0.0003,570,59,0,plate
|
145 |
+
144,0.0003,217,22,0,monitor;monitoring;device
|
146 |
+
145,0.0003,206,19,0,bulletin;board;notice;board
|
147 |
+
146,0.0003,130,14,0,shower
|
148 |
+
147,0.0003,178,28,0,radiator
|
149 |
+
148,0.0002,504,57,0,glass;drinking;glass
|
150 |
+
149,0.0002,775,96,0,clock
|
151 |
+
150,0.0002,421,56,0,flag
|
models/ade20k/resnet.py
ADDED
@@ -0,0 +1,181 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Modified from https://github.com/CSAILVision/semantic-segmentation-pytorch"""
|
2 |
+
|
3 |
+
import math
|
4 |
+
|
5 |
+
import torch.nn as nn
|
6 |
+
from torch.nn import BatchNorm2d
|
7 |
+
|
8 |
+
from .utils import load_url
|
9 |
+
|
10 |
+
__all__ = ['ResNet', 'resnet50']
|
11 |
+
|
12 |
+
|
13 |
+
model_urls = {
|
14 |
+
'resnet50': 'http://sceneparsing.csail.mit.edu/model/pretrained_resnet/resnet50-imagenet.pth',
|
15 |
+
}
|
16 |
+
|
17 |
+
|
18 |
+
def conv3x3(in_planes, out_planes, stride=1):
|
19 |
+
"3x3 convolution with padding"
|
20 |
+
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
|
21 |
+
padding=1, bias=False)
|
22 |
+
|
23 |
+
|
24 |
+
class BasicBlock(nn.Module):
|
25 |
+
expansion = 1
|
26 |
+
|
27 |
+
def __init__(self, inplanes, planes, stride=1, downsample=None):
|
28 |
+
super(BasicBlock, self).__init__()
|
29 |
+
self.conv1 = conv3x3(inplanes, planes, stride)
|
30 |
+
self.bn1 = BatchNorm2d(planes)
|
31 |
+
self.relu = nn.ReLU(inplace=True)
|
32 |
+
self.conv2 = conv3x3(planes, planes)
|
33 |
+
self.bn2 = BatchNorm2d(planes)
|
34 |
+
self.downsample = downsample
|
35 |
+
self.stride = stride
|
36 |
+
|
37 |
+
def forward(self, x):
|
38 |
+
residual = x
|
39 |
+
|
40 |
+
out = self.conv1(x)
|
41 |
+
out = self.bn1(out)
|
42 |
+
out = self.relu(out)
|
43 |
+
|
44 |
+
out = self.conv2(out)
|
45 |
+
out = self.bn2(out)
|
46 |
+
|
47 |
+
if self.downsample is not None:
|
48 |
+
residual = self.downsample(x)
|
49 |
+
|
50 |
+
out += residual
|
51 |
+
out = self.relu(out)
|
52 |
+
|
53 |
+
return out
|
54 |
+
|
55 |
+
|
56 |
+
class Bottleneck(nn.Module):
|
57 |
+
expansion = 4
|
58 |
+
|
59 |
+
def __init__(self, inplanes, planes, stride=1, downsample=None):
|
60 |
+
super(Bottleneck, self).__init__()
|
61 |
+
self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
|
62 |
+
self.bn1 = BatchNorm2d(planes)
|
63 |
+
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
|
64 |
+
padding=1, bias=False)
|
65 |
+
self.bn2 = BatchNorm2d(planes)
|
66 |
+
self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
|
67 |
+
self.bn3 = BatchNorm2d(planes * 4)
|
68 |
+
self.relu = nn.ReLU(inplace=True)
|
69 |
+
self.downsample = downsample
|
70 |
+
self.stride = stride
|
71 |
+
|
72 |
+
def forward(self, x):
|
73 |
+
residual = x
|
74 |
+
|
75 |
+
out = self.conv1(x)
|
76 |
+
out = self.bn1(out)
|
77 |
+
out = self.relu(out)
|
78 |
+
|
79 |
+
out = self.conv2(out)
|
80 |
+
out = self.bn2(out)
|
81 |
+
out = self.relu(out)
|
82 |
+
|
83 |
+
out = self.conv3(out)
|
84 |
+
out = self.bn3(out)
|
85 |
+
|
86 |
+
if self.downsample is not None:
|
87 |
+
residual = self.downsample(x)
|
88 |
+
|
89 |
+
out += residual
|
90 |
+
out = self.relu(out)
|
91 |
+
|
92 |
+
return out
|
93 |
+
|
94 |
+
|
95 |
+
class ResNet(nn.Module):
|
96 |
+
|
97 |
+
def __init__(self, block, layers, num_classes=1000):
|
98 |
+
self.inplanes = 128
|
99 |
+
super(ResNet, self).__init__()
|
100 |
+
self.conv1 = conv3x3(3, 64, stride=2)
|
101 |
+
self.bn1 = BatchNorm2d(64)
|
102 |
+
self.relu1 = nn.ReLU(inplace=True)
|
103 |
+
self.conv2 = conv3x3(64, 64)
|
104 |
+
self.bn2 = BatchNorm2d(64)
|
105 |
+
self.relu2 = nn.ReLU(inplace=True)
|
106 |
+
self.conv3 = conv3x3(64, 128)
|
107 |
+
self.bn3 = BatchNorm2d(128)
|
108 |
+
self.relu3 = nn.ReLU(inplace=True)
|
109 |
+
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
|
110 |
+
|
111 |
+
self.layer1 = self._make_layer(block, 64, layers[0])
|
112 |
+
self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
|
113 |
+
self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
|
114 |
+
self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
|
115 |
+
self.avgpool = nn.AvgPool2d(7, stride=1)
|
116 |
+
self.fc = nn.Linear(512 * block.expansion, num_classes)
|
117 |
+
|
118 |
+
for m in self.modules():
|
119 |
+
if isinstance(m, nn.Conv2d):
|
120 |
+
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
|
121 |
+
m.weight.data.normal_(0, math.sqrt(2. / n))
|
122 |
+
elif isinstance(m, BatchNorm2d):
|
123 |
+
m.weight.data.fill_(1)
|
124 |
+
m.bias.data.zero_()
|
125 |
+
|
126 |
+
def _make_layer(self, block, planes, blocks, stride=1):
|
127 |
+
downsample = None
|
128 |
+
if stride != 1 or self.inplanes != planes * block.expansion:
|
129 |
+
downsample = nn.Sequential(
|
130 |
+
nn.Conv2d(self.inplanes, planes * block.expansion,
|
131 |
+
kernel_size=1, stride=stride, bias=False),
|
132 |
+
BatchNorm2d(planes * block.expansion),
|
133 |
+
)
|
134 |
+
|
135 |
+
layers = []
|
136 |
+
layers.append(block(self.inplanes, planes, stride, downsample))
|
137 |
+
self.inplanes = planes * block.expansion
|
138 |
+
for i in range(1, blocks):
|
139 |
+
layers.append(block(self.inplanes, planes))
|
140 |
+
|
141 |
+
return nn.Sequential(*layers)
|
142 |
+
|
143 |
+
def forward(self, x):
|
144 |
+
x = self.relu1(self.bn1(self.conv1(x)))
|
145 |
+
x = self.relu2(self.bn2(self.conv2(x)))
|
146 |
+
x = self.relu3(self.bn3(self.conv3(x)))
|
147 |
+
x = self.maxpool(x)
|
148 |
+
|
149 |
+
x = self.layer1(x)
|
150 |
+
x = self.layer2(x)
|
151 |
+
x = self.layer3(x)
|
152 |
+
x = self.layer4(x)
|
153 |
+
|
154 |
+
x = self.avgpool(x)
|
155 |
+
x = x.view(x.size(0), -1)
|
156 |
+
x = self.fc(x)
|
157 |
+
|
158 |
+
return x
|
159 |
+
|
160 |
+
|
161 |
+
def resnet50(pretrained=False, **kwargs):
|
162 |
+
"""Constructs a ResNet-50 model.
|
163 |
+
|
164 |
+
Args:
|
165 |
+
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
166 |
+
"""
|
167 |
+
model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs)
|
168 |
+
if pretrained:
|
169 |
+
model.load_state_dict(load_url(model_urls['resnet50']), strict=False)
|
170 |
+
return model
|
171 |
+
|
172 |
+
|
173 |
+
def resnet18(pretrained=False, **kwargs):
|
174 |
+
"""Constructs a ResNet-18 model.
|
175 |
+
Args:
|
176 |
+
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
177 |
+
"""
|
178 |
+
model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs)
|
179 |
+
if pretrained:
|
180 |
+
model.load_state_dict(load_url(model_urls['resnet18']))
|
181 |
+
return model
|
models/ade20k/segm_lib/.DS_Store
ADDED
Binary file (6.15 kB). View file
|
|
models/ade20k/segm_lib/nn/.DS_Store
ADDED
Binary file (6.15 kB). View file
|
|
models/ade20k/segm_lib/nn/__init__.py
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
from .modules import *
|
2 |
+
from .parallel import UserScatteredDataParallel, user_scattered_collate, async_copy_to
|
models/ade20k/segm_lib/nn/modules/__init__.py
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
# File : __init__.py
|
3 |
+
# Author : Jiayuan Mao
|
4 |
+
# Email : [email protected]
|
5 |
+
# Date : 27/01/2018
|
6 |
+
#
|
7 |
+
# This file is part of Synchronized-BatchNorm-PyTorch.
|
8 |
+
# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
|
9 |
+
# Distributed under MIT License.
|
10 |
+
|
11 |
+
from .batchnorm import SynchronizedBatchNorm1d, SynchronizedBatchNorm2d, SynchronizedBatchNorm3d
|
12 |
+
from .replicate import DataParallelWithCallback, patch_replication_callback
|
models/ade20k/segm_lib/nn/modules/batchnorm.py
ADDED
@@ -0,0 +1,329 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
# File : batchnorm.py
|
3 |
+
# Author : Jiayuan Mao
|
4 |
+
# Email : [email protected]
|
5 |
+
# Date : 27/01/2018
|
6 |
+
#
|
7 |
+
# This file is part of Synchronized-BatchNorm-PyTorch.
|
8 |
+
# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
|
9 |
+
# Distributed under MIT License.
|
10 |
+
|
11 |
+
import collections
|
12 |
+
|
13 |
+
import torch
|
14 |
+
import torch.nn.functional as F
|
15 |
+
|
16 |
+
from torch.nn.modules.batchnorm import _BatchNorm
|
17 |
+
from torch.nn.parallel._functions import ReduceAddCoalesced, Broadcast
|
18 |
+
|
19 |
+
from .comm import SyncMaster
|
20 |
+
|
21 |
+
__all__ = ['SynchronizedBatchNorm1d', 'SynchronizedBatchNorm2d', 'SynchronizedBatchNorm3d']
|
22 |
+
|
23 |
+
|
24 |
+
def _sum_ft(tensor):
|
25 |
+
"""sum over the first and last dimention"""
|
26 |
+
return tensor.sum(dim=0).sum(dim=-1)
|
27 |
+
|
28 |
+
|
29 |
+
def _unsqueeze_ft(tensor):
|
30 |
+
"""add new dementions at the front and the tail"""
|
31 |
+
return tensor.unsqueeze(0).unsqueeze(-1)
|
32 |
+
|
33 |
+
|
34 |
+
_ChildMessage = collections.namedtuple('_ChildMessage', ['sum', 'ssum', 'sum_size'])
|
35 |
+
_MasterMessage = collections.namedtuple('_MasterMessage', ['sum', 'inv_std'])
|
36 |
+
|
37 |
+
|
38 |
+
class _SynchronizedBatchNorm(_BatchNorm):
|
39 |
+
def __init__(self, num_features, eps=1e-5, momentum=0.001, affine=True):
|
40 |
+
super(_SynchronizedBatchNorm, self).__init__(num_features, eps=eps, momentum=momentum, affine=affine)
|
41 |
+
|
42 |
+
self._sync_master = SyncMaster(self._data_parallel_master)
|
43 |
+
|
44 |
+
self._is_parallel = False
|
45 |
+
self._parallel_id = None
|
46 |
+
self._slave_pipe = None
|
47 |
+
|
48 |
+
# customed batch norm statistics
|
49 |
+
self._moving_average_fraction = 1. - momentum
|
50 |
+
self.register_buffer('_tmp_running_mean', torch.zeros(self.num_features))
|
51 |
+
self.register_buffer('_tmp_running_var', torch.ones(self.num_features))
|
52 |
+
self.register_buffer('_running_iter', torch.ones(1))
|
53 |
+
self._tmp_running_mean = self.running_mean.clone() * self._running_iter
|
54 |
+
self._tmp_running_var = self.running_var.clone() * self._running_iter
|
55 |
+
|
56 |
+
def forward(self, input):
|
57 |
+
# If it is not parallel computation or is in evaluation mode, use PyTorch's implementation.
|
58 |
+
if not (self._is_parallel and self.training):
|
59 |
+
return F.batch_norm(
|
60 |
+
input, self.running_mean, self.running_var, self.weight, self.bias,
|
61 |
+
self.training, self.momentum, self.eps)
|
62 |
+
|
63 |
+
# Resize the input to (B, C, -1).
|
64 |
+
input_shape = input.size()
|
65 |
+
input = input.view(input.size(0), self.num_features, -1)
|
66 |
+
|
67 |
+
# Compute the sum and square-sum.
|
68 |
+
sum_size = input.size(0) * input.size(2)
|
69 |
+
input_sum = _sum_ft(input)
|
70 |
+
input_ssum = _sum_ft(input ** 2)
|
71 |
+
|
72 |
+
# Reduce-and-broadcast the statistics.
|
73 |
+
if self._parallel_id == 0:
|
74 |
+
mean, inv_std = self._sync_master.run_master(_ChildMessage(input_sum, input_ssum, sum_size))
|
75 |
+
else:
|
76 |
+
mean, inv_std = self._slave_pipe.run_slave(_ChildMessage(input_sum, input_ssum, sum_size))
|
77 |
+
|
78 |
+
# Compute the output.
|
79 |
+
if self.affine:
|
80 |
+
# MJY:: Fuse the multiplication for speed.
|
81 |
+
output = (input - _unsqueeze_ft(mean)) * _unsqueeze_ft(inv_std * self.weight) + _unsqueeze_ft(self.bias)
|
82 |
+
else:
|
83 |
+
output = (input - _unsqueeze_ft(mean)) * _unsqueeze_ft(inv_std)
|
84 |
+
|
85 |
+
# Reshape it.
|
86 |
+
return output.view(input_shape)
|
87 |
+
|
88 |
+
def __data_parallel_replicate__(self, ctx, copy_id):
|
89 |
+
self._is_parallel = True
|
90 |
+
self._parallel_id = copy_id
|
91 |
+
|
92 |
+
# parallel_id == 0 means master device.
|
93 |
+
if self._parallel_id == 0:
|
94 |
+
ctx.sync_master = self._sync_master
|
95 |
+
else:
|
96 |
+
self._slave_pipe = ctx.sync_master.register_slave(copy_id)
|
97 |
+
|
98 |
+
def _data_parallel_master(self, intermediates):
|
99 |
+
"""Reduce the sum and square-sum, compute the statistics, and broadcast it."""
|
100 |
+
intermediates = sorted(intermediates, key=lambda i: i[1].sum.get_device())
|
101 |
+
|
102 |
+
to_reduce = [i[1][:2] for i in intermediates]
|
103 |
+
to_reduce = [j for i in to_reduce for j in i] # flatten
|
104 |
+
target_gpus = [i[1].sum.get_device() for i in intermediates]
|
105 |
+
|
106 |
+
sum_size = sum([i[1].sum_size for i in intermediates])
|
107 |
+
sum_, ssum = ReduceAddCoalesced.apply(target_gpus[0], 2, *to_reduce)
|
108 |
+
|
109 |
+
mean, inv_std = self._compute_mean_std(sum_, ssum, sum_size)
|
110 |
+
|
111 |
+
broadcasted = Broadcast.apply(target_gpus, mean, inv_std)
|
112 |
+
|
113 |
+
outputs = []
|
114 |
+
for i, rec in enumerate(intermediates):
|
115 |
+
outputs.append((rec[0], _MasterMessage(*broadcasted[i*2:i*2+2])))
|
116 |
+
|
117 |
+
return outputs
|
118 |
+
|
119 |
+
def _add_weighted(self, dest, delta, alpha=1, beta=1, bias=0):
|
120 |
+
"""return *dest* by `dest := dest*alpha + delta*beta + bias`"""
|
121 |
+
return dest * alpha + delta * beta + bias
|
122 |
+
|
123 |
+
def _compute_mean_std(self, sum_, ssum, size):
|
124 |
+
"""Compute the mean and standard-deviation with sum and square-sum. This method
|
125 |
+
also maintains the moving average on the master device."""
|
126 |
+
assert size > 1, 'BatchNorm computes unbiased standard-deviation, which requires size > 1.'
|
127 |
+
mean = sum_ / size
|
128 |
+
sumvar = ssum - sum_ * mean
|
129 |
+
unbias_var = sumvar / (size - 1)
|
130 |
+
bias_var = sumvar / size
|
131 |
+
|
132 |
+
self._tmp_running_mean = self._add_weighted(self._tmp_running_mean, mean.data, alpha=self._moving_average_fraction)
|
133 |
+
self._tmp_running_var = self._add_weighted(self._tmp_running_var, unbias_var.data, alpha=self._moving_average_fraction)
|
134 |
+
self._running_iter = self._add_weighted(self._running_iter, 1, alpha=self._moving_average_fraction)
|
135 |
+
|
136 |
+
self.running_mean = self._tmp_running_mean / self._running_iter
|
137 |
+
self.running_var = self._tmp_running_var / self._running_iter
|
138 |
+
|
139 |
+
return mean, bias_var.clamp(self.eps) ** -0.5
|
140 |
+
|
141 |
+
|
142 |
+
class SynchronizedBatchNorm1d(_SynchronizedBatchNorm):
|
143 |
+
r"""Applies Synchronized Batch Normalization over a 2d or 3d input that is seen as a
|
144 |
+
mini-batch.
|
145 |
+
|
146 |
+
.. math::
|
147 |
+
|
148 |
+
y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta
|
149 |
+
|
150 |
+
This module differs from the built-in PyTorch BatchNorm1d as the mean and
|
151 |
+
standard-deviation are reduced across all devices during training.
|
152 |
+
|
153 |
+
For example, when one uses `nn.DataParallel` to wrap the network during
|
154 |
+
training, PyTorch's implementation normalize the tensor on each device using
|
155 |
+
the statistics only on that device, which accelerated the computation and
|
156 |
+
is also easy to implement, but the statistics might be inaccurate.
|
157 |
+
Instead, in this synchronized version, the statistics will be computed
|
158 |
+
over all training samples distributed on multiple devices.
|
159 |
+
|
160 |
+
Note that, for one-GPU or CPU-only case, this module behaves exactly same
|
161 |
+
as the built-in PyTorch implementation.
|
162 |
+
|
163 |
+
The mean and standard-deviation are calculated per-dimension over
|
164 |
+
the mini-batches and gamma and beta are learnable parameter vectors
|
165 |
+
of size C (where C is the input size).
|
166 |
+
|
167 |
+
During training, this layer keeps a running estimate of its computed mean
|
168 |
+
and variance. The running sum is kept with a default momentum of 0.1.
|
169 |
+
|
170 |
+
During evaluation, this running mean/variance is used for normalization.
|
171 |
+
|
172 |
+
Because the BatchNorm is done over the `C` dimension, computing statistics
|
173 |
+
on `(N, L)` slices, it's common terminology to call this Temporal BatchNorm
|
174 |
+
|
175 |
+
Args:
|
176 |
+
num_features: num_features from an expected input of size
|
177 |
+
`batch_size x num_features [x width]`
|
178 |
+
eps: a value added to the denominator for numerical stability.
|
179 |
+
Default: 1e-5
|
180 |
+
momentum: the value used for the running_mean and running_var
|
181 |
+
computation. Default: 0.1
|
182 |
+
affine: a boolean value that when set to ``True``, gives the layer learnable
|
183 |
+
affine parameters. Default: ``True``
|
184 |
+
|
185 |
+
Shape:
|
186 |
+
- Input: :math:`(N, C)` or :math:`(N, C, L)`
|
187 |
+
- Output: :math:`(N, C)` or :math:`(N, C, L)` (same shape as input)
|
188 |
+
|
189 |
+
Examples:
|
190 |
+
>>> # With Learnable Parameters
|
191 |
+
>>> m = SynchronizedBatchNorm1d(100)
|
192 |
+
>>> # Without Learnable Parameters
|
193 |
+
>>> m = SynchronizedBatchNorm1d(100, affine=False)
|
194 |
+
>>> input = torch.autograd.Variable(torch.randn(20, 100))
|
195 |
+
>>> output = m(input)
|
196 |
+
"""
|
197 |
+
|
198 |
+
def _check_input_dim(self, input):
|
199 |
+
if input.dim() != 2 and input.dim() != 3:
|
200 |
+
raise ValueError('expected 2D or 3D input (got {}D input)'
|
201 |
+
.format(input.dim()))
|
202 |
+
super(SynchronizedBatchNorm1d, self)._check_input_dim(input)
|
203 |
+
|
204 |
+
|
205 |
+
class SynchronizedBatchNorm2d(_SynchronizedBatchNorm):
|
206 |
+
r"""Applies Batch Normalization over a 4d input that is seen as a mini-batch
|
207 |
+
of 3d inputs
|
208 |
+
|
209 |
+
.. math::
|
210 |
+
|
211 |
+
y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta
|
212 |
+
|
213 |
+
This module differs from the built-in PyTorch BatchNorm2d as the mean and
|
214 |
+
standard-deviation are reduced across all devices during training.
|
215 |
+
|
216 |
+
For example, when one uses `nn.DataParallel` to wrap the network during
|
217 |
+
training, PyTorch's implementation normalize the tensor on each device using
|
218 |
+
the statistics only on that device, which accelerated the computation and
|
219 |
+
is also easy to implement, but the statistics might be inaccurate.
|
220 |
+
Instead, in this synchronized version, the statistics will be computed
|
221 |
+
over all training samples distributed on multiple devices.
|
222 |
+
|
223 |
+
Note that, for one-GPU or CPU-only case, this module behaves exactly same
|
224 |
+
as the built-in PyTorch implementation.
|
225 |
+
|
226 |
+
The mean and standard-deviation are calculated per-dimension over
|
227 |
+
the mini-batches and gamma and beta are learnable parameter vectors
|
228 |
+
of size C (where C is the input size).
|
229 |
+
|
230 |
+
During training, this layer keeps a running estimate of its computed mean
|
231 |
+
and variance. The running sum is kept with a default momentum of 0.1.
|
232 |
+
|
233 |
+
During evaluation, this running mean/variance is used for normalization.
|
234 |
+
|
235 |
+
Because the BatchNorm is done over the `C` dimension, computing statistics
|
236 |
+
on `(N, H, W)` slices, it's common terminology to call this Spatial BatchNorm
|
237 |
+
|
238 |
+
Args:
|
239 |
+
num_features: num_features from an expected input of
|
240 |
+
size batch_size x num_features x height x width
|
241 |
+
eps: a value added to the denominator for numerical stability.
|
242 |
+
Default: 1e-5
|
243 |
+
momentum: the value used for the running_mean and running_var
|
244 |
+
computation. Default: 0.1
|
245 |
+
affine: a boolean value that when set to ``True``, gives the layer learnable
|
246 |
+
affine parameters. Default: ``True``
|
247 |
+
|
248 |
+
Shape:
|
249 |
+
- Input: :math:`(N, C, H, W)`
|
250 |
+
- Output: :math:`(N, C, H, W)` (same shape as input)
|
251 |
+
|
252 |
+
Examples:
|
253 |
+
>>> # With Learnable Parameters
|
254 |
+
>>> m = SynchronizedBatchNorm2d(100)
|
255 |
+
>>> # Without Learnable Parameters
|
256 |
+
>>> m = SynchronizedBatchNorm2d(100, affine=False)
|
257 |
+
>>> input = torch.autograd.Variable(torch.randn(20, 100, 35, 45))
|
258 |
+
>>> output = m(input)
|
259 |
+
"""
|
260 |
+
|
261 |
+
def _check_input_dim(self, input):
|
262 |
+
if input.dim() != 4:
|
263 |
+
raise ValueError('expected 4D input (got {}D input)'
|
264 |
+
.format(input.dim()))
|
265 |
+
super(SynchronizedBatchNorm2d, self)._check_input_dim(input)
|
266 |
+
|
267 |
+
|
268 |
+
class SynchronizedBatchNorm3d(_SynchronizedBatchNorm):
|
269 |
+
r"""Applies Batch Normalization over a 5d input that is seen as a mini-batch
|
270 |
+
of 4d inputs
|
271 |
+
|
272 |
+
.. math::
|
273 |
+
|
274 |
+
y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta
|
275 |
+
|
276 |
+
This module differs from the built-in PyTorch BatchNorm3d as the mean and
|
277 |
+
standard-deviation are reduced across all devices during training.
|
278 |
+
|
279 |
+
For example, when one uses `nn.DataParallel` to wrap the network during
|
280 |
+
training, PyTorch's implementation normalize the tensor on each device using
|
281 |
+
the statistics only on that device, which accelerated the computation and
|
282 |
+
is also easy to implement, but the statistics might be inaccurate.
|
283 |
+
Instead, in this synchronized version, the statistics will be computed
|
284 |
+
over all training samples distributed on multiple devices.
|
285 |
+
|
286 |
+
Note that, for one-GPU or CPU-only case, this module behaves exactly same
|
287 |
+
as the built-in PyTorch implementation.
|
288 |
+
|
289 |
+
The mean and standard-deviation are calculated per-dimension over
|
290 |
+
the mini-batches and gamma and beta are learnable parameter vectors
|
291 |
+
of size C (where C is the input size).
|
292 |
+
|
293 |
+
During training, this layer keeps a running estimate of its computed mean
|
294 |
+
and variance. The running sum is kept with a default momentum of 0.1.
|
295 |
+
|
296 |
+
During evaluation, this running mean/variance is used for normalization.
|
297 |
+
|
298 |
+
Because the BatchNorm is done over the `C` dimension, computing statistics
|
299 |
+
on `(N, D, H, W)` slices, it's common terminology to call this Volumetric BatchNorm
|
300 |
+
or Spatio-temporal BatchNorm
|
301 |
+
|
302 |
+
Args:
|
303 |
+
num_features: num_features from an expected input of
|
304 |
+
size batch_size x num_features x depth x height x width
|
305 |
+
eps: a value added to the denominator for numerical stability.
|
306 |
+
Default: 1e-5
|
307 |
+
momentum: the value used for the running_mean and running_var
|
308 |
+
computation. Default: 0.1
|
309 |
+
affine: a boolean value that when set to ``True``, gives the layer learnable
|
310 |
+
affine parameters. Default: ``True``
|
311 |
+
|
312 |
+
Shape:
|
313 |
+
- Input: :math:`(N, C, D, H, W)`
|
314 |
+
- Output: :math:`(N, C, D, H, W)` (same shape as input)
|
315 |
+
|
316 |
+
Examples:
|
317 |
+
>>> # With Learnable Parameters
|
318 |
+
>>> m = SynchronizedBatchNorm3d(100)
|
319 |
+
>>> # Without Learnable Parameters
|
320 |
+
>>> m = SynchronizedBatchNorm3d(100, affine=False)
|
321 |
+
>>> input = torch.autograd.Variable(torch.randn(20, 100, 35, 45, 10))
|
322 |
+
>>> output = m(input)
|
323 |
+
"""
|
324 |
+
|
325 |
+
def _check_input_dim(self, input):
|
326 |
+
if input.dim() != 5:
|
327 |
+
raise ValueError('expected 5D input (got {}D input)'
|
328 |
+
.format(input.dim()))
|
329 |
+
super(SynchronizedBatchNorm3d, self)._check_input_dim(input)
|
models/ade20k/segm_lib/nn/modules/comm.py
ADDED
@@ -0,0 +1,131 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
# File : comm.py
|
3 |
+
# Author : Jiayuan Mao
|
4 |
+
# Email : [email protected]
|
5 |
+
# Date : 27/01/2018
|
6 |
+
#
|
7 |
+
# This file is part of Synchronized-BatchNorm-PyTorch.
|
8 |
+
# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
|
9 |
+
# Distributed under MIT License.
|
10 |
+
|
11 |
+
import queue
|
12 |
+
import collections
|
13 |
+
import threading
|
14 |
+
|
15 |
+
__all__ = ['FutureResult', 'SlavePipe', 'SyncMaster']
|
16 |
+
|
17 |
+
|
18 |
+
class FutureResult(object):
|
19 |
+
"""A thread-safe future implementation. Used only as one-to-one pipe."""
|
20 |
+
|
21 |
+
def __init__(self):
|
22 |
+
self._result = None
|
23 |
+
self._lock = threading.Lock()
|
24 |
+
self._cond = threading.Condition(self._lock)
|
25 |
+
|
26 |
+
def put(self, result):
|
27 |
+
with self._lock:
|
28 |
+
assert self._result is None, 'Previous result has\'t been fetched.'
|
29 |
+
self._result = result
|
30 |
+
self._cond.notify()
|
31 |
+
|
32 |
+
def get(self):
|
33 |
+
with self._lock:
|
34 |
+
if self._result is None:
|
35 |
+
self._cond.wait()
|
36 |
+
|
37 |
+
res = self._result
|
38 |
+
self._result = None
|
39 |
+
return res
|
40 |
+
|
41 |
+
|
42 |
+
_MasterRegistry = collections.namedtuple('MasterRegistry', ['result'])
|
43 |
+
_SlavePipeBase = collections.namedtuple('_SlavePipeBase', ['identifier', 'queue', 'result'])
|
44 |
+
|
45 |
+
|
46 |
+
class SlavePipe(_SlavePipeBase):
|
47 |
+
"""Pipe for master-slave communication."""
|
48 |
+
|
49 |
+
def run_slave(self, msg):
|
50 |
+
self.queue.put((self.identifier, msg))
|
51 |
+
ret = self.result.get()
|
52 |
+
self.queue.put(True)
|
53 |
+
return ret
|
54 |
+
|
55 |
+
|
56 |
+
class SyncMaster(object):
|
57 |
+
"""An abstract `SyncMaster` object.
|
58 |
+
|
59 |
+
- During the replication, as the data parallel will trigger an callback of each module, all slave devices should
|
60 |
+
call `register(id)` and obtain an `SlavePipe` to communicate with the master.
|
61 |
+
- During the forward pass, master device invokes `run_master`, all messages from slave devices will be collected,
|
62 |
+
and passed to a registered callback.
|
63 |
+
- After receiving the messages, the master device should gather the information and determine to message passed
|
64 |
+
back to each slave devices.
|
65 |
+
"""
|
66 |
+
|
67 |
+
def __init__(self, master_callback):
|
68 |
+
"""
|
69 |
+
|
70 |
+
Args:
|
71 |
+
master_callback: a callback to be invoked after having collected messages from slave devices.
|
72 |
+
"""
|
73 |
+
self._master_callback = master_callback
|
74 |
+
self._queue = queue.Queue()
|
75 |
+
self._registry = collections.OrderedDict()
|
76 |
+
self._activated = False
|
77 |
+
|
78 |
+
def register_slave(self, identifier):
|
79 |
+
"""
|
80 |
+
Register an slave device.
|
81 |
+
|
82 |
+
Args:
|
83 |
+
identifier: an identifier, usually is the device id.
|
84 |
+
|
85 |
+
Returns: a `SlavePipe` object which can be used to communicate with the master device.
|
86 |
+
|
87 |
+
"""
|
88 |
+
if self._activated:
|
89 |
+
assert self._queue.empty(), 'Queue is not clean before next initialization.'
|
90 |
+
self._activated = False
|
91 |
+
self._registry.clear()
|
92 |
+
future = FutureResult()
|
93 |
+
self._registry[identifier] = _MasterRegistry(future)
|
94 |
+
return SlavePipe(identifier, self._queue, future)
|
95 |
+
|
96 |
+
def run_master(self, master_msg):
|
97 |
+
"""
|
98 |
+
Main entry for the master device in each forward pass.
|
99 |
+
The messages were first collected from each devices (including the master device), and then
|
100 |
+
an callback will be invoked to compute the message to be sent back to each devices
|
101 |
+
(including the master device).
|
102 |
+
|
103 |
+
Args:
|
104 |
+
master_msg: the message that the master want to send to itself. This will be placed as the first
|
105 |
+
message when calling `master_callback`. For detailed usage, see `_SynchronizedBatchNorm` for an example.
|
106 |
+
|
107 |
+
Returns: the message to be sent back to the master device.
|
108 |
+
|
109 |
+
"""
|
110 |
+
self._activated = True
|
111 |
+
|
112 |
+
intermediates = [(0, master_msg)]
|
113 |
+
for i in range(self.nr_slaves):
|
114 |
+
intermediates.append(self._queue.get())
|
115 |
+
|
116 |
+
results = self._master_callback(intermediates)
|
117 |
+
assert results[0][0] == 0, 'The first result should belongs to the master.'
|
118 |
+
|
119 |
+
for i, res in results:
|
120 |
+
if i == 0:
|
121 |
+
continue
|
122 |
+
self._registry[i].result.put(res)
|
123 |
+
|
124 |
+
for i in range(self.nr_slaves):
|
125 |
+
assert self._queue.get() is True
|
126 |
+
|
127 |
+
return results[0][1]
|
128 |
+
|
129 |
+
@property
|
130 |
+
def nr_slaves(self):
|
131 |
+
return len(self._registry)
|
models/ade20k/segm_lib/nn/modules/replicate.py
ADDED
@@ -0,0 +1,94 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
# File : replicate.py
|
3 |
+
# Author : Jiayuan Mao
|
4 |
+
# Email : [email protected]
|
5 |
+
# Date : 27/01/2018
|
6 |
+
#
|
7 |
+
# This file is part of Synchronized-BatchNorm-PyTorch.
|
8 |
+
# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
|
9 |
+
# Distributed under MIT License.
|
10 |
+
|
11 |
+
import functools
|
12 |
+
|
13 |
+
from torch.nn.parallel.data_parallel import DataParallel
|
14 |
+
|
15 |
+
__all__ = [
|
16 |
+
'CallbackContext',
|
17 |
+
'execute_replication_callbacks',
|
18 |
+
'DataParallelWithCallback',
|
19 |
+
'patch_replication_callback'
|
20 |
+
]
|
21 |
+
|
22 |
+
|
23 |
+
class CallbackContext(object):
|
24 |
+
pass
|
25 |
+
|
26 |
+
|
27 |
+
def execute_replication_callbacks(modules):
|
28 |
+
"""
|
29 |
+
Execute an replication callback `__data_parallel_replicate__` on each module created by original replication.
|
30 |
+
|
31 |
+
The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)`
|
32 |
+
|
33 |
+
Note that, as all modules are isomorphism, we assign each sub-module with a context
|
34 |
+
(shared among multiple copies of this module on different devices).
|
35 |
+
Through this context, different copies can share some information.
|
36 |
+
|
37 |
+
We guarantee that the callback on the master copy (the first copy) will be called ahead of calling the callback
|
38 |
+
of any slave copies.
|
39 |
+
"""
|
40 |
+
master_copy = modules[0]
|
41 |
+
nr_modules = len(list(master_copy.modules()))
|
42 |
+
ctxs = [CallbackContext() for _ in range(nr_modules)]
|
43 |
+
|
44 |
+
for i, module in enumerate(modules):
|
45 |
+
for j, m in enumerate(module.modules()):
|
46 |
+
if hasattr(m, '__data_parallel_replicate__'):
|
47 |
+
m.__data_parallel_replicate__(ctxs[j], i)
|
48 |
+
|
49 |
+
|
50 |
+
class DataParallelWithCallback(DataParallel):
|
51 |
+
"""
|
52 |
+
Data Parallel with a replication callback.
|
53 |
+
|
54 |
+
An replication callback `__data_parallel_replicate__` of each module will be invoked after being created by
|
55 |
+
original `replicate` function.
|
56 |
+
The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)`
|
57 |
+
|
58 |
+
Examples:
|
59 |
+
> sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False)
|
60 |
+
> sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1])
|
61 |
+
# sync_bn.__data_parallel_replicate__ will be invoked.
|
62 |
+
"""
|
63 |
+
|
64 |
+
def replicate(self, module, device_ids):
|
65 |
+
modules = super(DataParallelWithCallback, self).replicate(module, device_ids)
|
66 |
+
execute_replication_callbacks(modules)
|
67 |
+
return modules
|
68 |
+
|
69 |
+
|
70 |
+
def patch_replication_callback(data_parallel):
|
71 |
+
"""
|
72 |
+
Monkey-patch an existing `DataParallel` object. Add the replication callback.
|
73 |
+
Useful when you have customized `DataParallel` implementation.
|
74 |
+
|
75 |
+
Examples:
|
76 |
+
> sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False)
|
77 |
+
> sync_bn = DataParallel(sync_bn, device_ids=[0, 1])
|
78 |
+
> patch_replication_callback(sync_bn)
|
79 |
+
# this is equivalent to
|
80 |
+
> sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False)
|
81 |
+
> sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1])
|
82 |
+
"""
|
83 |
+
|
84 |
+
assert isinstance(data_parallel, DataParallel)
|
85 |
+
|
86 |
+
old_replicate = data_parallel.replicate
|
87 |
+
|
88 |
+
@functools.wraps(old_replicate)
|
89 |
+
def new_replicate(module, device_ids):
|
90 |
+
modules = old_replicate(module, device_ids)
|
91 |
+
execute_replication_callbacks(modules)
|
92 |
+
return modules
|
93 |
+
|
94 |
+
data_parallel.replicate = new_replicate
|
models/ade20k/segm_lib/nn/modules/tests/test_numeric_batchnorm.py
ADDED
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
# File : test_numeric_batchnorm.py
|
3 |
+
# Author : Jiayuan Mao
|
4 |
+
# Email : [email protected]
|
5 |
+
# Date : 27/01/2018
|
6 |
+
#
|
7 |
+
# This file is part of Synchronized-BatchNorm-PyTorch.
|
8 |
+
|
9 |
+
import unittest
|
10 |
+
|
11 |
+
import torch
|
12 |
+
import torch.nn as nn
|
13 |
+
from torch.autograd import Variable
|
14 |
+
|
15 |
+
from sync_batchnorm.unittest import TorchTestCase
|
16 |
+
|
17 |
+
|
18 |
+
def handy_var(a, unbias=True):
|
19 |
+
n = a.size(0)
|
20 |
+
asum = a.sum(dim=0)
|
21 |
+
as_sum = (a ** 2).sum(dim=0) # a square sum
|
22 |
+
sumvar = as_sum - asum * asum / n
|
23 |
+
if unbias:
|
24 |
+
return sumvar / (n - 1)
|
25 |
+
else:
|
26 |
+
return sumvar / n
|
27 |
+
|
28 |
+
|
29 |
+
class NumericTestCase(TorchTestCase):
|
30 |
+
def testNumericBatchNorm(self):
|
31 |
+
a = torch.rand(16, 10)
|
32 |
+
bn = nn.BatchNorm2d(10, momentum=1, eps=1e-5, affine=False)
|
33 |
+
bn.train()
|
34 |
+
|
35 |
+
a_var1 = Variable(a, requires_grad=True)
|
36 |
+
b_var1 = bn(a_var1)
|
37 |
+
loss1 = b_var1.sum()
|
38 |
+
loss1.backward()
|
39 |
+
|
40 |
+
a_var2 = Variable(a, requires_grad=True)
|
41 |
+
a_mean2 = a_var2.mean(dim=0, keepdim=True)
|
42 |
+
a_std2 = torch.sqrt(handy_var(a_var2, unbias=False).clamp(min=1e-5))
|
43 |
+
# a_std2 = torch.sqrt(a_var2.var(dim=0, keepdim=True, unbiased=False) + 1e-5)
|
44 |
+
b_var2 = (a_var2 - a_mean2) / a_std2
|
45 |
+
loss2 = b_var2.sum()
|
46 |
+
loss2.backward()
|
47 |
+
|
48 |
+
self.assertTensorClose(bn.running_mean, a.mean(dim=0))
|
49 |
+
self.assertTensorClose(bn.running_var, handy_var(a))
|
50 |
+
self.assertTensorClose(a_var1.data, a_var2.data)
|
51 |
+
self.assertTensorClose(b_var1.data, b_var2.data)
|
52 |
+
self.assertTensorClose(a_var1.grad, a_var2.grad)
|
53 |
+
|
54 |
+
|
55 |
+
if __name__ == '__main__':
|
56 |
+
unittest.main()
|
models/ade20k/segm_lib/nn/modules/tests/test_sync_batchnorm.py
ADDED
@@ -0,0 +1,111 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
# File : test_sync_batchnorm.py
|
3 |
+
# Author : Jiayuan Mao
|
4 |
+
# Email : [email protected]
|
5 |
+
# Date : 27/01/2018
|
6 |
+
#
|
7 |
+
# This file is part of Synchronized-BatchNorm-PyTorch.
|
8 |
+
|
9 |
+
import unittest
|
10 |
+
|
11 |
+
import torch
|
12 |
+
import torch.nn as nn
|
13 |
+
from torch.autograd import Variable
|
14 |
+
|
15 |
+
from sync_batchnorm import SynchronizedBatchNorm1d, SynchronizedBatchNorm2d, DataParallelWithCallback
|
16 |
+
from sync_batchnorm.unittest import TorchTestCase
|
17 |
+
|
18 |
+
|
19 |
+
def handy_var(a, unbias=True):
|
20 |
+
n = a.size(0)
|
21 |
+
asum = a.sum(dim=0)
|
22 |
+
as_sum = (a ** 2).sum(dim=0) # a square sum
|
23 |
+
sumvar = as_sum - asum * asum / n
|
24 |
+
if unbias:
|
25 |
+
return sumvar / (n - 1)
|
26 |
+
else:
|
27 |
+
return sumvar / n
|
28 |
+
|
29 |
+
|
30 |
+
def _find_bn(module):
|
31 |
+
for m in module.modules():
|
32 |
+
if isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d, SynchronizedBatchNorm1d, SynchronizedBatchNorm2d)):
|
33 |
+
return m
|
34 |
+
|
35 |
+
|
36 |
+
class SyncTestCase(TorchTestCase):
|
37 |
+
def _syncParameters(self, bn1, bn2):
|
38 |
+
bn1.reset_parameters()
|
39 |
+
bn2.reset_parameters()
|
40 |
+
if bn1.affine and bn2.affine:
|
41 |
+
bn2.weight.data.copy_(bn1.weight.data)
|
42 |
+
bn2.bias.data.copy_(bn1.bias.data)
|
43 |
+
|
44 |
+
def _checkBatchNormResult(self, bn1, bn2, input, is_train, cuda=False):
|
45 |
+
"""Check the forward and backward for the customized batch normalization."""
|
46 |
+
bn1.train(mode=is_train)
|
47 |
+
bn2.train(mode=is_train)
|
48 |
+
|
49 |
+
if cuda:
|
50 |
+
input = input.cuda()
|
51 |
+
|
52 |
+
self._syncParameters(_find_bn(bn1), _find_bn(bn2))
|
53 |
+
|
54 |
+
input1 = Variable(input, requires_grad=True)
|
55 |
+
output1 = bn1(input1)
|
56 |
+
output1.sum().backward()
|
57 |
+
input2 = Variable(input, requires_grad=True)
|
58 |
+
output2 = bn2(input2)
|
59 |
+
output2.sum().backward()
|
60 |
+
|
61 |
+
self.assertTensorClose(input1.data, input2.data)
|
62 |
+
self.assertTensorClose(output1.data, output2.data)
|
63 |
+
self.assertTensorClose(input1.grad, input2.grad)
|
64 |
+
self.assertTensorClose(_find_bn(bn1).running_mean, _find_bn(bn2).running_mean)
|
65 |
+
self.assertTensorClose(_find_bn(bn1).running_var, _find_bn(bn2).running_var)
|
66 |
+
|
67 |
+
def testSyncBatchNormNormalTrain(self):
|
68 |
+
bn = nn.BatchNorm1d(10)
|
69 |
+
sync_bn = SynchronizedBatchNorm1d(10)
|
70 |
+
|
71 |
+
self._checkBatchNormResult(bn, sync_bn, torch.rand(16, 10), True)
|
72 |
+
|
73 |
+
def testSyncBatchNormNormalEval(self):
|
74 |
+
bn = nn.BatchNorm1d(10)
|
75 |
+
sync_bn = SynchronizedBatchNorm1d(10)
|
76 |
+
|
77 |
+
self._checkBatchNormResult(bn, sync_bn, torch.rand(16, 10), False)
|
78 |
+
|
79 |
+
def testSyncBatchNormSyncTrain(self):
|
80 |
+
bn = nn.BatchNorm1d(10, eps=1e-5, affine=False)
|
81 |
+
sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False)
|
82 |
+
sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1])
|
83 |
+
|
84 |
+
bn.cuda()
|
85 |
+
sync_bn.cuda()
|
86 |
+
|
87 |
+
self._checkBatchNormResult(bn, sync_bn, torch.rand(16, 10), True, cuda=True)
|
88 |
+
|
89 |
+
def testSyncBatchNormSyncEval(self):
|
90 |
+
bn = nn.BatchNorm1d(10, eps=1e-5, affine=False)
|
91 |
+
sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False)
|
92 |
+
sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1])
|
93 |
+
|
94 |
+
bn.cuda()
|
95 |
+
sync_bn.cuda()
|
96 |
+
|
97 |
+
self._checkBatchNormResult(bn, sync_bn, torch.rand(16, 10), False, cuda=True)
|
98 |
+
|
99 |
+
def testSyncBatchNorm2DSyncTrain(self):
|
100 |
+
bn = nn.BatchNorm2d(10)
|
101 |
+
sync_bn = SynchronizedBatchNorm2d(10)
|
102 |
+
sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1])
|
103 |
+
|
104 |
+
bn.cuda()
|
105 |
+
sync_bn.cuda()
|
106 |
+
|
107 |
+
self._checkBatchNormResult(bn, sync_bn, torch.rand(16, 10, 16, 16), True, cuda=True)
|
108 |
+
|
109 |
+
|
110 |
+
if __name__ == '__main__':
|
111 |
+
unittest.main()
|
models/ade20k/segm_lib/nn/modules/unittest.py
ADDED
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
# File : unittest.py
|
3 |
+
# Author : Jiayuan Mao
|
4 |
+
# Email : [email protected]
|
5 |
+
# Date : 27/01/2018
|
6 |
+
#
|
7 |
+
# This file is part of Synchronized-BatchNorm-PyTorch.
|
8 |
+
# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
|
9 |
+
# Distributed under MIT License.
|
10 |
+
|
11 |
+
import unittest
|
12 |
+
|
13 |
+
import numpy as np
|
14 |
+
from torch.autograd import Variable
|
15 |
+
|
16 |
+
|
17 |
+
def as_numpy(v):
|
18 |
+
if isinstance(v, Variable):
|
19 |
+
v = v.data
|
20 |
+
return v.cpu().numpy()
|
21 |
+
|
22 |
+
|
23 |
+
class TorchTestCase(unittest.TestCase):
|
24 |
+
def assertTensorClose(self, a, b, atol=1e-3, rtol=1e-3):
|
25 |
+
npa, npb = as_numpy(a), as_numpy(b)
|
26 |
+
self.assertTrue(
|
27 |
+
np.allclose(npa, npb, atol=atol),
|
28 |
+
'Tensor close check failed\n{}\n{}\nadiff={}, rdiff={}'.format(a, b, np.abs(npa - npb).max(), np.abs((npa - npb) / np.fmax(npa, 1e-5)).max())
|
29 |
+
)
|
models/ade20k/segm_lib/nn/parallel/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from .data_parallel import UserScatteredDataParallel, user_scattered_collate, async_copy_to
|
models/ade20k/segm_lib/nn/parallel/data_parallel.py
ADDED
@@ -0,0 +1,112 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf8 -*-
|
2 |
+
|
3 |
+
import torch.cuda as cuda
|
4 |
+
import torch.nn as nn
|
5 |
+
import torch
|
6 |
+
import collections
|
7 |
+
from torch.nn.parallel._functions import Gather
|
8 |
+
|
9 |
+
|
10 |
+
__all__ = ['UserScatteredDataParallel', 'user_scattered_collate', 'async_copy_to']
|
11 |
+
|
12 |
+
|
13 |
+
def async_copy_to(obj, dev, main_stream=None):
|
14 |
+
if torch.is_tensor(obj):
|
15 |
+
v = obj.cuda(dev, non_blocking=True)
|
16 |
+
if main_stream is not None:
|
17 |
+
v.data.record_stream(main_stream)
|
18 |
+
return v
|
19 |
+
elif isinstance(obj, collections.Mapping):
|
20 |
+
return {k: async_copy_to(o, dev, main_stream) for k, o in obj.items()}
|
21 |
+
elif isinstance(obj, collections.Sequence):
|
22 |
+
return [async_copy_to(o, dev, main_stream) for o in obj]
|
23 |
+
else:
|
24 |
+
return obj
|
25 |
+
|
26 |
+
|
27 |
+
def dict_gather(outputs, target_device, dim=0):
|
28 |
+
"""
|
29 |
+
Gathers variables from different GPUs on a specified device
|
30 |
+
(-1 means the CPU), with dictionary support.
|
31 |
+
"""
|
32 |
+
def gather_map(outputs):
|
33 |
+
out = outputs[0]
|
34 |
+
if torch.is_tensor(out):
|
35 |
+
# MJY(20180330) HACK:: force nr_dims > 0
|
36 |
+
if out.dim() == 0:
|
37 |
+
outputs = [o.unsqueeze(0) for o in outputs]
|
38 |
+
return Gather.apply(target_device, dim, *outputs)
|
39 |
+
elif out is None:
|
40 |
+
return None
|
41 |
+
elif isinstance(out, collections.Mapping):
|
42 |
+
return {k: gather_map([o[k] for o in outputs]) for k in out}
|
43 |
+
elif isinstance(out, collections.Sequence):
|
44 |
+
return type(out)(map(gather_map, zip(*outputs)))
|
45 |
+
return gather_map(outputs)
|
46 |
+
|
47 |
+
|
48 |
+
class DictGatherDataParallel(nn.DataParallel):
|
49 |
+
def gather(self, outputs, output_device):
|
50 |
+
return dict_gather(outputs, output_device, dim=self.dim)
|
51 |
+
|
52 |
+
|
53 |
+
class UserScatteredDataParallel(DictGatherDataParallel):
|
54 |
+
def scatter(self, inputs, kwargs, device_ids):
|
55 |
+
assert len(inputs) == 1
|
56 |
+
inputs = inputs[0]
|
57 |
+
inputs = _async_copy_stream(inputs, device_ids)
|
58 |
+
inputs = [[i] for i in inputs]
|
59 |
+
assert len(kwargs) == 0
|
60 |
+
kwargs = [{} for _ in range(len(inputs))]
|
61 |
+
|
62 |
+
return inputs, kwargs
|
63 |
+
|
64 |
+
|
65 |
+
def user_scattered_collate(batch):
|
66 |
+
return batch
|
67 |
+
|
68 |
+
|
69 |
+
def _async_copy(inputs, device_ids):
|
70 |
+
nr_devs = len(device_ids)
|
71 |
+
assert type(inputs) in (tuple, list)
|
72 |
+
assert len(inputs) == nr_devs
|
73 |
+
|
74 |
+
outputs = []
|
75 |
+
for i, dev in zip(inputs, device_ids):
|
76 |
+
with cuda.device(dev):
|
77 |
+
outputs.append(async_copy_to(i, dev))
|
78 |
+
|
79 |
+
return tuple(outputs)
|
80 |
+
|
81 |
+
|
82 |
+
def _async_copy_stream(inputs, device_ids):
|
83 |
+
nr_devs = len(device_ids)
|
84 |
+
assert type(inputs) in (tuple, list)
|
85 |
+
assert len(inputs) == nr_devs
|
86 |
+
|
87 |
+
outputs = []
|
88 |
+
streams = [_get_stream(d) for d in device_ids]
|
89 |
+
for i, dev, stream in zip(inputs, device_ids, streams):
|
90 |
+
with cuda.device(dev):
|
91 |
+
main_stream = cuda.current_stream()
|
92 |
+
with cuda.stream(stream):
|
93 |
+
outputs.append(async_copy_to(i, dev, main_stream=main_stream))
|
94 |
+
main_stream.wait_stream(stream)
|
95 |
+
|
96 |
+
return outputs
|
97 |
+
|
98 |
+
|
99 |
+
"""Adapted from: torch/nn/parallel/_functions.py"""
|
100 |
+
# background streams used for copying
|
101 |
+
_streams = None
|
102 |
+
|
103 |
+
|
104 |
+
def _get_stream(device):
|
105 |
+
"""Gets a background stream for copying between CPU and GPU"""
|
106 |
+
global _streams
|
107 |
+
if device == -1:
|
108 |
+
return None
|
109 |
+
if _streams is None:
|
110 |
+
_streams = [None] * cuda.device_count()
|
111 |
+
if _streams[device] is None: _streams[device] = cuda.Stream(device)
|
112 |
+
return _streams[device]
|
models/ade20k/segm_lib/utils/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from .th import *
|
models/ade20k/segm_lib/utils/data/__init__.py
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
from .dataset import Dataset, TensorDataset, ConcatDataset
|
3 |
+
from .dataloader import DataLoader
|
models/ade20k/segm_lib/utils/data/dataloader.py
ADDED
@@ -0,0 +1,425 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.multiprocessing as multiprocessing
|
3 |
+
from torch._C import _set_worker_signal_handlers, \
|
4 |
+
_remove_worker_pids, _error_if_any_worker_fails
|
5 |
+
try:
|
6 |
+
from torch._C import _set_worker_pids
|
7 |
+
except:
|
8 |
+
from torch._C import _update_worker_pids as _set_worker_pids
|
9 |
+
from .sampler import SequentialSampler, RandomSampler, BatchSampler
|
10 |
+
import signal
|
11 |
+
import collections
|
12 |
+
import re
|
13 |
+
import sys
|
14 |
+
import threading
|
15 |
+
import traceback
|
16 |
+
from torch._six import string_classes, int_classes
|
17 |
+
import numpy as np
|
18 |
+
|
19 |
+
if sys.version_info[0] == 2:
|
20 |
+
import Queue as queue
|
21 |
+
else:
|
22 |
+
import queue
|
23 |
+
|
24 |
+
|
25 |
+
class ExceptionWrapper(object):
|
26 |
+
r"Wraps an exception plus traceback to communicate across threads"
|
27 |
+
|
28 |
+
def __init__(self, exc_info):
|
29 |
+
self.exc_type = exc_info[0]
|
30 |
+
self.exc_msg = "".join(traceback.format_exception(*exc_info))
|
31 |
+
|
32 |
+
|
33 |
+
_use_shared_memory = False
|
34 |
+
"""Whether to use shared memory in default_collate"""
|
35 |
+
|
36 |
+
|
37 |
+
def _worker_loop(dataset, index_queue, data_queue, collate_fn, seed, init_fn, worker_id):
|
38 |
+
global _use_shared_memory
|
39 |
+
_use_shared_memory = True
|
40 |
+
|
41 |
+
# Intialize C side signal handlers for SIGBUS and SIGSEGV. Python signal
|
42 |
+
# module's handlers are executed after Python returns from C low-level
|
43 |
+
# handlers, likely when the same fatal signal happened again already.
|
44 |
+
# https://docs.python.org/3/library/signal.html Sec. 18.8.1.1
|
45 |
+
_set_worker_signal_handlers()
|
46 |
+
|
47 |
+
torch.set_num_threads(1)
|
48 |
+
torch.manual_seed(seed)
|
49 |
+
np.random.seed(seed)
|
50 |
+
|
51 |
+
if init_fn is not None:
|
52 |
+
init_fn(worker_id)
|
53 |
+
|
54 |
+
while True:
|
55 |
+
r = index_queue.get()
|
56 |
+
if r is None:
|
57 |
+
break
|
58 |
+
idx, batch_indices = r
|
59 |
+
try:
|
60 |
+
samples = collate_fn([dataset[i] for i in batch_indices])
|
61 |
+
except Exception:
|
62 |
+
data_queue.put((idx, ExceptionWrapper(sys.exc_info())))
|
63 |
+
else:
|
64 |
+
data_queue.put((idx, samples))
|
65 |
+
|
66 |
+
|
67 |
+
def _worker_manager_loop(in_queue, out_queue, done_event, pin_memory, device_id):
|
68 |
+
if pin_memory:
|
69 |
+
torch.cuda.set_device(device_id)
|
70 |
+
|
71 |
+
while True:
|
72 |
+
try:
|
73 |
+
r = in_queue.get()
|
74 |
+
except Exception:
|
75 |
+
if done_event.is_set():
|
76 |
+
return
|
77 |
+
raise
|
78 |
+
if r is None:
|
79 |
+
break
|
80 |
+
if isinstance(r[1], ExceptionWrapper):
|
81 |
+
out_queue.put(r)
|
82 |
+
continue
|
83 |
+
idx, batch = r
|
84 |
+
try:
|
85 |
+
if pin_memory:
|
86 |
+
batch = pin_memory_batch(batch)
|
87 |
+
except Exception:
|
88 |
+
out_queue.put((idx, ExceptionWrapper(sys.exc_info())))
|
89 |
+
else:
|
90 |
+
out_queue.put((idx, batch))
|
91 |
+
|
92 |
+
numpy_type_map = {
|
93 |
+
'float64': torch.DoubleTensor,
|
94 |
+
'float32': torch.FloatTensor,
|
95 |
+
'float16': torch.HalfTensor,
|
96 |
+
'int64': torch.LongTensor,
|
97 |
+
'int32': torch.IntTensor,
|
98 |
+
'int16': torch.ShortTensor,
|
99 |
+
'int8': torch.CharTensor,
|
100 |
+
'uint8': torch.ByteTensor,
|
101 |
+
}
|
102 |
+
|
103 |
+
|
104 |
+
def default_collate(batch):
|
105 |
+
"Puts each data field into a tensor with outer dimension batch size"
|
106 |
+
|
107 |
+
error_msg = "batch must contain tensors, numbers, dicts or lists; found {}"
|
108 |
+
elem_type = type(batch[0])
|
109 |
+
if torch.is_tensor(batch[0]):
|
110 |
+
out = None
|
111 |
+
if _use_shared_memory:
|
112 |
+
# If we're in a background process, concatenate directly into a
|
113 |
+
# shared memory tensor to avoid an extra copy
|
114 |
+
numel = sum([x.numel() for x in batch])
|
115 |
+
storage = batch[0].storage()._new_shared(numel)
|
116 |
+
out = batch[0].new(storage)
|
117 |
+
return torch.stack(batch, 0, out=out)
|
118 |
+
elif elem_type.__module__ == 'numpy' and elem_type.__name__ != 'str_' \
|
119 |
+
and elem_type.__name__ != 'string_':
|
120 |
+
elem = batch[0]
|
121 |
+
if elem_type.__name__ == 'ndarray':
|
122 |
+
# array of string classes and object
|
123 |
+
if re.search('[SaUO]', elem.dtype.str) is not None:
|
124 |
+
raise TypeError(error_msg.format(elem.dtype))
|
125 |
+
|
126 |
+
return torch.stack([torch.from_numpy(b) for b in batch], 0)
|
127 |
+
if elem.shape == (): # scalars
|
128 |
+
py_type = float if elem.dtype.name.startswith('float') else int
|
129 |
+
return numpy_type_map[elem.dtype.name](list(map(py_type, batch)))
|
130 |
+
elif isinstance(batch[0], int_classes):
|
131 |
+
return torch.LongTensor(batch)
|
132 |
+
elif isinstance(batch[0], float):
|
133 |
+
return torch.DoubleTensor(batch)
|
134 |
+
elif isinstance(batch[0], string_classes):
|
135 |
+
return batch
|
136 |
+
elif isinstance(batch[0], collections.Mapping):
|
137 |
+
return {key: default_collate([d[key] for d in batch]) for key in batch[0]}
|
138 |
+
elif isinstance(batch[0], collections.Sequence):
|
139 |
+
transposed = zip(*batch)
|
140 |
+
return [default_collate(samples) for samples in transposed]
|
141 |
+
|
142 |
+
raise TypeError((error_msg.format(type(batch[0]))))
|
143 |
+
|
144 |
+
|
145 |
+
def pin_memory_batch(batch):
|
146 |
+
if torch.is_tensor(batch):
|
147 |
+
return batch.pin_memory()
|
148 |
+
elif isinstance(batch, string_classes):
|
149 |
+
return batch
|
150 |
+
elif isinstance(batch, collections.Mapping):
|
151 |
+
return {k: pin_memory_batch(sample) for k, sample in batch.items()}
|
152 |
+
elif isinstance(batch, collections.Sequence):
|
153 |
+
return [pin_memory_batch(sample) for sample in batch]
|
154 |
+
else:
|
155 |
+
return batch
|
156 |
+
|
157 |
+
|
158 |
+
_SIGCHLD_handler_set = False
|
159 |
+
"""Whether SIGCHLD handler is set for DataLoader worker failures. Only one
|
160 |
+
handler needs to be set for all DataLoaders in a process."""
|
161 |
+
|
162 |
+
|
163 |
+
def _set_SIGCHLD_handler():
|
164 |
+
# Windows doesn't support SIGCHLD handler
|
165 |
+
if sys.platform == 'win32':
|
166 |
+
return
|
167 |
+
# can't set signal in child threads
|
168 |
+
if not isinstance(threading.current_thread(), threading._MainThread):
|
169 |
+
return
|
170 |
+
global _SIGCHLD_handler_set
|
171 |
+
if _SIGCHLD_handler_set:
|
172 |
+
return
|
173 |
+
previous_handler = signal.getsignal(signal.SIGCHLD)
|
174 |
+
if not callable(previous_handler):
|
175 |
+
previous_handler = None
|
176 |
+
|
177 |
+
def handler(signum, frame):
|
178 |
+
# This following call uses `waitid` with WNOHANG from C side. Therefore,
|
179 |
+
# Python can still get and update the process status successfully.
|
180 |
+
_error_if_any_worker_fails()
|
181 |
+
if previous_handler is not None:
|
182 |
+
previous_handler(signum, frame)
|
183 |
+
|
184 |
+
signal.signal(signal.SIGCHLD, handler)
|
185 |
+
_SIGCHLD_handler_set = True
|
186 |
+
|
187 |
+
|
188 |
+
class DataLoaderIter(object):
|
189 |
+
"Iterates once over the DataLoader's dataset, as specified by the sampler"
|
190 |
+
|
191 |
+
def __init__(self, loader):
|
192 |
+
self.dataset = loader.dataset
|
193 |
+
self.collate_fn = loader.collate_fn
|
194 |
+
self.batch_sampler = loader.batch_sampler
|
195 |
+
self.num_workers = loader.num_workers
|
196 |
+
self.pin_memory = loader.pin_memory and torch.cuda.is_available()
|
197 |
+
self.timeout = loader.timeout
|
198 |
+
self.done_event = threading.Event()
|
199 |
+
|
200 |
+
self.sample_iter = iter(self.batch_sampler)
|
201 |
+
|
202 |
+
if self.num_workers > 0:
|
203 |
+
self.worker_init_fn = loader.worker_init_fn
|
204 |
+
self.index_queue = multiprocessing.SimpleQueue()
|
205 |
+
self.worker_result_queue = multiprocessing.SimpleQueue()
|
206 |
+
self.batches_outstanding = 0
|
207 |
+
self.worker_pids_set = False
|
208 |
+
self.shutdown = False
|
209 |
+
self.send_idx = 0
|
210 |
+
self.rcvd_idx = 0
|
211 |
+
self.reorder_dict = {}
|
212 |
+
|
213 |
+
base_seed = torch.LongTensor(1).random_(0, 2**31-1)[0]
|
214 |
+
self.workers = [
|
215 |
+
multiprocessing.Process(
|
216 |
+
target=_worker_loop,
|
217 |
+
args=(self.dataset, self.index_queue, self.worker_result_queue, self.collate_fn,
|
218 |
+
base_seed + i, self.worker_init_fn, i))
|
219 |
+
for i in range(self.num_workers)]
|
220 |
+
|
221 |
+
if self.pin_memory or self.timeout > 0:
|
222 |
+
self.data_queue = queue.Queue()
|
223 |
+
if self.pin_memory:
|
224 |
+
maybe_device_id = torch.cuda.current_device()
|
225 |
+
else:
|
226 |
+
# do not initialize cuda context if not necessary
|
227 |
+
maybe_device_id = None
|
228 |
+
self.worker_manager_thread = threading.Thread(
|
229 |
+
target=_worker_manager_loop,
|
230 |
+
args=(self.worker_result_queue, self.data_queue, self.done_event, self.pin_memory,
|
231 |
+
maybe_device_id))
|
232 |
+
self.worker_manager_thread.daemon = True
|
233 |
+
self.worker_manager_thread.start()
|
234 |
+
else:
|
235 |
+
self.data_queue = self.worker_result_queue
|
236 |
+
|
237 |
+
for w in self.workers:
|
238 |
+
w.daemon = True # ensure that the worker exits on process exit
|
239 |
+
w.start()
|
240 |
+
|
241 |
+
_set_worker_pids(id(self), tuple(w.pid for w in self.workers))
|
242 |
+
_set_SIGCHLD_handler()
|
243 |
+
self.worker_pids_set = True
|
244 |
+
|
245 |
+
# prime the prefetch loop
|
246 |
+
for _ in range(2 * self.num_workers):
|
247 |
+
self._put_indices()
|
248 |
+
|
249 |
+
def __len__(self):
|
250 |
+
return len(self.batch_sampler)
|
251 |
+
|
252 |
+
def _get_batch(self):
|
253 |
+
if self.timeout > 0:
|
254 |
+
try:
|
255 |
+
return self.data_queue.get(timeout=self.timeout)
|
256 |
+
except queue.Empty:
|
257 |
+
raise RuntimeError('DataLoader timed out after {} seconds'.format(self.timeout))
|
258 |
+
else:
|
259 |
+
return self.data_queue.get()
|
260 |
+
|
261 |
+
def __next__(self):
|
262 |
+
if self.num_workers == 0: # same-process loading
|
263 |
+
indices = next(self.sample_iter) # may raise StopIteration
|
264 |
+
batch = self.collate_fn([self.dataset[i] for i in indices])
|
265 |
+
if self.pin_memory:
|
266 |
+
batch = pin_memory_batch(batch)
|
267 |
+
return batch
|
268 |
+
|
269 |
+
# check if the next sample has already been generated
|
270 |
+
if self.rcvd_idx in self.reorder_dict:
|
271 |
+
batch = self.reorder_dict.pop(self.rcvd_idx)
|
272 |
+
return self._process_next_batch(batch)
|
273 |
+
|
274 |
+
if self.batches_outstanding == 0:
|
275 |
+
self._shutdown_workers()
|
276 |
+
raise StopIteration
|
277 |
+
|
278 |
+
while True:
|
279 |
+
assert (not self.shutdown and self.batches_outstanding > 0)
|
280 |
+
idx, batch = self._get_batch()
|
281 |
+
self.batches_outstanding -= 1
|
282 |
+
if idx != self.rcvd_idx:
|
283 |
+
# store out-of-order samples
|
284 |
+
self.reorder_dict[idx] = batch
|
285 |
+
continue
|
286 |
+
return self._process_next_batch(batch)
|
287 |
+
|
288 |
+
next = __next__ # Python 2 compatibility
|
289 |
+
|
290 |
+
def __iter__(self):
|
291 |
+
return self
|
292 |
+
|
293 |
+
def _put_indices(self):
|
294 |
+
assert self.batches_outstanding < 2 * self.num_workers
|
295 |
+
indices = next(self.sample_iter, None)
|
296 |
+
if indices is None:
|
297 |
+
return
|
298 |
+
self.index_queue.put((self.send_idx, indices))
|
299 |
+
self.batches_outstanding += 1
|
300 |
+
self.send_idx += 1
|
301 |
+
|
302 |
+
def _process_next_batch(self, batch):
|
303 |
+
self.rcvd_idx += 1
|
304 |
+
self._put_indices()
|
305 |
+
if isinstance(batch, ExceptionWrapper):
|
306 |
+
raise batch.exc_type(batch.exc_msg)
|
307 |
+
return batch
|
308 |
+
|
309 |
+
def __getstate__(self):
|
310 |
+
# TODO: add limited pickling support for sharing an iterator
|
311 |
+
# across multiple threads for HOGWILD.
|
312 |
+
# Probably the best way to do this is by moving the sample pushing
|
313 |
+
# to a separate thread and then just sharing the data queue
|
314 |
+
# but signalling the end is tricky without a non-blocking API
|
315 |
+
raise NotImplementedError("DataLoaderIterator cannot be pickled")
|
316 |
+
|
317 |
+
def _shutdown_workers(self):
|
318 |
+
try:
|
319 |
+
if not self.shutdown:
|
320 |
+
self.shutdown = True
|
321 |
+
self.done_event.set()
|
322 |
+
# if worker_manager_thread is waiting to put
|
323 |
+
while not self.data_queue.empty():
|
324 |
+
self.data_queue.get()
|
325 |
+
for _ in self.workers:
|
326 |
+
self.index_queue.put(None)
|
327 |
+
# done_event should be sufficient to exit worker_manager_thread,
|
328 |
+
# but be safe here and put another None
|
329 |
+
self.worker_result_queue.put(None)
|
330 |
+
finally:
|
331 |
+
# removes pids no matter what
|
332 |
+
if self.worker_pids_set:
|
333 |
+
_remove_worker_pids(id(self))
|
334 |
+
self.worker_pids_set = False
|
335 |
+
|
336 |
+
def __del__(self):
|
337 |
+
if self.num_workers > 0:
|
338 |
+
self._shutdown_workers()
|
339 |
+
|
340 |
+
|
341 |
+
class DataLoader(object):
|
342 |
+
"""
|
343 |
+
Data loader. Combines a dataset and a sampler, and provides
|
344 |
+
single- or multi-process iterators over the dataset.
|
345 |
+
|
346 |
+
Arguments:
|
347 |
+
dataset (Dataset): dataset from which to load the data.
|
348 |
+
batch_size (int, optional): how many samples per batch to load
|
349 |
+
(default: 1).
|
350 |
+
shuffle (bool, optional): set to ``True`` to have the data reshuffled
|
351 |
+
at every epoch (default: False).
|
352 |
+
sampler (Sampler, optional): defines the strategy to draw samples from
|
353 |
+
the dataset. If specified, ``shuffle`` must be False.
|
354 |
+
batch_sampler (Sampler, optional): like sampler, but returns a batch of
|
355 |
+
indices at a time. Mutually exclusive with batch_size, shuffle,
|
356 |
+
sampler, and drop_last.
|
357 |
+
num_workers (int, optional): how many subprocesses to use for data
|
358 |
+
loading. 0 means that the data will be loaded in the main process.
|
359 |
+
(default: 0)
|
360 |
+
collate_fn (callable, optional): merges a list of samples to form a mini-batch.
|
361 |
+
pin_memory (bool, optional): If ``True``, the data loader will copy tensors
|
362 |
+
into CUDA pinned memory before returning them.
|
363 |
+
drop_last (bool, optional): set to ``True`` to drop the last incomplete batch,
|
364 |
+
if the dataset size is not divisible by the batch size. If ``False`` and
|
365 |
+
the size of dataset is not divisible by the batch size, then the last batch
|
366 |
+
will be smaller. (default: False)
|
367 |
+
timeout (numeric, optional): if positive, the timeout value for collecting a batch
|
368 |
+
from workers. Should always be non-negative. (default: 0)
|
369 |
+
worker_init_fn (callable, optional): If not None, this will be called on each
|
370 |
+
worker subprocess with the worker id (an int in ``[0, num_workers - 1]``) as
|
371 |
+
input, after seeding and before data loading. (default: None)
|
372 |
+
|
373 |
+
.. note:: By default, each worker will have its PyTorch seed set to
|
374 |
+
``base_seed + worker_id``, where ``base_seed`` is a long generated
|
375 |
+
by main process using its RNG. You may use ``torch.initial_seed()`` to access
|
376 |
+
this value in :attr:`worker_init_fn`, which can be used to set other seeds
|
377 |
+
(e.g. NumPy) before data loading.
|
378 |
+
|
379 |
+
.. warning:: If ``spawn'' start method is used, :attr:`worker_init_fn` cannot be an
|
380 |
+
unpicklable object, e.g., a lambda function.
|
381 |
+
"""
|
382 |
+
|
383 |
+
def __init__(self, dataset, batch_size=1, shuffle=False, sampler=None, batch_sampler=None,
|
384 |
+
num_workers=0, collate_fn=default_collate, pin_memory=False, drop_last=False,
|
385 |
+
timeout=0, worker_init_fn=None):
|
386 |
+
self.dataset = dataset
|
387 |
+
self.batch_size = batch_size
|
388 |
+
self.num_workers = num_workers
|
389 |
+
self.collate_fn = collate_fn
|
390 |
+
self.pin_memory = pin_memory
|
391 |
+
self.drop_last = drop_last
|
392 |
+
self.timeout = timeout
|
393 |
+
self.worker_init_fn = worker_init_fn
|
394 |
+
|
395 |
+
if timeout < 0:
|
396 |
+
raise ValueError('timeout option should be non-negative')
|
397 |
+
|
398 |
+
if batch_sampler is not None:
|
399 |
+
if batch_size > 1 or shuffle or sampler is not None or drop_last:
|
400 |
+
raise ValueError('batch_sampler is mutually exclusive with '
|
401 |
+
'batch_size, shuffle, sampler, and drop_last')
|
402 |
+
|
403 |
+
if sampler is not None and shuffle:
|
404 |
+
raise ValueError('sampler is mutually exclusive with shuffle')
|
405 |
+
|
406 |
+
if self.num_workers < 0:
|
407 |
+
raise ValueError('num_workers cannot be negative; '
|
408 |
+
'use num_workers=0 to disable multiprocessing.')
|
409 |
+
|
410 |
+
if batch_sampler is None:
|
411 |
+
if sampler is None:
|
412 |
+
if shuffle:
|
413 |
+
sampler = RandomSampler(dataset)
|
414 |
+
else:
|
415 |
+
sampler = SequentialSampler(dataset)
|
416 |
+
batch_sampler = BatchSampler(sampler, batch_size, drop_last)
|
417 |
+
|
418 |
+
self.sampler = sampler
|
419 |
+
self.batch_sampler = batch_sampler
|
420 |
+
|
421 |
+
def __iter__(self):
|
422 |
+
return DataLoaderIter(self)
|
423 |
+
|
424 |
+
def __len__(self):
|
425 |
+
return len(self.batch_sampler)
|
models/ade20k/segm_lib/utils/data/dataset.py
ADDED
@@ -0,0 +1,118 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import bisect
|
2 |
+
import warnings
|
3 |
+
|
4 |
+
from torch._utils import _accumulate
|
5 |
+
from torch import randperm
|
6 |
+
|
7 |
+
|
8 |
+
class Dataset(object):
|
9 |
+
"""An abstract class representing a Dataset.
|
10 |
+
|
11 |
+
All other datasets should subclass it. All subclasses should override
|
12 |
+
``__len__``, that provides the size of the dataset, and ``__getitem__``,
|
13 |
+
supporting integer indexing in range from 0 to len(self) exclusive.
|
14 |
+
"""
|
15 |
+
|
16 |
+
def __getitem__(self, index):
|
17 |
+
raise NotImplementedError
|
18 |
+
|
19 |
+
def __len__(self):
|
20 |
+
raise NotImplementedError
|
21 |
+
|
22 |
+
def __add__(self, other):
|
23 |
+
return ConcatDataset([self, other])
|
24 |
+
|
25 |
+
|
26 |
+
class TensorDataset(Dataset):
|
27 |
+
"""Dataset wrapping data and target tensors.
|
28 |
+
|
29 |
+
Each sample will be retrieved by indexing both tensors along the first
|
30 |
+
dimension.
|
31 |
+
|
32 |
+
Arguments:
|
33 |
+
data_tensor (Tensor): contains sample data.
|
34 |
+
target_tensor (Tensor): contains sample targets (labels).
|
35 |
+
"""
|
36 |
+
|
37 |
+
def __init__(self, data_tensor, target_tensor):
|
38 |
+
assert data_tensor.size(0) == target_tensor.size(0)
|
39 |
+
self.data_tensor = data_tensor
|
40 |
+
self.target_tensor = target_tensor
|
41 |
+
|
42 |
+
def __getitem__(self, index):
|
43 |
+
return self.data_tensor[index], self.target_tensor[index]
|
44 |
+
|
45 |
+
def __len__(self):
|
46 |
+
return self.data_tensor.size(0)
|
47 |
+
|
48 |
+
|
49 |
+
class ConcatDataset(Dataset):
|
50 |
+
"""
|
51 |
+
Dataset to concatenate multiple datasets.
|
52 |
+
Purpose: useful to assemble different existing datasets, possibly
|
53 |
+
large-scale datasets as the concatenation operation is done in an
|
54 |
+
on-the-fly manner.
|
55 |
+
|
56 |
+
Arguments:
|
57 |
+
datasets (iterable): List of datasets to be concatenated
|
58 |
+
"""
|
59 |
+
|
60 |
+
@staticmethod
|
61 |
+
def cumsum(sequence):
|
62 |
+
r, s = [], 0
|
63 |
+
for e in sequence:
|
64 |
+
l = len(e)
|
65 |
+
r.append(l + s)
|
66 |
+
s += l
|
67 |
+
return r
|
68 |
+
|
69 |
+
def __init__(self, datasets):
|
70 |
+
super(ConcatDataset, self).__init__()
|
71 |
+
assert len(datasets) > 0, 'datasets should not be an empty iterable'
|
72 |
+
self.datasets = list(datasets)
|
73 |
+
self.cumulative_sizes = self.cumsum(self.datasets)
|
74 |
+
|
75 |
+
def __len__(self):
|
76 |
+
return self.cumulative_sizes[-1]
|
77 |
+
|
78 |
+
def __getitem__(self, idx):
|
79 |
+
dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx)
|
80 |
+
if dataset_idx == 0:
|
81 |
+
sample_idx = idx
|
82 |
+
else:
|
83 |
+
sample_idx = idx - self.cumulative_sizes[dataset_idx - 1]
|
84 |
+
return self.datasets[dataset_idx][sample_idx]
|
85 |
+
|
86 |
+
@property
|
87 |
+
def cummulative_sizes(self):
|
88 |
+
warnings.warn("cummulative_sizes attribute is renamed to "
|
89 |
+
"cumulative_sizes", DeprecationWarning, stacklevel=2)
|
90 |
+
return self.cumulative_sizes
|
91 |
+
|
92 |
+
|
93 |
+
class Subset(Dataset):
|
94 |
+
def __init__(self, dataset, indices):
|
95 |
+
self.dataset = dataset
|
96 |
+
self.indices = indices
|
97 |
+
|
98 |
+
def __getitem__(self, idx):
|
99 |
+
return self.dataset[self.indices[idx]]
|
100 |
+
|
101 |
+
def __len__(self):
|
102 |
+
return len(self.indices)
|
103 |
+
|
104 |
+
|
105 |
+
def random_split(dataset, lengths):
|
106 |
+
"""
|
107 |
+
Randomly split a dataset into non-overlapping new datasets of given lengths
|
108 |
+
ds
|
109 |
+
|
110 |
+
Arguments:
|
111 |
+
dataset (Dataset): Dataset to be split
|
112 |
+
lengths (iterable): lengths of splits to be produced
|
113 |
+
"""
|
114 |
+
if sum(lengths) != len(dataset):
|
115 |
+
raise ValueError("Sum of input lengths does not equal the length of the input dataset!")
|
116 |
+
|
117 |
+
indices = randperm(sum(lengths))
|
118 |
+
return [Subset(dataset, indices[offset - length:offset]) for offset, length in zip(_accumulate(lengths), lengths)]
|
models/ade20k/segm_lib/utils/data/distributed.py
ADDED
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import torch
|
3 |
+
from .sampler import Sampler
|
4 |
+
from torch.distributed import get_world_size, get_rank
|
5 |
+
|
6 |
+
|
7 |
+
class DistributedSampler(Sampler):
|
8 |
+
"""Sampler that restricts data loading to a subset of the dataset.
|
9 |
+
|
10 |
+
It is especially useful in conjunction with
|
11 |
+
:class:`torch.nn.parallel.DistributedDataParallel`. In such case, each
|
12 |
+
process can pass a DistributedSampler instance as a DataLoader sampler,
|
13 |
+
and load a subset of the original dataset that is exclusive to it.
|
14 |
+
|
15 |
+
.. note::
|
16 |
+
Dataset is assumed to be of constant size.
|
17 |
+
|
18 |
+
Arguments:
|
19 |
+
dataset: Dataset used for sampling.
|
20 |
+
num_replicas (optional): Number of processes participating in
|
21 |
+
distributed training.
|
22 |
+
rank (optional): Rank of the current process within num_replicas.
|
23 |
+
"""
|
24 |
+
|
25 |
+
def __init__(self, dataset, num_replicas=None, rank=None):
|
26 |
+
if num_replicas is None:
|
27 |
+
num_replicas = get_world_size()
|
28 |
+
if rank is None:
|
29 |
+
rank = get_rank()
|
30 |
+
self.dataset = dataset
|
31 |
+
self.num_replicas = num_replicas
|
32 |
+
self.rank = rank
|
33 |
+
self.epoch = 0
|
34 |
+
self.num_samples = int(math.ceil(len(self.dataset) * 1.0 / self.num_replicas))
|
35 |
+
self.total_size = self.num_samples * self.num_replicas
|
36 |
+
|
37 |
+
def __iter__(self):
|
38 |
+
# deterministically shuffle based on epoch
|
39 |
+
g = torch.Generator()
|
40 |
+
g.manual_seed(self.epoch)
|
41 |
+
indices = list(torch.randperm(len(self.dataset), generator=g))
|
42 |
+
|
43 |
+
# add extra samples to make it evenly divisible
|
44 |
+
indices += indices[:(self.total_size - len(indices))]
|
45 |
+
assert len(indices) == self.total_size
|
46 |
+
|
47 |
+
# subsample
|
48 |
+
offset = self.num_samples * self.rank
|
49 |
+
indices = indices[offset:offset + self.num_samples]
|
50 |
+
assert len(indices) == self.num_samples
|
51 |
+
|
52 |
+
return iter(indices)
|
53 |
+
|
54 |
+
def __len__(self):
|
55 |
+
return self.num_samples
|
56 |
+
|
57 |
+
def set_epoch(self, epoch):
|
58 |
+
self.epoch = epoch
|
models/ade20k/segm_lib/utils/data/sampler.py
ADDED
@@ -0,0 +1,131 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
|
3 |
+
|
4 |
+
class Sampler(object):
|
5 |
+
"""Base class for all Samplers.
|
6 |
+
|
7 |
+
Every Sampler subclass has to provide an __iter__ method, providing a way
|
8 |
+
to iterate over indices of dataset elements, and a __len__ method that
|
9 |
+
returns the length of the returned iterators.
|
10 |
+
"""
|
11 |
+
|
12 |
+
def __init__(self, data_source):
|
13 |
+
pass
|
14 |
+
|
15 |
+
def __iter__(self):
|
16 |
+
raise NotImplementedError
|
17 |
+
|
18 |
+
def __len__(self):
|
19 |
+
raise NotImplementedError
|
20 |
+
|
21 |
+
|
22 |
+
class SequentialSampler(Sampler):
|
23 |
+
"""Samples elements sequentially, always in the same order.
|
24 |
+
|
25 |
+
Arguments:
|
26 |
+
data_source (Dataset): dataset to sample from
|
27 |
+
"""
|
28 |
+
|
29 |
+
def __init__(self, data_source):
|
30 |
+
self.data_source = data_source
|
31 |
+
|
32 |
+
def __iter__(self):
|
33 |
+
return iter(range(len(self.data_source)))
|
34 |
+
|
35 |
+
def __len__(self):
|
36 |
+
return len(self.data_source)
|
37 |
+
|
38 |
+
|
39 |
+
class RandomSampler(Sampler):
|
40 |
+
"""Samples elements randomly, without replacement.
|
41 |
+
|
42 |
+
Arguments:
|
43 |
+
data_source (Dataset): dataset to sample from
|
44 |
+
"""
|
45 |
+
|
46 |
+
def __init__(self, data_source):
|
47 |
+
self.data_source = data_source
|
48 |
+
|
49 |
+
def __iter__(self):
|
50 |
+
return iter(torch.randperm(len(self.data_source)).long())
|
51 |
+
|
52 |
+
def __len__(self):
|
53 |
+
return len(self.data_source)
|
54 |
+
|
55 |
+
|
56 |
+
class SubsetRandomSampler(Sampler):
|
57 |
+
"""Samples elements randomly from a given list of indices, without replacement.
|
58 |
+
|
59 |
+
Arguments:
|
60 |
+
indices (list): a list of indices
|
61 |
+
"""
|
62 |
+
|
63 |
+
def __init__(self, indices):
|
64 |
+
self.indices = indices
|
65 |
+
|
66 |
+
def __iter__(self):
|
67 |
+
return (self.indices[i] for i in torch.randperm(len(self.indices)))
|
68 |
+
|
69 |
+
def __len__(self):
|
70 |
+
return len(self.indices)
|
71 |
+
|
72 |
+
|
73 |
+
class WeightedRandomSampler(Sampler):
|
74 |
+
"""Samples elements from [0,..,len(weights)-1] with given probabilities (weights).
|
75 |
+
|
76 |
+
Arguments:
|
77 |
+
weights (list) : a list of weights, not necessary summing up to one
|
78 |
+
num_samples (int): number of samples to draw
|
79 |
+
replacement (bool): if ``True``, samples are drawn with replacement.
|
80 |
+
If not, they are drawn without replacement, which means that when a
|
81 |
+
sample index is drawn for a row, it cannot be drawn again for that row.
|
82 |
+
"""
|
83 |
+
|
84 |
+
def __init__(self, weights, num_samples, replacement=True):
|
85 |
+
self.weights = torch.DoubleTensor(weights)
|
86 |
+
self.num_samples = num_samples
|
87 |
+
self.replacement = replacement
|
88 |
+
|
89 |
+
def __iter__(self):
|
90 |
+
return iter(torch.multinomial(self.weights, self.num_samples, self.replacement))
|
91 |
+
|
92 |
+
def __len__(self):
|
93 |
+
return self.num_samples
|
94 |
+
|
95 |
+
|
96 |
+
class BatchSampler(object):
|
97 |
+
"""Wraps another sampler to yield a mini-batch of indices.
|
98 |
+
|
99 |
+
Args:
|
100 |
+
sampler (Sampler): Base sampler.
|
101 |
+
batch_size (int): Size of mini-batch.
|
102 |
+
drop_last (bool): If ``True``, the sampler will drop the last batch if
|
103 |
+
its size would be less than ``batch_size``
|
104 |
+
|
105 |
+
Example:
|
106 |
+
>>> list(BatchSampler(range(10), batch_size=3, drop_last=False))
|
107 |
+
[[0, 1, 2], [3, 4, 5], [6, 7, 8], [9]]
|
108 |
+
>>> list(BatchSampler(range(10), batch_size=3, drop_last=True))
|
109 |
+
[[0, 1, 2], [3, 4, 5], [6, 7, 8]]
|
110 |
+
"""
|
111 |
+
|
112 |
+
def __init__(self, sampler, batch_size, drop_last):
|
113 |
+
self.sampler = sampler
|
114 |
+
self.batch_size = batch_size
|
115 |
+
self.drop_last = drop_last
|
116 |
+
|
117 |
+
def __iter__(self):
|
118 |
+
batch = []
|
119 |
+
for idx in self.sampler:
|
120 |
+
batch.append(idx)
|
121 |
+
if len(batch) == self.batch_size:
|
122 |
+
yield batch
|
123 |
+
batch = []
|
124 |
+
if len(batch) > 0 and not self.drop_last:
|
125 |
+
yield batch
|
126 |
+
|
127 |
+
def __len__(self):
|
128 |
+
if self.drop_last:
|
129 |
+
return len(self.sampler) // self.batch_size
|
130 |
+
else:
|
131 |
+
return (len(self.sampler) + self.batch_size - 1) // self.batch_size
|
models/ade20k/segm_lib/utils/th.py
ADDED
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch.autograd import Variable
|
3 |
+
import numpy as np
|
4 |
+
import collections
|
5 |
+
|
6 |
+
__all__ = ['as_variable', 'as_numpy', 'mark_volatile']
|
7 |
+
|
8 |
+
def as_variable(obj):
|
9 |
+
if isinstance(obj, Variable):
|
10 |
+
return obj
|
11 |
+
if isinstance(obj, collections.Sequence):
|
12 |
+
return [as_variable(v) for v in obj]
|
13 |
+
elif isinstance(obj, collections.Mapping):
|
14 |
+
return {k: as_variable(v) for k, v in obj.items()}
|
15 |
+
else:
|
16 |
+
return Variable(obj)
|
17 |
+
|
18 |
+
def as_numpy(obj):
|
19 |
+
if isinstance(obj, collections.Sequence):
|
20 |
+
return [as_numpy(v) for v in obj]
|
21 |
+
elif isinstance(obj, collections.Mapping):
|
22 |
+
return {k: as_numpy(v) for k, v in obj.items()}
|
23 |
+
elif isinstance(obj, Variable):
|
24 |
+
return obj.data.cpu().numpy()
|
25 |
+
elif torch.is_tensor(obj):
|
26 |
+
return obj.cpu().numpy()
|
27 |
+
else:
|
28 |
+
return np.array(obj)
|
29 |
+
|
30 |
+
def mark_volatile(obj):
|
31 |
+
if torch.is_tensor(obj):
|
32 |
+
obj = Variable(obj)
|
33 |
+
if isinstance(obj, Variable):
|
34 |
+
obj.no_grad = True
|
35 |
+
return obj
|
36 |
+
elif isinstance(obj, collections.Mapping):
|
37 |
+
return {k: mark_volatile(o) for k, o in obj.items()}
|
38 |
+
elif isinstance(obj, collections.Sequence):
|
39 |
+
return [mark_volatile(o) for o in obj]
|
40 |
+
else:
|
41 |
+
return obj
|
models/ade20k/utils.py
ADDED
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Modified from https://github.com/CSAILVision/semantic-segmentation-pytorch"""
|
2 |
+
|
3 |
+
import os
|
4 |
+
import sys
|
5 |
+
|
6 |
+
import numpy as np
|
7 |
+
import torch
|
8 |
+
|
9 |
+
try:
|
10 |
+
from urllib import urlretrieve
|
11 |
+
except ImportError:
|
12 |
+
from urllib.request import urlretrieve
|
13 |
+
|
14 |
+
|
15 |
+
def load_url(url, model_dir='./pretrained', map_location=None):
|
16 |
+
if not os.path.exists(model_dir):
|
17 |
+
os.makedirs(model_dir)
|
18 |
+
filename = url.split('/')[-1]
|
19 |
+
cached_file = os.path.join(model_dir, filename)
|
20 |
+
if not os.path.exists(cached_file):
|
21 |
+
sys.stderr.write('Downloading: "{}" to {}\n'.format(url, cached_file))
|
22 |
+
urlretrieve(url, cached_file)
|
23 |
+
return torch.load(cached_file, map_location=map_location)
|
24 |
+
|
25 |
+
|
26 |
+
def color_encode(labelmap, colors, mode='RGB'):
|
27 |
+
labelmap = labelmap.astype('int')
|
28 |
+
labelmap_rgb = np.zeros((labelmap.shape[0], labelmap.shape[1], 3),
|
29 |
+
dtype=np.uint8)
|
30 |
+
for label in np.unique(labelmap):
|
31 |
+
if label < 0:
|
32 |
+
continue
|
33 |
+
labelmap_rgb += (labelmap == label)[:, :, np.newaxis] * \
|
34 |
+
np.tile(colors[label],
|
35 |
+
(labelmap.shape[0], labelmap.shape[1], 1))
|
36 |
+
|
37 |
+
if mode == 'BGR':
|
38 |
+
return labelmap_rgb[:, :, ::-1]
|
39 |
+
else:
|
40 |
+
return labelmap_rgb
|
models/lpips_models/alex.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:df73285e35b22355a2df87cdb6b70b343713b667eddbda73e1977e0c860835c0
|
3 |
+
size 6009
|
models/lpips_models/squeeze.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:4a5350f23600cb79923ce65bb07cbf57dca461329894153e05a1346bd531cf76
|
3 |
+
size 10811
|
models/lpips_models/vgg.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:a78928a0af1e5f0fcb1f3b9e8f8c3a2a5a3de244d830ad5c1feddc79b8432868
|
3 |
+
size 7289
|