DavidFM43's picture
Fix height and width resize
2fda342
raw
history blame
3.65 kB
import os
import gradio as gr
import numpy as np
import segmentation_models_pytorch as smp
import torch
import torch.nn.functional as F
from torchvision import transforms
from torchvision.utils import draw_segmentation_masks
config = {
"downsize_res": 512,
"batch_size": 6,
"epochs": 30,
"lr": 3e-4,
"model_architecture": "Unet",
"model_config": {
"encoder_name": "resnet34",
"encoder_weights": "imagenet",
"in_channels": 3,
"classes": 7,
},
}
colors = [
(0, 255, 255),
(255, 255, 0),
(255, 0, 255),
(0, 255, 0),
(0, 0, 255),
(255, 255, 255),
(0, 0, 0),
]
cp_path = "CP_epoch20.pth"
device = "cuda" if torch.cuda.is_available() else "cpu"
# load model
model_architecture = getattr(smp, config["model_architecture"])
model = model_architecture(**config["model_config"])
model.load_state_dict(torch.load(cp_path, map_location=torch.device(device)))
model.to(device)
model.eval()
# transforms
downsize_t = transforms.Resize((config["downsize_res"], config["downsize_res"]), antialias=True)
transform = transforms.Compose(
[
transforms.ToTensor(),
]
)
def label_to_onehot(mask: torch.Tensor, num_classes: int) -> torch.Tensor:
"""Transforms a tensor from label encoding to one hot encoding in boolean dtype"""
dims_p = (2, 0, 1) if mask.ndim == 2 else (0, 3, 1, 2)
return torch.permute(
F.one_hot(mask.type(torch.long), num_classes=num_classes).type(torch.bool),
dims_p,
)
def get_overlay(image: torch.Tensor, preds: torch.Tensor, alpha: float) -> torch.Tensor:
"""Generates the segmentation ovelay for an satellite image"""
masks = label_to_onehot(preds.squeeze(), 7)
overlay = draw_segmentation_masks(image, masks=masks, alpha=alpha, colors=colors)
return overlay
def hwc_to_chw(image_tensor: torch.Tensor) -> torch.Tensor:
return torch.permute(image_tensor, (2, 0, 1))
def chw_to_hwc(image_tensor: torch.Tensor) -> torch.Tensor:
return torch.permute(image_tensor, (1, 2, 0))
def segment(satellite_image: np.ndarray) -> tuple[np.ndarray, np.ndarray]:
image_tensor = torch.from_numpy(satellite_image)
image_tensor = hwc_to_chw(image_tensor)
pil_image = transforms.functional.to_pil_image(image_tensor)
# preprocess image
X = transform(pil_image).unsqueeze(0)
X = X.to(device)
X_down = downsize_t(X)
# forward pass
logits = model(X_down)
preds = torch.argmax(logits, 1).detach()
# resize to evaluate with the original image
preds = transforms.functional.resize(preds, X.shape[-2:], antialias=True)
# get rbg formatted images
segmentation_overlay = chw_to_hwc(get_overlay(image_tensor, preds, 0.2)).numpy()
raw_segmentation = chw_to_hwc(
get_overlay(torch.zeros_like(image_tensor), preds, 1)
).numpy()
return raw_segmentation, segmentation_overlay
inputs = gr.inputs.Image(label="Input Image")
outputs = [gr.Image(label="Raw Segmentation"), gr.Image(label="Segmentation Overlay")]
images_dir = "sample_sat_images/"
examples = [f"{images_dir}/{image_id}" for image_id in os.listdir(images_dir)]
title = "Satellite Images Landcover Classification"
description = (
"Upload a satellite image from your computer or select one from"
" the examples to automatically. The model will segment the landcover"
" types from a preselected set of possible types."
)
article = open("article.md", "r").read()
iface = gr.Interface(
segment,
inputs,
outputs,
examples=examples,
title=title,
description=description,
cache_examples=True,
article=article,
)
iface.launch()