File size: 3,216 Bytes
3ae192a
 
 
 
 
3a7ac79
 
 
3ae192a
3a7ac79
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bc4c15d
 
 
3a7ac79
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bc4c15d
 
 
d38fb36
3a7ac79
bc4c15d
3a7ac79
 
 
 
bc4c15d
 
 
 
 
 
 
 
3a7ac79
 
bc4c15d
3a7ac79
 
 
 
 
 
 
bc4c15d
 
 
 
 
3a7ac79
bc4c15d
3a7ac79
 
 
 
 
bc4c15d
3a7ac79
 
 
 
bc4c15d
3a7ac79
bc4c15d
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
import os

import gradio as gr
import numpy as np
import segmentation_models_pytorch as smp
import torch
import torch.nn.functional as F
from torchvision import transforms
from torchvision.utils import draw_segmentation_masks

config = {
    "downsize_res": 512,
    "batch_size": 6,
    "epochs": 30,
    "lr": 3e-4,
    "model_architecture": "Unet",
    "model_config": {
        "encoder_name": "resnet34",
        "encoder_weights": "imagenet",
        "in_channels": 3,
        "classes": 7,
    },
}

class_rgb_colors = [
    (0, 255, 255),
    (255, 255, 0),
    (255, 0, 255),
    (0, 255, 0),
    (0, 0, 255),
    (255, 255, 255),
    (0, 0, 0),
]


def label_to_onehot(mask: torch.Tensor, num_classes: int) -> torch.Tensor:
    """Transforms a tensor from label encoding to one hot encoding in boolean dtype"""

    dims_p = (2, 0, 1) if mask.ndim == 2 else (0, 3, 1, 2)
    return torch.permute(
        F.one_hot(mask.type(torch.long), num_classes=num_classes).type(torch.bool),
        dims_p,
    )


cp_path = "CP_epoch20.pth"
device = "cuda" if torch.cuda.is_available() else "cpu"

# load model
model_architecture = getattr(smp, config["model_architecture"])
model = model_architecture(**config["model_config"])
model.load_state_dict(torch.load(cp_path, map_location=torch.device(device)))
model.to(device)
model.eval()


# transforms
downsize_t = transforms.Resize(config["downsize_res"], antialias=True)
transform = transforms.Compose(
    [
        transforms.ToTensor(),
    ]
)


def get_overlay(image: torch.Tensor, preds: torch.Tensor, alpha: float) -> torch.Tensor:
    """Generates the segmentation ovelay for an satellite image"""

    masks = label_to_onehot(preds.squeeze(), 7)
    overlay = draw_segmentation_masks(
        image, masks=masks, alpha=alpha, colors=class_rgb_colors
    )
    return overlay


def hwc_to_chw(image_tensor: torch.Tensor) -> torch.Tensor:
    return torch.permute(image_tensor, (2, 0, 1))


def segment(satellite_image: np.ndarray) -> tuple[np.ndarray, np.ndarray]:
    image_tensor = torch.from_numpy(satellite_image)
    image_tensor = hwc_to_chw(image_tensor)
    pil_image = transforms.functional.to_pil_image(image_tensor)

    # preprocess image
    X = transform(pil_image).unsqueeze(0)
    X = X.to(device)
    X_down = downsize_t(X)
    # forward pass
    logits = model(X_down)
    preds = torch.argmax(logits, 1).detach()
    # resize to evaluate with the original image
    preds = transforms.functional.resize(preds, X.shape[-2:], antialias=True)
    # get rbg formatted images
    segmentation_overlay = hwc_to_chw(get_overlay(image_tensor, preds, 0.2)).numpy()
    raw_segmentation = hwc_to_chw(
        get_overlay(torch.zeros_like(image_tensor), preds, 1)
    ).numpy()

    return raw_segmentation, segmentation_overlay


i = gr.inputs.Image()
o = [gr.Image(), gr.Image()]
images_dir = "sample_sat_images/"
examples = [f"{images_dir}/{image_id}" for image_id in os.listdir(images_dir)]
title = "Satellite Images Landcover Segmentation"
description = "Upload an image or select from examples to segment"

iface = gr.Interface(
    segment, i, o, examples=examples, title=title, description=description, cache_examples=True
)
iface.launch()