rybavery commited on
Commit
c597257
·
1 Parent(s): 996303b

update inf script to use correct import, isntructions for dependencies and data

Browse files
Files changed (3) hide show
  1. README.md +31 -0
  2. burn_scar_batch_inference_script.py +1 -1
  3. custom.py +191 -0
README.md CHANGED
@@ -33,6 +33,37 @@ Code for Finetuning is available through [github](https://github.com/NASA-IMPACT
33
  Configuration used for finetuning is available through [config](https://github.com/NASA-IMPACT/hls-foundation-os/blob/main/fine-tuning-examples/configs/firescars_config.py
34
  )
35
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
 
37
  ### Results
38
 
 
33
  Configuration used for finetuning is available through [config](https://github.com/NASA-IMPACT/hls-foundation-os/blob/main/fine-tuning-examples/configs/firescars_config.py
34
  )
35
 
36
+ To run inference, first install dependencies
37
+
38
+ ```
39
+ mamba create -n prithvi-burn-scar python=3.10 torchvision numpy matplotlib rasterio torchmetrics openmim
40
+ mamba activate prithvi-burn-scar
41
+ mim install mmcv-full==1.5
42
+ ```
43
+
44
+ #### Instructions for downloading from [HuggingFace datasets](https://huggingface.co/datasets)
45
+
46
+ 1. Create account on https://huggingface.co/join
47
+ 2. Install `git` following https://git-scm.com/downloads
48
+ 3. Install git-lfs with `sudo apt install git-lfs` and `git lfs install`
49
+ 4. Run the following command to download the HLS datasets. You may need to
50
+ enter your HuggingFace username/password to do the `git clone`.
51
+
52
+ ```
53
+ mkdir data
54
+ cd data/
55
+ git clone https://huggingface.co/datasets/ibm-nasa-geospatial/hls_burn_scars burn_scars
56
+ tar -xzvf burn_scars/hls_burn_scars.tar.gz -C data/
57
+ ls -lh data/
58
+ ```
59
+
60
+
61
+ With the datasets and the environment, you can now run the inference script.
62
+
63
+ ```
64
+
65
+
66
+ ```
67
 
68
  ### Results
69
 
burn_scar_batch_inference_script.py CHANGED
@@ -21,7 +21,7 @@ from torchvision import transforms
21
  from mmcv.parallel import MMDataParallel, MMDistributedDataParallel
22
 
23
  from mmseg.apis import multi_gpu_test, single_gpu_test, init_segmentor
24
- from mmseg.utils import custom # custom preprocessing for hls
25
  import pdb
26
 
27
  import numpy as np
 
21
  from mmcv.parallel import MMDataParallel, MMDistributedDataParallel
22
 
23
  from mmseg.apis import multi_gpu_test, single_gpu_test, init_segmentor
24
+ from . import custom # custom preprocessing for hls
25
  import pdb
26
 
27
  import numpy as np
custom.py ADDED
@@ -0,0 +1,191 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # utils.py
2
+
3
+ import numpy as np
4
+ import glob
5
+ import rasterio
6
+ from torchvision import transforms
7
+ import torch
8
+ import re
9
+ from torchmetrics import Dice
10
+ import os
11
+
12
+ def calculate_band_statistics(image_directory, image_pattern, bands=[0, 1, 2, 3, 4, 5]):
13
+ """
14
+ Calculate the mean and standard deviation of each band in a folder of GeoTIFF files.
15
+
16
+ Args:
17
+ image_directory (str): Directory where the source GeoTIFF files are stored that are passed to model for training.
18
+ image_pattern (str): Pattern of the GeoTIFF file names that globs files for computing stats.
19
+ bands (list, optional): List of bands to calculate statistics for. Defaults to [0, 1, 2, 3, 4, 5].
20
+
21
+ Raises:
22
+ Exception: If no images are found in the given directory.
23
+
24
+ Returns:
25
+ tuple: Two lists containing the means and standard deviations of each band.
26
+ """
27
+ # Initialize lists to store the means and standard deviations
28
+ all_means = []
29
+ all_stds = []
30
+
31
+ # Use glob to get a list of all .tif images in the directory
32
+ all_images = glob.glob(f"{image_directory}/{image_pattern}.tif")
33
+
34
+ # Make sure there are images to process
35
+ if not all_images:
36
+ raise Exception("No images found")
37
+
38
+ # Get the number of bands
39
+ num_bands = len(bands)
40
+
41
+ # Initialize arrays to hold sums and sum of squares for each band
42
+ band_sums = np.zeros(num_bands)
43
+ band_sq_sums = np.zeros(num_bands)
44
+ pixel_counts = np.zeros(num_bands)
45
+
46
+ # Iterate over each image
47
+ for image_file in all_images:
48
+ with rasterio.open(image_file) as src:
49
+ # For each band, calculate the sum, square sum, and pixel count
50
+ for band in bands:
51
+ data = src.read(band + 1) # rasterio band index starts from 1
52
+ band_sums[band] += np.nansum(data)
53
+ band_sq_sums[band] += np.nansum(data**2)
54
+ pixel_counts[band] += np.count_nonzero(~np.isnan(data))
55
+
56
+ # Calculate means and standard deviations for each band
57
+ for i in bands:
58
+ mean = band_sums[i] / pixel_counts[i]
59
+ std = np.sqrt((band_sq_sums[i] / pixel_counts[i]) - (mean**2))
60
+ all_means.append(mean)
61
+ all_stds.append(std)
62
+
63
+ return all_means, all_stds
64
+
65
+
66
+ def split_and_pad(array, target_shape):
67
+ """
68
+ Splits the input array into smaller arrays of the target shape, padding if necessary.
69
+
70
+ Args:
71
+ array (numpy.ndarray): The input array. Must be shape (batch, band, time, height, width)
72
+ target_shape (tuple): The target shape of the smaller arrays. Must be of shape
73
+ (batch, band, time, height, width)
74
+
75
+ Raises:
76
+ ValueError: If target shape is larger than the array shape.
77
+
78
+ Returns:
79
+ list[numpy.ndarray]: A list of the smaller arrays.
80
+ """
81
+ # Check if the target shape is smaller or equal to the array shape
82
+ if target_shape[-2:] > array.shape[-2:]:
83
+ raise ValueError('Target shape must be smaller or equal to the array shape.')
84
+
85
+ # Calculate how much padding is needed
86
+ pad_h = (target_shape[-2] - array.shape[-2] % target_shape[-2]) % target_shape[-2]
87
+ pad_w = (target_shape[-1] - array.shape[-1] % target_shape[-1]) % target_shape[-1]
88
+
89
+ # Apply padding to the array
90
+ padded_array = np.pad(array, ((0, 0), (0, 0), (0, 0), (0, pad_h), (0, pad_w)))
91
+
92
+ # Split the array into smaller arrays of the target shape
93
+ result = []
94
+ for i in range(0, padded_array.shape[-2], target_shape[-2]):
95
+ for j in range(0, padded_array.shape[-1], target_shape[-1]):
96
+ result.append(padded_array[..., i:i+target_shape[-2], j:j+target_shape[-1]])
97
+
98
+ return result
99
+
100
+ def merge_and_unpad(np_array_list, original_shape, target_shape):
101
+ """
102
+ Assembles smaller numpy arrays back into the original larger numpy array, removing padding if necessary.
103
+
104
+ Args:
105
+ np_array_list (list[numpy.ndarray]): The list of smaller numpy arrays derived from split_and_pad.
106
+ original_shape (tuple): The original shape of the larger numpy array. Must be shape (Height, Width).
107
+ target_shape (tuple): The target shape of the smaller numpy arrays. Must be shape (Height, Width).
108
+
109
+ Returns:
110
+ numpy.ndarray: The original larger numpy array.
111
+ """
112
+ # Calculate how much padding was added
113
+ pad_h = (target_shape[0] - original_shape[0] % target_shape[0]) % target_shape[0]
114
+ pad_w = (target_shape[1] - original_shape[1] % target_shape[1]) % target_shape[1]
115
+
116
+ # Calculate the shape of the padded larger array
117
+ padded_shape = (original_shape[0] + pad_h, original_shape[1] + pad_w)
118
+
119
+ # Calculate the number of smaller arrays in each dimension
120
+ num_arrays_h = padded_shape[0] // target_shape[0]
121
+ num_arrays_w = padded_shape[1] // target_shape[1]
122
+
123
+ # Reshape the list of smaller arrays back into the shape of the padded larger array
124
+ merged_array = np.stack(np_array_list).reshape(num_arrays_h, num_arrays_w, *target_shape)
125
+
126
+ # Rearrange the array dimensions
127
+ merged_array = merged_array.transpose(0, 2, 1, 3).reshape(*padded_shape)
128
+
129
+ # Remove the padding
130
+ unpadded_array = merged_array[:original_shape[0], :original_shape[1]]
131
+
132
+ return unpadded_array
133
+
134
+ def compute_metrics(gt_dir, pred_dir):
135
+ """
136
+ Compute the Dice similarity coefficient between the predicted and ground truth images.
137
+
138
+ Args:
139
+ gt_dir (str): Directory where the ground truth images are stored.
140
+ pred_dir (str): Directory where the predicted images are stored.
141
+
142
+ Returns:
143
+ Tensor: Dice similarity coefficient score.
144
+ """
145
+ dice_metric = Dice()
146
+
147
+ # find all .tif files in the prediction directory
148
+ pred_files = glob.glob(os.path.join(pred_dir, "*.tif"))
149
+
150
+ # iterate over each prediction file
151
+ for pred_file in pred_files:
152
+ # extract the unique_id from the file name
153
+ unique_id = re.search('HLS\..*\.v1\.4', os.path.basename(pred_file))
154
+
155
+ if unique_id is not None:
156
+ unique_id = unique_id.group()
157
+
158
+ # create the unique pattern for the gt directory
159
+ gt_file_pattern = os.path.join(gt_dir, f"*{unique_id}*mask.tif")
160
+
161
+ # glob the file pattern
162
+ gt_files = glob.glob(gt_file_pattern)
163
+
164
+ # if we found a matching gt file
165
+ if len(gt_files) == 1:
166
+ gt_file = gt_files[0]
167
+
168
+ # read the .tif files
169
+ with rasterio.open(gt_file) as src:
170
+ gt_img = src.read(1) # ground truth image
171
+
172
+ with rasterio.open(pred_file) as src:
173
+ pred_img = src.read(1) # predicted image
174
+
175
+ # make sure the images are binary (values are 0 or 1)
176
+ gt_img = (gt_img > 0).astype(np.uint8)
177
+ pred_img = (pred_img > 0).astype(np.uint8)
178
+
179
+ # convert numpy arrays to PyTorch tensors
180
+ gt_img_tensor = torch.from_numpy(gt_img).long().flatten()
181
+ pred_img_tensor = torch.from_numpy(pred_img).long().flatten()
182
+
183
+ # update dice_metric
184
+ dice_metric.update(pred_img_tensor, gt_img_tensor)
185
+
186
+ else:
187
+ print(f"No matching ground truth file for prediction file {pred_file}.")
188
+
189
+ # compute the dice score
190
+ dice_score = dice_metric.compute()
191
+ return dice_score