Paolo-Fraccaro
commited on
Commit
·
f95bf59
1
Parent(s):
8594822
Update multi_temporal_crop_classification_Prithvi_100M.py
Browse files
multi_temporal_crop_classification_Prithvi_100M.py
CHANGED
@@ -1,3 +1,5 @@
|
|
|
|
|
|
1 |
dist_params = dict(backend='nccl')
|
2 |
log_level = 'INFO'
|
3 |
load_from = None
|
@@ -7,20 +9,50 @@ custom_imports = dict(imports=['geospatial_fm'])
|
|
7 |
num_frames = 3
|
8 |
img_size = 224
|
9 |
num_workers = 2
|
10 |
-
|
|
|
|
|
|
|
11 |
num_layers = 6
|
12 |
patch_size = 16
|
13 |
embed_dim = 768
|
14 |
num_heads = 8
|
15 |
tubelet_size = 1
|
16 |
-
|
17 |
-
eval_epoch_interval =
|
18 |
-
|
19 |
-
|
20 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
21 |
gpu_ids = range(0, 1)
|
22 |
dataset_type = 'GeospatialDataset'
|
23 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
24 |
img_norm_cfg = dict(
|
25 |
means=[
|
26 |
494.905781, 815.239594, 924.335066, 2968.881459, 2634.621962,
|
@@ -33,261 +65,95 @@ img_norm_cfg = dict(
|
|
33 |
284.925432, 357.84876, 575.566823, 896.601013, 951.900334, 921.407808,
|
34 |
284.925432, 357.84876, 575.566823, 896.601013, 951.900334, 921.407808
|
35 |
])
|
36 |
-
|
37 |
-
train=
|
38 |
-
'/home/ubuntu/hls-foundation-os/fine-tuning-examples/data_splits/crop_classification/training_data.txt',
|
39 |
-
val=
|
40 |
-
'/home/ubuntu/hls-foundation-os/fine-tuning-examples/data_splits/crop_classification/validation_data.txt',
|
41 |
-
test=
|
42 |
-
'/home/ubuntu/hls-foundation-os/fine-tuning-examples/data_splits/crop_classification/validation_data.txt'
|
43 |
-
)
|
44 |
bands = [0, 1, 2, 3, 4, 5]
|
|
|
45 |
tile_size = 224
|
46 |
orig_nsize = 512
|
47 |
-
crop_size = (
|
48 |
train_pipeline = [
|
49 |
-
dict(type='LoadGeospatialImageFromFile', to_float32=True),
|
50 |
dict(type='LoadGeospatialAnnotations', reduce_zero_label=True),
|
51 |
dict(type='RandomFlip', prob=0.5),
|
52 |
dict(type='ToTensor', keys=['img', 'gt_semantic_seg']),
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
stds=[
|
62 |
-
284.925432, 357.84876, 575.566823, 896.601013, 951.900334,
|
63 |
-
921.407808, 284.925432, 357.84876, 575.566823, 896.601013,
|
64 |
-
951.900334, 921.407808, 284.925432, 357.84876, 575.566823,
|
65 |
-
896.601013, 951.900334, 921.407808
|
66 |
-
]),
|
67 |
-
dict(type='TorchRandomCrop', crop_size=(224, 224)),
|
68 |
-
dict(type='Reshape', keys=['img'], new_shape=(6, 3, 224, 224)),
|
69 |
-
dict(type='Reshape', keys=['gt_semantic_seg'], new_shape=(1, 224, 224)),
|
70 |
-
dict(
|
71 |
-
type='CastTensor',
|
72 |
-
keys=['gt_semantic_seg'],
|
73 |
-
new_type='torch.LongTensor'),
|
74 |
-
dict(type='Collect', keys=['img', 'gt_semantic_seg'])
|
75 |
-
]
|
76 |
-
val_pipeline = [
|
77 |
-
dict(type='LoadGeospatialImageFromFile', to_float32=True),
|
78 |
-
dict(type='LoadGeospatialAnnotations', reduce_zero_label=True),
|
79 |
-
dict(type='ToTensor', keys=['img', 'gt_semantic_seg']),
|
80 |
-
dict(
|
81 |
-
type='TorchNormalize',
|
82 |
-
means=[
|
83 |
-
494.905781, 815.239594, 924.335066, 2968.881459, 2634.621962,
|
84 |
-
1739.579917, 494.905781, 815.239594, 924.335066, 2968.881459,
|
85 |
-
2634.621962, 1739.579917, 494.905781, 815.239594, 924.335066,
|
86 |
-
2968.881459, 2634.621962, 1739.579917
|
87 |
-
],
|
88 |
-
stds=[
|
89 |
-
284.925432, 357.84876, 575.566823, 896.601013, 951.900334,
|
90 |
-
921.407808, 284.925432, 357.84876, 575.566823, 896.601013,
|
91 |
-
951.900334, 921.407808, 284.925432, 357.84876, 575.566823,
|
92 |
-
896.601013, 951.900334, 921.407808
|
93 |
-
]),
|
94 |
-
dict(type='TorchRandomCrop', crop_size=(224, 224)),
|
95 |
-
dict(type='Reshape', keys=['img'], new_shape=(6, 3, 224, 224)),
|
96 |
-
dict(type='Reshape', keys=['gt_semantic_seg'], new_shape=(1, 224, 224)),
|
97 |
-
dict(
|
98 |
-
type='CastTensor',
|
99 |
-
keys=['gt_semantic_seg'],
|
100 |
-
new_type='torch.LongTensor'),
|
101 |
-
dict(
|
102 |
-
type='Collect',
|
103 |
-
keys=['img', 'gt_semantic_seg'],
|
104 |
-
meta_keys=[
|
105 |
-
'img_info', 'ann_info', 'seg_fields', 'img_prefix', 'seg_prefix',
|
106 |
-
'filename', 'ori_filename', 'img', 'img_shape', 'ori_shape',
|
107 |
-
'pad_shape', 'scale_factor', 'img_norm_cfg', 'gt_semantic_seg'
|
108 |
-
])
|
109 |
]
|
|
|
110 |
test_pipeline = [
|
111 |
-
dict(type='LoadGeospatialImageFromFile', to_float32=True),
|
112 |
dict(type='ToTensor', keys=['img']),
|
113 |
-
|
114 |
-
|
115 |
-
|
116 |
-
|
117 |
-
|
118 |
-
|
119 |
-
|
120 |
-
|
121 |
-
stds=[
|
122 |
-
284.925432, 357.84876, 575.566823, 896.601013, 951.900334,
|
123 |
-
921.407808, 284.925432, 357.84876, 575.566823, 896.601013,
|
124 |
-
951.900334, 921.407808, 284.925432, 357.84876, 575.566823,
|
125 |
-
896.601013, 951.900334, 921.407808
|
126 |
-
]),
|
127 |
-
dict(
|
128 |
-
type='Reshape',
|
129 |
-
keys=['img'],
|
130 |
-
new_shape=(6, 3, -1, -1),
|
131 |
-
look_up=dict({
|
132 |
-
'2': 1,
|
133 |
-
'3': 2
|
134 |
-
})),
|
135 |
-
dict(type='CastTensor', keys=['img'], new_type='torch.FloatTensor'),
|
136 |
-
dict(
|
137 |
-
type='CollectTestList',
|
138 |
-
keys=['img'],
|
139 |
-
meta_keys=[
|
140 |
-
'img_info', 'seg_fields', 'img_prefix', 'seg_prefix', 'filename',
|
141 |
-
'ori_filename', 'img', 'img_shape', 'ori_shape', 'pad_shape',
|
142 |
-
'scale_factor', 'img_norm_cfg'
|
143 |
-
])
|
144 |
]
|
145 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
146 |
data = dict(
|
147 |
-
samples_per_gpu=
|
148 |
-
workers_per_gpu=
|
149 |
train=dict(
|
150 |
-
type=
|
151 |
-
CLASSES=
|
152 |
reduce_zero_label=True,
|
153 |
-
data_root=
|
154 |
-
img_dir='
|
155 |
-
ann_dir='
|
156 |
-
pipeline=
|
157 |
-
dict(type='LoadGeospatialImageFromFile', to_float32=True),
|
158 |
-
dict(type='LoadGeospatialAnnotations', reduce_zero_label=True),
|
159 |
-
dict(type='RandomFlip', prob=0.5),
|
160 |
-
dict(type='ToTensor', keys=['img', 'gt_semantic_seg']),
|
161 |
-
dict(
|
162 |
-
type='TorchNormalize',
|
163 |
-
means=[
|
164 |
-
494.905781, 815.239594, 924.335066, 2968.881459,
|
165 |
-
2634.621962, 1739.579917, 494.905781, 815.239594,
|
166 |
-
924.335066, 2968.881459, 2634.621962, 1739.579917,
|
167 |
-
494.905781, 815.239594, 924.335066, 2968.881459,
|
168 |
-
2634.621962, 1739.579917
|
169 |
-
],
|
170 |
-
stds=[
|
171 |
-
284.925432, 357.84876, 575.566823, 896.601013, 951.900334,
|
172 |
-
921.407808, 284.925432, 357.84876, 575.566823, 896.601013,
|
173 |
-
951.900334, 921.407808, 284.925432, 357.84876, 575.566823,
|
174 |
-
896.601013, 951.900334, 921.407808
|
175 |
-
]),
|
176 |
-
dict(type='TorchRandomCrop', crop_size=(224, 224)),
|
177 |
-
dict(type='Reshape', keys=['img'], new_shape=(6, 3, 224, 224)),
|
178 |
-
dict(
|
179 |
-
type='Reshape',
|
180 |
-
keys=['gt_semantic_seg'],
|
181 |
-
new_shape=(1, 224, 224)),
|
182 |
-
dict(
|
183 |
-
type='CastTensor',
|
184 |
-
keys=['gt_semantic_seg'],
|
185 |
-
new_type='torch.LongTensor'),
|
186 |
-
dict(type='Collect', keys=['img', 'gt_semantic_seg'])
|
187 |
-
],
|
188 |
img_suffix='_merged.tif',
|
189 |
seg_map_suffix='.mask.tif',
|
190 |
-
split=
|
191 |
-
'/home/ubuntu/hls-foundation-os/fine-tuning-examples/data_splits/crop_classification/training_data.txt'
|
192 |
-
),
|
193 |
val=dict(
|
194 |
-
type=
|
195 |
-
CLASSES=
|
196 |
reduce_zero_label=True,
|
197 |
-
data_root=
|
198 |
-
img_dir='
|
199 |
-
ann_dir='
|
200 |
-
pipeline=
|
201 |
-
dict(type='LoadGeospatialImageFromFile', to_float32=True),
|
202 |
-
dict(type='ToTensor', keys=['img']),
|
203 |
-
dict(
|
204 |
-
type='TorchNormalize',
|
205 |
-
means=[
|
206 |
-
494.905781, 815.239594, 924.335066, 2968.881459,
|
207 |
-
2634.621962, 1739.579917, 494.905781, 815.239594,
|
208 |
-
924.335066, 2968.881459, 2634.621962, 1739.579917,
|
209 |
-
494.905781, 815.239594, 924.335066, 2968.881459,
|
210 |
-
2634.621962, 1739.579917
|
211 |
-
],
|
212 |
-
stds=[
|
213 |
-
284.925432, 357.84876, 575.566823, 896.601013, 951.900334,
|
214 |
-
921.407808, 284.925432, 357.84876, 575.566823, 896.601013,
|
215 |
-
951.900334, 921.407808, 284.925432, 357.84876, 575.566823,
|
216 |
-
896.601013, 951.900334, 921.407808
|
217 |
-
]),
|
218 |
-
dict(
|
219 |
-
type='Reshape',
|
220 |
-
keys=['img'],
|
221 |
-
new_shape=(6, 3, -1, -1),
|
222 |
-
look_up=dict({
|
223 |
-
'2': 1,
|
224 |
-
'3': 2
|
225 |
-
})),
|
226 |
-
dict(
|
227 |
-
type='CastTensor', keys=['img'], new_type='torch.FloatTensor'),
|
228 |
-
dict(
|
229 |
-
type='CollectTestList',
|
230 |
-
keys=['img'],
|
231 |
-
meta_keys=[
|
232 |
-
'img_info', 'seg_fields', 'img_prefix', 'seg_prefix',
|
233 |
-
'filename', 'ori_filename', 'img', 'img_shape',
|
234 |
-
'ori_shape', 'pad_shape', 'scale_factor', 'img_norm_cfg'
|
235 |
-
])
|
236 |
-
],
|
237 |
img_suffix='_merged.tif',
|
238 |
seg_map_suffix='.mask.tif',
|
239 |
-
split=
|
240 |
-
'/home/ubuntu/hls-foundation-os/fine-tuning-examples/data_splits/crop_classification/validation_data.txt'
|
241 |
),
|
242 |
test=dict(
|
243 |
-
type=
|
244 |
-
CLASSES=
|
245 |
reduce_zero_label=True,
|
246 |
-
data_root=
|
247 |
-
img_dir='
|
248 |
-
ann_dir='
|
249 |
-
pipeline=
|
250 |
-
dict(type='LoadGeospatialImageFromFile', to_float32=True),
|
251 |
-
dict(type='ToTensor', keys=['img']),
|
252 |
-
dict(
|
253 |
-
type='TorchNormalize',
|
254 |
-
means=[
|
255 |
-
494.905781, 815.239594, 924.335066, 2968.881459,
|
256 |
-
2634.621962, 1739.579917, 494.905781, 815.239594,
|
257 |
-
924.335066, 2968.881459, 2634.621962, 1739.579917,
|
258 |
-
494.905781, 815.239594, 924.335066, 2968.881459,
|
259 |
-
2634.621962, 1739.579917
|
260 |
-
],
|
261 |
-
stds=[
|
262 |
-
284.925432, 357.84876, 575.566823, 896.601013, 951.900334,
|
263 |
-
921.407808, 284.925432, 357.84876, 575.566823, 896.601013,
|
264 |
-
951.900334, 921.407808, 284.925432, 357.84876, 575.566823,
|
265 |
-
896.601013, 951.900334, 921.407808
|
266 |
-
]),
|
267 |
-
dict(
|
268 |
-
type='Reshape',
|
269 |
-
keys=['img'],
|
270 |
-
new_shape=(6, 3, -1, -1),
|
271 |
-
look_up=dict({
|
272 |
-
'2': 1,
|
273 |
-
'3': 2
|
274 |
-
})),
|
275 |
-
dict(
|
276 |
-
type='CastTensor', keys=['img'], new_type='torch.FloatTensor'),
|
277 |
-
dict(
|
278 |
-
type='CollectTestList',
|
279 |
-
keys=['img'],
|
280 |
-
meta_keys=[
|
281 |
-
'img_info', 'seg_fields', 'img_prefix', 'seg_prefix',
|
282 |
-
'filename', 'ori_filename', 'img', 'img_shape',
|
283 |
-
'ori_shape', 'pad_shape', 'scale_factor', 'img_norm_cfg'
|
284 |
-
])
|
285 |
-
],
|
286 |
img_suffix='_merged.tif',
|
287 |
seg_map_suffix='.mask.tif',
|
288 |
-
split=
|
289 |
-
'/home/ubuntu/hls-foundation-os/fine-tuning-examples/data_splits/crop_classification/validation_data.txt'
|
290 |
))
|
|
|
291 |
optimizer = dict(
|
292 |
type='Adam', lr=1.5e-05, betas=(0.9, 0.999), weight_decay=0.05)
|
293 |
optimizer_config = dict(grad_clip=None)
|
@@ -303,55 +169,45 @@ log_config = dict(
|
|
303 |
interval=10,
|
304 |
hooks=[dict(type='TextLoggerHook'),
|
305 |
dict(type='TensorboardLoggerHook')])
|
|
|
306 |
checkpoint_config = dict(
|
307 |
by_epoch=True,
|
308 |
-
interval=
|
309 |
-
out_dir=
|
310 |
-
|
|
|
311 |
reduce_train_set = dict(reduce_train_set=False)
|
312 |
reduce_factor = dict(reduce_factor=1)
|
313 |
-
runner = dict(type='EpochBasedRunner', max_epochs=
|
314 |
-
workflow = [('train', 1)
|
315 |
norm_cfg = dict(type='BN', requires_grad=True)
|
316 |
-
|
317 |
-
0.386375, 0.661126, 0.548184, 0.640482, 0.876862, 0.925186, 3.249462,
|
318 |
-
1.542289, 2.175141, 2.272419, 3.062762, 3.626097, 1.198702
|
319 |
-
]
|
320 |
-
loss_func = dict(
|
321 |
-
type='CrossEntropyLoss',
|
322 |
-
use_sigmoid=False,
|
323 |
-
class_weight=[
|
324 |
-
0.386375, 0.661126, 0.548184, 0.640482, 0.876862, 0.925186, 3.249462,
|
325 |
-
1.542289, 2.175141, 2.272419, 3.062762, 3.626097, 1.198702
|
326 |
-
],
|
327 |
-
avg_non_ignore=True)
|
328 |
-
output_embed_dim = 2304
|
329 |
model = dict(
|
330 |
type='TemporalEncoderDecoder',
|
331 |
frozen_backbone=False,
|
332 |
backbone=dict(
|
333 |
type='TemporalViTEncoder',
|
334 |
-
pretrained=
|
335 |
-
img_size=
|
336 |
-
patch_size=
|
337 |
-
num_frames=
|
338 |
tubelet_size=1,
|
339 |
-
in_chans=
|
340 |
-
embed_dim=
|
341 |
depth=6,
|
342 |
-
num_heads=
|
343 |
mlp_ratio=4.0,
|
344 |
norm_pix_loss=False),
|
345 |
neck=dict(
|
346 |
type='ConvTransformerTokensToEmbeddingNeck',
|
347 |
-
embed_dim=
|
348 |
-
output_embed_dim=
|
349 |
drop_cls_token=True,
|
350 |
Hp=14,
|
351 |
Wp=14),
|
352 |
decode_head=dict(
|
353 |
-
num_classes=
|
354 |
-
in_channels=
|
355 |
type='FCNHead',
|
356 |
in_index=-1,
|
357 |
channels=256,
|
@@ -360,18 +216,10 @@ model = dict(
|
|
360 |
dropout_ratio=0.1,
|
361 |
norm_cfg=dict(type='BN', requires_grad=True),
|
362 |
align_corners=False,
|
363 |
-
loss_decode=
|
364 |
-
type='CrossEntropyLoss',
|
365 |
-
use_sigmoid=False,
|
366 |
-
class_weight=[
|
367 |
-
0.386375, 0.661126, 0.548184, 0.640482, 0.876862, 0.925186,
|
368 |
-
3.249462, 1.542289, 2.175141, 2.272419, 3.062762, 3.626097,
|
369 |
-
1.198702
|
370 |
-
],
|
371 |
-
avg_non_ignore=True)),
|
372 |
auxiliary_head=dict(
|
373 |
-
num_classes=
|
374 |
-
in_channels=
|
375 |
type='FCNHead',
|
376 |
in_index=-1,
|
377 |
channels=256,
|
@@ -380,15 +228,7 @@ model = dict(
|
|
380 |
dropout_ratio=0.1,
|
381 |
norm_cfg=dict(type='BN', requires_grad=True),
|
382 |
align_corners=False,
|
383 |
-
loss_decode=
|
384 |
-
type='CrossEntropyLoss',
|
385 |
-
use_sigmoid=False,
|
386 |
-
class_weight=[
|
387 |
-
0.386375, 0.661126, 0.548184, 0.640482, 0.876862, 0.925186,
|
388 |
-
3.249462, 1.542289, 2.175141, 2.272419, 3.062762, 3.626097,
|
389 |
-
1.198702
|
390 |
-
],
|
391 |
-
avg_non_ignore=True)),
|
392 |
train_cfg=dict(),
|
393 |
-
test_cfg=dict(mode='slide', stride=(
|
394 |
auto_resume = False
|
|
|
1 |
+
import os
|
2 |
+
|
3 |
dist_params = dict(backend='nccl')
|
4 |
log_level = 'INFO'
|
5 |
load_from = None
|
|
|
9 |
num_frames = 3
|
10 |
img_size = 224
|
11 |
num_workers = 2
|
12 |
+
|
13 |
+
# model
|
14 |
+
# TO BE DEFINED BY USER: model path
|
15 |
+
pretrained_weights_path = '<path to pretrained weights>'
|
16 |
num_layers = 6
|
17 |
patch_size = 16
|
18 |
embed_dim = 768
|
19 |
num_heads = 8
|
20 |
tubelet_size = 1
|
21 |
+
max_epochs = 80
|
22 |
+
eval_epoch_interval = 5
|
23 |
+
|
24 |
+
loss_weights_multi = [
|
25 |
+
0.386375, 0.661126, 0.548184, 0.640482, 0.876862, 0.925186, 3.249462,
|
26 |
+
1.542289, 2.175141, 2.272419, 3.062762, 3.626097, 1.198702
|
27 |
+
]
|
28 |
+
loss_func = dict(
|
29 |
+
type='CrossEntropyLoss',
|
30 |
+
use_sigmoid=False,
|
31 |
+
class_weight=loss_weights_multi,
|
32 |
+
avg_non_ignore=True)
|
33 |
+
output_embed_dim = embed_dim*num_frames
|
34 |
+
|
35 |
+
|
36 |
+
# TO BE DEFINED BY USER: Save directory
|
37 |
+
experiment = '<experiment name>'
|
38 |
+
project_dir = '<project directory name>'
|
39 |
+
work_dir = os.path.join(project_dir, experiment)
|
40 |
+
save_path = work_dir
|
41 |
+
|
42 |
+
|
43 |
gpu_ids = range(0, 1)
|
44 |
dataset_type = 'GeospatialDataset'
|
45 |
+
|
46 |
+
# TO BE DEFINED BY USER: data directory
|
47 |
+
data_root = '<path to data root>'
|
48 |
+
|
49 |
+
splits = dict(
|
50 |
+
train='<path to train split>',
|
51 |
+
val= '<path to val split>',
|
52 |
+
test= '<path to test split>'
|
53 |
+
)
|
54 |
+
|
55 |
+
|
56 |
img_norm_cfg = dict(
|
57 |
means=[
|
58 |
494.905781, 815.239594, 924.335066, 2968.881459, 2634.621962,
|
|
|
65 |
284.925432, 357.84876, 575.566823, 896.601013, 951.900334, 921.407808,
|
66 |
284.925432, 357.84876, 575.566823, 896.601013, 951.900334, 921.407808
|
67 |
])
|
68 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
69 |
bands = [0, 1, 2, 3, 4, 5]
|
70 |
+
|
71 |
tile_size = 224
|
72 |
orig_nsize = 512
|
73 |
+
crop_size = (tile_size, tile_size)
|
74 |
train_pipeline = [
|
75 |
+
dict(type='LoadGeospatialImageFromFile', to_float32=True, channels_last=True),
|
76 |
dict(type='LoadGeospatialAnnotations', reduce_zero_label=True),
|
77 |
dict(type='RandomFlip', prob=0.5),
|
78 |
dict(type='ToTensor', keys=['img', 'gt_semantic_seg']),
|
79 |
+
# to channels first
|
80 |
+
dict(type="TorchPermute", keys=["img"], order=(2, 0, 1)),
|
81 |
+
dict(type='TorchNormalize', **img_norm_cfg),
|
82 |
+
dict(type='TorchRandomCrop', crop_size=crop_size),
|
83 |
+
dict(type='Reshape', keys=['img'], new_shape=(len(bands), num_frames, tile_size, tile_size)),
|
84 |
+
dict(type='Reshape', keys=['gt_semantic_seg'], new_shape=(1, tile_size, tile_size)),
|
85 |
+
dict(type='CastTensor', keys=['gt_semantic_seg'], new_type="torch.LongTensor"),
|
86 |
+
dict(type='Collect', keys=['img', 'gt_semantic_seg']),
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
87 |
]
|
88 |
+
|
89 |
test_pipeline = [
|
90 |
+
dict(type='LoadGeospatialImageFromFile', to_float32=True, channels_last=True),
|
91 |
dict(type='ToTensor', keys=['img']),
|
92 |
+
# to channels first
|
93 |
+
dict(type="TorchPermute", keys=["img"], order=(2, 0, 1)),
|
94 |
+
dict(type='TorchNormalize', **img_norm_cfg),
|
95 |
+
dict(type='Reshape', keys=['img'], new_shape=(len(bands), num_frames, -1, -1), look_up = {'2': 1, '3': 2}),
|
96 |
+
dict(type='CastTensor', keys=['img'], new_type="torch.FloatTensor"),
|
97 |
+
dict(type='CollectTestList', keys=['img'],
|
98 |
+
meta_keys=['img_info', 'seg_fields', 'img_prefix', 'seg_prefix', 'filename', 'ori_filename', 'img',
|
99 |
+
'img_shape', 'ori_shape', 'pad_shape', 'scale_factor', 'img_norm_cfg']),
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
100 |
]
|
101 |
+
|
102 |
+
CLASSES = ('Natural Vegetation',
|
103 |
+
'Forest',
|
104 |
+
'Corn',
|
105 |
+
'Soybeans',
|
106 |
+
'Wetlands',
|
107 |
+
'Developed/Barren',
|
108 |
+
'Open Water',
|
109 |
+
'Winter Wheat',
|
110 |
+
'Alfalfa',
|
111 |
+
'Fallow/Idle Cropland',
|
112 |
+
'Cotton',
|
113 |
+
'Sorghum',
|
114 |
+
'Other')
|
115 |
+
|
116 |
+
dataset = 'GeospatialDataset'
|
117 |
+
|
118 |
data = dict(
|
119 |
+
samples_per_gpu=8,
|
120 |
+
workers_per_gpu=4,
|
121 |
train=dict(
|
122 |
+
type=dataset,
|
123 |
+
CLASSES=CLASSES,
|
124 |
reduce_zero_label=True,
|
125 |
+
data_root=data_root,
|
126 |
+
img_dir='training_chips',
|
127 |
+
ann_dir='training_chips',
|
128 |
+
pipeline=train_pipeline,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
129 |
img_suffix='_merged.tif',
|
130 |
seg_map_suffix='.mask.tif',
|
131 |
+
split=splits['train']),
|
|
|
|
|
132 |
val=dict(
|
133 |
+
type=dataset,
|
134 |
+
CLASSES=CLASSES,
|
135 |
reduce_zero_label=True,
|
136 |
+
data_root=data_root,
|
137 |
+
img_dir='validation_chips',
|
138 |
+
ann_dir='validation_chips',
|
139 |
+
pipeline=test_pipeline,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
140 |
img_suffix='_merged.tif',
|
141 |
seg_map_suffix='.mask.tif',
|
142 |
+
split=splits['val']
|
|
|
143 |
),
|
144 |
test=dict(
|
145 |
+
type=dataset,
|
146 |
+
CLASSES=CLASSES,
|
147 |
reduce_zero_label=True,
|
148 |
+
data_root=data_root,
|
149 |
+
img_dir='validation_chips',
|
150 |
+
ann_dir='validation_chips',
|
151 |
+
pipeline=test_pipeline,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
152 |
img_suffix='_merged.tif',
|
153 |
seg_map_suffix='.mask.tif',
|
154 |
+
split=splits['val']
|
|
|
155 |
))
|
156 |
+
|
157 |
optimizer = dict(
|
158 |
type='Adam', lr=1.5e-05, betas=(0.9, 0.999), weight_decay=0.05)
|
159 |
optimizer_config = dict(grad_clip=None)
|
|
|
169 |
interval=10,
|
170 |
hooks=[dict(type='TextLoggerHook'),
|
171 |
dict(type='TensorboardLoggerHook')])
|
172 |
+
|
173 |
checkpoint_config = dict(
|
174 |
by_epoch=True,
|
175 |
+
interval=100,
|
176 |
+
out_dir=save_path)
|
177 |
+
|
178 |
+
evaluation = dict(interval=eval_epoch_interval, metric='mIoU', pre_eval=True, save_best='mIoU', by_epoch=True)
|
179 |
reduce_train_set = dict(reduce_train_set=False)
|
180 |
reduce_factor = dict(reduce_factor=1)
|
181 |
+
runner = dict(type='EpochBasedRunner', max_epochs=max_epochs)
|
182 |
+
workflow = [('train', 1)]
|
183 |
norm_cfg = dict(type='BN', requires_grad=True)
|
184 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
185 |
model = dict(
|
186 |
type='TemporalEncoderDecoder',
|
187 |
frozen_backbone=False,
|
188 |
backbone=dict(
|
189 |
type='TemporalViTEncoder',
|
190 |
+
pretrained=pretrained_weights_path,
|
191 |
+
img_size=img_size,
|
192 |
+
patch_size=patch_size,
|
193 |
+
num_frames=num_frames,
|
194 |
tubelet_size=1,
|
195 |
+
in_chans=len(bands),
|
196 |
+
embed_dim=embed_dim,
|
197 |
depth=6,
|
198 |
+
num_heads=num_heads,
|
199 |
mlp_ratio=4.0,
|
200 |
norm_pix_loss=False),
|
201 |
neck=dict(
|
202 |
type='ConvTransformerTokensToEmbeddingNeck',
|
203 |
+
embed_dim=embed_dim*num_frames,
|
204 |
+
output_embed_dim=output_embed_dim,
|
205 |
drop_cls_token=True,
|
206 |
Hp=14,
|
207 |
Wp=14),
|
208 |
decode_head=dict(
|
209 |
+
num_classes=len(CLASSES),
|
210 |
+
in_channels=output_embed_dim,
|
211 |
type='FCNHead',
|
212 |
in_index=-1,
|
213 |
channels=256,
|
|
|
216 |
dropout_ratio=0.1,
|
217 |
norm_cfg=dict(type='BN', requires_grad=True),
|
218 |
align_corners=False,
|
219 |
+
loss_decode=loss_func),
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
220 |
auxiliary_head=dict(
|
221 |
+
num_classes=len(CLASSES),
|
222 |
+
in_channels=output_embed_dim,
|
223 |
type='FCNHead',
|
224 |
in_index=-1,
|
225 |
channels=256,
|
|
|
228 |
dropout_ratio=0.1,
|
229 |
norm_cfg=dict(type='BN', requires_grad=True),
|
230 |
align_corners=False,
|
231 |
+
loss_decode=loss_func),
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
232 |
train_cfg=dict(),
|
233 |
+
test_cfg=dict(mode='slide', stride=(int(tile_size/2), int(tile_size/2)), crop_size=(tile_size, tile_size)))
|
234 |
auto_resume = False
|