File size: 3,150 Bytes
3a7ac79
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
import torch
from torchvision.utils import draw_segmentation_masks
import torch.nn.functional as F
from torchvision import transforms
import segmentation_models_pytorch as smp
import gradio as gr
import os

config = {
    "downsize_res": 512,
    "batch_size": 6,
    "epochs": 30,
    "lr": 3e-4,
    "model_architecture": "Unet",
    "model_config": {
        "encoder_name": "resnet34",
        "encoder_weights": "imagenet",
        "in_channels": 3,
        "classes": 7,
    },
}

class_rgb_colors = [
    (0, 255, 255),
    (255, 255, 0),
    (255, 0, 255),
    (0, 255, 0),
    (0, 0, 255),
    (255, 255, 255),
    (0, 0, 0),
]


def label_to_onehot(mask, num_classes):
    dims_p = (2, 0, 1) if mask.ndim == 2 else (0, 3, 1, 2)
    return torch.permute(
        F.one_hot(mask.type(torch.long), num_classes=num_classes).type(torch.bool),
        dims_p,
    )


cp_path = "CP_epoch20.pth"
device = "cuda" if torch.cuda.is_available() else "cpu"

# load model
model_architecture = getattr(smp, config["model_architecture"])
model = model_architecture(**config["model_config"])
model.load_state_dict(torch.load(cp_path, map_location=torch.device(device)))
model.to(device)
model.eval()


# mean = [0.4085, 0.3798, 0.2822]
# std = [0.1410, 0.1051, 0.0927]
# transforms
downsize_t = transforms.Resize(config["downsize_res"], antialias=True)
transform = transforms.Compose(
    [
        transforms.ToTensor(),
        # transforms.Normalize(mean, std),
    ]
)


def get_overlay(sat_img, preds, alpha):
    class_rgb_colors = [
        (0, 255, 255),
        (255, 255, 0),
        (255, 0, 255),
        (0, 255, 0),
        (0, 0, 255),
        (255, 255, 255),
        (0, 0, 0),
    ]
    masks = preds.squeeze()
    masks = label_to_onehot(masks, 7)
    overlay = draw_segmentation_masks(
        sat_img, masks=masks, alpha=alpha, colors=class_rgb_colors
    )
    return overlay


def segment(numpy_arr):
    sat_img_arr = torch.from_numpy(numpy_arr)
    sat_img_arr = torch.permute(sat_img_arr, (2, 0, 1))
    sat_img_pil = transforms.functional.to_pil_image(sat_img_arr)

    # preprocess image
    X = transform(sat_img_pil).unsqueeze(0)
    X = X.to(device)
    X_down = downsize_t(X)
    # forward pass
    logits = model(X_down)
    preds = torch.argmax(logits, 1).detach()
    # resize to evaluate with the original image
    preds = transforms.functional.resize(preds, X.shape[-2:], antialias=True)

    # ger rbg formatted images
    overlay = get_overlay(sat_img_arr, preds, 0.2)
    raw_masks = get_overlay(torch.zeros_like(sat_img_arr), preds, 1)
    raw_masks = torch.permute(raw_masks, (1, 2, 0))
    overlay = torch.permute(overlay, (1, 2, 0))

    return raw_masks.numpy(), overlay.numpy()


i = gr.inputs.Image()
o = [gr.Image(), gr.Image()]
images_dir = "sample_sat_images/"

image_ids = os.listdir(images_dir)
examples = [f"{images_dir}/{image_id}" for image_id in image_ids]
title = "Satellite Images Landcover Segmentation"
description = "Upload an image or select from examples to segment"

iface = gr.Interface(
    segment, i, o, examples=examples, title=title, description=description
)
iface.launch(debug=True)