Clean code
Browse files
app.py
CHANGED
@@ -4,6 +4,7 @@ import torch.nn.functional as F
|
|
4 |
from torchvision import transforms
|
5 |
import segmentation_models_pytorch as smp
|
6 |
import gradio as gr
|
|
|
7 |
import os
|
8 |
|
9 |
config = {
|
@@ -31,7 +32,9 @@ class_rgb_colors = [
|
|
31 |
]
|
32 |
|
33 |
|
34 |
-
def label_to_onehot(mask, num_classes):
|
|
|
|
|
35 |
dims_p = (2, 0, 1) if mask.ndim == 2 else (0, 3, 1, 2)
|
36 |
return torch.permute(
|
37 |
F.one_hot(mask.type(torch.long), num_classes=num_classes).type(torch.bool),
|
@@ -50,43 +53,36 @@ model.to(device)
|
|
50 |
model.eval()
|
51 |
|
52 |
|
53 |
-
# mean = [0.4085, 0.3798, 0.2822]
|
54 |
-
# std = [0.1410, 0.1051, 0.0927]
|
55 |
# transforms
|
56 |
downsize_t = transforms.Resize(config["downsize_res"], antialias=True)
|
57 |
transform = transforms.Compose(
|
58 |
[
|
59 |
transforms.ToTensor(),
|
60 |
-
# transforms.Normalize(mean, std),
|
61 |
]
|
62 |
)
|
63 |
|
64 |
|
65 |
-
def get_overlay(
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
(255, 0, 255),
|
70 |
-
(0, 255, 0),
|
71 |
-
(0, 0, 255),
|
72 |
-
(255, 255, 255),
|
73 |
-
(0, 0, 0),
|
74 |
-
]
|
75 |
-
masks = preds.squeeze()
|
76 |
-
masks = label_to_onehot(masks, 7)
|
77 |
overlay = draw_segmentation_masks(
|
78 |
-
|
79 |
)
|
80 |
return overlay
|
81 |
|
82 |
|
83 |
-
def
|
84 |
-
|
85 |
-
|
86 |
-
|
|
|
|
|
|
|
|
|
87 |
|
88 |
# preprocess image
|
89 |
-
X = transform(
|
90 |
X = X.to(device)
|
91 |
X_down = downsize_t(X)
|
92 |
# forward pass
|
@@ -94,26 +90,23 @@ def segment(numpy_arr):
|
|
94 |
preds = torch.argmax(logits, 1).detach()
|
95 |
# resize to evaluate with the original image
|
96 |
preds = transforms.functional.resize(preds, X.shape[-2:], antialias=True)
|
|
|
|
|
|
|
|
|
|
|
97 |
|
98 |
-
|
99 |
-
overlay = get_overlay(sat_img_arr, preds, 0.2)
|
100 |
-
raw_masks = get_overlay(torch.zeros_like(sat_img_arr), preds, 1)
|
101 |
-
raw_masks = torch.permute(raw_masks, (1, 2, 0))
|
102 |
-
overlay = torch.permute(overlay, (1, 2, 0))
|
103 |
-
|
104 |
-
return raw_masks.numpy(), overlay.numpy()
|
105 |
|
106 |
|
107 |
i = gr.inputs.Image()
|
108 |
o = [gr.Image(), gr.Image()]
|
109 |
images_dir = "sample_sat_images/"
|
110 |
-
|
111 |
-
image_ids = os.listdir(images_dir)
|
112 |
-
examples = [f"{images_dir}/{image_id}" for image_id in image_ids]
|
113 |
title = "Satellite Images Landcover Segmentation"
|
114 |
description = "Upload an image or select from examples to segment"
|
115 |
|
116 |
iface = gr.Interface(
|
117 |
-
segment, i, o, examples=examples, title=title, description=description
|
118 |
)
|
119 |
-
iface.launch(
|
|
|
4 |
from torchvision import transforms
|
5 |
import segmentation_models_pytorch as smp
|
6 |
import gradio as gr
|
7 |
+
import numpy as np
|
8 |
import os
|
9 |
|
10 |
config = {
|
|
|
32 |
]
|
33 |
|
34 |
|
35 |
+
def label_to_onehot(mask: torch.Tensor, num_classes: int) -> torch.Tensor:
|
36 |
+
"""Transforms a tensor from label encoding to one hot encoding in boolean dtype"""
|
37 |
+
|
38 |
dims_p = (2, 0, 1) if mask.ndim == 2 else (0, 3, 1, 2)
|
39 |
return torch.permute(
|
40 |
F.one_hot(mask.type(torch.long), num_classes=num_classes).type(torch.bool),
|
|
|
53 |
model.eval()
|
54 |
|
55 |
|
|
|
|
|
56 |
# transforms
|
57 |
downsize_t = transforms.Resize(config["downsize_res"], antialias=True)
|
58 |
transform = transforms.Compose(
|
59 |
[
|
60 |
transforms.ToTensor(),
|
|
|
61 |
]
|
62 |
)
|
63 |
|
64 |
|
65 |
+
def get_overlay(image: torch.Tensor, preds: torch.Tensor, alpha: float) -> torch.Tensor:
|
66 |
+
"""Generates the segmentation ovelay for an satellite image"""
|
67 |
+
|
68 |
+
masks = preds.squeeze().label_to_onehot(masks, 7)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
69 |
overlay = draw_segmentation_masks(
|
70 |
+
image, masks=masks, alpha=alpha, colors=class_rgb_colors
|
71 |
)
|
72 |
return overlay
|
73 |
|
74 |
|
75 |
+
def hwc_to_chw(image_tensor: torch.Tensor) -> torch.Tensor:
|
76 |
+
return torch.permute(image_tensor, (2, 0, 1))
|
77 |
+
|
78 |
+
|
79 |
+
def segment(satellite_image: np.ndarray) -> tuple[np.ndarray, np.ndarray]:
|
80 |
+
image_tensor = torch.from_numpy(satellite_image)
|
81 |
+
image_tensor = hwc_to_chw(image_tensor)
|
82 |
+
pil_image = transforms.functional.to_pil_image(image_tensor)
|
83 |
|
84 |
# preprocess image
|
85 |
+
X = transform(pil_image).unsqueeze(0)
|
86 |
X = X.to(device)
|
87 |
X_down = downsize_t(X)
|
88 |
# forward pass
|
|
|
90 |
preds = torch.argmax(logits, 1).detach()
|
91 |
# resize to evaluate with the original image
|
92 |
preds = transforms.functional.resize(preds, X.shape[-2:], antialias=True)
|
93 |
+
# get rbg formatted images
|
94 |
+
segmentation_overlay = hwc_to_chw(get_overlay(image_tensor, preds, 0.2)).numpy()
|
95 |
+
raw_segmentation = hwc_to_chw(
|
96 |
+
get_overlay(torch.zeros_like(image_tensor), preds, 1)
|
97 |
+
).numpy()
|
98 |
|
99 |
+
return raw_segmentation, segmentation_overlay
|
|
|
|
|
|
|
|
|
|
|
|
|
100 |
|
101 |
|
102 |
i = gr.inputs.Image()
|
103 |
o = [gr.Image(), gr.Image()]
|
104 |
images_dir = "sample_sat_images/"
|
105 |
+
examples = [f"{images_dir}/{image_id}" for image_id in os.listdir(images_dir)]
|
|
|
|
|
106 |
title = "Satellite Images Landcover Segmentation"
|
107 |
description = "Upload an image or select from examples to segment"
|
108 |
|
109 |
iface = gr.Interface(
|
110 |
+
segment, i, o, examples=examples, title=title, description=description, cache_examples=True
|
111 |
)
|
112 |
+
iface.launch()
|