|
import os |
|
|
|
import gradio as gr |
|
import numpy as np |
|
import segmentation_models_pytorch as smp |
|
import torch |
|
import torch.nn.functional as F |
|
from torchvision import transforms |
|
from torchvision.utils import draw_segmentation_masks |
|
|
|
config = { |
|
"downsize_res": 512, |
|
"batch_size": 6, |
|
"epochs": 30, |
|
"lr": 3e-4, |
|
"model_architecture": "Unet", |
|
"model_config": { |
|
"encoder_name": "resnet34", |
|
"encoder_weights": "imagenet", |
|
"in_channels": 3, |
|
"classes": 7, |
|
}, |
|
} |
|
|
|
colors = [ |
|
(0, 255, 255), |
|
(255, 255, 0), |
|
(255, 0, 255), |
|
(0, 255, 0), |
|
(0, 0, 255), |
|
(255, 255, 255), |
|
(0, 0, 0), |
|
] |
|
|
|
|
|
cp_path = "CP_epoch20.pth" |
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
|
|
|
model_architecture = getattr(smp, config["model_architecture"]) |
|
model = model_architecture(**config["model_config"]) |
|
model.load_state_dict(torch.load(cp_path, map_location=torch.device(device))) |
|
model.to(device) |
|
model.eval() |
|
|
|
|
|
|
|
downsize_t = transforms.Resize((config["downsize_res"], config["downsize_res"]), antialias=True) |
|
transform = transforms.Compose( |
|
[ |
|
transforms.ToTensor(), |
|
] |
|
) |
|
|
|
|
|
def label_to_onehot(mask: torch.Tensor, num_classes: int) -> torch.Tensor: |
|
"""Transforms a tensor from label encoding to one hot encoding in boolean dtype""" |
|
|
|
dims_p = (2, 0, 1) if mask.ndim == 2 else (0, 3, 1, 2) |
|
return torch.permute( |
|
F.one_hot(mask.type(torch.long), num_classes=num_classes).type(torch.bool), |
|
dims_p, |
|
) |
|
|
|
|
|
def get_overlay(image: torch.Tensor, preds: torch.Tensor, alpha: float) -> torch.Tensor: |
|
"""Generates the segmentation ovelay for an satellite image""" |
|
|
|
masks = label_to_onehot(preds.squeeze(), 7) |
|
overlay = draw_segmentation_masks(image, masks=masks, alpha=alpha, colors=colors) |
|
return overlay |
|
|
|
|
|
def hwc_to_chw(image_tensor: torch.Tensor) -> torch.Tensor: |
|
return torch.permute(image_tensor, (2, 0, 1)) |
|
|
|
|
|
def chw_to_hwc(image_tensor: torch.Tensor) -> torch.Tensor: |
|
return torch.permute(image_tensor, (1, 2, 0)) |
|
|
|
|
|
def segment(satellite_image: np.ndarray) -> tuple[np.ndarray, np.ndarray]: |
|
image_tensor = torch.from_numpy(satellite_image) |
|
image_tensor = hwc_to_chw(image_tensor) |
|
pil_image = transforms.functional.to_pil_image(image_tensor) |
|
|
|
X = transform(pil_image).unsqueeze(0) |
|
X = X.to(device) |
|
X_down = downsize_t(X) |
|
|
|
logits = model(X_down) |
|
preds = torch.argmax(logits, 1).detach() |
|
|
|
preds = transforms.functional.resize(preds, X.shape[-2:], antialias=True) |
|
|
|
segmentation_overlay = chw_to_hwc(get_overlay(image_tensor, preds, 0.2)).numpy() |
|
raw_segmentation = chw_to_hwc( |
|
get_overlay(torch.zeros_like(image_tensor), preds, 1) |
|
).numpy() |
|
|
|
return raw_segmentation, segmentation_overlay |
|
|
|
|
|
inputs = gr.inputs.Image(label="Input Image") |
|
outputs = [gr.Image(label="Raw Segmentation"), gr.Image(label="Segmentation Overlay")] |
|
images_dir = "sample_sat_images/" |
|
examples = [f"{images_dir}/{image_id}" for image_id in os.listdir(images_dir)] |
|
title = "Satellite Images Landcover Classification" |
|
description = ( |
|
"Upload a satellite image from your computer or select one from" |
|
" the examples to automatically. The model will segment the landcover" |
|
" types from a preselected set of possible types." |
|
) |
|
article = open("article.md", "r").read() |
|
|
|
|
|
iface = gr.Interface( |
|
segment, |
|
inputs, |
|
outputs, |
|
examples=examples, |
|
title=title, |
|
description=description, |
|
cache_examples=True, |
|
article=article, |
|
) |
|
iface.launch() |
|
|