DavidFM43's picture
upload demo
3a7ac79
raw
history blame
3.15 kB
import torch
from torchvision.utils import draw_segmentation_masks
import torch.nn.functional as F
from torchvision import transforms
import segmentation_models_pytorch as smp
import gradio as gr
import os
config = {
"downsize_res": 512,
"batch_size": 6,
"epochs": 30,
"lr": 3e-4,
"model_architecture": "Unet",
"model_config": {
"encoder_name": "resnet34",
"encoder_weights": "imagenet",
"in_channels": 3,
"classes": 7,
},
}
class_rgb_colors = [
(0, 255, 255),
(255, 255, 0),
(255, 0, 255),
(0, 255, 0),
(0, 0, 255),
(255, 255, 255),
(0, 0, 0),
]
def label_to_onehot(mask, num_classes):
dims_p = (2, 0, 1) if mask.ndim == 2 else (0, 3, 1, 2)
return torch.permute(
F.one_hot(mask.type(torch.long), num_classes=num_classes).type(torch.bool),
dims_p,
)
cp_path = "CP_epoch20.pth"
device = "cuda" if torch.cuda.is_available() else "cpu"
# load model
model_architecture = getattr(smp, config["model_architecture"])
model = model_architecture(**config["model_config"])
model.load_state_dict(torch.load(cp_path, map_location=torch.device(device)))
model.to(device)
model.eval()
# mean = [0.4085, 0.3798, 0.2822]
# std = [0.1410, 0.1051, 0.0927]
# transforms
downsize_t = transforms.Resize(config["downsize_res"], antialias=True)
transform = transforms.Compose(
[
transforms.ToTensor(),
# transforms.Normalize(mean, std),
]
)
def get_overlay(sat_img, preds, alpha):
class_rgb_colors = [
(0, 255, 255),
(255, 255, 0),
(255, 0, 255),
(0, 255, 0),
(0, 0, 255),
(255, 255, 255),
(0, 0, 0),
]
masks = preds.squeeze()
masks = label_to_onehot(masks, 7)
overlay = draw_segmentation_masks(
sat_img, masks=masks, alpha=alpha, colors=class_rgb_colors
)
return overlay
def segment(numpy_arr):
sat_img_arr = torch.from_numpy(numpy_arr)
sat_img_arr = torch.permute(sat_img_arr, (2, 0, 1))
sat_img_pil = transforms.functional.to_pil_image(sat_img_arr)
# preprocess image
X = transform(sat_img_pil).unsqueeze(0)
X = X.to(device)
X_down = downsize_t(X)
# forward pass
logits = model(X_down)
preds = torch.argmax(logits, 1).detach()
# resize to evaluate with the original image
preds = transforms.functional.resize(preds, X.shape[-2:], antialias=True)
# ger rbg formatted images
overlay = get_overlay(sat_img_arr, preds, 0.2)
raw_masks = get_overlay(torch.zeros_like(sat_img_arr), preds, 1)
raw_masks = torch.permute(raw_masks, (1, 2, 0))
overlay = torch.permute(overlay, (1, 2, 0))
return raw_masks.numpy(), overlay.numpy()
i = gr.inputs.Image()
o = [gr.Image(), gr.Image()]
images_dir = "sample_sat_images/"
image_ids = os.listdir(images_dir)
examples = [f"{images_dir}/{image_id}" for image_id in image_ids]
title = "Satellite Images Landcover Segmentation"
description = "Upload an image or select from examples to segment"
iface = gr.Interface(
segment, i, o, examples=examples, title=title, description=description
)
iface.launch(debug=True)