File size: 3,653 Bytes
3ae192a 3a7ac79 3ae192a 3a7ac79 ef005af 3a7ac79 2fda342 3a7ac79 ef005af bc4c15d d38fb36 ef005af 3a7ac79 bc4c15d ef005af bc4c15d 3a7ac79 bc4c15d 3a7ac79 bc4c15d ef005af bc4c15d 3a7ac79 bc4c15d 3a7ac79 d17e4b7 3a7ac79 bc4c15d d17e4b7 3a7ac79 ef005af d17e4b7 ef005af d17e4b7 3a7ac79 bc4c15d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 |
import os
import gradio as gr
import numpy as np
import segmentation_models_pytorch as smp
import torch
import torch.nn.functional as F
from torchvision import transforms
from torchvision.utils import draw_segmentation_masks
config = {
"downsize_res": 512,
"batch_size": 6,
"epochs": 30,
"lr": 3e-4,
"model_architecture": "Unet",
"model_config": {
"encoder_name": "resnet34",
"encoder_weights": "imagenet",
"in_channels": 3,
"classes": 7,
},
}
colors = [
(0, 255, 255),
(255, 255, 0),
(255, 0, 255),
(0, 255, 0),
(0, 0, 255),
(255, 255, 255),
(0, 0, 0),
]
cp_path = "CP_epoch20.pth"
device = "cuda" if torch.cuda.is_available() else "cpu"
# load model
model_architecture = getattr(smp, config["model_architecture"])
model = model_architecture(**config["model_config"])
model.load_state_dict(torch.load(cp_path, map_location=torch.device(device)))
model.to(device)
model.eval()
# transforms
downsize_t = transforms.Resize((config["downsize_res"], config["downsize_res"]), antialias=True)
transform = transforms.Compose(
[
transforms.ToTensor(),
]
)
def label_to_onehot(mask: torch.Tensor, num_classes: int) -> torch.Tensor:
"""Transforms a tensor from label encoding to one hot encoding in boolean dtype"""
dims_p = (2, 0, 1) if mask.ndim == 2 else (0, 3, 1, 2)
return torch.permute(
F.one_hot(mask.type(torch.long), num_classes=num_classes).type(torch.bool),
dims_p,
)
def get_overlay(image: torch.Tensor, preds: torch.Tensor, alpha: float) -> torch.Tensor:
"""Generates the segmentation ovelay for an satellite image"""
masks = label_to_onehot(preds.squeeze(), 7)
overlay = draw_segmentation_masks(image, masks=masks, alpha=alpha, colors=colors)
return overlay
def hwc_to_chw(image_tensor: torch.Tensor) -> torch.Tensor:
return torch.permute(image_tensor, (2, 0, 1))
def chw_to_hwc(image_tensor: torch.Tensor) -> torch.Tensor:
return torch.permute(image_tensor, (1, 2, 0))
def segment(satellite_image: np.ndarray) -> tuple[np.ndarray, np.ndarray]:
image_tensor = torch.from_numpy(satellite_image)
image_tensor = hwc_to_chw(image_tensor)
pil_image = transforms.functional.to_pil_image(image_tensor)
# preprocess image
X = transform(pil_image).unsqueeze(0)
X = X.to(device)
X_down = downsize_t(X)
# forward pass
logits = model(X_down)
preds = torch.argmax(logits, 1).detach()
# resize to evaluate with the original image
preds = transforms.functional.resize(preds, X.shape[-2:], antialias=True)
# get rbg formatted images
segmentation_overlay = chw_to_hwc(get_overlay(image_tensor, preds, 0.2)).numpy()
raw_segmentation = chw_to_hwc(
get_overlay(torch.zeros_like(image_tensor), preds, 1)
).numpy()
return raw_segmentation, segmentation_overlay
inputs = gr.inputs.Image(label="Input Image")
outputs = [gr.Image(label="Raw Segmentation"), gr.Image(label="Segmentation Overlay")]
images_dir = "sample_sat_images/"
examples = [f"{images_dir}/{image_id}" for image_id in os.listdir(images_dir)]
title = "Satellite Images Landcover Classification"
description = (
"Upload a satellite image from your computer or select one from"
" the examples to automatically. The model will segment the landcover"
" types from a preselected set of possible types."
)
article = open("article.md", "r").read()
iface = gr.Interface(
segment,
inputs,
outputs,
examples=examples,
title=title,
description=description,
cache_examples=True,
article=article,
)
iface.launch()
|