Paolo-Fraccaro commited on
Commit
f95bf59
·
1 Parent(s): 8594822

Update multi_temporal_crop_classification_Prithvi_100M.py

Browse files
multi_temporal_crop_classification_Prithvi_100M.py CHANGED
@@ -1,3 +1,5 @@
 
 
1
  dist_params = dict(backend='nccl')
2
  log_level = 'INFO'
3
  load_from = None
@@ -7,20 +9,50 @@ custom_imports = dict(imports=['geospatial_fm'])
7
  num_frames = 3
8
  img_size = 224
9
  num_workers = 2
10
- pretrained_weights_path = '/home/ubuntu/hls-loss-weights/Prithvi_100M.pt'
 
 
 
11
  num_layers = 6
12
  patch_size = 16
13
  embed_dim = 768
14
  num_heads = 8
15
  tubelet_size = 1
16
- epochs = 80
17
- eval_epoch_interval = 2
18
- experiment = 'multiclass_exp_newSplit'
19
- work_dir = '/home/ubuntu/clark_gfm_eval/multiclass_exp_newSplit'
20
- save_path = '/home/ubuntu/clark_gfm_eval/multiclass_exp_newSplit'
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
  gpu_ids = range(0, 1)
22
  dataset_type = 'GeospatialDataset'
23
- data_root = '/home/ubuntu/hls_cdl_reclassed/'
 
 
 
 
 
 
 
 
 
 
24
  img_norm_cfg = dict(
25
  means=[
26
  494.905781, 815.239594, 924.335066, 2968.881459, 2634.621962,
@@ -33,261 +65,95 @@ img_norm_cfg = dict(
33
  284.925432, 357.84876, 575.566823, 896.601013, 951.900334, 921.407808,
34
  284.925432, 357.84876, 575.566823, 896.601013, 951.900334, 921.407808
35
  ])
36
- splits = dict(
37
- train=
38
- '/home/ubuntu/hls-foundation-os/fine-tuning-examples/data_splits/crop_classification/training_data.txt',
39
- val=
40
- '/home/ubuntu/hls-foundation-os/fine-tuning-examples/data_splits/crop_classification/validation_data.txt',
41
- test=
42
- '/home/ubuntu/hls-foundation-os/fine-tuning-examples/data_splits/crop_classification/validation_data.txt'
43
- )
44
  bands = [0, 1, 2, 3, 4, 5]
 
45
  tile_size = 224
46
  orig_nsize = 512
47
- crop_size = (224, 224)
48
  train_pipeline = [
49
- dict(type='LoadGeospatialImageFromFile', to_float32=True),
50
  dict(type='LoadGeospatialAnnotations', reduce_zero_label=True),
51
  dict(type='RandomFlip', prob=0.5),
52
  dict(type='ToTensor', keys=['img', 'gt_semantic_seg']),
53
- dict(
54
- type='TorchNormalize',
55
- means=[
56
- 494.905781, 815.239594, 924.335066, 2968.881459, 2634.621962,
57
- 1739.579917, 494.905781, 815.239594, 924.335066, 2968.881459,
58
- 2634.621962, 1739.579917, 494.905781, 815.239594, 924.335066,
59
- 2968.881459, 2634.621962, 1739.579917
60
- ],
61
- stds=[
62
- 284.925432, 357.84876, 575.566823, 896.601013, 951.900334,
63
- 921.407808, 284.925432, 357.84876, 575.566823, 896.601013,
64
- 951.900334, 921.407808, 284.925432, 357.84876, 575.566823,
65
- 896.601013, 951.900334, 921.407808
66
- ]),
67
- dict(type='TorchRandomCrop', crop_size=(224, 224)),
68
- dict(type='Reshape', keys=['img'], new_shape=(6, 3, 224, 224)),
69
- dict(type='Reshape', keys=['gt_semantic_seg'], new_shape=(1, 224, 224)),
70
- dict(
71
- type='CastTensor',
72
- keys=['gt_semantic_seg'],
73
- new_type='torch.LongTensor'),
74
- dict(type='Collect', keys=['img', 'gt_semantic_seg'])
75
- ]
76
- val_pipeline = [
77
- dict(type='LoadGeospatialImageFromFile', to_float32=True),
78
- dict(type='LoadGeospatialAnnotations', reduce_zero_label=True),
79
- dict(type='ToTensor', keys=['img', 'gt_semantic_seg']),
80
- dict(
81
- type='TorchNormalize',
82
- means=[
83
- 494.905781, 815.239594, 924.335066, 2968.881459, 2634.621962,
84
- 1739.579917, 494.905781, 815.239594, 924.335066, 2968.881459,
85
- 2634.621962, 1739.579917, 494.905781, 815.239594, 924.335066,
86
- 2968.881459, 2634.621962, 1739.579917
87
- ],
88
- stds=[
89
- 284.925432, 357.84876, 575.566823, 896.601013, 951.900334,
90
- 921.407808, 284.925432, 357.84876, 575.566823, 896.601013,
91
- 951.900334, 921.407808, 284.925432, 357.84876, 575.566823,
92
- 896.601013, 951.900334, 921.407808
93
- ]),
94
- dict(type='TorchRandomCrop', crop_size=(224, 224)),
95
- dict(type='Reshape', keys=['img'], new_shape=(6, 3, 224, 224)),
96
- dict(type='Reshape', keys=['gt_semantic_seg'], new_shape=(1, 224, 224)),
97
- dict(
98
- type='CastTensor',
99
- keys=['gt_semantic_seg'],
100
- new_type='torch.LongTensor'),
101
- dict(
102
- type='Collect',
103
- keys=['img', 'gt_semantic_seg'],
104
- meta_keys=[
105
- 'img_info', 'ann_info', 'seg_fields', 'img_prefix', 'seg_prefix',
106
- 'filename', 'ori_filename', 'img', 'img_shape', 'ori_shape',
107
- 'pad_shape', 'scale_factor', 'img_norm_cfg', 'gt_semantic_seg'
108
- ])
109
  ]
 
110
  test_pipeline = [
111
- dict(type='LoadGeospatialImageFromFile', to_float32=True),
112
  dict(type='ToTensor', keys=['img']),
113
- dict(
114
- type='TorchNormalize',
115
- means=[
116
- 494.905781, 815.239594, 924.335066, 2968.881459, 2634.621962,
117
- 1739.579917, 494.905781, 815.239594, 924.335066, 2968.881459,
118
- 2634.621962, 1739.579917, 494.905781, 815.239594, 924.335066,
119
- 2968.881459, 2634.621962, 1739.579917
120
- ],
121
- stds=[
122
- 284.925432, 357.84876, 575.566823, 896.601013, 951.900334,
123
- 921.407808, 284.925432, 357.84876, 575.566823, 896.601013,
124
- 951.900334, 921.407808, 284.925432, 357.84876, 575.566823,
125
- 896.601013, 951.900334, 921.407808
126
- ]),
127
- dict(
128
- type='Reshape',
129
- keys=['img'],
130
- new_shape=(6, 3, -1, -1),
131
- look_up=dict({
132
- '2': 1,
133
- '3': 2
134
- })),
135
- dict(type='CastTensor', keys=['img'], new_type='torch.FloatTensor'),
136
- dict(
137
- type='CollectTestList',
138
- keys=['img'],
139
- meta_keys=[
140
- 'img_info', 'seg_fields', 'img_prefix', 'seg_prefix', 'filename',
141
- 'ori_filename', 'img', 'img_shape', 'ori_shape', 'pad_shape',
142
- 'scale_factor', 'img_norm_cfg'
143
- ])
144
  ]
145
- CLASSES = (1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
146
  data = dict(
147
- samples_per_gpu=2,
148
- workers_per_gpu=1,
149
  train=dict(
150
- type='GeospatialDataset',
151
- CLASSES=(1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13),
152
  reduce_zero_label=True,
153
- data_root='/home/ubuntu/hls_cdl_reclassed/',
154
- img_dir='/home/ubuntu/hls_cdl_reclassed/training_chips',
155
- ann_dir='/home/ubuntu/hls_cdl_reclassed/training_chips',
156
- pipeline=[
157
- dict(type='LoadGeospatialImageFromFile', to_float32=True),
158
- dict(type='LoadGeospatialAnnotations', reduce_zero_label=True),
159
- dict(type='RandomFlip', prob=0.5),
160
- dict(type='ToTensor', keys=['img', 'gt_semantic_seg']),
161
- dict(
162
- type='TorchNormalize',
163
- means=[
164
- 494.905781, 815.239594, 924.335066, 2968.881459,
165
- 2634.621962, 1739.579917, 494.905781, 815.239594,
166
- 924.335066, 2968.881459, 2634.621962, 1739.579917,
167
- 494.905781, 815.239594, 924.335066, 2968.881459,
168
- 2634.621962, 1739.579917
169
- ],
170
- stds=[
171
- 284.925432, 357.84876, 575.566823, 896.601013, 951.900334,
172
- 921.407808, 284.925432, 357.84876, 575.566823, 896.601013,
173
- 951.900334, 921.407808, 284.925432, 357.84876, 575.566823,
174
- 896.601013, 951.900334, 921.407808
175
- ]),
176
- dict(type='TorchRandomCrop', crop_size=(224, 224)),
177
- dict(type='Reshape', keys=['img'], new_shape=(6, 3, 224, 224)),
178
- dict(
179
- type='Reshape',
180
- keys=['gt_semantic_seg'],
181
- new_shape=(1, 224, 224)),
182
- dict(
183
- type='CastTensor',
184
- keys=['gt_semantic_seg'],
185
- new_type='torch.LongTensor'),
186
- dict(type='Collect', keys=['img', 'gt_semantic_seg'])
187
- ],
188
  img_suffix='_merged.tif',
189
  seg_map_suffix='.mask.tif',
190
- split=
191
- '/home/ubuntu/hls-foundation-os/fine-tuning-examples/data_splits/crop_classification/training_data.txt'
192
- ),
193
  val=dict(
194
- type='GeospatialDataset',
195
- CLASSES=(1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13),
196
  reduce_zero_label=True,
197
- data_root='/home/ubuntu/hls_cdl_reclassed/',
198
- img_dir='/home/ubuntu/hls_cdl_reclassed/validation_chips',
199
- ann_dir='/home/ubuntu/hls_cdl_reclassed/validation_chips',
200
- pipeline=[
201
- dict(type='LoadGeospatialImageFromFile', to_float32=True),
202
- dict(type='ToTensor', keys=['img']),
203
- dict(
204
- type='TorchNormalize',
205
- means=[
206
- 494.905781, 815.239594, 924.335066, 2968.881459,
207
- 2634.621962, 1739.579917, 494.905781, 815.239594,
208
- 924.335066, 2968.881459, 2634.621962, 1739.579917,
209
- 494.905781, 815.239594, 924.335066, 2968.881459,
210
- 2634.621962, 1739.579917
211
- ],
212
- stds=[
213
- 284.925432, 357.84876, 575.566823, 896.601013, 951.900334,
214
- 921.407808, 284.925432, 357.84876, 575.566823, 896.601013,
215
- 951.900334, 921.407808, 284.925432, 357.84876, 575.566823,
216
- 896.601013, 951.900334, 921.407808
217
- ]),
218
- dict(
219
- type='Reshape',
220
- keys=['img'],
221
- new_shape=(6, 3, -1, -1),
222
- look_up=dict({
223
- '2': 1,
224
- '3': 2
225
- })),
226
- dict(
227
- type='CastTensor', keys=['img'], new_type='torch.FloatTensor'),
228
- dict(
229
- type='CollectTestList',
230
- keys=['img'],
231
- meta_keys=[
232
- 'img_info', 'seg_fields', 'img_prefix', 'seg_prefix',
233
- 'filename', 'ori_filename', 'img', 'img_shape',
234
- 'ori_shape', 'pad_shape', 'scale_factor', 'img_norm_cfg'
235
- ])
236
- ],
237
  img_suffix='_merged.tif',
238
  seg_map_suffix='.mask.tif',
239
- split=
240
- '/home/ubuntu/hls-foundation-os/fine-tuning-examples/data_splits/crop_classification/validation_data.txt'
241
  ),
242
  test=dict(
243
- type='GeospatialDataset',
244
- CLASSES=(1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13),
245
  reduce_zero_label=True,
246
- data_root='/home/ubuntu/hls_cdl_reclassed/',
247
- img_dir='/home/ubuntu/hls_cdl_reclassed/validation_chips',
248
- ann_dir='/home/ubuntu/hls_cdl_reclassed/validation_chips',
249
- pipeline=[
250
- dict(type='LoadGeospatialImageFromFile', to_float32=True),
251
- dict(type='ToTensor', keys=['img']),
252
- dict(
253
- type='TorchNormalize',
254
- means=[
255
- 494.905781, 815.239594, 924.335066, 2968.881459,
256
- 2634.621962, 1739.579917, 494.905781, 815.239594,
257
- 924.335066, 2968.881459, 2634.621962, 1739.579917,
258
- 494.905781, 815.239594, 924.335066, 2968.881459,
259
- 2634.621962, 1739.579917
260
- ],
261
- stds=[
262
- 284.925432, 357.84876, 575.566823, 896.601013, 951.900334,
263
- 921.407808, 284.925432, 357.84876, 575.566823, 896.601013,
264
- 951.900334, 921.407808, 284.925432, 357.84876, 575.566823,
265
- 896.601013, 951.900334, 921.407808
266
- ]),
267
- dict(
268
- type='Reshape',
269
- keys=['img'],
270
- new_shape=(6, 3, -1, -1),
271
- look_up=dict({
272
- '2': 1,
273
- '3': 2
274
- })),
275
- dict(
276
- type='CastTensor', keys=['img'], new_type='torch.FloatTensor'),
277
- dict(
278
- type='CollectTestList',
279
- keys=['img'],
280
- meta_keys=[
281
- 'img_info', 'seg_fields', 'img_prefix', 'seg_prefix',
282
- 'filename', 'ori_filename', 'img', 'img_shape',
283
- 'ori_shape', 'pad_shape', 'scale_factor', 'img_norm_cfg'
284
- ])
285
- ],
286
  img_suffix='_merged.tif',
287
  seg_map_suffix='.mask.tif',
288
- split=
289
- '/home/ubuntu/hls-foundation-os/fine-tuning-examples/data_splits/crop_classification/validation_data.txt'
290
  ))
 
291
  optimizer = dict(
292
  type='Adam', lr=1.5e-05, betas=(0.9, 0.999), weight_decay=0.05)
293
  optimizer_config = dict(grad_clip=None)
@@ -303,55 +169,45 @@ log_config = dict(
303
  interval=10,
304
  hooks=[dict(type='TextLoggerHook'),
305
  dict(type='TensorboardLoggerHook')])
 
306
  checkpoint_config = dict(
307
  by_epoch=True,
308
- interval=10,
309
- out_dir='/home/ubuntu/clark_gfm_eval/multiclass_exp_newSplit')
310
- evaluation = dict(interval=2, metric='mIoU', pre_eval=True, save_best='mIoU')
 
311
  reduce_train_set = dict(reduce_train_set=False)
312
  reduce_factor = dict(reduce_factor=1)
313
- runner = dict(type='EpochBasedRunner', max_epochs=80)
314
- workflow = [('train', 1), ('val', 1)]
315
  norm_cfg = dict(type='BN', requires_grad=True)
316
- loss_weights_multi = [
317
- 0.386375, 0.661126, 0.548184, 0.640482, 0.876862, 0.925186, 3.249462,
318
- 1.542289, 2.175141, 2.272419, 3.062762, 3.626097, 1.198702
319
- ]
320
- loss_func = dict(
321
- type='CrossEntropyLoss',
322
- use_sigmoid=False,
323
- class_weight=[
324
- 0.386375, 0.661126, 0.548184, 0.640482, 0.876862, 0.925186, 3.249462,
325
- 1.542289, 2.175141, 2.272419, 3.062762, 3.626097, 1.198702
326
- ],
327
- avg_non_ignore=True)
328
- output_embed_dim = 2304
329
  model = dict(
330
  type='TemporalEncoderDecoder',
331
  frozen_backbone=False,
332
  backbone=dict(
333
  type='TemporalViTEncoder',
334
- pretrained='/home/ubuntu/hls-loss-weights/Prithvi_100M.pt',
335
- img_size=224,
336
- patch_size=16,
337
- num_frames=3,
338
  tubelet_size=1,
339
- in_chans=6,
340
- embed_dim=768,
341
  depth=6,
342
- num_heads=8,
343
  mlp_ratio=4.0,
344
  norm_pix_loss=False),
345
  neck=dict(
346
  type='ConvTransformerTokensToEmbeddingNeck',
347
- embed_dim=2304,
348
- output_embed_dim=2304,
349
  drop_cls_token=True,
350
  Hp=14,
351
  Wp=14),
352
  decode_head=dict(
353
- num_classes=13,
354
- in_channels=2304,
355
  type='FCNHead',
356
  in_index=-1,
357
  channels=256,
@@ -360,18 +216,10 @@ model = dict(
360
  dropout_ratio=0.1,
361
  norm_cfg=dict(type='BN', requires_grad=True),
362
  align_corners=False,
363
- loss_decode=dict(
364
- type='CrossEntropyLoss',
365
- use_sigmoid=False,
366
- class_weight=[
367
- 0.386375, 0.661126, 0.548184, 0.640482, 0.876862, 0.925186,
368
- 3.249462, 1.542289, 2.175141, 2.272419, 3.062762, 3.626097,
369
- 1.198702
370
- ],
371
- avg_non_ignore=True)),
372
  auxiliary_head=dict(
373
- num_classes=13,
374
- in_channels=2304,
375
  type='FCNHead',
376
  in_index=-1,
377
  channels=256,
@@ -380,15 +228,7 @@ model = dict(
380
  dropout_ratio=0.1,
381
  norm_cfg=dict(type='BN', requires_grad=True),
382
  align_corners=False,
383
- loss_decode=dict(
384
- type='CrossEntropyLoss',
385
- use_sigmoid=False,
386
- class_weight=[
387
- 0.386375, 0.661126, 0.548184, 0.640482, 0.876862, 0.925186,
388
- 3.249462, 1.542289, 2.175141, 2.272419, 3.062762, 3.626097,
389
- 1.198702
390
- ],
391
- avg_non_ignore=True)),
392
  train_cfg=dict(),
393
- test_cfg=dict(mode='slide', stride=(112, 112), crop_size=(224, 224)))
394
  auto_resume = False
 
1
+ import os
2
+
3
  dist_params = dict(backend='nccl')
4
  log_level = 'INFO'
5
  load_from = None
 
9
  num_frames = 3
10
  img_size = 224
11
  num_workers = 2
12
+
13
+ # model
14
+ # TO BE DEFINED BY USER: model path
15
+ pretrained_weights_path = '<path to pretrained weights>'
16
  num_layers = 6
17
  patch_size = 16
18
  embed_dim = 768
19
  num_heads = 8
20
  tubelet_size = 1
21
+ max_epochs = 80
22
+ eval_epoch_interval = 5
23
+
24
+ loss_weights_multi = [
25
+ 0.386375, 0.661126, 0.548184, 0.640482, 0.876862, 0.925186, 3.249462,
26
+ 1.542289, 2.175141, 2.272419, 3.062762, 3.626097, 1.198702
27
+ ]
28
+ loss_func = dict(
29
+ type='CrossEntropyLoss',
30
+ use_sigmoid=False,
31
+ class_weight=loss_weights_multi,
32
+ avg_non_ignore=True)
33
+ output_embed_dim = embed_dim*num_frames
34
+
35
+
36
+ # TO BE DEFINED BY USER: Save directory
37
+ experiment = '<experiment name>'
38
+ project_dir = '<project directory name>'
39
+ work_dir = os.path.join(project_dir, experiment)
40
+ save_path = work_dir
41
+
42
+
43
  gpu_ids = range(0, 1)
44
  dataset_type = 'GeospatialDataset'
45
+
46
+ # TO BE DEFINED BY USER: data directory
47
+ data_root = '<path to data root>'
48
+
49
+ splits = dict(
50
+ train='<path to train split>',
51
+ val= '<path to val split>',
52
+ test= '<path to test split>'
53
+ )
54
+
55
+
56
  img_norm_cfg = dict(
57
  means=[
58
  494.905781, 815.239594, 924.335066, 2968.881459, 2634.621962,
 
65
  284.925432, 357.84876, 575.566823, 896.601013, 951.900334, 921.407808,
66
  284.925432, 357.84876, 575.566823, 896.601013, 951.900334, 921.407808
67
  ])
68
+
 
 
 
 
 
 
 
69
  bands = [0, 1, 2, 3, 4, 5]
70
+
71
  tile_size = 224
72
  orig_nsize = 512
73
+ crop_size = (tile_size, tile_size)
74
  train_pipeline = [
75
+ dict(type='LoadGeospatialImageFromFile', to_float32=True, channels_last=True),
76
  dict(type='LoadGeospatialAnnotations', reduce_zero_label=True),
77
  dict(type='RandomFlip', prob=0.5),
78
  dict(type='ToTensor', keys=['img', 'gt_semantic_seg']),
79
+ # to channels first
80
+ dict(type="TorchPermute", keys=["img"], order=(2, 0, 1)),
81
+ dict(type='TorchNormalize', **img_norm_cfg),
82
+ dict(type='TorchRandomCrop', crop_size=crop_size),
83
+ dict(type='Reshape', keys=['img'], new_shape=(len(bands), num_frames, tile_size, tile_size)),
84
+ dict(type='Reshape', keys=['gt_semantic_seg'], new_shape=(1, tile_size, tile_size)),
85
+ dict(type='CastTensor', keys=['gt_semantic_seg'], new_type="torch.LongTensor"),
86
+ dict(type='Collect', keys=['img', 'gt_semantic_seg']),
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
87
  ]
88
+
89
  test_pipeline = [
90
+ dict(type='LoadGeospatialImageFromFile', to_float32=True, channels_last=True),
91
  dict(type='ToTensor', keys=['img']),
92
+ # to channels first
93
+ dict(type="TorchPermute", keys=["img"], order=(2, 0, 1)),
94
+ dict(type='TorchNormalize', **img_norm_cfg),
95
+ dict(type='Reshape', keys=['img'], new_shape=(len(bands), num_frames, -1, -1), look_up = {'2': 1, '3': 2}),
96
+ dict(type='CastTensor', keys=['img'], new_type="torch.FloatTensor"),
97
+ dict(type='CollectTestList', keys=['img'],
98
+ meta_keys=['img_info', 'seg_fields', 'img_prefix', 'seg_prefix', 'filename', 'ori_filename', 'img',
99
+ 'img_shape', 'ori_shape', 'pad_shape', 'scale_factor', 'img_norm_cfg']),
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
100
  ]
101
+
102
+ CLASSES = ('Natural Vegetation',
103
+ 'Forest',
104
+ 'Corn',
105
+ 'Soybeans',
106
+ 'Wetlands',
107
+ 'Developed/Barren',
108
+ 'Open Water',
109
+ 'Winter Wheat',
110
+ 'Alfalfa',
111
+ 'Fallow/Idle Cropland',
112
+ 'Cotton',
113
+ 'Sorghum',
114
+ 'Other')
115
+
116
+ dataset = 'GeospatialDataset'
117
+
118
  data = dict(
119
+ samples_per_gpu=8,
120
+ workers_per_gpu=4,
121
  train=dict(
122
+ type=dataset,
123
+ CLASSES=CLASSES,
124
  reduce_zero_label=True,
125
+ data_root=data_root,
126
+ img_dir='training_chips',
127
+ ann_dir='training_chips',
128
+ pipeline=train_pipeline,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
129
  img_suffix='_merged.tif',
130
  seg_map_suffix='.mask.tif',
131
+ split=splits['train']),
 
 
132
  val=dict(
133
+ type=dataset,
134
+ CLASSES=CLASSES,
135
  reduce_zero_label=True,
136
+ data_root=data_root,
137
+ img_dir='validation_chips',
138
+ ann_dir='validation_chips',
139
+ pipeline=test_pipeline,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
140
  img_suffix='_merged.tif',
141
  seg_map_suffix='.mask.tif',
142
+ split=splits['val']
 
143
  ),
144
  test=dict(
145
+ type=dataset,
146
+ CLASSES=CLASSES,
147
  reduce_zero_label=True,
148
+ data_root=data_root,
149
+ img_dir='validation_chips',
150
+ ann_dir='validation_chips',
151
+ pipeline=test_pipeline,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
152
  img_suffix='_merged.tif',
153
  seg_map_suffix='.mask.tif',
154
+ split=splits['val']
 
155
  ))
156
+
157
  optimizer = dict(
158
  type='Adam', lr=1.5e-05, betas=(0.9, 0.999), weight_decay=0.05)
159
  optimizer_config = dict(grad_clip=None)
 
169
  interval=10,
170
  hooks=[dict(type='TextLoggerHook'),
171
  dict(type='TensorboardLoggerHook')])
172
+
173
  checkpoint_config = dict(
174
  by_epoch=True,
175
+ interval=100,
176
+ out_dir=save_path)
177
+
178
+ evaluation = dict(interval=eval_epoch_interval, metric='mIoU', pre_eval=True, save_best='mIoU', by_epoch=True)
179
  reduce_train_set = dict(reduce_train_set=False)
180
  reduce_factor = dict(reduce_factor=1)
181
+ runner = dict(type='EpochBasedRunner', max_epochs=max_epochs)
182
+ workflow = [('train', 1)]
183
  norm_cfg = dict(type='BN', requires_grad=True)
184
+
 
 
 
 
 
 
 
 
 
 
 
 
185
  model = dict(
186
  type='TemporalEncoderDecoder',
187
  frozen_backbone=False,
188
  backbone=dict(
189
  type='TemporalViTEncoder',
190
+ pretrained=pretrained_weights_path,
191
+ img_size=img_size,
192
+ patch_size=patch_size,
193
+ num_frames=num_frames,
194
  tubelet_size=1,
195
+ in_chans=len(bands),
196
+ embed_dim=embed_dim,
197
  depth=6,
198
+ num_heads=num_heads,
199
  mlp_ratio=4.0,
200
  norm_pix_loss=False),
201
  neck=dict(
202
  type='ConvTransformerTokensToEmbeddingNeck',
203
+ embed_dim=embed_dim*num_frames,
204
+ output_embed_dim=output_embed_dim,
205
  drop_cls_token=True,
206
  Hp=14,
207
  Wp=14),
208
  decode_head=dict(
209
+ num_classes=len(CLASSES),
210
+ in_channels=output_embed_dim,
211
  type='FCNHead',
212
  in_index=-1,
213
  channels=256,
 
216
  dropout_ratio=0.1,
217
  norm_cfg=dict(type='BN', requires_grad=True),
218
  align_corners=False,
219
+ loss_decode=loss_func),
 
 
 
 
 
 
 
 
220
  auxiliary_head=dict(
221
+ num_classes=len(CLASSES),
222
+ in_channels=output_embed_dim,
223
  type='FCNHead',
224
  in_index=-1,
225
  channels=256,
 
228
  dropout_ratio=0.1,
229
  norm_cfg=dict(type='BN', requires_grad=True),
230
  align_corners=False,
231
+ loss_decode=loss_func),
 
 
 
 
 
 
 
 
232
  train_cfg=dict(),
233
+ test_cfg=dict(mode='slide', stride=(int(tile_size/2), int(tile_size/2)), crop_size=(tile_size, tile_size)))
234
  auto_resume = False